首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

如何在Pytorch中精细化利用显存

前言

在上篇文章《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》中我们对如何计算各种变量所占显存大小进行了一些探索。而这篇文章我们着重讲解如何利用Pytorch深度学习框架的一些特性,去查看我们当前使用的变量所占用的显存大小,以及一些优化工作。以下代码所使用的平台框架为Pytorch。

优化显存

在Pytorch中优化显存是我们处理大量数据时必要的做法,因为我们并不可能拥有无限的显存。显存是有限的,而数据是无限的,我们只有优化显存的使用量才能够最大化地利用我们的数据,实现多种多样的算法。

估测模型所占的内存

上篇文章中说过,一个模型所占的显存无非是这两种:

模型权重参数

模型所储存的中间变量

其实权重参数一般来说并不会占用很多的显存空间,主要占用显存空间的还是计算时产生的中间变量,当我们定义了一个model之后,我们可以通过以下代码简单计算出这个模型权重参数所占用的数据量:

假设我们有这样一个model:

然后我们得到的是,但是我们计算出来的仅仅是权重参数的“数量”,单位是B,我们需要转化一下:

这样就可以打印出:

但是我们之前说过一个神经网络的模型,不仅仅有权重参数还要计算中间变量的大小。怎么去计算,我们可以假设一个,然后将这个输入变量投入这个模型中,然后我们主动提取这些计算出来的中间变量:

上面得到的值是模型在运行时候产生所有的中间变量的“数量”,当然我们需要换算一下:

因为在的时候所有的中间变量需要保存下来再来进行计算,所以我们在计算的时候,计算出来的中间变量需要乘个2。

然后我们得出,上面这个模型的中间变量需要的占用的显存,很显然,中间变量占用的值比模型本身的权重值多多了。如果进行一次backward那么需要的就更多。

我们总结一下之前的代码:

当然我们计算出来的占用显存值仅仅是做参考作用,因为Pytorch在运行的时候需要额外的显存值开销,所以实际的显存会比我们计算的稍微大一些。

.关于`inplace=False`

我们都知道激活函数有一个默认参数,默认设置为False,当设置为True时,我们在通过relu()计算时的得到的新值不会占用新的空间而是直接覆盖原来的值,这也就是为什么当inplace参数设置为True时可以节省一部分内存的缘故。

relu

牺牲计算速度减少显存使用量

在出来了一个新的功能,可以将一个计算过程分成两半,也就是如果一个模型需要占用的显存太大了,我们就可以先计算一半,保存后一半需要的中间结果,然后再计算后一半。

也就是说,新的允许您、我们只存储反向传播所需要的部分内容。如果当中缺少一个输出(为了节省内存而导致的),将会从最近的检查点重新计算中间输出,以便减少内存使用(当然计算时间增加了):

上面的模型需要占用很多的内存,因为计算中会产生很多的中间变量。为此就可以帮助我们来节省内存的占用了。

对于Sequential-model来说,因为中可以包含很多的block,所以官方提供了另一个功能包:

跟踪显存使用情况

显存的使用情况,在编写程序中我们可能无法精确计算,但是我们可以通过pynvml这个Nvidia的Python环境库和Python的垃圾回收工具,可以实时地打印我们使用的显存以及哪些Tensor使用了我们的显存。

类似于下面的报告:

以下是相关的代码,目前代码依然有些地方需要修改,等修改完善好我会将完整代码以及使用说明放到github上:https://github.com/Oldpan/Pytorch-Memory-Utils

请大家多多留意。

需要注意的是,linecache中的getlines只能读取缓冲过的文件,如果这个文件没有运行过则返回无效值。Python 的垃圾收集机制会在变量没有应引用的时候立马进行回收,但是为什么模型中计算的中间变量在执行结束后还会存在呢。既然都没有引用了为什么还会占用空间?

一种可能的情况是这些引用不在Python代码中,而是在神经网络层的运行中为了backward被保存为gradient,这些引用都在计算图中,我们在程序中是无法看到的:

TIM截图20180608173020后记

实际中我们会有些只使用一次的模型,为了节省显存,我们需要一边计算一遍清除中间变量,使用进行操作。限于篇幅这里不进行讲解,下一篇会进行说明。

关注Oldpan博客,同步更新博客最新消息,持续酝酿深度学习质量文。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180623G1QM0300?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券