前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >自定义PyTorch中的Sampler

自定义PyTorch中的Sampler

作者头像
带萝卜
发布2020-10-23 14:51:18
3.8K0
发布2020-10-23 14:51:18
举报
文章被收录于专栏:我的机器学习之路

本文使用 Zhihu On VSCode 创作并发布

在训练GAN的过程中,一次只训练一个类别据说有助于模型收敛,但是PyTorch里面没有预设这种数据加载方式,要这样训练的话,需要自己定义Sampler,即自定义数据采样方式。下面是自定义的方法:

首先,我们虚构一个Dataset类,用于测试。

这个类中的label标签是混乱的,无法通过控制index范围来实现单类别训练。

代码语言:javascript
复制
class Data(Dataset):
    def __init__(self):
        self.img = torch.cat([torch.ones(2, 2) for i in range(100)], dim=0)
        self.num_classes = 2
        self.label = torch.tensor(
            [random.randint(0, self.num_classes - 1) for i in range(100)]
        )

    def __getitem__(self, index):
        return self.img[index], self.label[index]

    def __len__(self):
        return len(self.label)

然后,自定义一个Sampler类,这个类的作用是生成一个index列表,可以理解为重排data中的index。

代码语言:javascript
复制
class CustomSampler(Sampler):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        indices = []
        for n in range(self.data.num_classes):
            index = torch.where(self.data.label == n)[0]
            indices.append(index)
        indices = torch.cat(indices, dim=0)
        return iter(indices)

    def __len__(self):
        return len(self.data)

定义好了之后可以封装成DataLoader并查看运行结果:

代码语言:javascript
复制
d = Data()
s = CustomSampler(d)
dl = DataLoader(d, 8, sampler=s)
for img, label in dl:
    print(label)

结果

代码语言:javascript
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1])

显然,这样的结果并不能让人满意,有一个batch中还是包含了两种不同类型的标签,为了达到目的,我们还需要再定义一个BatchSampler类

代码语言:javascript
复制
class CustomBatchSampler:
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        i = 0
        sampler_list = list(self.sampler)
        for idx in sampler_list:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []

            if (
                i < len(sampler_list) - 1
                and self.sampler.data.label[idx]
                != self.sampler.data.label[sampler_list[i + 1]]
            ):
                if len(batch) > 0 and not self.drop_last:
                    yield batch
                    batch = []
                else:
                    batch = []
            i += 1
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

虽然PyTorch要求Sampler需要定义成一个迭代器,但是如果你自己定义BatchSampler的话,Sampler的形式可以自己定,就算写成一个普通的列表也没关系。

再次封装成DataLoader并查看运行结果:

代码语言:javascript
复制
d = Data()
s = CustomSampler(d)
bs = CustomBatchSampler(s, 8, False)
dl = DataLoader(d, batch_sampler=bs)
for img, label in dl:
    print(label)

drop_last = False 的结果:

代码语言:javascript
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1])

drop_last = True 的结果:

代码语言:javascript
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])

以上就是自定义Sampler的步骤了。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档