前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >『跟我做AI工程化』使用Python原生实现PyTorch的Transforms数据变换操作

『跟我做AI工程化』使用Python原生实现PyTorch的Transforms数据变换操作

作者头像
小宋是呢
发布2021-04-23 14:19:16
1.3K0
发布2021-04-23 14:19:16
举报
文章被收录于专栏:深度应用深度应用

0x01:引子

在应用PyTorch训练好的模型时,为了保证模型的准确稳定性,需要保持与训练时相同的操作。

在模型的训练与测试时,我们通常会借助“torchvision.transforms”包来实现那个对数据变换的操作。一般会包括统一化图片的尺寸(Resize)、数据格式转化(ToTensor)与数据归一化大小(Normalize)等操作。

具体步骤:

  1. 使用“torchvision.transforms”来定义一个数据变化方法:trans_f。
  2. 通过调用trans_f实现数据转化

如下所示:

import cv2
import PIL
import torchvision

trans_f = torchvision.transforms.Compose([

            torchvision.transforms.Resize((64,128)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


src_img = cv2.imread("demo.png")

print("src img shape: ",src_img.shape)

pil_img = PIL.Image.fromarray(src_img)#transforms操作接受数据必须时PIL格式图片,不改变图片尺寸


trans_img = trans_f(pil_img)#对样本进行变换

print("dst img shape: ",trans_img.shape)

输出:

src img shape:  (624, 1710, 3)
dst img shape:  torch.Size([3, 64, 128])

可以看出trans_f,实现的就是数据的转换功能。

但是在实际的应用部署中依赖项越少越好,所以下面笔者将演示如何使用Python中如果不使用“torchvision.transforms”包来实现数据转换操作。

0x02:实现

上述例子中,主要用到了三个操作:Resize、ToTensor与Normalize。

首先需要搞清楚这些操作的具体原理,在这个官方文档链接中,可以找到对应介绍:https://pytorch.org/vision/stable/transforms.html

Resize操作

CLASS torchvision.transforms.Resize(size, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>)[SOURCE]

Resize the input image to the given size. If the image is torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions

Parameters:

size (sequence or int) – Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). In torchscript mode size as single int is not supported, use a sequence of length 1: [size, ].interpolation (InterpolationMode) – Desired interpolation enum defined by torchvision.transforms.InterpolationMode. Default is InterpolationMode.BILINEAR. If input is Tensor, only InterpolationMode.NEAREST, InterpolationMode.BILINEAR and InterpolationMode.BICUBIC are supported. For backward compatibility integer values (e.g. PIL.Image.NEAREST) are still acceptable.

  • size (sequence or int) – Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). In torchscript mode size as single int is not supported, use a sequence of length 1: [size, ].
  • interpolation (InterpolationMode) – Desired interpolation enum defined by torchvision.transforms.InterpolationMode. Default is InterpolationMode.BILINEAR. If input is Tensor, only InterpolationMode.NEAREST, InterpolationMode.BILINEAR and InterpolationMode.BICUBIC are supported. For backward compatibility integer values (e.g. PIL.Image.NEAREST) are still acceptable.

forward(img)[SOURCE]

Parameters:

img (PIL Image or Tensor) – Image to be scaled.

Returns:

Rescaled image.

Return type:

PIL Image or Tensor

可以看出,其实这里就是对图进行Resize操作,插值方法默认为bilinear。这里其实就可以通过opencv的cv2.resize()接口完成改写

import cv2

y = cv2.resize(x,[64,128])

ToTensor操作

CLASStorchvision.transforms.ToTensor[SOURCE]

Convert a PIL Image or numpy.ndarray to tensor. This transform does not support torchscript.

Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8

In the other cases, tensors are returned without scaling.

NOTE

Because the input image is scaled to [0.0, 1.0], this transformation should not be used when transforming target image masks. See the references for implementing the transforms for image masks.

重点:Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

可以通过numpy实现转化

import numpy as np

x = np.transpose(x,[2,0,1])
y = x / 255.

Normalize操作

CLASS torchvision.transforms.Normalize(mean, std, inplace=False)[SOURCE]

Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: (mean[1],...,mean[n]) and std: (std[1],..,std[n]) for n channels, this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]

NOTE

This transform acts out of place, i.e., it does not mutate the input tensor.

Parameters:

mean (sequence) – Sequence of means for each channel.std (sequence) – Sequence of standard deviations for each channel.inplace (bool,optional) – Bool to make this operation in-place.

  • mean (sequence) – Sequence of means for each channel.
  • std (sequence) – Sequence of standard deviations for each channel.
  • inplace (bool,optional) – Bool to make this operation in-place.

forward(tensor: torch.Tensor) → torch.Tensor[SOURCE]

Parameters:

tensor (Tensor) – Tensor image to be normalized.

Returns:

Normalized Tensor image.

Return type:

Tensor

重点:this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]

这里我们也可以通过numpy来实现操作。

这里利用到了广播机制,可以参考这里:https://www.cnblogs.com/jiaxin359/p/9021726.html

import numpy as np

mean, std = np.array([0.485, 0.456, 0.406]).reshape([3,1,1]), np.array([0.229, 0.224, 0.225]).reshape([3,1,1])

y = (x - mean)/std

0x03:后记

这个博客对你有用的话欢迎收藏转发,也麻烦可爱又爱学的你能赏个赞,菜小宋更博不易,在这里谢过啦。

如果你想学习更多开发技巧与AI算法,欢迎搜索关注笔者公众号“简明AI”,和爱学习讨论的小伙伴一起交流学习。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-04-21 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 0x01:引子
  • 0x02:实现
    • Resize操作
      • ToTensor操作
        • Normalize操作
        • 0x03:后记
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档