首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow's tf.keras.layers.Dense和PyTorch's torch.nn.Linear的区别?

Tensorflow's tf.keras.layers.Dense和PyTorch's torch.nn.Linear的区别?
EN

Stack Overflow用户
提问于 2021-03-14 16:11:59
回答 2查看 19K关注 0票数 11

关于Tensorflow是如何定义其线性层的,我有一个快速的(也可能是愚蠢的)问题。在PyTorch中,线性(或稠密)层定义为,y=x^T+b,其中A和b是线性层的权重矩阵和偏置向量(见这里)。

然而,我无法精确地为Tensorflow找到一个等价的方程!它和PyTorch是一样的还是只是y=x+b?

提前谢谢你!

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-03-14 16:28:06

tf.keras.layers.Dense是在tensorflow源代码中定义的:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/core.py#L1081

如果您遵循它的call函数中的引用,它将引出这里使用的操作的定义,它实际上是输入和权重的矩阵乘法加上预期的偏置向量:

https://github.com/tensorflow/tensorflow/blob/a68c6117a1a53431e739752bd2ab8654dbe2534a/tensorflow/python/keras/layers/ops/core.py#L74

代码语言:javascript
运行
复制
outputs = gen_math_ops.MatMul(a=inputs, b=kernel)
...
outputs = nn_ops.bias_add(outputs, bias)
票数 6
EN

Stack Overflow用户

发布于 2021-03-14 16:43:22

如果我们在None API中的密集层中设置keras激活,那么它们在技术上是等价的。

丹索尔·弗洛

代码语言:javascript
运行
复制
tf.keras.layers.Dense(..., activation=None) 

根据文档,更多的研究这里

激活:要使用的激活功能。如果不指定任何内容,则不应用任何激活(即。“线性”激活: a(x) =x。

在PyTorch的src里。

代码语言:javascript
运行
复制
torch.nn.Linear

现在他们在这一点上是平等的。对传入数据的线性转换:y = x*W^T + b。请参阅下面这两种方法的更具体的等效实现。在PyTorch中,我们有

代码语言:javascript
运行
复制
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.fc1 = torch.nn.Linear(5, 30)
    def forward(self, state):
        return self.fc1(state)

或,

代码语言:javascript
运行
复制
trd = torch.nn.Linear(in_features = 3, out_features = 30)
y = trd(torch.ones(5, 3))
print(y.size())
# torch.Size([5, 30])

它的等效tf实现应该是

代码语言:javascript
运行
复制
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30, input_shape=(5,), activation=None)) 

或,

代码语言:javascript
运行
复制
tfd = tf.keras.layers.Dense(30, input_shape=(3,), activation=None)
x = tfd(tf.ones(shape=(5, 3)))
print(x.shape)
# (5, 30)
票数 15
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66626700

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档