首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >add_loss方法的tf.keras.model文档

add_loss方法的tf.keras.model文档
EN

Stack Overflow用户
提问于 2020-05-11 11:47:55
回答 1查看 2K关注 0票数 2

每当我在过去使用过tensorflow的keras时,我就为model.compile模型指定了损失函数。我目前正在做一个回购,它使用'add_loss‘来指定model.call中的丢失函数。或者至少,这就是我所假设的情况,因为我找不到这个方法的任何文档(即docs/python/tf/keras/Model上没有),我也找不到任何使用该方法的教程。更重要的是,我甚至不知道它是在源代码中定义的。

代码语言:javascript
运行
复制
class TRPO(Model):
    def __init__(self, obs_dim, act_dim, hid1_mult, kl_targ, init_logvar, eta, **kwargs):
        super(TRPO, self).__init__(**kwargs)
        self.kl_targ = kl_targ
        self.eta = eta
        self.beta = self.add_weight('beta', initializer='zeros', trainable=False)
        self.policy = PolicyNN(obs_dim, act_dim, hid1_mult, init_logvar)
        self.logprob = LogProb()
        self.kl_entropy = KLEntropy()

    def call(self, inputs):
        obs, act, adv, old_means, old_logvars, old_logp = inputs
        new_means, new_logvars = self.policy(obs)
        new_logp = self.logprob([act, new_means, new_logvars])
        kl, entropy = self.kl_entropy([old_means, old_logvars,
                                       new_means, new_logvars])
        loss1 = -K.mean(adv * K.exp(new_logp - old_logp))
        loss2 = K.mean(self.beta * kl)
        # TODO - Take mean before or after hinge loss?
        loss3 = self.eta * K.square(K.maximum(0.0, K.mean(kl) - 2.0 * self.kl_targ))
        self.add_loss(loss1 + loss2 + loss3)

        return [kl, entropy]

任何人都有使用add_loss的经验,并能指出它是如何工作的?并解释为什么您不只是编写自己的损失函数并将其传递到model.compile中?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-10 08:54:07

您可以找到add_loss 这里的正式文档。添加损耗张量,可能依赖于层输入。这种方法可以在子类层或模型的调用函数中使用,在这种情况下,损失应该是张量或张量列表。文档中几乎没有解释add_loss的例子。

您可以在add_losstf.keras.layers.Layer中找到源代码。这是所有层继承的类。点击“查看GitHub上的源代码”并搜索add_loss

来问你的问题-

任何人都有使用add_loss的经验,并能指出它是如何工作的?

您可以使用add_loss、这里这里找到一个很好的例子,并给出解释。

并解释为什么您不只是编写自己的损失函数并将其传递到model.compile中?

在Tensorflow中,model.compile()损耗函数通常包含两个参数y_truey_pred。使用model.add_loss()没有这样的限制,并且允许您编写依赖于许多其他张量的更复杂的损失,但是它有更多依赖于模型的不便,而标准损失函数只适用于任何模型。

希望这能回答你的问题。学习愉快。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61729000

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档