首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何从3D卷中提取重叠的补丁,并从这些块中重新创建输入形状?

如何从3D卷中提取重叠的补丁,并从这些块中重新创建输入形状?
EN

Stack Overflow用户
提问于 2021-06-27 10:37:26
回答 2查看 712关注 0票数 3

Pytorch提供了torch.Tensor.unfold操作,可以将其链接到任意多个维度,以提取重叠的补丁。如何逆转补丁提取操作,使补丁被组合成输入形状。

焦点是具有1通道的三维立体图像(生物医学)。在unfold中提取是可能的,如果它们重叠,我们如何组合这些补丁。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-06-27 10:37:26

为了提取(重叠)块和重建输入形状,我们可以使用torch.nn.functional.unfold和逆运算torch.nn.functional.fold。这些方法只处理4D张量或2D图像,但是您可以使用这些方法一次处理一维。

很少有人注意到:

  1. 这种方法需要来自py手电筒的折叠/展开方法,不幸的是,我还没有在TF中找到类似的方法。
  2. 我们从2D到3D再到4D来显示增量的差异,您可以扩展到任意多个维度(可能会编写一个循环,而不是像我一样硬编码每个维度)。
  3. 我们可以用两种方式提取补丁,它们的输出是相同的。这些方法称为extract_patches_Xdextract_patches_Xds,其中X是维数。后者使用torch.Tensor.unfold(),代码行更少。(输出相同,但不能使用扩展)
  4. 方法extract_patches_Xdcombine_patches_Xd方法,组合器一步一步地从提取器中反算出步骤。
  5. 代码行后面是表示维数的注释,如(B,C,T,D,H,W)。使用的方法如下:
代码语言:javascript
运行
复制
1. `B`: Batch size
2. `C`: Channels
3. `T`: Time Dimension
4. `D`: Depth Dimension
5. `H`: Height Dimension
6. `W`: Width Dimension
7. `x_dim_in`: In the extraction method, this is the number input pixels in dimension `x`. In the combining method, this is the number of number of sliding windows in dimension `x`.
8. `x_dim_out`: In the extraction method, this is the number of sliding windows in dimension `x`. In the combining method, this is the number output pixels in dimension `x`.
  1. 我有一个使用公共笔记本来测试代码
  2. 我已经尝试了基本的2D,3D和4D张量,如下所示。但是,我的代码并不是完美无缺的,当在其他输入上进行测试时,我很欣赏反馈。
  3. get_dim_blocks()方法是在pytorch文档网站上给出的计算卷积层输出形状的函数。
  4. 请注意,如果您有重叠的补丁,并将它们组合在一起,则将对重叠元素进行求和。如果您想再次获得初始输入,有一种方法。
代码语言:javascript
运行
复制
1. Create similar sized tensor of ones as the patches with `torch.ones_like(patches_tensor)`.
2. Combine the patches into full image with same output shape. (this creates a counter for overlapping elements).
3. Divide the Combined image with the Combined ones, this should reverse any double summation of elements.

第一(2D):

torch.nn.functional.foldtorch.nn.functional.unfold方法可以直接使用。

代码语言:javascript
运行
复制
import torch
代码语言:javascript
运行
复制
def extract_patches_2ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1])
    # (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
    # (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
    return x

