PyTorch Handbook

224课时
3.1K学过
8分

10. CNN:MNIST数据集手写数字识别

11. RNN实例:通过Sin预测Cos

课程评价 (0)

请对课程作出评价:
0/300

学员评价

暂无精选评价
8分钟

NumPy 转换

将一个Torch Tensor转换为NumPy数组是一件轻松的事,反之亦然。

Torch Tensor与NumPy数组共享底层内存地址,修改一个会导致另一个的变化。

将一个Torch Tensor转换为NumPy数组

In [16]:

a = torch.ones(5)
print(a)

tensor([1., 1., 1., 1., 1.])

In [17]:

b = a.numpy()
print(b)

[1. 1. 1. 1. 1.]

观察numpy数组的值是如何改变的。

In [18]:

a.add_(1)
print(a)
print(b)

tensor([2., 2., 2., 2., 2.]) [2. 2. 2. 2. 2.]

NumPy Array 转化成 Torch Tensor

使用from_numpy自动转化

In [19]:

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

[2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

所有的 Tensor 类型默认都是基于CPU, CharTensor 类型不支持到 NumPy 的转换.