我正在写一个RESNET,但是我不能理解在哪里使用"call“函数。
也许这是由TensorFlow自动调用的,所以这意味着我们必须编写一个名为"call“的函数?如果是这样的话,这个"call“函数的确切要求是什么?谢谢你!!
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=strides, padding="same")
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding="same")
self.bn2 = layers.BatchNormalization()
if strides != 1:
self.downsample = layers.Conv2D(filter_num, (1, 1), strides=strides)
else:
self.downsample = lambda x:x
def call(self, inputs, training=None):
out = self.conv1(inputs)
out = self.bn1(out, training=training)
out = self.relu(out)
out = self.conv2(out, training=training)
out = self.bn2(out)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
发布于 2021-07-22 08:25:15
调用函数的使用方式如下:
basic_block = BasicBlock()
basic_block(args)
因此,它不是来自于:
basic_block.call(args)
发布于 2021-07-22 15:16:50
定义自定义层时,您将扩展基类tensorflow.keras.layers.Layer
并按如下方式使用它:
import tensorflow as tf
class BasicBlock(tf.keras.layers.Layer):
...
basic_block = BasicBlock()
basic_block(inputs)
上面代码片段的最后一行将调用类中的魔术方法__call__
(如果你对A Guide to Python's Magic Methods感兴趣,可以在这里了解更多关于魔术方法的信息)
由于您没有在BasicBlock
中定义__call__
方法(您定义了不同的call
),因此将使用tensorflow.keras.layers.Layer
中的__call__
。
根据Tensorflow documentation,此方法包含以下文档
包装调用,应用预处理和后处理步骤。
粗略地说,您将拥有(如果您感兴趣,可以查看source code,但它要复杂得多):
class Layer(...):
....
def __call__(self, ...):
# preprocessing steps
self.call(...)
# post processing steps
如果您熟悉继承,那么在使用basic_block(inputs)
时应该猜到不同的步骤
BasicBlock
是否有名为__call__
=> NoLayer
是否有名为__call__
=>的方法是,使用它并进入此方法<代码>H124应用预处理步骤<代码>H225<代码>H126检查<代码>D27是否有名为<代码>D28 =>的方法是,使用它并将其应用于输入<代码>H229<代码>H130应用后处理步骤<代码>H231<代码>G232
关于实现最佳资源的call
方法的要求,可以在official Tensorflow documentation中获得关于输入数据结构、预期和关键字参数的所有解释
https://stackoverflow.com/questions/68481424
复制