def extract_patches_2d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out
        
    channels = x.shape[1]
    h_dim_in = x.shape[2]
    w_dim_in = x.shape[3]
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])

    # (B, C, H, W)
    x = torch.nn.functional.unfold(x, kernel_size, padding=padding, stride=stride, dilation=dilation)
    # (B, C * kernel_size[0] * kernel_size[1], h_dim_out * w_dim_out)
    x = x.view(-1, channels, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
    # (B, C, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
    x = x.permute(0,1,4,5,2,3)
    # (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
    # (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
    return x


def combine_patches_2d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    h_dim_out, w_dim_out = output_shape[2:]
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])

    # (B * h_dim_in * w_dim_in, C, kernel_size[0], kernel_size[1])
    x = x.view(-1, channels, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
    # (B, C, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
    x = x.permute(0,1,4,5,2,3)
    # (B, C, kernel_size[0], kernel_size[1], h_dim_in, w_dim_in)
    x = x.contiguous().view(-1, channels * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
    x = torch.nn.functional.fold(x, (h_dim_out, w_dim_out), kernel_size=(kernel_size[0], kernel_size[1]), padding=padding, stride=stride, dilation=dilation)
    # (B, C, H, W)
    return x



a = torch.arange(1, 65, dtype=torch.float).view(2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_2d(a, 2, padding=1, stride=2, dilation=1)
# b = extract_patches_2ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_2d(b, 2, (2,2,4,4), padding=1, stride=2, dilation=1)
print(c.shape)
print(c)
print(torch.all(a==c))

输出(2D)

代码语言:javascript
运行
复制
torch.Size([2, 2, 4, 4])
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[17., 18., 19., 20.],
          [21., 22., 23., 24.],
          [25., 26., 27., 28.],
          [29., 30., 31., 32.]]],


        [[[33., 34., 35., 36.],
          [37., 38., 39., 40.],
          [41., 42., 43., 44.],
          [45., 46., 47., 48.]],

         [[49., 50., 51., 52.],
          [53., 54., 55., 56.],
          [57., 58., 59., 60.],
          [61., 62., 63., 64.]]]])
torch.Size([18, 2, 2, 2])
tensor([[[[ 0.,  0.],
          [ 0.,  1.]],

         [[ 0.,  0.],
          [ 2.,  3.]]],


        [[[ 0.,  0.],
          [ 4.,  0.]],

         [[ 0.,  5.],
          [ 0.,  9.]]],


        [[[ 6.,  7.],
          [10., 11.]],

         [[ 8.,  0.],
          [12.,  0.]]],


        [[[ 0., 13.],
          [ 0.,  0.]],

         [[14., 15.],
          [ 0.,  0.]]],


        [[[16.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [ 0., 17.]]],


        [[[ 0.,  0.],
          [18., 19.]],

         [[ 0.,  0.],
          [20.,  0.]]],


        [[[ 0., 21.],
          [ 0., 25.]],

         [[22., 23.],
          [26., 27.]]],


        [[[24.,  0.],
          [28.,  0.]],

         [[ 0., 29.],
          [ 0.,  0.]]],


        [[[30., 31.],
          [ 0.,  0.]],

         [[32.,  0.],
          [ 0.,  0.]]],


        [[[ 0.,  0.],
          [ 0., 33.]],

         [[ 0.,  0.],
          [34., 35.]]],


        [[[ 0.,  0.],
          [36.,  0.]],

         [[ 0., 37.],
          [ 0., 41.]]],


        [[[38., 39.],
          [42., 43.]],

         [[40.,  0.],
          [44.,  0.]]],


        [[[ 0., 45.],
          [ 0.,  0.]],

         [[46., 47.],
          [ 0.,  0.]]],


        [[[48.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [ 0., 49.]]],


        [[[ 0.,  0.],
          [50., 51.]],

         [[ 0.,  0.],
          [52.,  0.]]],


        [[[ 0., 53.],
          [ 0., 57.]],

         [[54., 55.],
          [58., 59.]]],


        [[[56.,  0.],
          [60.,  0.]],

         [[ 0., 61.],
          [ 0.,  0.]]],


        [[[62., 63.],
          [ 0.,  0.]],

         [[64.,  0.],
          [ 0.,  0.]]]])
torch.Size([2, 2, 4, 4])
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[17., 18., 19., 20.],
          [21., 22., 23., 24.],
          [25., 26., 27., 28.],
          [29., 30., 31., 32.]]],


        [[[33., 34., 35., 36.],
          [37., 38., 39., 40.],
          [41., 42., 43., 44.],
          [45., 46., 47., 48.]],

         [[49., 50., 51., 52.],
          [53., 54., 55., 56.],
          [57., 58., 59., 60.],
          [61., 62., 63., 64.]]]])
