pytorch: 常见bug与debug

本博文用来记录自己的 pytorch 踩坑过程,虽然 pytorch 命令式编程,声称容易 debug,可是 代码给出的错误提示可是令人相当头疼,所以在本文中记录一下在编写 pytorch 代码过程中碰到的 坑,和如何 填坑。

  • TypeError: ××× received an invalid combination of arguments 如果检查过了数据类型的正确性之后(float32, int) 。下一步要关心的就是 op 操作的两个 Variable/Tensor 是不是在同一个 设备上 ,如果一个在 cpu 上,一个在 gpu 上就有可能会报错
  • 注意 op 的参数要求,有些是 要求 Tensor 有些 是要求 Variable ,有些是 都可以。
  • 当需要 求梯度时,一个 op 的两个输入都必须是要 Variable:
# 这段代码,如果 requires_grad=False,  a 是 Tensor,则是没错的
# 但是 requires_grad=True, a 是 Tensor,则会报错
# 这时的报错信息是 
# save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this condition
# requires_grad=True,  a是 Variable, 则不会报错
import torch
from torch.autograd import Variable

v1 = Variable(torch.FloatTensor([1., 2., 3.]), requires_grad=True)
a = Variable(torch.FloatTensor([1., 0., 0.]).type(new_type=torch.ByteTensor))

res = torch.masked_select(v1, a)
res = 3 * res

res.backward(torch.FloatTensor([1.]))
print(v1.grad)
  • 卷积层 -> 全连接层,中间一定要 view 一下,否则会shape不匹配

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏SimpleAI

令人困惑的TensorFlow【1】

我叫 Jacob,是 Google AI Resident 项目的研究学者。我是在 2017 年夏天加入该项目的,尽管已经拥有了丰富的编程经验,并且对机器学习的...

782
来自专栏数据结构与算法

01:谁考了第k名 个人博客:doubleq.win

个人博客:doubleq.win 01:谁考了第k名 查看 提交 统计 提问 总时间限制: 1000ms 内存限制: 65536kB描述 在一次考试中,每个学...

3155
来自专栏Rovo89

UML类图的学习笔记

903
来自专栏小詹同学

Leetcode打卡 | No.011 盛最多水的容器

欢迎和小詹一起定期刷leetcode,每周一和周五更新一题,每一题都吃透,欢迎一题多解,寻找最优解!这个记录帖哪怕只有一个读者,小詹也会坚持刷下去的!

1262
来自专栏mathor

LeetCode69. x 的平方根

 这道题直接一个return Math.sqrt就出来了,但是秉承着学习的心态,尝试着用二分法ac  首先要确定的就是左右区间,左区间是0无疑了,那么右...

1032
来自专栏机器之心

令人困惑的TensorFlow!

我叫 Jacob,是 Google AI Resident 项目的研究学者。我是在 2017 年夏天加入该项目的,尽管已经拥有了丰富的编程经验,并且对机器学习的...

1253
来自专栏Linyb极客之路

写出优质Java代码的4个技巧

如果现在要求对你写的Java代码进行优化,那你会怎么做呢?作者在本文介绍了可以提高系统性能以及代码可读性的四种方法,如果你对此感兴趣,就让我们一起来看看吧。

961
来自专栏灯塔大数据

每周学点大数据 | No.12数据流中的频繁元素

No.12期 数据流中的频繁元素 Mr. 王:我们再来讲一个例子,数据流中的频繁元素。我们先来说说大数据的数据流模型。 小可:数据流,是流动的数据的意思吗?和...

3087
来自专栏程序员叨叨叨

7.2 uniform

Cg 语言将输入数据流分为两类(参见文献[3]Program inputs and Outputs ):

584
来自专栏向治洪

java解决hash算法冲突

看了ConcurrentHashMap的实现, 使用的是拉链法. 虽然我们不希望发生冲突,但实际上发生冲突的可能性仍是存在的。当关键字值域远大于哈希表的长度...

2089

扫码关注云+社区