深入学习Apache Spark和TensorFlow

神经网络在过去的几年中取得了惊人的进展,现在它们已经成为图像识别和自动翻译领域的领先技术。TensorFlow是Google发布的用于数值计算和神经网络的新框架。在这篇博文中,我们将演示如何使用TensorFlow和Spark一起来训练和应用深度学习模型。

您可能想知道:当大多数高性能深度学习是单节点实现时,Apache Spark在这里使用的是什么?为了回答这个问题,我们介绍两个用例,并解释如何使用Spark和一组机器来改进使用TensorFlow的深度学习管道:

  1. 超参数调整:使用Spark来寻找神经网络训练的最佳超参数集,从而使训练时间减少10倍,错误率降低34%。
  2. 大规模部署模型:使用Spark将经过训练的神经网络模型应用于大量数据。

超参数调整

深度学习机器学习(ML)技术的一个例子是人工神经网络。他们需要一个复杂的输入,如图像或录音,然后对这些信号应用复杂的数学变换。这个变换的输出是一个更容易被其他ML算法操纵的数字向量。人造神经网络通过模仿人类大脑的视觉皮层中的神经元(以非常简化的形式)来执行这种转变。

就像人类学会解释他们看到的一样,人工神经网络需要被训练来识别“有趣”的特定模式。例如,这些可以是简单的模式,例如边界,圆形,但是它们可能更复杂。在这里,我们将使用NIST的经典数据集,并训练一个神经网络来识别这些数字:

TensorFlow库自动创建各种形状和大小的神经网络的训练算法。然而,构建神经网络的实际过程比在数据集上运行某个函数要复杂得多。通常有许多非常重要的超参数(非专业人员的配置参数)来设置,这会影响模型的训练。选择正确的参数会导致高性能,而错误的参数会导致长时间的训练和糟糕的性能。在实践中,机器学习从业者用不同的超参数重复运行相同的模型,以找到最佳组合。这是一种称为超参数调整的经典技术。

在建立神经网络时,有许多重要的超参数要慎重选择。例如:

  • 每层神经元的数量:神经元数量太少会降低网络的表达能力,但太多神经元会大幅增加运行时间并返回噪声估计值。
  • 学习率:如果它太高,神经网络将只关注最后看到的几个样本,而不考虑以前积累的所有经验。如果太低,达到一个好的状态将需要很长的时间。

这里有趣的是,即使TensorFlow本身不是分布式的,超参数调优过程也是“令人尴尬的并行”,可以使用Spark进行分发。在这种情况下,我们可以使用Spark来广播数据和模型描述等通用元素,然后以容错的方式在一组机器上安排单个重复计算。

如何使用Spark提高准确性?默认超参数组的准确度是99.2%。超参数调优的最佳结果在测试集上的准确率为99.47%,测试误差减少34%。将计算的线性分布与添加到集群中的节点的数量进行比例分配:使用13节点的集群,我们能够并行训练13个模型,相比于在一台机器上一次一个地训练模型,这转化为7倍的加速。以下是关于群集中计算机数量的计算时间(以秒为单位)的图形:

更重要的是,我们深入了解培训程序对各种超参数培训的敏感性。例如,对于不同数量的神经元,我们绘制关于学习速率的最终测试性能:

这显示了神经网络的典型权衡曲线:

  • 学习率是至关重要的:如果它太低,神经网络不会学到任何东西(高测试错误)。如果太高,则训练过程可能会随机摆动,甚至在某些配置上发散。
  • 神经元的数量对于获得良好的表现并不重要,而且具有许多神经元的网络对学习速率更为敏感。这是Occam的剃刀原理:对于大多数目的,简单的模型往往是“足够好”的。如果您有足够的时间和资源去处理错过1%的测试错误,那么您必须愿意投入大量的资源来进行培训,并找到适当的超参数,这些参数会有所作为。

通过使用参数的稀疏样本,我们可以将最有希望的一组参数归零。

我如何使用它?

由于TensorFlow可以使用每个工作人员的所有内核,因此我们只能在每个工作人员上同时运行一个任务,并将他们一起批处理以限制争用。按照TensorFlow网站上的说明, TensorFlow库可以作为常规Python库安装在Spark集群上。下面的笔记本展示了如何安装TensorFlow并让用户重新运行这篇博文的实验:

使用TensorFlow分布式处理图像

使用TensorFlow测试图像的分布处理

按比例部署模型

TensorFlow模型可以直接嵌入管道中,以便对数据集执行复杂的识别任务。作为一个例子,我们展示了如何从一个已经被训练的股票神经网络模型标记一组图像。

该模型首先使用Spark内置的广播机制分发给集群的工作人员:

