首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >分割任务中图像对的Pytorch transforms.Compose用法

分割任务中图像对的Pytorch transforms.Compose用法
EN

Stack Overflow用户
提问于 2021-02-20 04:52:30
回答 2查看 1.1K关注 0票数 0

我正在尝试在我的分割任务中使用transforms.Compose()。但我不确定如何对图像和蒙版使用相同的(几乎)随机变换。

所以在我的分割任务中,我有原始的图片和相应的掩模,我想生成更多的随机变换图像对来训练popurse。这意味着如果我对我的原始图片做了一些转换,这个转换也应该发生在我的面具图片上,然后这两张图片就可以进入我的CNN了。我的转换器类似于:

代码语言:javascript
运行
复制
train_transform = transforms.Compose([
            transforms.Resize(512), # resize, the smaller edge will be matched.
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(90),
            transforms.RandomResizedCrop(320,scale=(0.3, 1.0)),
            AddGaussianNoise(0., 1.),
            transforms.ToTensor(), # convert a PIL image or ndarray to tensor. 
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # normalize to Imagenet mean and std
])

mask_transform = transforms.Compose([
            transforms.Resize(512), # resize, the smaller edge will be matched.
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(90),
            transforms.RandomResizedCrop(320,scale=(0.3, 1.0)),
            ##---------------------!------------------
            transforms.ToTensor(), # convert a PIL image or ndarray to tensor. 
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # normalize to Imagenet mean and std
])

注意,在代码块中,我添加了一个可以向原始图像转换添加随机噪声的类,这个类不在mask_transformation中,我希望我的蒙版图像遵循原始图像转换,但忽略随机噪声。那么,这两个转换如何成对发生(具有相同的随机行为)?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-02-20 15:06:16

这里似乎有一个答案:How to apply same transform on a pair of picture

基本上,您可以使用torchvision函数API来获取随机转换(如RandomCrop )的随机生成参数的句柄。然后使用相同的参数值在两个镜像上调用torchvision.transforms.functional.crop()。这看起来有点冗长,但可以完成工作。您可以根据需要跳过对某些图像的一些转换。

我看到elsewhere的另一个选择是用相同的种子重新播种随机生成器,强制生成相同的随机变换两次。我认为这样的实现是老生常谈的,并且随着pytorch版本的不同而不断变化(例如,是否重新播种np.randomrandomtorch.manual_seed() ?)

票数 1
EN

Stack Overflow用户

发布于 2021-02-21 12:25:06

所以Sabyasachi的答案对我来说真的很有帮助,我能够使用PyTorch中的转换器来转换我的图像。torchvision.transformer的这种用法不是传输图像的最直接的方式。因此,我添加了我的解决方案,它有一个使用torchvision.transforms.functional的示例,但也使用了skimage.filters,并且这里提供了许多转换函数:https://scikit-image.org/docs/dev/api/skimage.filters.html#skimage.filters.unsharp_mask

代码语言:javascript
运行
复制
import torchvision.transforms.functional as TF
from skimage.filters import gaussian
from skimage.filters import unsharp_mask

def transformer(image, mask):
    # image and mask are PIL image object. 
    img_w, img_h = image.size
    
    # Random horizontal flipping
    if random.random() > 0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)

    # Random vertical flipping
    if random.random() > 0.5:
        image = TF.vflip(image)
        mask = TF.vflip(mask)
  
    # Random affine
    affine_param = transforms.RandomAffine.get_params(
        degrees = [-180, 180], translate = [0.3,0.3],  
        img_size = [img_w, img_h], scale_ranges = [1, 1.3], 
        shears = [2,2])
    image = TF.affine(image, 
                      affine_param[0], affine_param[1],
                      affine_param[2], affine_param[3])
    mask = TF.affine(mask, 
                     affine_param[0], affine_param[1],
                     affine_param[2], affine_param[3])

    image = np.array(image)
    mask = np.array(mask)
    
    # Randome GaussianBlur -- only for images
    if random.random() < 0.25:
        sigma_param = random.uniform(0.01, 1)
        image = gaussian(image, sigma=sigma_param)
    
    # Randome Gaussian Noise -- only for images
    if random.random() < 0.25:
        factor_param = random.uniform(0.01, 0.5)
        image = image + factor_param * image.std() * np.random.randn(image.shape[0], image.shape[1])
    
    # Unsharp filter -- only for images
    if random.random() < 0.25:
        radius_param = random.uniform(0, 5)
        amount_param = random.uniform(0.5, 2)
        image = unsharp_mask(image, radius = radius_param, amount=amount_param)
    
    f, ax = plt.subplots(1, 2, figsize=(8, 8))
    ax[0].imshow(image)
    ax[1].imshow(mask)   

    return image, mask
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66284850

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档