tensorflow学习笔记(二十九):merge_all引发的血案

merge_all引发的血案

  1. 在训练深度神经网络的时候,我们经常会使用Dropout,然而在test的时候,需要把dropout撤掉.为了应对这种问题,我们通常要建立两个模型,让他们共享变量。详情.
  2. 为了使用Tensorboard来可视化我们的数据,我们会经常使用Summary,最终都会用一个简单的merge_all函数来管理我们的Summary

错误示例

当这两种情况相遇时,bug就产生了,看代码:

import tensorflow as tf
import numpy as np
class Model(object):
    def __init__(self):
        self.graph()
        self.merged_summary = tf.summary.merge_all()# 引起血案的地方
    def graph(self):
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,1])
        self.label = tf.placeholder(dtype=tf.float32, shape=[None,1])
        w = tf.get_variable("w",shape=[1,1])
        self.predict = tf.matmul(self.x,w)
        self.loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.label-self.predict),axis=1))
        self.train_op = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
        tf.summary.scalar("loss",self.loss)
def run_epoch(session, model):
    x = np.random.rand(1000).reshape(-1,1)
    label = x*3
    feed_dic = {model.x.name:x, model.label:label}
    su = session.run([model.merged_summary], feed_dic)
def main():
    with tf.Graph().as_default():
        with tf.name_scope("train"):
            with tf.variable_scope("var1",dtype=tf.float32):
                model1 = Model()
        with tf.name_scope("test"):
            with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
                model2 = Model()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            run_epoch(sess,model1)
            run_epoch(sess,model2)
if __name__ == "__main__":
    main()

运行情况是这样的: 执行run_epoch(sess,model1)时候,程序并不会报错,一旦执行到run_epoch(sess,model1),就会报错(错误信息见文章最后)。

错误原因

看代码片段:

class Model(object):
    def __init__(self):
        self.graph()
        self.merged_summary = tf.summary.merge_all()# 引起血案的地方
...
with tf.name_scope("train"):
    with tf.variable_scope("var1",dtype=tf.float32):
        model1 = Model() # 这里的merge_all只是管理了自己的summary
with tf.name_scope("test"):
    with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
        model2 = Model()# 这里的merge_all管理了自己的summary和上边模型的Summary

由于Summary的计算是需要feed数据的,所以会报错。

解决方法

我们只需要替换掉merge_all就可以解决这个问题。看代码

class Model(object):
    def __init__(self,scope):
        self.graph()
        self.merged_summary = tf.summary.merge(
        tf.get_collection(tf.GraphKeys.SUMMARIES,scope)
        )
...
with tf.Graph().as_default():
    with tf.name_scope("train") as train_scope:
        with tf.variable_scope("var1",dtype=tf.float32):
            model1 = Model(train_scope)
    with tf.name_scope("test") as test_scope:
        with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
            model2 = Model(test_scope)

关于tf.get_collection地址

当有多个模型时,出现类似错误,应该考虑使用的方法是不是涉及到了其他的模型

error

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor ‘train/var1/Placeholder’ with dtype float [Node: train/var1/Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device=”/job:localhost/replica:0/task:0/gpu:0”]]

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏和蔼的张星的图像处理专栏

9.SSD目标检测之三:训练失败记录(我为什么有脸写这个……)

这个大概折腾了三四天,反正我能想到改的地方都改了,笔记本上试过了,宿舍的电脑上也试过了,反正就是不行,我也没什么办法了,后面就转向YoloV3了。尽管失败了,还...

692
来自专栏移动端周边技术扩展

coremltools安装

1443
来自专栏程序员互动联盟

【C语言练手】C语言画太极图

呵呵昨天花了一个圆,今天想画个太极图,我知道没啥技术含量,但是挺有意思的,希望各位看官不要鄙视我不务正业,画完此图,不再做这些事情。 先展示下画出来的图像的情况...

3515
来自专栏机器之心

教程 | 如何在浏览器使用synaptic.js训练简单的神经网络推荐系统

3554
来自专栏TensorFlow从0到N

TensorFlow从0到1 - 17 - Step By Step上手TensorBoard

上一篇16 L2正则化对抗“过拟合”提到,为了检测训练过程中发生的过拟合,需要记录每次迭代(甚至每次step)模型在训练集和验证集上的识别精度。其实,为了能更...

3538
来自专栏我是攻城师

如何使用opencv和matplotlib把多个图片显示在一个窗体内

在使用opencv处理一些计算机视觉方面的一些东西时,经常会遇到把多张图片放在一个窗体内对比展示,而不是同时打开多个窗体,opencv作为一个专业的科学计算库,...

692
来自专栏简书专栏

基于xgboost的波士顿房价预测kaggle实战

2018年8月24日笔记 这是作者在波士顿房价预测项目的第3篇文章,在查看此篇文章之前,请确保已经阅读前2篇文章。 第2篇文章链接:https://www....

3902
来自专栏MixLab科技+设计实验室

自己动手做一个识别手写数字的web应用02

继续上文。 自己动手做一个识别手写数字的web应用01 01 再次进入docker容器 接着上一篇文章,我们继续使用上次新建好的容器,可以终端输入 : d...

3457
来自专栏AI研习社

Github 项目推荐 | Basel Face Model 2017 完全参数化人脸

本软件可以从 Basel Face Model 2017 里生成完全参数化的人脸,论文链接: https://arxiv.org/abs/1712.01619 ...

3997
来自专栏海天一树

LDA文档主题生成模型入门

LDA(Latent Dirichlet Allocation)是一种文档主题生成模型,也称为一个三层贝叶斯概率模型,包含词、主题和文档三层结构。所谓生成模型,...

1142

扫码关注云+社区