tensorflow学习笔记(二十九):merge_all引发的血案

merge_all引发的血案

  1. 在训练深度神经网络的时候,我们经常会使用Dropout,然而在test的时候,需要把dropout撤掉.为了应对这种问题,我们通常要建立两个模型,让他们共享变量。详情.
  2. 为了使用Tensorboard来可视化我们的数据,我们会经常使用Summary,最终都会用一个简单的merge_all函数来管理我们的Summary

错误示例

当这两种情况相遇时,bug就产生了,看代码:

import tensorflow as tf
import numpy as np
class Model(object):
    def __init__(self):
        self.graph()
        self.merged_summary = tf.summary.merge_all()# 引起血案的地方
    def graph(self):
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,1])
        self.label = tf.placeholder(dtype=tf.float32, shape=[None,1])
        w = tf.get_variable("w",shape=[1,1])
        self.predict = tf.matmul(self.x,w)
        self.loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.label-self.predict),axis=1))
        self.train_op = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
        tf.summary.scalar("loss",self.loss)
def run_epoch(session, model):
    x = np.random.rand(1000).reshape(-1,1)
    label = x*3
    feed_dic = {model.x.name:x, model.label:label}
    su = session.run([model.merged_summary], feed_dic)
def main():
    with tf.Graph().as_default():
        with tf.name_scope("train"):
            with tf.variable_scope("var1",dtype=tf.float32):
                model1 = Model()
        with tf.name_scope("test"):
            with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
                model2 = Model()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            run_epoch(sess,model1)
            run_epoch(sess,model2)
if __name__ == "__main__":
    main()

运行情况是这样的: 执行run_epoch(sess,model1)时候,程序并不会报错,一旦执行到run_epoch(sess,model1),就会报错(错误信息见文章最后)。

错误原因

看代码片段:

class Model(object):
    def __init__(self):
        self.graph()
        self.merged_summary = tf.summary.merge_all()# 引起血案的地方
...
with tf.name_scope("train"):
    with tf.variable_scope("var1",dtype=tf.float32):
        model1 = Model() # 这里的merge_all只是管理了自己的summary
with tf.name_scope("test"):
    with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
        model2 = Model()# 这里的merge_all管理了自己的summary和上边模型的Summary

由于Summary的计算是需要feed数据的,所以会报错。

解决方法

我们只需要替换掉merge_all就可以解决这个问题。看代码

class Model(object):
    def __init__(self,scope):
        self.graph()
        self.merged_summary = tf.summary.merge(
        tf.get_collection(tf.GraphKeys.SUMMARIES,scope)
        )
...
with tf.Graph().as_default():
    with tf.name_scope("train") as train_scope:
        with tf.variable_scope("var1",dtype=tf.float32):
            model1 = Model(train_scope)
    with tf.name_scope("test") as test_scope:
        with tf.variable_scope("var1",reuse=True,dtype=tf.float32):
            model2 = Model(test_scope)

关于tf.get_collection地址

当有多个模型时,出现类似错误,应该考虑使用的方法是不是涉及到了其他的模型

error

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor ‘train/var1/Placeholder’ with dtype float [Node: train/var1/Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device=”/job:localhost/replica:0/task:0/gpu:0”]]

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏恰同学骚年

设计模式的征途—8.桥接(Bridge)模式

在现实生活中,我们常常会用到两种或多种类型的笔,比如毛笔和蜡笔。假设我们需要大、中、小三种类型的画笔来绘制12中不同的颜色,如果我们使用蜡笔,需要准备3*12=...

893
来自专栏Flutter&Dart

DartVM服务器开发(第十四天)--Jaguar_ORM增删查改

正确做法,是先通过bean.find(primaryKey)查询该数据是否已经存在,然后再进行添加

641
来自专栏软件开发 -- 分享 互助 成长

组合模式

一、简介 1、组合模式将对象组合成树形结构以表示‘部分和整体’的层次结构。组合模式使得用户对单个对象和组合对象的使用具有一致性。 2、模式中的几个重要的类 Co...

1807
来自专栏wOw的Android小站

[设计模式]之十:组合模式

将对象组合成树形结构以表示“部分-整体”的层次结构。组合模式使得用户对单个对象和组合对象的使用具有一致性。

161
来自专栏社区的朋友们

ServerFrame::HashMap VS stl::unordered_map-性能探究之旅

突然就对项目中的 HashMap 有了强烈的好奇心,这个 HashMap 的实现够高效吗,和 std::unordered_map 的效率比较性能如何? 他们的...

2290
来自专栏别先生

mysql的时间戳timestamp精确到小数点后六位

公司业务使用到Greenplun数据库,根据查询的时间戳来不断的将每个时间段之间的数据,进行数据交换,但是今天发现,mysql的时间戳没有小数点后6位,即精确度...

531
来自专栏IMWeb前端团队

Nodejs进阶:如何将图片转成datauri嵌入到网页中去

本文作者:IMWeb 陈映平 原文出处:IMWeb社区 未经同意,禁止转载 问题:将图片转成datauri 今天,在QQ群有个群友问了个问题:“nod...

1748
来自专栏about云

Apache Spark 2.2中基于成本的优化器(CBO)

问题导读 1.什么是CBO,RBO? 2.什么是执行计划? 3.什么是join,filter? 4.事实表和维度表的区别? Apache Spark 2.2最近...

3647
来自专栏java 成神之路

java.util.Random 实现原理

2955
来自专栏王二麻子IT技术交流园地

八、VueJs 填坑日记之参数传递及内容页面的开发

我们在上一篇博文中,渲染出来了一个列表,并在列表中使用了router-link标签,标签内的:to就是链接地址,昨天咱们是<router-link :to="'...

1857

扫码关注云+社区