首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >我的自定义Keras图层中的add_update()不会更新权重

我的自定义Keras图层中的add_update()不会更新权重
EN

Stack Overflow用户
提问于 2020-04-07 10:53:09
回答 2查看 683关注 0票数 2

所以我实现了中心损失:https://ydwen.github.io/papers/WenECCV16.pdf,我在更新我的层中的权重时遇到了问题,这意味着在中心损失中更新中心。当我像这样打印我的class_centers时,它们永远不会改变。当我打印其他变量时,它们看起来没有问题,所以我能想到的唯一问题是add_update()没有做它应该做的事情。

自定义层:

代码语言:javascript
运行
复制
class CenterLossLayer(Layer):
    def __init__(self, alpha=0.5, **kwargs):
        self.alpha = alpha  
        super(CenterLossLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        print('Center loss input 1 (feature_size): ', input_shape[0][1])
        print('Center loss input 2 (num_classes): ', input_shape[1][1])
        self.class_centers = self.add_weight(name='class_centers',
                                       shape=(input_shape[1][1], input_shape[0][1]),
                                       initializer='uniform',
                                       trainable=False)
        super(CenterLossLayer, self).build(input_shape)

    def call(self, x, mask=None):
        embeddings, one_hots = x
        tf.print(self.class_centers, summarize=-1, output_stream='file:///tensors.txt')

        batch_centers = K.dot(one_hots, self.class_centers)
        batch_delta = batch_centers - embeddings

        class_delta = K.dot(K.transpose(one_hots), batch_delta)
        counts = K.sum(K.transpose(one_hots), axis=1, keepdims=True) + 1
        class_delta = class_delta / counts
        class_delta = K.in_train_phase(self.alpha * class_delta, 0 * class_delta)

        updated_class_centers = self.class_centers - class_delta
        self.add_update((self.class_centers, updated_class_centers), x[0])

        losses = K.sum(K.square(embeddings - batch_centers), axis=1, keepdims=True)

        return losses

    def compute_output_shape(self, input_shape):
        return (input_shape[1][0], )

最后的损失是:

代码语言:javascript
运行
复制
def batch_mean_loss(y_true, y_pred):
    return K.mean(y_pred, axis=0)

其中y_pred是来自CenterLossLayer的losses

奇怪的是,即使中心没有更新,中心损失也会随着每个时代的下降而下降,最终的模型比只使用Softmax损失训练的模型更好。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-04-09 03:34:03

所以我检查了add_update()是如何在BatchNormalization层中使用的:

代码语言:javascript
运行
复制
self.add_update([K.moving_average_update(self.moving_mean,
                                                 mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 variance,
                                                 self.momentum)],
                        inputs)

问题是,方法add_update()的第一个参数是"updates: Update op“,并且moving_average_update()返回”更新变量的操作。“所以我猜add_update()需要某种操作,而moving_average_update()会返回这些操作。我不知道如何创建这个操作,所以我做了:

代码语言:javascript
运行
复制
self.add_update(K.moving_average_update(self.class_centers, updated_class_centers, 0.0), x)

因此,它的功能就像用updated_class_centers替换self.class_centers一样,而且它是有效的。

即使它工作正常,如果有人知道如何正确地做这件事,我将不胜感激。

票数 1
EN

Stack Overflow用户

发布于 2020-12-09 19:01:43

看起来你应该这样做:

代码语言:javascript
运行
复制
class ComputeSum(keras.layers.Layer):
  def __init__(self, input_dim):
    super(ComputeSum, self).__init__()
    self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)

  def call(self, inputs):
    self.total.assign_add(tf.reduce_sum(inputs, axis=0))
    return self.total

https://keras.io/guides/making_new_layers_and_models_via_subclassing/#layers-can-have-nontrainable-weighto获取的代码片段

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61072079

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档