专栏首页AI研习社Github 项目推荐 | 在 Spark 上实现 TensorFlow 的库 —— Sparkflow

Github 项目推荐 | 在 Spark 上实现 TensorFlow 的库 —— Sparkflow

该库是 TensorFlow 在 Spark 上的实现,旨在 Spark 上使用 TensorFlow 提供一个简单的、易于理解的接口。借助 SparkFlow,开发者可以轻松地将深度学习模型与 ML Spark Pipeline 相集成。SparkFlow 使用参数服务器以分布式方式训练 Tensorflow 网络,通过 API,用户可以指定训练风格,无论是 Hogwild 还是异步锁定。

为什么要使用 SparkFlow

虽然有很多的库都能在 Apache Spark 上实现 TensorFlow,但 SparkFlow 的目标是使用 ML Pipelines,为训练 Tensorflow 图提供一个简单的界面,并为快速开发提供基本抽象。关于训练,SparkFlow 使用一个参数服务器,它位于驱动程序上并允许异步培训。此工具在训练大数据时提供更快的训练时间。

Github:

https://github.com/lifeomic/sparkflow

安装

通过 pip 安装:pip install sparkflow

安装需求:Apache Spark 版本 >= 2.0,同时安装好 TensorFlow

示例

简单的 MNIST 深度学习例子:

from sparkflow.graph_utils import build_graph
from sparkflow.tensorflow_async import SparkAsyncDL
import tensorflow as tf
from pyspark.ml.feature import VectorAssembler, OneHotEncoder
from pyspark.ml.pipeline import Pipeline
    
#simple tensorflow network
def small_model():
    x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
    y = tf.placeholder(tf.float32, shape=[None, 10], name='y')
    layer1 = tf.layers.dense(x, 256, activation=tf.nn.relu)
    layer2 = tf.layers.dense(layer1, 256, activation=tf.nn.relu)
    out = tf.layers.dense(layer2, 10)
    z = tf.argmax(out, 1, name='out')
    loss = tf.losses.softmax_cross_entropy(y, out)
    return loss
    
df = spark.read.option("inferSchema", "true").csv('mnist_train.csv')
mg = build_graph(small_model)
#Assemble and one hot encode
va = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')
encoded = OneHotEncoder(inputCol='_c0', outputCol='labels', dropLast=False)

spark_model = SparkAsyncDL(
    inputCol='features',
    tensorflowGraph=mg,
    tfInput='x:0',
    tfLabel='y:0',
    tfOutput='out:0',
    tfLearningRate=.001,
    iters=1,
    predictionCol='predicted',
    labelCol='labels',
    verbose=1
)

p = Pipeline(stages=[va, encoded, spark_model]).fit(df)
p.write().overwrite().save("location")

本文分享自微信公众号 - AI研习社(okweiwu),作者:AI 研习君

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-04-19

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 如何利用微信监管你的TF训练?

    之前回答问题【在机器学习模型的训练期间,大概几十分钟到几小时不等,大家都会在等实验的时候做什么?(http://t.cn/Rl8119m)】的时候,说到可以用微...

    AI研习社
  • 手把手教你如何用 TensorFlow 实现基于 DNN 的文本分类

    编者按:本文源自一位数据科学开发者的个人博客,主要面对初学者用户,AI 研习社编译。另外,关于 TensorFlow 和 DNN 的更多深度内容,欢迎大家在文末...

    AI研习社
  • 代码+实战:TensorFlow Estimator of Deep CTR —— DeepFM/NFM/AFM/FNN/PNN

    深度学习在 ctr 预估领域的应用越来越多,新的模型不断冒出。从 ctr 预估问题看看 f(x) 设计—DNN 篇(https://zhuanlan.zhihu...

    AI研习社
  • 深度学习之 TensorFlow(五):mnist 的 Alexnet 实现

    希希里之海
  • TensorFlow API 树 (Python)

    JNingWei
  • tf25: 使用深度学习做阅读理解+完形填空

    记的在学生时代,英语考试有这么一种类型的题,叫:阅读理解。首先让你读一段洋文材料,然后回答一些基于这个洋文材料提的问题。 我先给你出一道阅读理解 Big ...

    MachineLP
  • tensorflow零起点快速入门(5) --强化学习摘录截图

    嘘、小点声
  • tensorflow 常用API

    注意tensorflow会检查类型,不指定类型时按照默认类型,如1认为是int32, 1.0认为是float32

    羽翰尘
  • TensorFlow基础:常量

    例如 tf.zeros,tf.ones,tf.zeros_like,tf.diag ...

    lyhue1991
  • TensorFlow-实战Google深度学习框架 笔记(上)

    TensorFlow 是一种采用数据流图(data flow graphs),用于数值计算的开源软件库。在 Tensorflow 中,所有不同的变量和运算都是储...

    范中豪

扫码关注云+社区

领取腾讯云代金券