tensor(True)

第二(3D):

现在变得有趣了:我们需要使用2 foldunfold,首先将fold应用于D维度,并通过将内核设置为1、填充到0、跨进到1和扩展到1来保持WH不受影响。在我们检查张量并折叠到HW维度之后。展开是反向发生的,首先是HW,然后是D

代码语言:javascript
运行
复制
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))

输出(3D)

(我不得不限制字符,请看记事本)

第三(4D)

我们在3D卷中添加一个时间维。我们从T维度开始折叠,DHW与3D版本类似。然后我们放弃D,离开HW。最后,我们做了HW。这种情况又反过来发生了。希望现在你注意到一个模式,你可以任意地添加许多维度,并开始一个一个地折叠。这种情况又反过来发生了。

代码语言:javascript
运行
复制
def extract_patches_4ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, T, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2]).unfold(5, kernel_size[3], stride[3])
    # (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
    # (B * t_dim_out, d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
    return x

def extract_patches_4d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    t_dim_in = x.shape[2]
    d_dim_in = x.shape[3]
    h_dim_in = x.shape[4]
    w_dim_in = x.shape[5]
    t_dim_out = get_dim_blocks(t_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[3], padding[3], stride[3], dilation[3])
    # print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, T, D, H, W)
    x = x.view(-1, channels, t_dim_in, d_dim_in * h_dim_in * w_dim_in)                                                     
    # (B, C, T, D * H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C * kernel_size[0], t_dim_out * D * H * W)

    x = x.view(-1, channels * kernel_size[0] * t_dim_out, d_dim_in, h_dim_in * w_dim_in)                                   
    # (B, C * kernel_size[0] * t_dim_out, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
    # (B, C * kernel_size[0] * t_dim_out * kernel_size[1], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, h_dim_in, w_dim_in)
    # (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))        
    # (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out * kernel_size[2] * kernel_size[3], h_dim_out * w_dim_out)

    x = x.view(-1, channels, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)
    # (B, C, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)

    x = x.permute(0, 1, 3, 5, 8, 9, 2, 4, 6, 7)
    # (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
    # (B * t_dim_out * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])

    return x



def combine_patches_4d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    t_dim_out, d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    t_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[3], padding[3], stride[3], dilation[3])
    # print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
    # (B, C, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])

    x = x.permute(0, 1, 6, 2, 7, 3, 8, 9, 4, 5)
    # (B, C, kernel_size[0], t_dim_in, kernel_size[1], d_dim_in, kernel_size[2], kernel_size[3], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in, w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))
    # (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
    # (B, C * kernel_size[0] * t_dim_in, D, H * W)

    x = x.view(-1, channels * kernel_size[0], t_dim_in * d_dim_out * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], t_dim_in * D * H * W)

    x = torch.nn.functional.fold(x, output_size=(t_dim_out, d_dim_out * h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, T, D * H * W)
    
    x = x.view(-1, channels, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, T, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,2,4,2)
print(a.shape)
print(a)
# b = extract_patches_4d(a, 2, padding=1, stride=2)
b = extract_patches_4ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_4d(b, 2, (2,2,2,2,4,2), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))

产出(4D)

(我不得不限制字符,请看记事本)

票数 2
EN

Stack Overflow用户

发布于 2021-07-13 09:42:27

上面的解决方案在内存中复制,因为它保持补丁的连续性。这就导致了大量具有许多重叠体素的内存问题。若要提取补丁而不在内存中复制,我们可以在pytorch中执行以下操作:

