前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-张量的运算和类型陷阱

PyTorch入门笔记-张量的运算和类型陷阱

作者头像
触摸壹缕阳光
发布2021-03-16 11:00:03
1.8K0
发布2021-03-16 11:00:03
举报

加、减、乘、除

加、减、乘、除是最基本的数学运算,分别通过 torch.addtorch.subtorch.multorch.div 函数实现,Pytorch 已经重载了 +、-、* 和 / 运算符。

代码语言:javascript
复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# add加法
print(torch.add(x, y))
# tensor([[1., 2.],
#         [3., 4.]])
print(x + y)
# tensor([[1., 2.],
#         [3., 4.]])

# subtraction减法
print(torch.sub(x, y))
# tensor([[ 1.,  0.],
#         [-1., -2.]])
print(x - y)
# tensor([[ 1.,  0.],
#         [-1., -2.]])

# multiplication乘法
print(torch.mul(x, y))
# tensor([[0., 1.],
#         [2., 3.]])
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# division除法
print(torch.div(x, y))
# tensor([[   inf, 1.0000],
#         [0.5000, 0.3333]])
print(x / y)
# tensor([[   inf, 1.0000],
#         [0.5000, 0.3333]])

这里需要注意,张量 y 的第一个元素为 0,而在 x 和 y 进行除法运算时,y 中的 0 作为了除数。在 PyTorch 中,除数为 0 时程序并不会报错,而是的等于 inf。

这些加、减、乘、除基本的数学运算在 PyTorch 中的实现都比较简单,但是在使用过程中还是需要注意以下几点(下面都以乘法为例,其余三种运算同理):

  1. 参与基本数学运算的张量必须形状一致,或者可以通过广播机制扩展到相同的形状;
代码语言:javascript
复制
import torch

x = torch.ones(1, 2)
y = torch.arange(4).reshape(2, 2)

# 此时的x通过广播机制形状变成(2, 2)
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# 此时将张量y的形状变成(1, 4)
y = y.reshape(1, 4)
# 此时x和y不满足广播机制
print(x * y)
'''
Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/test01.py", line 224, in <module>
    print(x * y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1
'''
  1. 基本的数学运算与 NumPy 一样,都是 Element-Wise(逐元素运算),因此 torch.mul 实现的并不是张量乘法(两个张量相乘后的张量形状遵循:中间相等取两头的规则),而是相乘张量中对应位置的元素相乘;
代码语言:javascript
复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# 矩阵乘法
# 矩阵相乘需要保证张量中元素一致
y = y.float()
print(torch.matmul(x, y))
# tensor([[2., 4.],
#         [2., 4.]])
  1. 基本的数学运算支持两种接口,换句话说,可以使用 tensor.addtensor.subtensor.multensor.div
代码语言:javascript
复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

print(y.mul(x))
# tensor([[0., 1.],
#         [2., 3.]])
  1. 基本的数学运算也支持原地操作(in-place operation);
代码语言:javascript
复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(y.mul(x))
# tensor([[0., 1.],
#         [2., 3.]])

print(y) # 张量y没有改变
# tensor([[0, 1],
#         [2, 3]])

y = y.float()
print(y.mul_(x))
# tensor([[0., 1.],
#         [2., 3.]])

print(y) # 张量y = x * y
# tensor([[0., 1.],
#         [2., 3.]])

类型陷阱

本小节我们一共使用了 2 次 y = y.float,第一次在第 2 点演示矩阵乘法(torch.matmul(x, y))之前,第二次在第 4 点演示原地操作(y.mul_(x))之前。这是因为生成张量 x 和 y 的类型不一致,当然本小节使用的都是 torch.arange 函数生成张量 y,这也是为了说明类型陷阱的问题。

代码语言:javascript
复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

print(x.dtype)
# torch.float32

print(y.dtype)
# torch.int64

虽然加减乘除基本运算对张量的类型没有要求,但是有一些运算操作对运算的张量类型还是比较敏感的。

  • 矩阵乘法要求相乘的张量类型一致;
  • 原地操作由于将运算后的张量赋值给原始张量,但是如果运算后的张量和原始张量的类型不一样,也会抛出错误。比如张量 y 为 torch.int64x * y 后的张量为 torch.float32 类型,将 torch.float32 类型的张量赋值给 torch.int64 的张量 y,程序会抛出错误;
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-02-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 类型陷阱
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档