首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >从100x100 pytorch张量获得一个10x10的补丁,该张量具有环状样式包裹在边界上。

从100x100 pytorch张量获得一个10x10的补丁,该张量具有环状样式包裹在边界上。
EN

Stack Overflow用户
提问于 2021-02-02 16:47:29
回答 1查看 61关注 0票数 4

我如何从100x100 pytorch张量获得一个10x10的面片,并增加一个约束,如果一个面片将超出数组的边界,那么它将围绕边缘(就像数组是一个环面,顶部连接到底部,左侧连接到右侧)?

我写了这段代码来完成这项工作,我正在寻找更优雅,更高效,更清晰的代码:

代码语言:javascript
运行
复制
def shift_matrix(a, distances) -> Tensor:
  x, y = distances
  a = torch.cat((a[x:], a[0:x]), dim=0)
  a = torch.cat((a[:, y:], a[:, :y]), dim=1)
  return a

def randomly_shift_matrix(a) -> Tensor:
  return shift_matrix(a, np.random.randint(low = 0, high = a.size()))

def random_patch(a, size) -> Tensor:
  full_shifted_matrix = randomly_shift_matrix(a)
  return full_shifted_matrix[0:size[0], 0:size[1]]

我觉得带负指数切片的东西应该行得通。不过我还没找到。

你可以使用see the code in google colab here

EN

Stack Overflow用户

回答已采纳

发布于 2021-02-02 17:44:56

您正在寻找torch.roll

代码语言:javascript
运行
复制
def random_patch(a, size) -> Tensor:
  shifts = np.random.randint(low = 0, high = a.size())
  return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]
票数 4
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66006277

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档