前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-常见的矩阵乘法

PyTorch入门笔记-常见的矩阵乘法

作者头像
触摸壹缕阳光
发布2021-03-16 11:06:14
1.5K0
发布2021-03-16 11:06:14
举报

前言

前文介绍了根据传入参数的张量维度决定其实现功能的 torch.matmul 函数。torch.matmul 函数功能强大,虽然可以使用其重载的运算符 @,但是使用起来比较麻烦,并且在实际使用场景中,常用的矩阵乘积运算就那么几种。为了方便使用这些常用的矩阵乘积运算,PyTorch 提供了一些更为方便的函数。

二维矩阵乘法

神经网络中包含大量的 2D 张量矩阵乘法运算,而使用 torch.matmul 函数比较复杂,因此 PyTorch 提供了更为简单方便的 torch.mm(input, other, out = None) 函数。下表是 torch.matmul 函数和 torch.mm 函数的简单对比。

torch.matmul 函数支持广播,主要指的是当参与矩阵乘积运算的两个张量中其中有一个是 1D 张量,torch.matmul 函数会将其广播成 2D 张量参与运算,最后将广播添加的维度删除作为最终 torch.matmul 函数的返回结果。torch.mm 函数不支持广播,相对应的输入的两个张量必须为 2D。

代码语言:javascript
复制
import torch

input = torch.tensor([[1., 2.], [3., 4.]])
other = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

result = torch.mm(input, other)
print(result)
# tensor([[21., 24., 27.],
#         [47., 54., 61.]])

批量矩阵乘法

同理,由于 torch.bmm 函数不支持广播,相对应的输入的两个张量必须为 3D。

代码语言:javascript
复制
import torch

input = torch.randn(10, 3, 4)
other = torch.randn(10, 4, 2)

result = torch.bmm(input, other)

print(result.size())
# torch.Size([10, 3, 2])
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-03-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 二维矩阵乘法
  • 批量矩阵乘法
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档