首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Tensorflow中复制操作和占位符

如何在Tensorflow中复制操作和占位符
EN

Stack Overflow用户
提问于 2016-06-11 05:34:06
回答 1查看 1.3K关注 0票数 3

假设我定义了两个神经网络模型,每个模型都有一个输入占位符和一个输出张量。从这两个输出中,我需要3个单独的值。

代码语言:javascript
运行
复制
inputs: i1, i2, outputs: o1, o2
a = 1
b = 2

v1 = session.run(o1, feed_dict={i1: a})
v2 = session.run(o1, feed_dict={i1: b})
v3 = session.run(o2, feed_dict={i2: a})

问题是,我需要将这3个值输入一个损失函数,这样我就不能这样做了。我需要做的

代码语言:javascript
运行
复制
loss = session.run(L, feed_dict={i1: a, i1: b, i2:a })

我不认为我能做到这一点,但是即使我可以,在以后的操作中我仍然会有模糊性,因为带有输入i1的i1与o1与输入i2的用法不同。

我认为可以通过在第一个神经网络中有两个输入占位符和两个输出来解决这个问题。因此,既然我已经有了一个模型,那么是否有一种方法来重组输入和输出,使我能够适应这种情况?

在视觉上我想翻个身

代码语言:javascript
运行
复制
i1 ---- (model) ----- o1 

转到

代码语言:javascript
运行
复制
i1a                          o1a
  \                         /
   \                       /
    x ----- (model) ----- x        
   /                       \
  /                         \
i1b                          o1b
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-06-11 07:05:46

您的直觉是正确的,您必须为您的网络1创建两个不同的占位符i1a和i1b,其中有两个输出o1a和o1b。你的视觉效果很好,下面是我的建议:

代码语言:javascript
运行
复制
i1a  ----- (model) -----  o1a
              |            
        shared weights                                  
              |            
i1b  ----- (model) -----  o1b

正确的方法是通过对每个带有tf.get_variable()的变量使用reuse=True来复制您的网络。

代码语言:javascript
运行
复制
def create_variables():
  with tf.variable_scope('model'):
    w1 = tf.get_variable('w1', [1, 2])
    b1 = tf.get_variable('b1', [2])

def inference(input):
  with tf.variable_scope('model', reuse=True):
    w1 = tf.get_variable('w1')
    b1 = tf.get_variable('b1')
    output = tf.matmul(input, w1) + b1
  return output

create_variables()

i1a = tf.placeholder(tf.float32, [3, 1])
o1a = inference(i1a)

i1b = tf.placeholder(tf.float32, [3, 1])
o1b = inference(i1b)

loss = tf.reduce_mean(o1a - o1b)


with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  sess.run(loss, feed_dict={i1a: [[0.], [1.], [2.]], i1b: [[0.5], [1.5], [2.5]]})
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/37760323

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档