首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch RuntimeError 解决办法

Pytorch RuntimeError 解决办法

作者头像
花猪
发布2023-03-01 10:47:46
1.8K0
发布2023-03-01 10:47:46
举报

问题描述

在Pytorch训练自定义数据集中发生如下错误:

RuntimeError: result type Float can't be cast to the desired output type Long

RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights]))

问题解决

BCEWithLogitsLoss 要求它的目标是一个float 张量,而不是long。所以应该通过dtype=torch.float32指定张量的类型。

将上述代码修改如下:

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32))

参考文章:Pytorch 抛出错误 RuntimeError: result type Float can’t be cast to the desired output type Long答案 - 爱码网 (likecs.com)

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2023-02-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 问题描述
  • 问题解决
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档