OpenAI推新程序包:GPU适应十倍大模型仅需增加20%训练时间

安妮 编译自 Medium 量子位 出品 | 公众号 QbitAI

GPU内存太小可能是神经网络训练过程中最大的拦路虎。

不怕,用这个OpenAI推出的gradient-checkpointing程序包,对于前馈模型来说,仅仅需要增加20%的计算时间,这个程序包,GPU就能适应十倍大的模型。

还有这种操作?

训练神经网络对内存的要求随着网络的深度和batch-size呈线性增长。在内存有限的情况下,如果想训练深层模型,并且增加batch-size,很多研究人员会采用KFAC这样的二阶方法。与小批量的SGD相比,这种方法发需要学习较少的样例。

重点来了。昨天,OpenAI的研究科学家Tim Salimans和前Google Brain工程师的数据科学家Yaroslav Bulatov两人发布了一个python/TensorFlow包,名为gradient-checkpointing。

这个程序包使用了“用亚线性的存储成本训练神经网络”的技术,为简单的前馈网络提供了等价的内存存储,同时能为一般的神经网络节省内存,比如多层架构。

将这个程序包应用到TensorFlow官方CIFAR10 ResNet示例中。在batch size=1280的情况下,将内存和执行时间情况如下图所示。

常规反向传播为线性扩展,但优化后的方法以深度的平方根方式扩展。当我们在更深层次的网络上尝试时,差异就更明显了。

用标准方法,运行这个迭代需要60GB的内存,但新方法只需6GB的RAM。

再来看看计算时间。在实验中,在GTX1080上的运行时间增加了20%,在V100 GPU上时间增加了30%。

如果想了解这个程序包是如何节约内存的,可以移步GitHub一探究竟:

https://github.com/openai/gradient-checkpointing

原文发布于微信公众号 - 量子位(QbitAI)

原文发表时间:2018-01-16

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏吉浦迅科技

【在线视频】如何在GPU上进行混合精度训练

使用精度低于FP32的系统可以减少内存使用,允许部署更大的网络。数据传输需要更少的时间,而且计算性能会提高,尤其是在NVIDIA gpu上,它的Tensor C...

3171
来自专栏机器之心

业界 | 现代「罗塞塔石碑」:微软提出深度学习框架的通用语言

选自arXiv 作者:Ilia Karmanov等 机器之心编译 参与:路雪、刘晓坤、白妤昕 深度学习框架就像语言一样:很多人会说英语,但每种语言都有自己的特殊...

3444
来自专栏AI研习社

不可错过的TensorFlow工具包,内含8大算法,即去即用!

这是来自谷歌的工程师Ashish Agarwal2017 TensorFlow开发者峰会在的演讲,主题是《ML Toolkit》。他认为TensorFlow 是...

4293
来自专栏有趣的Python和你

sklearn调包侠之无敌小抄

1606
来自专栏AI研习社

TensorFlow实现神经网络入门篇

如果你一直关注数据科学/机器学习,你就不能错过深度学习和神经网络的热潮。互联网公司正在寻找这方面的人,而且从竞赛到开源项目,都有巨额奖金。 如果你对深度学习所提...

3734
来自专栏数据派THU

教你用TensorFlow实现神经网络(附代码)

? 来源:云栖社区 作者:Pavel Surmenok 本文长度为2600字,建议阅读5分钟 本文帮助你理解神经网络的应用,并使用TensorFlow解决现实...

2628
来自专栏AI科技评论

干货 | AI 从业者都应该知道的实验数据集

AI 科技评论按:数据集对于深度学习模型的重要性不言而喻,然而根据性质、类型、领域的不同,数据集往往散落在不同的资源平台里,急需人们做出整理。 fast.ai ...

1373
来自专栏技术随笔

[译] Introduction to debugging neural networks

3506
来自专栏AI研习社

2017 TensorFlow开发者峰会之ML工具包

这是来自谷歌的工程师Ashish Agarwal的演讲,主题是《ML Toolkit》。他认为TensorFlow 是一项很棒的技术,在谷歌,它已经在为很多系统...

3043
来自专栏自学笔记

Label Propagation

Label propagation是基于标传播的一种社区划分算法。Label Propagation Algorithm简称LPA算法,也可以是说是一种划分小团...

2254

扫码关注云+社区