前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch 的 dim 和 numpy 的axis 表示方向不同

torch 的 dim 和 numpy 的axis 表示方向不同

作者头像
用户2965768
发布2021-12-06 11:22:39
5260
发布2021-12-06 11:22:39
举报
文章被收录于专栏:wym

1. torch中以index_select为例子

torch.index_select(input, dim, index, out=None) - 功能:在维度dim上,按index索引数据 - 返回值:依index索引数据拼接的张量 - index:要索引的张量 - dim:要索引的维度 - index:要索引数据的序号

代码语言:javascript
复制
x = torch.randn(3, 4)
print(x)
indices = torch.tensor([0, 2])
torch.index_select(x, 1, indices)



#把1改为0
y = torch.randn(3, 4)
print(y)
indices = torch.tensor([0, 2])
torch.index_select(y, 0, indices)

输出如下,可以看出,dim=1时按照列索引;dim=0时,按照行索引

代码语言:javascript
复制
tensor([[ 1.9626,  0.1007, -1.2005,  1.2650],
        [ 0.3603,  0.6343, -0.6197,  0.5740],
        [-0.0798,  0.9674, -0.7761,  0.5552]])
tensor([[ 1.9626, -1.2005],
        [ 0.3603, -0.6197],
        [-0.0798, -0.7761]])


tensor([[ 0.2274, -2.1934, -0.3129,  0.3869],
        [ 0.3831, -0.7156, -1.0765, -2.1098],
        [-0.8007, -0.0095,  0.8703, -0.8797]])
tensor([[ 0.2274, -2.1934, -0.3129,  0.3869],
        [-0.8007, -0.0095,  0.8703, -0.8797]])

2.numpy 中 以mean为例

代码语言:javascript
复制
x = numpy.random.randint(1,10,(3,4))
print(x)
print(x.mean(0))


y = numpy.random.randint(1,10,(3,4))
print(y)
print(y.mean(1))

输出如下,axis = 0时,按照竖直方向从上往下计算均值,输出4个数;axis=1时,按照水平方向从左往右计算均值,输出三个数。

代码语言:javascript
复制
[[6 8 4 9]
 [7 5 9 3]
 [1 7 6 1]]
[4.66666667 6.66666667 6.33333333 4.33333333]


[[3 3 6 5]
 [4 3 1 5]
 [7 2 2 5]]
[4.25 3.25 4.  ]
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/07/22 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档