前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch打怪路(三)Pytorch创建自己的数据集2

Pytorch打怪路(三)Pytorch创建自己的数据集2

作者头像
TeeyoHuang
发布2019-05-25 22:41:36
9230
发布2019-05-25 22:41:36
举报

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433783

前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签本篇介绍输入的标签label亦为图像的数据集,并包含一些常用的处理手段。比如做图像语义分割时就会用到这种数据输入方式。

1、数据集简介

以VOC2012数据集为例,图像是RGB3通道的,label是1通道的,(其实label原来是几通道的无所谓,只要读取的时候转化成灰度图就行)。

训练数据:

语义label:

这里我们看到label图片都是黑色的,只有白色的轮廓而已

其实是因为label图片里的像素值取值范围是0 ~ 20,即像素点可能的类别共有21类(对此数据集来说),详情如下:

所以对于灰度值0---20来说,我们肉眼看上去就确实都是黑色的,因为灰度值太低了,而白色的轮廓的灰度值是255!

但是这些边界在计算损失值的时候是不作为有效值的,也就是对于灰度值=255的点是忽略的。

如果想看的话,可以用一些色彩变换,对0--20这每一个数字对应一个色彩,就能看出来了,示例如下

这不是重点,只是给大家看一下方便理解而已

2、文本信息

同样有一个文本来指导我对数据的读取,我的信息如下

这其实就是一个记载了图像ID的文本文档,连后缀都没有,但我们依然可以根据这个去数据集中读取相应的image和label

3、代码示例

这个代码是我自己在利用deeplabV2 跑semantic segmentation 任务时写的一个,也许写的并不优美,但反正是可以用的,

可以做个抛砖引玉的目的,对于才入门的朋友,理解这个思路就可,不必照搬我的代码风格……

代码语言:javascript
复制
import os
import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
import cv2
from PIL import Image
import torchvision.transforms as transforms
from torch.utils import data

class VOCDataSet(data.Dataset):
    def __init__(self, root, list_path,  crop_size=(321, 321), mean=(104.008, 116.669, 122.675), mirror=True, scale=True, ignore_label=255):
        super(VOCDataSet,self).__init__()
        self.root = root
        self.list_path = list_path
        self.crop_h, self.crop_w = crop_size
        self.ignore_label = ignore_label
        self.mean = np.asarray(mean, np.float32)
        self.is_mirror = mirror
        self.is_scale = scale

        self.img_ids = [i_id.strip() for i_id in open(list_path)]

        self.files = []
        for name in self.img_ids:
            img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
            label_file = os.path.join(self.root, "SegmentationClassAug/%s.png" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })

    def __len__(self):
        return len(self.files)


    def __getitem__(self, index):
        datafiles = self.files[index]

        '''load the datas'''
        name = datafiles["name"]
        image = Image.open(datafiles["img"]).convert('RGB')
        label = Image.open(datafiles["label"]).convert('L')
        size_origin = image.size # W * H

        '''random scale the images and labels'''
        if self.is_scale: #如果我在定义dataset时选择了scale=True,就执行本语句对尺度进行随机变换
            ratio = 0.5 + random.randint(0, 11) // 10.0 #0.5~1.5
            out_h, out_w = int(size_origin[1]*ratio), int(size_origin[0]*ratio)
            # (H,W)for Resize
            image = transforms.Resize((out_h, out_w), Image.LANCZOS)(image)
            label = transforms.Resize((out_h, out_w), Image.NEAREST)(label)

        '''pad the inputs if their size is smaller than the crop_size'''
        pad_w = max(self.crop_w - out_w, 0)
        pad_h = max(self.crop_h - out_h, 0)
        img_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=0, padding_mode='constant')(image)
        label_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=self.ignore_label, padding_mode='constant')(label)
        out_size = img_pad.size

        '''random crop the inputs'''
        if (self.crop_h != 0 or self.crop_w != 0):
            #select a random start-point for croping operation
            h_off = random.randint(0, out_size[1] - self.crop_h)
            w_off = random.randint(0, out_size[0] - self.crop_w)
            #crop the image and the label
            image = img_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
            label = label_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))

        '''mirror operation'''
        if self.is_mirror:
            if np.random.random() < 0.5:
                #0:FLIP_LEFT_RIGHT, 1:FLIP_TOP_BOTTOM, 2:ROTATE_90, 3:ROTATE_180, 4:or ROTATE_270.
                image = image.transpose(0)
                label = label.transpose(0)

        '''convert PIL Image to numpy array'''
        I = np.asarray(image,np.float32) - self.mean
        I = I.transpose((2,0,1))#transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        #print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name

#这是一个测试函数,也即我的代码写好后,如果直接python运行当前py文件,就会执行以下代码的内容,以检测我上面的代码是否有问题,这其实就是方便我们调试,而不是每次都去run整个网络再看哪里报错
if __name__ == '__main__':
    DATA_DIRECTORY = '/home/teeyo/STA/Data/voc_aug/'
    DATA_LIST_PATH = '../dataset/list/val.txt'
    Batch_size = 4
    MEAN = (104.008, 116.669, 122.675)
    dst = VOCDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = data.DataLoader(dst, batch_size = Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels,_,_= data
        if i%1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8) #change the dtype from float32 to uint8, because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0))#transpose the Channels*H*W to  H*W*Channels
            #img = img[:, :, ::-1]
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)

            #input()

我个人觉得我应该注释的地方都有相应的注释,虽然有点长, 因为实现了crop和翻转以及scale等功能,但是大家可以下去慢慢揣摩,理解其中的主要思路,与我前一篇的博文Pytorch创建自己的数据集1做对比,那篇博文相当于是提供了最基本的骨架,而这篇就在骨架上长肉生发而已,有疑问的欢迎评论探讨~~

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年08月27日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档