首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >对于神经网络函数,tf.GradientTape()不返回任何值

对于神经网络函数,tf.GradientTape()不返回任何值
EN

Stack Overflow用户
提问于 2021-07-25 00:10:10
回答 2查看 65关注 0票数 1

所以我创建了自己的神经网络,我想对它做一个关于输入变量的自动微分。我的神经网络代码是这样的

代码语言:javascript
运行
复制
n_input = 1     
n_hidden_1 = 50 
n_hidden_2 = 50 
n_output = 1 

weights = {
'h1': tf.Variable(tf.random.normal([n_input, n_hidden_1],0,0.5)),
'h2': tf.Variable(tf.random.normal([n_hidden_1, n_hidden_2],0,0.5)),
'out': tf.Variable(tf.random.normal([n_hidden_2, n_output],0,0.5))
}

biases = {
'b1': tf.Variable(tf.random.normal([n_hidden_1],0,0.5)),
'b2': tf.Variable(tf.random.normal([n_hidden_2],0,0.5)),
'out': tf.Variable(tf.random.normal([n_output],0,0.5))
}

def multilayer_perceptron(x):
    x = np.array([[[x]]],  dtype='float32')
    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
    layer_1 = tf.nn.tanh(layer_1)
    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
    layer_2 = tf.nn.tanh(layer_2)
    output = tf.matmul(layer_2, weights['out']) + biases['out']
    return output

tf.GradientTape()中,我尝试用这个来区分神经网络

代码语言:javascript
运行
复制
x = tf.Variable(1.0)
with tf.GradientTape() as tape:
    y = multilayer_perceptron(x)
dNN1 = tape.gradient(y,x)
print(dNN1)

这就导致了None。我在这里做错了什么?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-07-25 00:44:20

因为您通过np.arrayx转换为numpy数组,这是不可微的。

像这样修改你的代码:

代码语言:javascript
运行
复制
def multilayer_perceptron(x):
    #x = np.array([[[x]]],  dtype='float32') #comment
    layer_1 = tf.add(tf.matmul([[[x]]], weights['h1']), biases['b1']) #change x shape by adding []
    layer_1 = tf.nn.tanh(layer_1)
    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
    layer_2 = tf.nn.tanh(layer_2)
    output = tf.matmul(layer_2, weights['out']) + biases['out']
    return output
票数 1
EN

Stack Overflow用户

发布于 2021-07-25 01:12:00

为了更好地运行一些tensorflow操作,最好操作的所有元素都是tf.tensor类型,您必须使用

代码语言:javascript
运行
复制
def multilayer_perceptron(x):
 x =  tf.reshape(x , (1,1,1))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68511708

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档