Pytorch提供了torch.Tensor.unfold
操作,可以将其链接到任意多个维度,以提取重叠的补丁。如何逆转补丁提取操作,使补丁被组合成输入形状。
焦点是具有1通道的三维立体图像(生物医学)。在unfold
中提取是可能的,如果它们重叠,我们如何组合这些补丁。
发布于 2021-06-27 10:37:26
为了提取(重叠)块和重建输入形状,我们可以使用torch.nn.functional.unfold
和逆运算torch.nn.functional.fold
。这些方法只处理4D张量或2D图像,但是您可以使用这些方法一次处理一维。
很少有人注意到:
extract_patches_Xd
和extract_patches_Xds
,其中X是维数。后者使用torch.Tensor.unfold(),代码行更少。(输出相同,但不能使用扩展)extract_patches_Xd
和combine_patches_Xd
是逆方法,组合器一步一步地从提取器中反算出步骤。1. `B`: Batch size
2. `C`: Channels
3. `T`: Time Dimension
4. `D`: Depth Dimension
5. `H`: Height Dimension
6. `W`: Width Dimension
7. `x_dim_in`: In the extraction method, this is the number input pixels in dimension `x`. In the combining method, this is the number of number of sliding windows in dimension `x`.
8. `x_dim_out`: In the extraction method, this is the number of sliding windows in dimension `x`. In the combining method, this is the number output pixels in dimension `x`.
get_dim_blocks()
方法是在pytorch文档网站上给出的计算卷积层输出形状的函数。1. Create similar sized tensor of ones as the patches with `torch.ones_like(patches_tensor)`.
2. Combine the patches into full image with same output shape. (this creates a counter for overlapping elements).
3. Divide the Combined image with the Combined ones, this should reverse any double summation of elements.
第一(2D):
torch.nn.functional.fold
和torch.nn.functional.unfold
方法可以直接使用。
import torch
def extract_patches_2ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1])
# (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
# (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
return x
def extract_patches_2d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
h_dim_in = x.shape[2]
w_dim_in = x.shape[3]
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
# (B, C, H, W)
x = torch.nn.functional.unfold(x, kernel_size, padding=padding, stride=stride, dilation=dilation)
# (B, C * kernel_size[0] * kernel_size[1], h_dim_out * w_dim_out)
x = x.view(-1, channels, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
x = x.permute(0,1,4,5,2,3)
# (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
# (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
return x
def combine_patches_2d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
h_dim_out, w_dim_out = output_shape[2:]
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
# (B * h_dim_in * w_dim_in, C, kernel_size[0], kernel_size[1])
x = x.view(-1, channels, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
# (B, C, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
x = x.permute(0,1,4,5,2,3)
# (B, C, kernel_size[0], kernel_size[1], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, (h_dim_out, w_dim_out), kernel_size=(kernel_size[0], kernel_size[1]), padding=padding, stride=stride, dilation=dilation)
# (B, C, H, W)
return x
a = torch.arange(1, 65, dtype=torch.float).view(2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_2d(a, 2, padding=1, stride=2, dilation=1)
# b = extract_patches_2ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_2d(b, 2, (2,2,4,4), padding=1, stride=2, dilation=1)
print(c.shape)
print(c)
print(torch.all(a==c))
输出(2D)
torch.Size([2, 2, 4, 4])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],
[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]]],
[[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]],
[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.],
[61., 62., 63., 64.]]]])
torch.Size([18, 2, 2, 2])
tensor([[[[ 0., 0.],
[ 0., 1.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 4., 0.]],
[[ 0., 5.],
[ 0., 9.]]],
[[[ 6., 7.],
[10., 11.]],
[[ 8., 0.],
[12., 0.]]],
[[[ 0., 13.],
[ 0., 0.]],
[[14., 15.],
[ 0., 0.]]],
[[[16., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 17.]]],
[[[ 0., 0.],
[18., 19.]],
[[ 0., 0.],
[20., 0.]]],
[[[ 0., 21.],
[ 0., 25.]],
[[22., 23.],
[26., 27.]]],
[[[24., 0.],
[28., 0.]],
[[ 0., 29.],
[ 0., 0.]]],
[[[30., 31.],
[ 0., 0.]],
[[32., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 33.]],
[[ 0., 0.],
[34., 35.]]],
[[[ 0., 0.],
[36., 0.]],
[[ 0., 37.],
[ 0., 41.]]],
[[[38., 39.],
[42., 43.]],
[[40., 0.],
[44., 0.]]],
[[[ 0., 45.],
[ 0., 0.]],
[[46., 47.],
[ 0., 0.]]],
[[[48., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 49.]]],
[[[ 0., 0.],
[50., 51.]],
[[ 0., 0.],
[52., 0.]]],
[[[ 0., 53.],
[ 0., 57.]],
[[54., 55.],
[58., 59.]]],
[[[56., 0.],
[60., 0.]],
[[ 0., 61.],
[ 0., 0.]]],
[[[62., 63.],
[ 0., 0.]],
[[64., 0.],
[ 0., 0.]]]])
torch.Size([2, 2, 4, 4])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],
[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]]],
[[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]],
[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.],
[61., 62., 63., 64.]]]])
tensor(True)
第二(3D):
现在变得有趣了:我们需要使用2 fold
和unfold
,首先将fold
应用于D
维度,并通过将内核设置为1、填充到0、跨进到1和扩展到1来保持W
和H
不受影响。在我们检查张量并折叠到H
和W
维度之后。展开是反向发生的,首先是H
和W
,然后是D
。
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))
输出(3D)
(我不得不限制字符,请看记事本)
第三(4D)
我们在3D卷中添加一个时间维。我们从T
维度开始折叠,D
、H
和W
与3D版本类似。然后我们放弃D
,离开H
和W
。最后,我们做了H
和W
。这种情况又反过来发生了。希望现在你注意到一个模式,你可以任意地添加许多维度,并开始一个一个地折叠。这种情况又反过来发生了。
def extract_patches_4ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, T, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2]).unfold(5, kernel_size[3], stride[3])
# (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B * t_dim_out, d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
return x
def extract_patches_4d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
t_dim_in = x.shape[2]
d_dim_in = x.shape[3]
h_dim_in = x.shape[4]
w_dim_in = x.shape[5]
t_dim_out = get_dim_blocks(t_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[3], padding[3], stride[3], dilation[3])
# print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
# (B, C, T, D, H, W)
x = x.view(-1, channels, t_dim_in, d_dim_in * h_dim_in * w_dim_in)
# (B, C, T, D * H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], t_dim_out * D * H * W)
x = x.view(-1, channels * kernel_size[0] * t_dim_out, d_dim_in, h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * t_dim_out, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out * kernel_size[2] * kernel_size[3], h_dim_out * w_dim_out)
x = x.view(-1, channels, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 5, 8, 9, 2, 4, 6, 7)
# (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B * t_dim_out * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
return x
def combine_patches_4d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
t_dim_out, d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
t_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[3], padding[3], stride[3], dilation[3])
# print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B, C, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.permute(0, 1, 6, 2, 7, 3, 8, 9, 4, 5)
# (B, C, kernel_size[0], t_dim_in, kernel_size[1], d_dim_in, kernel_size[2], kernel_size[3], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in, w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
# (B, C * kernel_size[0] * t_dim_in, D, H * W)
x = x.view(-1, channels * kernel_size[0], t_dim_in * d_dim_out * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], t_dim_in * D * H * W)
x = torch.nn.functional.fold(x, output_size=(t_dim_out, d_dim_out * h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, T, D * H * W)
x = x.view(-1, channels, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
# (B, C, T, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,2,4,2)
print(a.shape)
print(a)
# b = extract_patches_4d(a, 2, padding=1, stride=2)
b = extract_patches_4ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_4d(b, 2, (2,2,2,2,4,2), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))
产出(4D)
(我不得不限制字符,请看记事本)
发布于 2021-07-13 09:42:27
上面的解决方案在内存中复制,因为它保持补丁的连续性。这就导致了大量具有许多重叠体素的内存问题。若要提取补丁而不在内存中复制,我们可以在pytorch中执行以下操作:
def get_dim_blocks(dim_in, kernel_size, padding=0, stride=1, dilation=1):
return (dim_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
def extract_patches_3d(x, kernel_size, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
x = x.contiguous()
channels, depth, height, width = x.shape[-4:]
d_blocks = get_dim_blocks(depth, kernel_size=kernel_size[0], stride=stride[0], dilation=dilation[0])
h_blocks = get_dim_blocks(height, kernel_size=kernel_size[1], stride=stride[1], dilation=dilation[1])
w_blocks = get_dim_blocks(width, kernel_size=kernel_size[2], stride=stride[2], dilation=dilation[2])
shape = (channels, d_blocks, h_blocks, w_blocks, kernel_size[0], kernel_size[1], kernel_size[2])
strides = (width*height*depth,
stride[0]*width*height,
stride[1]*width,
stride[2],
dilation[0]*width*height,
dilation[1]*width,
dilation[2])
x = x.as_strided(shape, strides)
x = x.permute(1,2,3,0,4,5,6)
return x
该方法期望张量的形状为`(B,C,D,H,W)。该方法基于这和这个答案(用短小的),更详细地解释了内存步长的作用。输出不连续,前3维是D、H和W维中的块数或滑动窗口数。合并为一维是不可能的,因为这将需要一个副本的连续内存。
跨步试验
a = torch.arange(81, dtype=torch.float32).view(1,3,3,3,3)
print(a)
b = extract_patches_3d(a, kernel_size=2, stride=2)
print(b.shape)
print(b.storage())
print(a.data_ptr() == b.data_ptr())
print(b)
输出
tensor([[[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.]],
[[18., 19., 20.],
[21., 22., 23.],
[24., 25., 26.]]],
[[[27., 28., 29.],
[30., 31., 32.],
[33., 34., 35.]],
[[36., 37., 38.],
[39., 40., 41.],
[42., 43., 44.]],
[[45., 46., 47.],
[48., 49., 50.],
[51., 52., 53.]]],
[[[54., 55., 56.],
[57., 58., 59.],
[60., 61., 62.]],
[[63., 64., 65.],
[66., 67., 68.],
[69., 70., 71.]],
[[72., 73., 74.],
[75., 76., 77.],
[78., 79., 80.]]]]])
torch.Size([1, 1, 1, 3, 2, 2, 2])
0.0
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
10.0
11.0
12.0
13.0
14.0
15.0
16.0
17.0
18.0
19.0
20.0
21.0
22.0
23.0
24.0
25.0
26.0
27.0
28.0
29.0
30.0
31.0
32.0
33.0
34.0
35.0
36.0
37.0
38.0
39.0
40.0
41.0
42.0
43.0
44.0
45.0
46.0
47.0
48.0
49.0
50.0
51.0
52.0
53.0
54.0
55.0
56.0
57.0
58.0
59.0
60.0
61.0
62.0
63.0
64.0
65.0
66.0
67.0
68.0
69.0
70.0
71.0
72.0
73.0
74.0
75.0
76.0
77.0
78.0
79.0
80.0
[torch.FloatStorage of size 81]
True
tensor([[[[[[[ 0., 1.],
[ 3., 4.]],
[[ 9., 10.],
[12., 13.]]],
[[[27., 28.],
[30., 31.]],
[[36., 37.],
[39., 40.]]],
[[[54., 55.],
[57., 58.]],
[[63., 64.],
[66., 67.]]]]]]])
假设张量是连续的(就像在NN中处理后的那样),用记忆步长进行重叠体素求和是不可能的。但是,您可以像上面解释的那样手动地将它们相加,或者按照解释的这里进行切片。
https://stackoverflow.com/questions/68150248
复制相似问题