各位数据大咖,还记得自己的跑模型的心路历程么?我想大家都在经历着下面的一个或多个阶段:
数据分析师在进阶,工具在进阶,但数据也在进阶!
传统机器学习方法,需要把训练数据集中于某一台机器或是单个数据中心里,为了满足逐渐增加的数据量级,还要不断加机器、不断建设基础设施。
现在,谷歌研发出一种训练 AI 的新模式,可以直接在用户的手机上训练并改进 AI 算法,数据都保存在终端手机里。更神奇的是,多台手机之间还能进行协作训练,共享预测模型。
它有一个很霸气的名字——Federated Learning,联盟学习!
工作原理
Federated Learning 的工作流程如下:
(模型训练循环图,图来源于googleblog)
模型优点
这样一种「联盟学习」的模式,有何优点?数说君根据谷歌的官方文章,给出如下总结:
模型应用
目前,谷歌已经在谷歌输入法 Gboard 上测试该模型。
有个背景数说君要先介绍一下,Gboard 不仅是一个简单的输入法,它还在键盘上集成了 Google 搜索,在输入文字的同时拥有了强大的第二大脑。
当使用 Gboard 集成的 Google搜索 功能时,Google搜索 会显示推荐搜索项,此时手机会在将搜索内容储存在本地。Federated Learning 会对本地的这些数据进行处理训练, 以用来改进 Gboard 检索推荐模型。
挑战与解决
然而问题还是有的,谷歌承认实现 Federated Learning 还有一些技术上的挑战:
在典型的机器学习系统中,超大型数据集会被平均分割到云端的多个服务器上。像随机梯度下降(SGD)这样的优化算法很适合在此上面运行。因为这些反复迭代的算法,需要与训练数据集之间有低延迟、高流量的连接。 但在 Federated Learning 系统中,数据以非常不平均的方式分布在数百万的移动设备上。而且,智能手机的延迟更高、吞吐的流量更低,并且仅可在保证用户日常使用的前提下,断断续续地进行训练。
为解决这些问题,谷歌专门开发出了一套名为 Federated Averageing 的算法(见参考资料(3)),相比于原生 SGD 算法,该算法在训练深度神经网络时,只需要10%~1%的网络通信要求。
由于上传速度一般都会比下载速度慢很多,为把上传速度再提升,谷歌为此还通过使用 random rotation 和 quantization 来压缩更新,把上传速度再减少100倍(见参考资料(4))。
另外,谷歌还专门设计了一个针对高维稀疏 convex 模型的算法 Federate Optimization,该算法特别擅长解决点击率预测等问题(见参考资料(5))。
未来,谷歌会不断拓展 Federated Learning 的功能,并希望能根据手机输入习惯改进语言模型;以及根据图片浏览数据改进图片排列等。
参考资料:
(1) Federated Learning: Collaborative Machine Learning without Centralized Training Data,https://research.googleblog.com/2017/04/federated-learning-collaborative.html
(2) Practical Secure Aggregation for Privacy Preserving Machine Learning,http://eprint.iacr.org/2017/281
(3) Communication-Efficient Learning of Deep Networks from Decentralized Data,https://arxiv.org/abs/1602.05629
(4) Federated Learning: Strategies for Improving Communication Efficiency, https://arxiv.org/abs/1610.05492
(5) Federated Optimization: Distributed Machine Learning for On-Device Intelligence,https://arxiv.org/abs/1610.02527
- END -