首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >pytorch中的加权mse损失

pytorch中的加权mse损失
EN

Stack Overflow用户
提问于 2019-07-12 17:49:46
回答 1查看 2.4K关注 0票数 0
代码语言:javascript
运行
复制
def weighted_mse_loss(input_tensor, target_tensor, weight = 1):
    observation_dim = input_tensor.size()[-1]
    streched_tensor = ((input_tensor - target_tensor) ** 2).view(-1, observation_dim)
    entry_num = float(streched_tensor.size())[0]
    non_zero_entry_num = torch.sum(streched_tensor[:,0] != 0).float()
    weighted_tensor = torch.mm(
        ((input_tensor - target_tensor)**2).view(-1, observation_dim),
        (torch.diag(weight.float().view(-1)))
    )
    return torch.mean(weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num

我不明白代码是如何给出加权均方误差损失的。我知道observation_dim是最终的输出维度(我猜是类号),但在那一行之后,我就不明白了。有人能帮我弄清楚代码是如何计算损失的吗?

非常感谢。

EN

回答 1

Stack Overflow用户

发布于 2020-08-11 11:09:56

代码语言:javascript
运行
复制
    def weighted_mse_loss(input, target, weight):
        return (weight * (input - target) ** 2).mean()

试试这个,希望这个能有所帮助。所有的参数都需要张量。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57004498

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档