我如何从100x100 pytorch张量获得一个10x10的面片,并增加一个约束,如果一个面片将超出数组的边界,那么它将围绕边缘(就像数组是一个环面,顶部连接到底部,左侧连接到右侧)?
我写了这段代码来完成这项工作,我正在寻找更优雅,更高效,更清晰的代码:
def shift_matrix(a, distances) -> Tensor:
x, y = distances
a = torch.cat((a[x:], a[0:x]), dim=0)
a = torch.cat((a[:, y:], a[:, :y]), dim=1)
return a
def randomly_shift_matrix(a) -> Tensor:
return shift_matrix(a, np.random.randint(low = 0, high = a.size()))
def random_patch(a, size) -> Tensor:
full_shifted_matrix = randomly_shift_matrix(a)
return full_shifted_matrix[0:size[0], 0:size[1]]我觉得带负指数切片的东西应该行得通。不过我还没找到。
发布于 2021-02-02 17:44:56
您正在寻找torch.roll
def random_patch(a, size) -> Tensor:
shifts = np.random.randint(low = 0, high = a.size())
return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]https://stackoverflow.com/questions/66006277
复制相似问题