我有一个张量,我只想复制其中的一些值(按列)。相同的值在另一个张量中,但顺序是随机的。我想要的是来自tensor1值的tensor2的列索引。下面是一个示例:
copy_ind = torch.tensor([0, 1, 3], dtype=torch.long)
tensor1 = torch.tensor([[4, 6, 5, 1, 8],[10, 0, 8, 2, 1]])
temp = torch.index_select(tensor1, 1, copy_ind) # values to copy
tensor2 = torch.tensor([[1, 4, 5, 6, 8],[2, 10, 8, 0, 1]], dtype=torch.long)
_, t_ind = torch.sort(temp[0], dim=0)
t2_ind = copy_ind[t_ind] # indices of tensor2输出应为:
t2_ind = [1, 3, 0]这是另一个例子,我想根据c1_new得到张量的值
c1 = torch.tensor([[6, 7, 7, 8, 6, 8, 9, 4, 7, 6, 1, 3],[5, 11, 5, 7, 2, 9, 5, 5, 7, 11, 10, 7]], dtype=torch.long)
copy_ind = torch.tensor([1, 2, 3, 5, 7, 8], dtype=torch.long)
c1_new = torch.index_select(c1, 1, copy_ind)
indices = torch.as_tensor([[1, 3, 4, 6, 6, 6, 7, 7, 7, 8, 8, 9], [10, 7, 5, 2, 5, 11, 5, 7, 11, 7, 9, 5]])
values = torch.randn(12)
tensor = torch.sparse.FloatTensor(indices, values, (12, 12))
_, t_ind = torch.sort(c1[0], dim=0)
ind = t_ind[copy_ind] # should be [8, 6, 9, 10, 2, 7]不幸的是,索引ind是不正确的。谁能帮帮我?
发布于 2021-01-24 05:26:09
如果你同意使用for循环,你可以这样做:对照tensor2的列检查临时张量的每一列:
编辑:在维度1中使用torch.prod以确保两行都匹配
[torch.prod((temp.T[i] == tesnor2.T), dim=1).nonzero()[0] for i in range(temp.size(1))]我的第一个示例的输出是[tensor(1), tensor(3), tensor(0)]
https://stackoverflow.com/questions/65860296
复制相似问题