专栏首页Datawhale专栏PyTorch常见的坑汇总

PyTorch常见的坑汇总

最近刚开始用pytorch不久,陆陆续续踩了不少坑,记录一下,个人感觉应该都是一些很容易遇到的一些坑,也在此比较感谢帮我排坑的小伙伴,持续更新,也祝愿自己遇到的坑越来越少。

首先作为tensorflow的骨灰级玩家+轻微强迫症患者,一路打怪升级,从0.6版本用到1.2,再用到1.10,经历了tensorfow数个版本更迭,这里不得不说一下tf.data.dataset+tfrecord使用起来效率远比dataloader高的多。

tensorflow有一个比较好用的队列机制,tf.inputproducer + tfrecord, 但是inputproducer有一个bug,就是无法对每个epoch单独shuffle,它只能整体shuffle,也就意味着我们无法进行正常的训练流程(train几个epoch,在validation上测一个epoch,最终选一个validation上的最好的结果,进行test)。后来我当时给官方提了一个issue,官方当时的回答是,这个bug目前无法解决,但是他们在即将到来的tf1.2版本中, 推出的新型数据处理API tf.contrib.data.dataset(tf1.3版本将其合并到了tf.data.dataset)可以完美解决这个bug,并且将于tf2.0摒弃tf.input_producer。然后tf1.2版本刚出来以后,我就立马升级并且开始tf.data.dataset踩坑,踩了大概2周多的坑,(这个新版的API其实功能并不是非常强大,有不少局限性,在此就不展开)。

好像扯远了,回归pytorch,首先让我比较尴尬的是pytorch并没有一套属于自己的数据结构以及数据读取算法,dataloader个人感觉其实就是类似于tf中的feed,并没有任何速度以及性能上的提升。

先总结一下遇到的坑:

1. 没有比较高效的数据存储,tensorflow有tfrecord, caffe有lmdb,cv.imread在网络训练过程中实属浪费时间。这里感谢一下小智大神@智天成

解决方案:

当时看到了一个还不错的github链接:

https://github.com/Lyken17/Efficient-PyTorch

主要是讲如何使用lmdb,h5py,pth,lmdb,n5等数据存储方式皆可以。

个人的感受是,h5在数据调用上比较快,但是如果要使用多线程读写,就尽量不要使用h5,因为h5的多线程读写好像比较麻烦。

http://docs.h5py.org/en/stable/mpi.html

这里贴一下h5数据的读写代码(主要需要注意的是字符串的读写需要encode,decode,最好用create_dataset,直接写的话读的时候会报错):

写:
    imagenametotal_.append(os.path.join('images', imagenametotal).encode())
    with h5py.File(outfile) as f:
        f.create_dataset('imagename', data=imagenametotal_)
        f['part'] = parts_
        f['S'] = Ss_
        f['image'] = cvimgs
读:
with h5py.File(outfile) as f:
    imagename = [x.decode() for x in f['imagename']]
    kp2ds = np.array(f['part'])
    kp3ds = np.array(f['S'])
    cvimgs = np.array(f['image'])

2. gpu imbalance,这里感谢一下张航学长@张航

老生常谈的问题,第一个GPU显存占用多一点。

张航学长提了一个开源的gpu balance的工具—PyTorch-Encoding。

https://github.com/zhanghang1989/PyTorch-Encoding

使用方法还是比较便捷的,如下所示:

from balanced_parallel import DataParallelModel, DataParallelCriterion
model = DataParallelModel(model, device_ids=gpus).cuda()
criterion = loss_fn().cuda()

这里其实有2个注意点,第一,测试的时候需要手动将gpu合并,代码如下:

from torch.nn.parallel.scatter_gather import gather
preds = gather(preds, 0)

第二,当loss函数有多个组成的时候,比如 loss = loss1 + loss2 + loss3

那么需要把这三个loss写到一个class中,然后再forward里面将其加起来。

其次,我们还可以用另外一个函数distributedDataParallel来解决gpu imbalance的问题。

使用方法如下:(注:此方法好像无法和h5数据同时使用)

from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

torch.distributed.init_process_group(backend="nccl")
# 配置每个进程的gpu
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)

#封装之前要把模型移到对应的gpu
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[local_rank],
                                                   output_device=local_rank)

