我定制了一个层,合并了batch_size和第一个维度,其他维度保持不变,但compute_output_shape似乎没有影响,导致后续层无法获得准确的形状信息,导致错误。我如何使compute_output_shape工作?
import keras
from keras import backend as K
class BatchMergeReshape(keras.layers.Layer):
def __init__(self, **kwargs):
super(BatchMergeReshape, self).__init__(**kwargs)
def build(self, input_shape):
super(BatchMergeReshape, self).build(input_shape)
def call(self, x):
input_shape = K.shape(x)
batch_size, seq_len = input_shape[0], input_shape[1]
r = K.reshape(x, (batch_size*seq_len,)+input_shape[2:])
print("call_shape:",r.shape)
return r
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
r = (None,)+input_shape[2:]
print("compute_output_shape:",r)
return r
else:
r = (input_shape[0]*input_shape[1],)+input_shape[2:]
return r
a = keras.layers.Input(shape=(3,4,5))
b = BatchMergeReshape()(a)
print(b.shape)
# call_shape: (?, ?)
# compute_output_shape: (None, 4, 5)
# (?, ?)我需要得到(无,4,5),但得到(无,无),为什么compute_output_shape没有工作。我的keras版本是2.2.4
发布于 2019-09-24 03:58:52
问题可能是K.shape返回一个张量,而不是元组。你不能做(batch_size*seq_len,) + input_shape[2:]。这是混合了很多东西,张量和元组,结果肯定是错误的。
好的是,如果您知道其他维度,而不是批处理大小,您只需要这一层:
Lambda(lambda x: K.reshape(x, (-1,) + other_dimensions_tuple))如果你不这样做的话:
input_shape = K.shape(x)
new_batch_size = input_shape[0:1] * input_shape[1:2] #needs to keep a shape of an array
#new_batch_size.shape = (1,)
new_shape = K.concatenate([new_batch_size, input_shape[2:]]) #this is a tensor
#result of concatenating 2 tensors
r = K.reshape(x, new_shape)请注意,这在Tensorflow中有效,但在Theano中可能不起作用。
还请注意,Keras将要求模型输出的批处理大小等于模型输入的批处理大小。这意味着您需要在模型结束之前恢复原来的批处理大小。
https://stackoverflow.com/questions/58072362
复制相似问题