tntorch - Tensor Network Learning with PyTorch
by rballester
https://tntorch.readthedocs.io/
Github项目地址:
https://github.com/rballester/tntorch
New:我们的 Read the Docs 网站已经发布!
欢迎使用tntorch,一个使用张量网络的PyTorch驱动的建模和学习库。 这种网络的独特之处在于它们使用多线性神经单元(而不是非线性激活单元)。 功能包括:
可用的张量格式包括:
例如,以下网络都代表TT和TT-Tucker格式的4D张量(即可以采用 I1 x I2 x I3 x I4可能值的实数函数):
在tntorch 中,所有张量分解共享相同的接口。 你可以用容易理解的形式处理它们,就像它们是纯NumPy数组或PyTorch张量一样:
> import tntorch as tn> t = tn.randn(32, 32, 32, 32, ranks_tt=5) # Random 4D TT tensor of shape 32 x 32 x 32 x 32 and TT-rank 5> print(t)
4D TT tensor:
32 32 32 32 | | | | (0) (1) (2) (3) / \ / \ / \ / \1 5 5 5 1
> print(tn.mean(t))
tensor(8.0388)
> print(tn.norm(t))
tensor(9632.3726)
解压缩张量很容易:
> print(t.torch().shape)torch.Size([32, 32, 32, 32])
由于PyTorch的自动微分,你可以很容易地定义张量上的各种损失函数:
def loss(t): return torch.norm(t[:, 0, 10:, [3, 4]].torch()) # NumPy-like "fancy indexing" for arrays
最重要的是,损失函数也可以在压缩张量上定义:
def loss(t): return tn.norm(t[:3, :3, :3, :3] - t[-3:, -3:, -3:, -3:])
查看 introductory notebook ,了解有关基础知识的所有详细信息。
主要依赖项是 NumPy 和 PyTorch。 要下载并安装 tntorch ,请输入:
git clone https://github.com/rballester/tntorch.gitcd tntorchpip install .
我们使用 pytest 进行测试。 简单地运行以下命令即可:
cd tests/pytest