#原有的dataloader上面加一个数据sample
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=DistributedSampler(train_dataset)
    )

3. gpu利用率不高+gpu现存占用浪费

常用配置:

(1)主函数前面加:(这个会牺牲一点点现存提高模型精度)

cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.enabled = True

(2)训练时,epoch前面加:(定期清空模型,效果感觉不明显)

torch.cuda.empty_cache()

(3)无用变量前面加:(同上,效果某些操作上还挺明显的)

del xxx(变量名)

(4)dataloader的长度_len_设置:(dataloader会间歇式出现卡顿,设置成这样会避免不少)

def __len__(self):
    return self.images.shape[0]

(5)dataloader的预加载设置:(会在模型训练的时候加载数据,提高一点点gpu利用率)

train_loader = torch.utils.data.DataLoader(
        train_dataset,
        pin_memory=True,
    )

(6)网络设计很重要,外加不要初始化任何用不到的变量,因为pyroch的初始化和forward是分开的,他不会因为你不去使用,而不去初始化。

(7)最后放一张目前依旧困扰我的图片:

可以看到,每个epoch刚开始训练数据的时候,第一个iteration时间会占用的非常多,pytorch这里就做的很糟糕,并不是一个动态分配的过程,我也看到了一个看上去比较靠谱的解决方案,解决方案如下:在深度学习中喂饱 GPU

但是我看了下代码,可能需要重构dataloader,看了评论好像还有问题,有点懒,目前还没有踩坑,准备后面有时间踩一下。

本文分享自微信公众号 - Datawhale(Datawhale),作者:郁振波

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-09-22

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 从入门到冠军 中国移动人群画像赛TOP1经验分享

    郑雨轩,Datawhale数据科学团队负责人,长期活跃在国内外数据竞赛领域,拥有较为丰富的竞赛经验,最近比赛经历:

    Datawhale
  • pandas_profiling:一行代码生成你的数据分析报告

    笔者最近发现一款将pandas数据框快速转化为描述性数据分析报告的package——pandas_profiling。一行代码即可生成内容丰富的EDA内容,两行...

    Datawhale
  • 【NLP】 深入浅出解析BERT原理及其表征的内容

    了解自然语言处理的同学,应该没有不知道BERT的;甚至连很多自然语言处理圈外的开发的同学,都知道这个模型的厉害。不得不佩服谷歌的影响力和营销水平,当然,也确实是...

    Datawhale
  • 架构爬坑记——从0开始学习分布式锁

    1、定义锁的接口Lock 2、在AbstractLock模板锁里面实现getLock方法,实现通用的逻辑。 3、不能确实的步骤,作为虚拟方法,甩锅给子类实现...

    慕容千语
  • 架构爬坑记——从0开始学习分布式锁

    4、子类只需要聚焦自己的小步骤逻辑,实现tryLock,waitLock,unLock方法。

    Java知音
  • Android TV开发简介

    0.概述 智能电视近两年可谓发展迅猛,据艾瑞咨询数据显示,2015年国内智能电视终端销售量将达到3556.8万台,预计在2016年用户量突破一亿。不同于手机端群...

    QQ音乐技术团队
  • Jenkins 配置信息变更历史跟踪

    像 Jenkins 这样的系统,使用的过程就是配置文件变更的过程。如果能够对配置文件的变更进行跟踪管理,将极大的提高系统的可用性。Job Configurati...

    Criss@陈磊
  • Jenkins:配置信息变更历史

    作者:sparkdev 出处:http://www.cnblogs.com/sparkdev/ 像 Jenkins 这样的系统,使用的过程就是配置文件变更的过程...

    程序猿DD
  • 用JS编写一个Java虚拟机?谈谈哗众取宠的BicaVM

    今日目睹某网络新闻,开篇明义便包含如下几行文字 【程序员Artur Ventura,这位超级大牛,用JavaScript写了一个java虚拟机BicaVM】 ...

    三哥
  • 简述:Rust-1.38.0 RELEASE NOTE

    mac OS更新,如果使用brew安装的,那么恭喜你,现在brew上面只能更新到1.37.0:

    江湖安得便相忘

扫码关注云+社区

领取腾讯云代金券