首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >Tensorflow Estimator API:摘要

Tensorflow Estimator API:摘要
EN

Stack Overflow用户
提问于 2017-02-11 01:00:55
回答 2查看 6.8K关注 0票数 10

我无法使用Tensorflow的Estimator API来实现摘要。

Estimator类非常有用,原因有很多:我已经实现了自己的类,这些类非常相似,但我正在尝试切换到这个类。

以下是代码示例:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow.contrib.layers as layers
import tensorflow.contrib.learn as learn
import numpy as np

 # To reproduce the error: docker run --rm -w /algo -v $(pwd):/algo tensorflow/tensorflow bash -c "python sample.py"

def model_fn(x, y, mode):
    logits = layers.fully_connected(x, 12, scope="dense-1")
    logits = layers.fully_connected(logits, 56, scope="dense-2")
    logits = layers.fully_connected(logits, 4, scope="dense-3")

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y), name="xentropy")

    return {"predictions":logits}, loss, tf.train.AdamOptimizer(0.001).minimize(loss)


def input_fun():
    """ To be completed for a 4 classes classification problem """

    feature = tf.constant(np.random.rand(100,10))
    labels = tf.constant(np.random.random_integers(0,3, size=(100,)))

    return feature, labels

estimator = learn.Estimator(model_fn=model_fn, )

trainingConfig = tf.contrib.learn.RunConfig(save_checkpoints_secs=60)

estimator = learn.Estimator(model_fn=model_fn, model_dir="./tmp", config=trainingConfig)

# Works
estimator.fit(input_fn=input_fun, steps=2)

# The following code does not work

# Can't initialize saver

# saver = tf.train.Saver(max_to_keep=10) # Error: No variables to save

# The following fails because I am missing a saver... :(

hooks=[
        tf.train.LoggingTensorHook(["xentropy"], every_n_iter=100),
        tf.train.CheckpointSaverHook("./tmp", save_steps=1000, checkpoint_basename='model.ckpt'),
        tf.train.StepCounterHook(every_n_steps=100, output_dir="./tmp"),
        tf.train.SummarySaverHook(save_steps=100, output_dir="./tmp"),
]

estimator.fit(input_fn=input_fun, steps=2, monitors=hooks)

如您所见,我可以创建一个Estimator并使用它,但我可以实现在拟合过程中添加钩子。

日志钩子工作得很好,但是其他钩子需要张量、保存器,这是我无法提供的。

张量是在模型函数中定义的,因此我不能将它们传递给SummaryHook,也无法初始化,因为没有要保存的张量……

我的问题有解决方案吗?(我猜是的,但是在tensorflow文档中缺少这一部分的文档)

  • 如何初始化我的保护程序?或者我应该使用其他对象,如Scaffold
  • How can I pass to SummaryHook,因为它们是在我的模型函数中定义的?

提前谢谢。

PS:我已经看过DNNClassifier应用程序接口,但我想使用卷积网络和其他的估计器应用程序接口。我需要为任何估计器创建摘要。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-02-11 06:03:51

预期的用例是让Estimator为您保存摘要。RunConfig中有一些用于配置摘要编写的选项。如果为constructing the Estimator,则传递RunConfigs。

票数 10
EN

Stack Overflow用户

发布于 2018-06-19 22:26:51

只需在model_fn中使用tf.summary.scalar("loss", loss),然后在没有summary_hook的情况下运行代码。损失被记录下来并显示在拉力板上。

另请参阅:

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42164772

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档