首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >沿对角线如何替换PyTorch张量中的特定值?

沿对角线如何替换PyTorch张量中的特定值?
EN

Stack Overflow用户
提问于 2021-03-23 03:16:53
回答 2查看 1.3K关注 0票数 1

例如,有一个PyTorch矩阵A

代码语言:javascript
运行
复制
A = tensor([[3,2,1],[1,0,2],[2,2,0]])

我需要在对角线上用1替换0,所以结果应该是:

代码语言:javascript
运行
复制
tensor([[3,2,1],[1,1,2],[2,2,1]])
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-03-23 09:42:19

您可以使用torch内置的对角线函数替换如下所示的对角线元素:

代码语言:javascript
运行
复制
mask = A.diagonal() == 0
A += torch.diag(mask)
代码语言:javascript
运行
复制
>>> A
tensor([[3, 2, 1],
        [1, 1, 2],
        [2, 2, 1]])

如果要用另一个值替换0,请将mask更改为mask * replace_value

票数 2
EN

Stack Overflow用户

发布于 2021-03-23 03:46:32

您可以使用向量索引提取对角线,对其进行处理,然后将其放入原始矩阵中:

代码语言:javascript
运行
复制
N=10
a = torch.randint(0,N,[N,N])

#tensor([[0, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])

diag = a[range(N),range(N)] #index (1,1), (2,2), ... etc
diag[diag==0] = 1 # set according to your condition
a[range(N),range(N)] = diag #return the diagonal to its place

#tensor([[1, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66756790

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档