首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >Tensorflow:将值输入到现有占位符或将其删除

Tensorflow:将值输入到现有占位符或将其删除
EN

Stack Overflow用户
提问于 2018-06-03 19:26:58
回答 1查看 579关注 0票数 1

我正在实现一个自定义的keras回调,并且在同一模型上执行两个连续的训练阶段。

在回调中,我创建了几个占位符,以便在训练结束时为评估提供一些度量值。对于第一个训练阶段,这是很好的,因为占位符不存在,然而,在第二个训练阶段,这将导致错误,因为tensorflow将创建第二组占位符,但具有索引名称。

因此,我正在寻找一种解决方案,要么将第一个训练阶段的值输入到占位符中(可能类似于按名称查找占位符,然后将值输入其中),要么按名称删除某些占位符,以便我可以创建新的占位符

编辑:

来澄清我目前的处境。我实现了这个自定义的Keras回调函数(我将省略度量的计算):

代码语言:javascript
复制
class Metric(keras.callbacks.Callback):


def __init__(self):
    self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, name="prec")
    tf.summary.scalar("val_precision", self.val_prec_ph)

    self.merged = tf.summary.merge_all()
    self.writer = tf.summary.FileWriter(self.log_dir)

def on_train_begin(self, logs={}):

    self.precision = []

def on_train_end(self, logs={}):

   //do some calculations

    self.precision.append(calculation)

    summary = self.session.run(self.merged,
                               feed_dict={self.val_prec_ph: self.precision[-1]})

    self.writer.add_summary(summary)
    self.writer.flush()

这基本上就是我用来做占位符的框架。由于连续运行,tensorflow将执行以下操作:第一次训练将不会出现问题,并将占位符命名为"prec“。然而,在第二次运行中,tensorflow会将self.val_prec_ph占位符命名为类似于"prec_“的名称,这将导致错误,即"prec”占位符尚未被馈送,尽管它仍然在那里。

因此,我要么直接写入"prec“占位符,要么在第一次运行后将其删除,这样就不会有重复的内容。

这就是为什么我在训练结束的时候这样做的原因。是另一个故事,它有另一个问题。

EN

回答 1

Stack Overflow用户

发布于 2018-06-03 21:08:22

以下是您特定问题的可能解决方案,按名称在图形中搜索占位符(使用tf.Graph().get_tensor_by_name()),如果找不到则创建占位符:

代码语言:javascript
复制
class Metric(keras.callbacks.Callback):

    def __init__(self, ph_name="prec"):
        try:
            self.val_prec_ph = tf.get_default_graph().get_tensor_by_name(
                ph_name + ':0')
            # Check this solution by @rvinas to cover possible suffix/scope errors:
            # https://stackoverflow.com/a/38935343/624547
        except KeyError:
            self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, 
                                              name=ph_name)

        tf.summary.scalar("val_precision", self.val_prec_ph)

        self.merged = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter(self.log_dir)

    # ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50665873

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档