用gfile 。FastGFile (' classify_image_graph_def 。PB ',' RB ')作为˚F :
  model_data = ˚F 。read ()
model_data_bc = sc 。广播(model_data )   

然后将这个模型加载到每个节点上并应用于图像。这是在每个节点上运行的代码草图:

def apply_batch (image_url ):#创建一个新的TensorFlow计算图并用tf 导入模型。Graph ()。as_default ()为g :
    graph_def = tf 。GraphDef ()
    graph_def 。ParseFromString (model_data_bc 。值)
    TF 。import_graph_def (graph_def ,name =“)
  
   
    #从URL加载图像数据:
    image_data = urllib 。请求。urlopen (img_url ,timeout = 1.0 )。read ()
    #运行用tf 加载的张量流会话。Session ()作为sess :
      softmax_tensor = sess 。图表。get_tensor_by_name (' softmax :0 ')
      预测= sess 。运行(softmax_tensor ,{' DecodeJpeg / contents :0 ':image_data })返回预测

这些代码可以通过将图像拼凑在一起来提高效率。

这里是一个图像的例子:

这是根据神经网络对这个图像的解释,这是相当准确的:

(' 珊瑚礁',0.88503921 ),(' 潜水员',0.025853464 ),(' 脑珊瑚',0.0090828091 ),(' 呼吸管',0.0036010914 ),(' 海角,岬角,头部,前陆',0.0022605944 ) 

期待

我们已经展示了如何结合Spark和TensorFlow在手写数字识别和图像标签上训练和部署神经网络。尽管我们使用的神经网络框架只能在单节点中工作,但我们可以使用Spark来分配超参数调整过程和模型部署。这不仅减少了训练时间,而且提高了准确性,使我们更好地理解各种超参数的敏感性。

虽然这种支持仅适用于Python,但我们期望在TensorFlow和Spark框架的其他部分之间提供更深入的整合。

本文的版权归 KX_WEN 所有,如需转载请联系作者。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据派THU

教你在Python中构建物体检测系统(附代码、学习资料)

本文介绍物体检测技术以及解决此领域问题的几种不同方法,带你深入研究在Python中如何构建我们自己的对象检测系统。

2473
来自专栏IT技术精选文摘

机器学习在启动耗时测试中的应用及模型调优(一)

启动耗时自动化方案在关键帧识别时,常规的图像对比准确率很低。本文详细介绍了采用scikit-learn图片分类算法在启动耗时应用下的模型调优过程。在之后的续篇中...

1694
来自专栏最新技术

深入学习Apache Spark和TensorFlow

想要了解更多关于Apache Spark的信息,请在2016年2月在纽约出席Spark东部峰会。

3088
来自专栏机器学习算法工程师

Isolation Forest算法原理详解

作者:章华燕 编辑:栾志勇 前言 随着机器学习近年来的流行,尤其是深度学习的火热。机器学习算法在很多领域的应用越来越普遍。最近,作者在一家广告公司做广告...

7358
来自专栏机器之心

学界 | Ian Goodfellow等人提出对抗重编程,让神经网络执行其他任务

作者:Gamaleldin F. Elsayed、Ian Goodfellow、Jascha Sohl-Dickstein

1233
来自专栏大数据文摘

深度 | 你的神经网络不work? 这37个原因总有一款适合你!

1413
来自专栏AI2ML人工智能to机器学习

Hopfield网络及其收敛性

在上一次的神经网络之双向关联记忆网络(BAM)中我们介绍了神经网络中能量的概念。在BAM的基础上稍加改变就可以得到著名的Hopfield网络。

1013
来自专栏机器之心

学界 | 深度神经网络的分布式训练概述:常用方法和技巧全面总结

深度学习已经为人工智能领域带来了巨大的发展进步。但是,必须说明训练深度学习模型需要显著大量的计算。在一台具有一个现代 GPU 的单台机器上完成一次基于 Imag...

1952
来自专栏机器之心

专栏 | 手机端运行卷积神经网络实践:基于TensorFlow和OpenCV实现文档检测功能

机器之心投稿 作者:腾讯 iOS 客户端高级工程师冯牮 本文作者通过一个真实的产品案例,展示了在手机客户端上运行一个神经网络的关键技术点。 前言 本文不是神经网...

4255
来自专栏AI科技评论

内部分享:这篇文章教你如何用神经网络破Flappy Bird记录

AI科技评论按:本文作者杨浩,原文载于作者个人博客。 以下内容来源于一次部门内部的分享,主要针对 AI 初学者,介绍包括 CNN、Deep Q Network...

3827

扫码关注云+社区

领取腾讯云代金券