前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch常用函数功能使用(一)

Pytorch常用函数功能使用(一)

作者头像
大黄大黄大黄
发布2019-03-12 16:51:57
1.1K0
发布2019-03-12 16:51:57
举报
文章被收录于专栏:机器学习从入门到成神

1. view

代码语言:javascript
复制
import torch

number_1 = torch.randn(2, 3)

print(number_1)
print(number_1.shape)

print(number_1.view(1, -1))
print(number_1.view(3, -1))

输出:

代码语言:javascript
复制
tensor([[ 1.0506, -0.5875, -1.2477],
        [ 0.0635,  0.8997,  0.1551]])
        
torch.Size([2, 3])

tensor([[ 1.0506, -0.5875, -1.2477,  0.0635,  0.8997,  0.1551]])

tensor([[ 1.0506, -0.5875],
        [-1.2477,  0.0635],
        [ 0.8997,  0.1551]])

View(a,b)中第一个参数a代表目标张量的行数,b代表列数。为了简便起见,也可以只指定第一个参数a,b这个参数设置成-1,函数会自动计算对应的列数。

2. squeeze

代码语言:javascript
复制
number_2 = torch.randn(2, 1)

print(number_2)
print(torch.squeeze(number_2))
print(torch.squeeze(number_2, 0))
print(torch.squeeze(number_2, 1))

输出:

代码语言:javascript
复制
tensor([[ 0.5856],
        [-1.7095]])
tensor([ 0.5856, -1.7095])
tensor([[ 0.5856],
        [-1.7095]])
tensor([ 0.5856, -1.7095])

Squeeze的功能是进行维度缩减(维度为1的删除)。Squeeze(a,b)中第一个参数a代表传入的张量,b代表要缩减的维数。如果第二个参数没有指定,则默认删除所有维度为1的维度

代码语言:javascript
复制
number_3 = torch.randn(1, 2)

print(number_3)
print(torch.squeeze(number_3))
print(torch.squeeze(number_3, 0))
print(torch.squeeze(number_3, 1))

输出:

代码语言:javascript
复制
tensor([[ 0.1555, -0.4286]])
tensor([ 0.1555, -0.4286])
tensor([ 0.1555, -0.4286])
tensor([[ 0.1555, -0.4286]])

3. unsqueeze

代码语言:javascript
复制
number_4 = torch.randn(3, 2)

print(number_4)
print(torch.unsqueeze(number_4, 0))
print(torch.unsqueeze(number_4, 1))

输出:

代码语言:javascript
复制
tensor([[ 0.0358, -0.2769],
        [-0.3257,  0.1895],
        [ 1.9278, -0.9444]])
tensor([[[ 0.0358, -0.2769],
         [-0.3257,  0.1895],
         [ 1.9278, -0.9444]]])
tensor([[[ 0.0358, -0.2769]],

        [[-0.3257,  0.1895]],

        [[ 1.9278, -0.9444]]])

Unsqueeze的功能与squeeze相反,可以增加张量的维度。Unqueeze(a,b)中第一个参数a代表传入的张量,b代表要增加维度的维数。

4. max

代码语言:javascript
复制
number_5 = torch.randn(2, 3)
print(number_5)
print(torch.max(number_5, 0))
print(torch.max(number_5, 1))

输出:

代码语言:javascript
复制
tensor([[-0.4916,  1.3999,  1.0527],
        [ 1.0194, -2.4695, -0.2378]])
(tensor([1.0194, 1.3999, 1.0527]), tensor([1, 0, 0]))
(tensor([1.3999, 1.0194]), tensor([1, 0]))

Max的功能是返回对应维度最大的数以及对应的索引。Max(a,b)中第一个参数a代表传入的张量,b代表要对应的维数。0代表返回每一列的最大值,1代表返回每一行的最大值。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年03月03日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. view
  • 2. squeeze
  • 3. unsqueeze
  • 4. max
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档