如何计算显存的占用,常常遇到out of memory?

如何计算显存的占用,预防out of memory?

最近一次组会上,师兄点评一篇文章显存占用过多,突然发现还不知道如何具体的计算显存,只好去学习一下。

显存类似于内存,可以存放模型数据,参数等等;显存越大,所能运行的网络也就越大

torch.FatalError: cuda runtime error (2) : out of memory at /opt/conda/.......

out of memory: 显存装不下你那么多的模型权重还有中间变量

GPU计算单元用来进行数值计算,衡量计算量的单位是flop,浮点数先乘后加算一个flop计算能力越强大,速度越快。衡量计算能力的单位是 flops: 每秒能执行的 flop数量。

2*2+2 :1个flop

2*2+3*3+4*4 : 3个flop

1、 存储指标

1 Byte = 8 bit
1 K = 1024 Byte
1 M = 1024 K
1 G = 1024 M

除此之外,

1 Byte = 8 bit
1 KB = 1000 Byte
1 MB = 1000 KB
1 GB = 1000 MB
1TB = 1000 GB

常用的数值类型:

若一张256*256的RGB图片存储在显存中占有显存为(float):

3*256*256*4=0.75M,若batchsize=100,也就占用75M,显存,显然,占用显存较大的不是输入图片数据,那会是什么呢?

什么占用了显存?

首先,了解神经网络的构成,我们当然知道神经网络只是一种类似神经的架构,主要由构成网络层的各种参数构成,以及神经网络的各种中间输出。

网络模型的参数:

看一个例子:

  • 模型权重:各种网络层的参数
  • 卷积层,通常的conv2d
  • 全连接层,也就是Linear层
  • BatchNorm层
  • Embedding层
  • 中间变量:各种网络层的输出

而不占用显存的则是:

  • 刚才说到的激活层Relu等
  • 池化层
  • Dropout层

具体计算方式:

  • Conv2d(Cin, Cout, K): 参数数目:Cin × Cout × K × K
  • Linear(M->N): 参数数目:M×N
  • BatchNorm(N): 参数数目: 2N
  • Embedding(N,W): 参数数目: N × W

参数占用显存:

参数占用显存 = 参数数目×n

n = 4 :float32
n = 2 : float16
n = 8 : double64

优化器的显存占用:

例如SGD优化器:

除了保存W之外还要保存参数对应的梯度,因此显存占用等于参数占用的显存的2倍。

Momentum-SGD:保存参数、梯度、动量------3倍

Adam:------------------------------------------4倍

输入输出的显存占用:

特点:

  • 需要计算每一层的feature map的形状(多维数组的形状)
  • 模型输出的显存占用与 batch size 成正比
  • 需要保存输出对应的梯度用以反向传播(链式法则)
  • 模型输出不需要存储相应的动量信息(因为不需要执行优化)

具体计算:

显存占用 = 模型显存占用 + batch_size × 每个样本的显存占用

注意 : 输入数据不用计算梯度;激活函数不用保存输入;

如何减小显存占用?出现out of memory如何处理?

  • 尽量不使用全连接层
  • 下采样
  • 减小batchsize

最简单处理方法,也是最常用的方法

减小batchsize

一般模型参数与batchsize成一定的不严格的正比关系。

参考资料:https://blog.csdn.net/liusandian/article/details/79069926

本文分享自微信公众号 - AI深度学习求索(AIDeepLearningQ)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-10-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏GreenLeaves

TFS2018环境搭建一硬件要求

TFS可以安装在Windows Server和Windows PC操作系统中,但是TFS2018和2018只支持64位操作系统中,早期的版本没有操作系统的位数限...

41030
来自专栏smy

一张图解释负载均衡

首先当大量用户访问时候,先请求到nignx服务器,因为nignx对于高并发支持较好,所以由nignx服务器将访问需求分配给不同的apache服务器,apache...

21330
来自专栏大数据文摘

迷人又诡异的辛普森悖论:同一个数据集是如何证明两个完全相反的观点的?

在辛普森悖论中,餐馆可以同时比竞争对手更好或更差,锻炼可以降低和增加疾病的风险,同样的数据集能够用于证明两个完全相反的论点。

15730
来自专栏编程坑太多

『高级篇』docker之Mesos集群架构图(23)

11840
来自专栏我是攻城师

理解BitMap算法的原理

位图:一种常用的数据结构,代表了有限域中的稠集(dense set),每一个元素至少出现一次,没有其他的数据和元素相关联。在索引,数据压缩,海量数据处理等方面有...

25030
来自专栏Python专栏

200行代码,一行行教你自制微信机器人

1) 用一个windows客户端工具运营公众号,真的很局限。虽然工具的功能很强大,能自动添加好友,自动拉好友入群,关键字回复等等,但是有一个绕不开的点,它是一款...

62120
来自专栏chenssy

多线程:为什么在while循环中加入System.out.println,线程可以停止

这个我们都知道,由于 stopReqested 的更新值在主内存中,而线程栈中的值不是最新的,所以会一直循环,线程并不能停止。加上 Volatile 关键字后,...

21440
来自专栏苦逼的码农

一些常用的算法技巧总结

数组的下标是一个隐含的很有用的数组,特别是在统计一些数字,或者判断一些整型数是否出现过的时候。例如,给你一串字母,让你判断这些字母出现的次数时,我们就可以把这些...

21530
来自专栏机器之心

Diss所有深度生成模型,DeepMind说它们真的不知道到底不知道什么

深度学习在应用层面获得了巨大成功,这些实际应用一般都希望利用判别模型构建条件分布 p(y|x),其中 y 是标签、x 是特征。但这些判别模型无法处理从其他分布中...

11610
来自专栏数据结构笔记

python基础类型(一):字符串和列表

注意到最后三个的单双引号是嵌套使用的,但是最后一个的使用方法是错误的,因为当我们混合使用两种引号时必须有一种用来划分字符串的边界,即在两边的引号不能出现在字符串...

13920

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励