首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >pytorch的grid_sample返回不正确的值

pytorch的grid_sample返回不正确的值
EN

Stack Overflow用户
提问于 2021-06-25 20:54:12
回答 1查看 61关注 0票数 1

我有一个3D矩阵:img[i, j, k] = i+j+k

在我看来,如果我想要(123)的值,grid_sample应该返回6。但事实并非如此。

代码是:

代码语言:javascript
复制
import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
    for j in range(Y):
        for k in range(Z):
            img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)
grid = 2*grid - 1
outp = F.grid_sample(inp, grid=grid, mode='bilinear', align_corners=True)
print(outp)

grid_sample返回6.15。我的代码有什么问题吗?

EN

回答 1

Stack Overflow用户

发布于 2021-06-27 09:48:36

最后,我找到了解决方案。上面代码返回不正确的值的原因是torch.grid_sample接受(z,y,x)点。

因此,正确的代码应该是:

代码语言:javascript
复制
import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
    for j in range(Y):
        for k in range(Z):
            img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)

grid = 2*grid - 1

newgrid = grid.clone()
newgrid[..., 0] = grid[..., 2]
newgrid[..., 1] = grid[..., 1]
newgrid[..., 2] = grid[..., 0]

outp = F.grid_sample(inp, grid=newgrid, mode='bilinear', align_corners=True)
print(outp)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68131325

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档