假设我有一个Custom Layer
:
class Custom_Layer(keras.layers.Layer):
def __init__(self, **kwargs):
self.w_0 = tf.Variable(tf.random_uniform_initializer(),trainable=True)
self.b_0 = tf.Variable(tf.zeros_initializer(),trainable=True)
....
def call(self, inputs):
output = A_Method(self, inputs)
return output
def A_Method(self, TensorA):
....
return something
如果我想将@tf.function(with input_signature)
装饰为A_Method
以控制跟踪
@tf.function(input_signature=[???, tf.TensorSpec(shape=None)])
def A_Method(self, TensorA):
....
return something
我应该为self
制定什么规范?我试着把tf.TensorSpec
放上去,但是它引起了一个错误
___Updated the question___ :
我对tensorflow非常陌生,抱歉,如果代码是奇怪的或没有意义的。我这么做的原因是我发现RNN花了很长时间才开始第一个阶段,我不知道这个自定义层是否可以做一些类似的事情,但花费的时间更少。但最终我相信初始化时间的缓慢是因为tensorflow retracing repeatedly even on same input_spec - input_shape
。我反复使用这一层,
input_layer = Input(shape=( X_.shape[1],X_.shape[2]), name='input')
for loop :
Hard_Code_RNN_Layer(input_layer[:,:, slicing])
然后我运行了.experimental_get_tracing_count()
计数是300,这实际上不应该超过10,这就是为什么我想把这个方法从def Mimic_RNN(self, step_input, step_state)
remove it from the class
中取出来,并尝试给它一个input_signature。请见下文:
def Initialize_Variable(input_dim, units):
w_init = tf.random_normal_initializer()
b_init = tf.zeros_initializer()
w_0 = tf.Variable(initial_value=w_init(shape=(input_dim, units)))
b_0 = tf.Variable(initial_value=b_init(shape=(units)))
return w_0, b_0
def Initialize_One_Variable(input_dim, units):
w_init = tf.random_uniform_initializer()
R_kernal = tf.Variable(initial_value=w_init(shape=(input_dim, units)))
return R_kernal
class Hard_Code_RNN_Layer(keras.layers.Layer):
def __init__(self, input_tuple, Sequencee=True, **kwargs):
super(Hard_Code_RNN_Layer, self).__init__()
input_shape, units = input_tuple
self.Hidden_Size = (int)(input_shape * 0.85)
self.inputshape = input_shape
self.units = units
self.thiseq = Sequencee
self.Uz = Initialize_One_Variable(self.Hidden_Size, self.Hidden_Size)
self.Ur = Initialize_One_Variable(self.Hidden_Size, self.Hidden_Size)
self.w_hz, self.b_hz = Initialize_Variable(self.units, self.Hidden_Size)
self.w_out, self.b_out = Initialize_Variable(self.Hidden_Size,self.units)
self.w_0, self.b_0 = Initialize_Variable(self.inputshape,self.units)
def get_config(self):
cfg = super().get_config()
return cfg
def Layer_Method(inputs, w_h, b_h):
return tf.matmul(inputs, w_h) + b_h
def Mimic_RNN(self, step_input, step_state): <-----------input_signature_this
x__j = self.Layer_Method(step_input, self.w_0, self.b_0)
r = tf.sigmoid(tf.matmul(step_state, self.Ur))
z = tf.sigmoid(tf.matmul(step_state, self.Uz))
h__ = tf.nn.relu(tf.matmul(x__j, self.w_hz) + tf.multiply(r, step_state) + self.b_hz)
h = (1-z) * h__ + z * step_state
output__ = tf.nn.relu(tf.matmul(h, self.w_out) + self.b_out)
return output__, h
def call(self, inputs):
unstack = tf.unstack(inputs, axis=1)
out1, hiddd = self.Mimic_RNN(step_input=unstack[0], step_state=tf.zeros_like(unstack[0][:,0:self.Hidden_Size]))
out2, hiddd = self.Mimic_RNN(step_input=unstack[1], step_state=hiddd)
out3, hiddd = self.Mimic_RNN(step_input=unstack[2], step_state=hiddd)
if(self.thiseq):
return tf.stack([out1, out2, out3], axis =1 )
else:
return out3
发布于 2021-05-03 07:10:09
如果指定了输入签名,那么python函数的所有输入都必须转换为Tensor
。在这种情况下,self
保存对调用方法的实例的引用,并且不能作为张量进行转换。您只是不能在您的input_signature
A_method
函数中指定.。
但是,仍然可以从类中修饰方法,因为TensorFlow将检测要修饰的函数是否是方法,如果是这样的话,将自动删除self
参数。您可以检查源代码
if self._is_method:
# Remove `self`: default arguments shouldn't be matched to it.
# TODO(b/127938157): Should this error out if there is no arg to
# be removed?
args = fullargspec.args[1:]
值得注意的是,如果在类之外定义了一个方法,则此检查将失败。(检查依赖于标准库ismethod
模块的inspect
函数)。由于self
不能转换为张量,因此修饰方法在调用时会抛出一个错误。
在类定义之外定义一个方法并不是最佳实践:它使代码更难阅读,更难使用。有关更多细节,您可以查看这个问题:在类定义之外定义一个方法?。类之间重用逻辑的python方法要么使用继承,要么定义一个不依赖于对象属性的函数(或者将这些属性作为参数传递给函数的位置)。
发布于 2022-05-05 09:13:42
实际上,可以将与类方法一起使用input_signature
。在输入签名规范中,只需忽略初始的self
参数,因此您只需为其他参数提供tf.TensorSpec
。
例如:
import tensorflow as tf
class MyClass:
@tf.function(input_signature=(tf.TensorSpec([None], tf.float32),
tf.TensorSpec([None], tf.float32)))
def my_method(self, a, b):
return a + b
tf.print(MyClass().my_method([1, 2, 3], [4]))
# 5, 6, 7
https://stackoverflow.com/questions/67368816
复制