前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何估算transformer模型的显存大小

如何估算transformer模型的显存大小

作者头像
deephub
发布2022-11-11 17:14:53
1.9K0
发布2022-11-11 17:14:53
举报
文章被收录于专栏:DeepHub IMBA

点击上方“Deephub Imba”,关注公众号,好文章不错过 !

在微调GPT/BERT模型时,会经常遇到“ cuda out of memory”的情况。这是因为transformer是内存密集型的模型,并且内存要求也随序列长度而增加。所以如果能对模型的内存要求进行粗略的估计将有助于估计任务所需的资源。

如果你想直接看结果,可以跳到本文最后。不过在阅读本文前请记住所有神经网络都是通过反向传播的方法进行训练的, 这一点对于我们计算内存的占用十分重要。

代码语言:javascript
复制
 total_memory = memory_modal + memory_activations + memory_gradients

这里的memory_modal是指存储模型所有参数所需的内存。memory_activations是计算并存储在正向传播中的中间变量,在计算梯度时需要使用这些变量。因为模型中梯度的数量通常等于中间变量的数量,所以memory_activations= memory_gradients。因此可以写成:

代码语言:javascript
复制
 total_memory = memory_modal + 2 * memory_activations

所以我们计算总体内存的要求时只需要找到memory_modal和memory_activations就可以了。

估算模型的内存

下面我们以GPT为例。GPT由许多transformer块组成(后面我用n_tr_blocks表示其数量)。每个transformer块都包含以下结构:

代码语言:javascript
复制
 multi_headed_attention --> layer_normalization --> MLP -->layer_normalization

每个multi_headed_attention元素都由键,值和查询组成。其中包括n_head个注意力头和dim个维度。MLP是包含有n_head * dim的尺寸。这些权重都是要占用内存的,那么

代码语言:javascript
复制
 memory_modal = memory of multi_headed_attention + memory of MLP
              = memory of value  + memory of key + memory of query + memory of MLP
              = square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
              = 4*square_of(n_head * dim)

因为我们的模型包含了n个单元。所以最后内存就变为:

代码语言:javascript
复制
 memory_modal = 4*n_tr_blocks*square_of(n_head * dim)

上面的估算没有考虑到偏差所需的内存,因为这大部分是静态的,不依赖于批大小、输入序列等。

估算中间变量的内存

多头注意力通常使用softmax,可以写成:

代码语言:javascript
复制
 multi_headed_attention = softmax(query * key * sequence_length) * value

k,q,v的维度是:

代码语言:javascript
复制
 [batch_size, n_head, sequence_length, dim]

multi_headed_attention操作会得出如下形状:

代码语言:javascript
复制
 [batch_size, n_head, sequence_length, sequence_length]

所以最终得内存为:

代码语言:javascript
复制
 memory_softmax  = batch_size * n_head * square_of(sequence_length)

q* k * sequence_length操作乘以value的形状为[batch_size, n_head, sequence_length, dim]。MLP也有相同的维度:

代码语言:javascript
复制
 memory of MLP  = batch_size * n_head * sequence_length * dim
 memory of value = batch_size * n_head * sequence_length * dim

我们把上面的整合在一起,单个transformer的中间变量为:

代码语言:javascript
复制
 memory_activations = memory_softmax + memory_value + memory_MLP
         = batch_size * n_head * square_of(sequence_length)
           + batch_size * n_head * sequence_length * dim
           + batch_size * n_head * sequence_length * dim
         = batch_size * n_head * sequence_length * (sequence_length + 2*dim)

再乘以块的数量,模型所有的memory_activations就是:

代码语言:javascript
复制
 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

整合在一起

我们把上面两个公式进行归纳总结,想看结果的话直接看这里就行了。transformer模型所需的总内存为:

代码语言:javascript
复制
 total_memory = memory_modal + 2 * memory_activations

模型参数的内存:

代码语言:javascript
复制
 4*n_tr_blocks*square_of(n_head * dim)

中间变量内存:

代码语言:javascript
复制
 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

我们使用下面的符号可以更简洁地写出这些公式。

代码语言:javascript
复制
 R = n_tr_blocks = transformer层堆叠的数量
 N = n_head = 注意力头数量
 D = dim = 注意力头的维度
 B = batch_size = 批大小
 S = sequence_length =输入序列的长度
 
 memory modal = 4 * R * N^2 * D^2
 
 memory activations = RBNS(S + 2D)

所以在训练模型时总的内存占用为:

代码语言:javascript
复制
 M = (4 * R * N^2 * D^2) + RBNS(S + 2D)

因为内存的占用和序列长度又很大的关系,如果有一个很长的序列长度S >> D S + 2D <——> S,这时可以将计算变为:

代码语言:javascript
复制
 M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2

可以看到对于较大的序列,M与输入序列长度的平方成正比,与批大小成线性比例,这也就证明了序列长度和内存占用有很大的关系。

所以最终的内存占用的评估为:

代码语言:javascript
复制
 总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)

作者:Schartz Rehan


MORE

kaggle比赛交流和组队

加我的微信,邀你进群

喜欢就关注一下吧:

点个 在看 你最好看!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-08-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DeepHub IMBA 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 估算模型的内存
  • 估算中间变量的内存
  • 整合在一起
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档