代码语言:javascript
运行
复制
def get_dim_blocks(dim_in, kernel_size, padding=0, stride=1, dilation=1):
    return (dim_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

def extract_patches_3d(x, kernel_size, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)
    x = x.contiguous()

    channels, depth, height, width = x.shape[-4:]
    d_blocks = get_dim_blocks(depth, kernel_size=kernel_size[0], stride=stride[0], dilation=dilation[0])
    h_blocks = get_dim_blocks(height, kernel_size=kernel_size[1], stride=stride[1], dilation=dilation[1])
    w_blocks = get_dim_blocks(width, kernel_size=kernel_size[2], stride=stride[2], dilation=dilation[2])
    shape = (channels, d_blocks, h_blocks, w_blocks, kernel_size[0], kernel_size[1], kernel_size[2])
    strides = (width*height*depth,
               stride[0]*width*height, 
               stride[1]*width, 
               stride[2], 
               dilation[0]*width*height, 
               dilation[1]*width,
               dilation[2])

    x = x.as_strided(shape, strides)
    x = x.permute(1,2,3,0,4,5,6)
    return x

该方法期望张量的形状为`(B,C,D,H,W)。该方法基于这个答案(用短小的),更详细地解释了内存步长的作用。输出不连续,前3维是D、H和W维中的块数或滑动窗口数。合并为一维是不可能的,因为这将需要一个副本的连续内存。

跨步试验

代码语言:javascript
运行
复制
a = torch.arange(81, dtype=torch.float32).view(1,3,3,3,3)
print(a)
b = extract_patches_3d(a, kernel_size=2, stride=2)
print(b.shape)
print(b.storage())
print(a.data_ptr() == b.data_ptr())
print(b)

输出

代码语言:javascript
运行
复制
tensor([[[[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.],
           [ 6.,  7.,  8.]],

          [[ 9., 10., 11.],
           [12., 13., 14.],
           [15., 16., 17.]],

          [[18., 19., 20.],
           [21., 22., 23.],
           [24., 25., 26.]]],


         [[[27., 28., 29.],
           [30., 31., 32.],
           [33., 34., 35.]],

          [[36., 37., 38.],
           [39., 40., 41.],
           [42., 43., 44.]],

          [[45., 46., 47.],
           [48., 49., 50.],
           [51., 52., 53.]]],


         [[[54., 55., 56.],
           [57., 58., 59.],
           [60., 61., 62.]],

          [[63., 64., 65.],
           [66., 67., 68.],
           [69., 70., 71.]],

          [[72., 73., 74.],
           [75., 76., 77.],
           [78., 79., 80.]]]]])
torch.Size([1, 1, 1, 3, 2, 2, 2])
 0.0
 1.0
 2.0
 3.0
 4.0
 5.0
 6.0
 7.0
 8.0
 9.0
 10.0
 11.0
 12.0
 13.0
 14.0
 15.0
 16.0
 17.0
 18.0
 19.0
 20.0
 21.0
 22.0
 23.0
 24.0
 25.0
 26.0
 27.0
 28.0
 29.0
 30.0
 31.0
 32.0
 33.0
 34.0
 35.0
 36.0
 37.0
 38.0
 39.0
 40.0
 41.0
 42.0
 43.0
 44.0
 45.0
 46.0
 47.0
 48.0
 49.0
 50.0
 51.0
 52.0
 53.0
 54.0
 55.0
 56.0
 57.0
 58.0
 59.0
 60.0
 61.0
 62.0
 63.0
 64.0
 65.0
 66.0
 67.0
 68.0
 69.0
 70.0
 71.0
 72.0
 73.0
 74.0
 75.0
 76.0
 77.0
 78.0
 79.0
 80.0
[torch.FloatStorage of size 81]
True
tensor([[[[[[[ 0.,  1.],
             [ 3.,  4.]],

            [[ 9., 10.],
             [12., 13.]]],


           [[[27., 28.],
             [30., 31.]],

            [[36., 37.],
             [39., 40.]]],


           [[[54., 55.],
             [57., 58.]],

            [[63., 64.],
             [66., 67.]]]]]]])

假设张量是连续的(就像在NN中处理后的那样),用记忆步长进行重叠体素求和是不可能的。但是,您可以像上面解释的那样手动地将它们相加,或者按照解释的这里进行切片。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68150248

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档