首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >将批规范的is_training (TensorFlow)转换为False

将批规范的is_training (TensorFlow)转换为False
EN

Stack Overflow用户
提问于 2017-09-30 00:14:49
回答 1查看 3.1K关注 0票数 3

我想在训练后将模型的is_training状态转换为False,我怎么能做到这一点?

代码语言:javascript
复制
net = tf.layers.conv2d(inputs = features, filters = 64, kernel_size = [3, 3], strides = (2, 2), padding = 'same')
net = tf.contrib.layers.batch_norm(net, is_training = True)
net = tf.nn.relu(net)
net = tf.reshape(net, [-1, 64 * 7 * 7]) #
net = tf.layers.dense(inputs = net, units = class_num, kernel_initializer = tf.contrib.layers.xavier_initializer(), name = 'regression_output')

#......
#after training

saver = tf.train.Saver()
saver.save(sess, 'reshape_final.ckpt')
tf.train.write_graph(sess.graph.as_graph_def(), "", 'graph_final.pb')

保存后,如何将批处理规范的is_training转换为False

我试过像tensorflow batchnorm这样的关键词训练,tensorflow变化状态,但找不出如何去做。

编辑1:

由于@Maxim解决方案,它可以工作,但当我试图冻结图形时,还会出现另一个问题。

指挥:

代码语言:javascript
复制
python3 ~/.keras2/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py --input_graph=graph_final.pb --input_checkpoint=reshape_final.ckpt --output_graph=frozen_graph.pb --output_node_names=regression_output/BiasAdd

python3 ~/.keras2/lib/python3.5/site-packages/tensorflow/python/tools/optimize_for_inference.py --input frozen_graph.pb --output opt_graph.pb --frozen_graph True --input_names input --output_names regression_output/BiasAdd

~/Qt/3rdLibs/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=opt_graph.pb --out_graph=fused_graph.pb --inputs=input --outputs=regression_output/BiasAdd --transforms="fold_constants sort_by_execution_order fold_batch_norms fold_old_batch_norms"

执行transform_graph后,会弹出错误消息。

“您必须使用dtype bool为占位符张量‘训练’提供一个值。”

我通过以下代码保存图表:

代码语言:javascript
复制
sess.run(loss, feed_dict={features : train_imgs, x : real_delta, training : False})
saver = tf.train.Saver()
saver.save(sess, 'reshape_final.ckpt')
tf.train.write_graph(sess.graph.as_graph_def(), "", 'graph_final.pb')

编辑2:

将占位符更改为变量有效,但转换后的图形不能由opencv加载。

变化

代码语言:javascript
复制
training = tf.placeholder(tf.bool, name='training')

代码语言:javascript
复制
training = tf.Variable(False, name='training', trainable=False)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-09-30 11:13:55

您应该为模式定义一个placeholder变量(可以是布尔值,也可以是字符串),并在培训和测试期间将不同的值传递给session.run。样本代码:

代码语言:javascript
复制
x = tf.placeholder('float32', (None, 784), name='x')
y = tf.placeholder('float32', (None, 10), name='y')
phase = tf.placeholder(tf.bool, name='phase')
...

# training (phase = 1)
sess.run([loss, accuracy], 
         feed_dict={'x:0': mnist.train.images,
                    'y:0': mnist.train.labels,
                    'phase:0': 1})
...

# testing (phase = 0)
sess.run([loss, accuracy],
         feed_dict={'x:0': mnist.test.images,
                    'y:0': mnist.test.labels,
                    'phase:0': 0})

您可以在这个职位中找到完整的代码。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46498332

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档