前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-张量相乘matmul函数02

PyTorch入门笔记-张量相乘matmul函数02

作者头像
触摸壹缕阳光
发布2021-03-16 11:02:32
5.4K0
发布2021-03-16 11:02:32
举报

Matmul 函数

torch.matmul(input, other, out = None) 函数对 input 和 other 两个张量进行矩阵相乘。torch.matmul 函数根据传入参数的张量维度有很多重载函数。为了方便后续的介绍,将传入 input 参数中的张量命名为 a,而传入 other 参数的张量命名为 b。

  • 若 a 为 1D 张量,b 为 2D 张量,torch.matmul 函数:
    • 首先,在 1D 张量 a 的前面插入一个长度为 1 的新维度变成 2D 张量;
    • 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
    • 最后,将矩阵乘积结果中长度为 1 的维度(前面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
代码语言:javascript
复制
import torch

# a为1D张量,b为2D张量
a = torch.tensor([1., 2.])
b = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

result = torch.matmul(a, b)
print(result.size())
# torch.Size([3])

print(result)
# tensor([21., 24., 27.])

  • 若 a 为 2D 张量,b 为 1D 张量,torch.matmul 函数:
    • 首先,在 1D 张量 b 的后面插入一个长度为 1 的新维度变成 2D 张量;
    • 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
    • 最后,将矩阵乘积结果中长度为 1 的维度(后面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
代码语言:javascript
复制
import torch

# a为2D张量,b为1D张量
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.tensor([7., 8., 9.])

result = torch.matmul(a, b)
print(result.size())
# torch.Size([2])

print(result)
# tensor([50., 122.])

具体细节和 a 为 1D 张量,b 为 2D 张量的情况差不多,只不过,一个在 1D 张量的前面插入长度为 1 的新维度(a 为 1D 张量,b 为 2D 张量),另一个是在 1D 张量的后面插入长度为 1 的新维度(a 为 2D 张量,b 为 1D 张量)。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-03-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Matmul 函数
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档