首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >InvalidArgumentError:输出形状的内部尺寸必须与更新形状的内部尺寸相匹配

InvalidArgumentError:输出形状的内部尺寸必须与更新形状的内部尺寸相匹配
EN

Stack Overflow用户
提问于 2021-05-22 18:22:59
回答 1查看 741关注 0票数 1

我正试图在keras中实现SPL损失。我所需要做的就是非常简单,我将用numpy来解释我需要什么:

代码语言:javascript
运行
复制
def spl_loss(y_true, y_pred, lmda):
    # compute any arbitrary loss function
    L = categorical_cross_entropy(y_true, y_pred)
    # set to zero those values with an error greater than lambda
    L[L>lmda] = 0
    return L

我正在尝试实现它遵循本教程,但是我在将值设置为零所需的步骤上遇到了麻烦。

目前,我有以下代码:

代码语言:javascript
运行
复制
def spl_loss(lmda, loss_fn):
    def loss(y_true, y_pred):
         # compute an arbitrary loss function, L
        loss_value = loss_fn(y_true, y_pred) # tensor of shape (64,)
        # get the mask of L greater than lmda
        mask = tf.greater( loss_value, tf.constant( float(lmda) ) )    # tensor of shape (64,)
        # compute indexes for the mask
        indexes = tf.reshape(tf.where(mask), [-1])  # tensor of shape (n,); where n<=64
        # set to zero values on indexes
        spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, tf.zeros_like(loss_value, dtype=loss_value.dtype) )  # this line gives the error
        
        return spl_loss_value
    return loss

根据文档tensor_scatter_nd_update操作应该执行赋值操作,但如果出现以下错误,则会失败:

代码语言:javascript
运行
复制
    spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, tf.zeros_like(loss_value, dtype=loss_value.dtype) )
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:5512 tensor_scatter_nd_update
        tensor=tensor, indices=indices, updates=updates, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_array_ops.py:11236 tensor_scatter_update
        _ops.raise_from_not_ok_status(e, name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:6862 raise_from_not_ok_status
        six.raise_from(core._status_to_exception(e.code, message), None)
    <string>:3 raise_from
        

    InvalidArgumentError: Inner dimensions of output shape must match inner dimensions of updates shape. Output: [64] updates: [64] [Op:TensorScatterUpdate]

我正在colab运行,这里你可以试试。

我尝试了几次重新塑造,因为我知道这是一个期待的形状与获得的问题,但我找不到方法。这里发生了什么事?

提前感谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-22 21:00:28

出现此错误的原因是,indices in tf.tensor_scatter_nd_update至少需要两轴,或者tf.rank(indices) > = 2需要完全填充。indices2D中(在标量更新中)的原因是保存两个信息,一个是更新 (num_updates)的长度和索引向量长度。有关此问题的详细概述,请参阅以下有关此问题的答案:最新消息?

以下是SPL丢失在tensorflow中的正确实现。

代码语言:javascript
运行
复制
def spl_loss(lmda):
    def loss(y_true, y_pred):
         # compute an arbitrary loss function, L
        loss_value = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)

        # get the mask of L greater than lmda
        mask = tf.greater( loss_value, tf.constant(float(lmda) ) )    

        # compute indexes for the mask
        indexes = tf.where(mask) # tensor of shape (n,); where n<=64
        updates = tf.reshape(tf.zeros_like(indexes, dtype=tf.float32), [-1])

        # scaler update check
        num_updates, index_depth = indexes.shape.as_list()
        assert updates.shape == [num_updates]
        assert index_depth == tf.rank(loss_value)


        # print()
        # print('A', tf.reshape(tf.where(mask), [-1])[:10].numpy()) 
        # print('B', tf.where(mask).numpy()[:10]) 
        # print('Ranks: ', tf.rank(loss_value).numpy(), 
        #                  tf.rank(indices).numpy(), 
        #                   tf.rank(updates).numpy())
        # print('Shape: ', loss_value.shape, indexes.shape, updates.shape)

        # set to zero values on indexes
        spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, updates )

        return spl_loss_value
    return loss
代码语言:javascript
运行
复制
...
model.compile(optimizer="adam", loss=spl_loss(lmda=2.), run_eagerly=True)
...

参考文献:更新

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67652872

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档