首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何让这个PyTorch张量(B,C,H,W)平铺和混合代码更简单、更高效?

如何让这个PyTorch张量(B,C,H,W)平铺和混合代码更简单、更高效?
EN

Stack Overflow用户
提问于 2020-10-14 00:29:56
回答 1查看 337关注 0票数 1

所以,我在几个月前写了下面的代码,它运行得很好。虽然我正在努力研究如何简化它并使其更有效率。

下面的函数将图像张量(B,C,H,W)拆分成相等大小的瓦片(B,C,H,W),然后您可以单独对瓦片进行填充,以节省内存。然后,当从瓷砖重建张量时,它使用掩码来确保瓷砖无缝地混合在一起。掩码函数中的“特殊蒙版”处理最右侧列中的瓦片或底部行中的瓦片不能使用与其他瓦片相同的重叠。这意味着右边缘平铺和底部平铺有时可能几乎看不到它们的内容。这样做是为了确保平铺总是精确的指定大小,而不管原始图像/张量的大小(对于可视化/DeepDream、神经样式转移等很重要)。与边缘行/列相邻的行/列也具有特殊的掩码,用于它们与边缘行/列重叠的位置。

每个瓷砖有8个可能的蒙版,其中4个蒙版可以同时使用。4个可能的蒙版是左、右、上和下,每个蒙版都有一个特殊的版本。

代码语言:javascript
运行
复制
# Improved version of: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
import torch


# Apply blend masks to tiles
def mask_tile(tile, overlap, side='bottom'):
    c, h, w = tile.size(1), tile.size(2), tile.size(3)
    top_overlap, bottom_overlap, right_overlap, left_overlap = overlap[0], overlap[1], overlap[2], overlap[3]

    base_mask = torch.ones_like(tile)

    if 'left' in side and 'left-special' not in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,:left_overlap] = base_mask[:,:,:,:left_overlap] * lin_mask_left
    if 'right' in side and 'right-special' not in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,w-right_overlap:] = base_mask[:,:,:,w-right_overlap:] * lin_mask_right
    if 'top' in side and 'top-special' not in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:top_overlap,:] = base_mask[:,:,:top_overlap,:] * lin_mask_top
    if 'bottom' in side and 'bottom-special' not in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,h-bottom_overlap:,:] = base_mask[:,:,h-bottom_overlap:,:] * lin_mask_bottom

    if 'left-special' in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device)
        zeros_mask = torch.zeros(w-(left_overlap*2), device=tile.device)
        ones_mask = torch.ones(left_overlap, device=tile.device)
        lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_left
    if 'right-special' in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device)
        ones_mask = torch.ones(w-right_overlap, device=tile.device)
        lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_right
    if 'top-special' in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device)
        zeros_mask = torch.zeros(h-(top_overlap*2), device=tile.device)
        ones_mask = torch.ones(top_overlap, device=tile.device)
        lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_top
    if 'bottom-special' in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device)
        ones_mask = torch.ones(h-bottom_overlap, device=tile.device)
        lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_bottom
        
    # Apply mask to tile and return masked tile
    return tile * base_mask


def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):

    # Check for any tiles that need different overlap values
    r, c = len(tile_coords[0]), len(tile_coords[1])
    f_ovlp = (tile_coords[0][r-1] - tile_coords[0][r-2], tile_coords[1][c-1] - tile_coords[1][c-2])

    h, w = tiles[0].size(2), tiles[0].size(3)
    t=0
    column, row, = 0, 0
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            mask_sides=''
            c_overlap = overlap.copy()
            if row == 0:
                if row == len(tile_coords[0]) - 2:
                    mask_sides += 'bottom-special'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                else:
                    mask_sides += 'bottom'
            elif row > 0 and row < len(tile_coords[0]) -2:
                mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) - 2:
                if f_ovlp[0] > 0:
                    mask_sides += 'bottom-special,top'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) -1:
                if f_ovlp[0] > 0:
                    mask_sides += 'top-special'
                    c_overlap[0] = f_ovlp[0] # Change top overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'top'

            if column == 0:
                if column == len(tile_coords[1]) -2:
                    mask_sides += ',right-special'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                else:
                    mask_sides += ',right'
            elif column > 0 and column < len(tile_coords[1]) -2:
                mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -2:
                if f_ovlp[1] > 0:
                    mask_sides += ',right-special,left'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -1:
                if f_ovlp[1] > 0:
                    mask_sides += ',left-special'
                    c_overlap[3] = f_ovlp[1] # Change left overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',left'

            tile = mask_tile(tiles[t], c_overlap, side=mask_sides)
            base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] = base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] + tile
            t+=1
            column+=1
        row+=1
        column=0
    return base_img


# Calculate the coordinates for tiles
def get_tile_coords(d, tile_dim, overlap=0):
    move = int(tile_dim * (1-overlap))
    c, tile_start, coords = 1, 0, [0]
    while tile_start + tile_dim < d:
        tile_start = move * c
        if tile_start + tile_dim >= d:
            coords.append(d - tile_dim)
        else:
            coords.append(tile_start)
        c += 1
    return coords


# Calculates info required for tiling
def tile_setup(tile_size, overlap_percent, base_size):
    if type(tile_size) is not tuple and type(tile_size) is not list:
        tile_size = (tile_size, tile_size)
    if type(overlap_percent) is not tuple and type(overlap_percent) is not list:
        overlap_percent = (overlap_percent, overlap_percent)
    x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
    y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
    y_ovlp, x_ovlp = int(tile_size[0] * overlap_percent[0]), int(tile_size[1] * overlap_percent[1])
    return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]


# Split tensor into tiles
def tile_image(img, tile_size, overlap_percent, info_only=False):
    tile_coords, tile_size, _ = tile_setup(tile_size, overlap_percent, (img.size(2), img.size(3)))

    # Cut out tiles
    tile_list = []
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            tile = img[:, :, y:y + tile_size[0], x:x + tile_size[1]]
            tile_list.append(tile)
    return tile_list


# Put tiles back into the original tensor
def rebuild_image(tiles, image_size, tile_size, overlap_percent):
    base_img = torch.zeros(image_size, device=tiles[0].device)
    tile_coords, tile_size, overlap = tile_setup(tile_size, overlap_percent, (base_img.size(2), base_img.size(3)))
    return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)

上面的代码可以用下面的代码测试:

代码语言:javascript
运行
复制
import torchvision.transforms as transforms
from PIL import Image
import random

# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)    

test_input = preprocess_simple('tubingen.jpg', (1024,1024))
tile_size=260
overlap_percent=0.5

img_tiles = tile_image(test_input, tile_size=tile_size, overlap_percent=overlap_percent)

random.shuffle(img_tiles) # Comment this out to not randomize tile positions

output_tensor = rebuild_image(img_tiles, test_input.size(), tile_size=tile_size, overlap_percent=overlap_percent)
deprocess_simple(output_tensor, 'tiled_image.jpg')

我在下面包含了一个示例(顶部是原始图像,底部是当我将瓷砖随机放回以展示混合系统时):

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-20 03:39:17

我在这里删除了所有的bug并简化了代码:https://github.com/ProGamerGov/dream-creator/blob/master/utils/tile_utils.py

特殊的口罩只在两种情况下才需要,它们是rebuild_tensor中的bug,我必须修复它们。重叠百分比应等于或小于50%。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64339360

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档