8分钟
NumPy 转换
将一个Torch Tensor转换为NumPy数组是一件轻松的事,反之亦然。
Torch Tensor与NumPy数组共享底层内存地址,修改一个会导致另一个的变化。
将一个Torch Tensor转换为NumPy数组
In [16]:
a = torch.ones(5)
print(a)tensor([1., 1., 1., 1., 1.])
In [17]:
b = a.numpy()
print(b)[1. 1. 1. 1. 1.]
观察numpy数组的值是如何改变的。
In [18]:
a.add_(1)
print(a)
print(b)tensor([2., 2., 2., 2., 2.]) [2. 2. 2. 2. 2.]
NumPy Array 转化成 Torch Tensor
使用from_numpy自动转化
In [19]:
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)[2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
所有的 Tensor 类型默认都是基于CPU, CharTensor 类型不支持到 NumPy 的转换.
学员评价