3个关键点,把你的TensorFlow代码重构为分布式!

对于机器学习模型,分布式大致分两类:模型分布式和数据分布式:

模型分布式非常复杂和灵活, 它把整个机器学习模型分割,分散在多个节点上,在每个节点上计算模型的各个部分, 最后把结果拼接起来。如果你造了一个并行性很高的深度网络,比如这个,那就更棒了。你只要在每个节点上,计算不同的层,最后把各个层的异步结果通过较为精妙的方式汇总起来。

而我们今天要手把手教大家的是数据分布式。模型把数据拷贝到多个节点上, 每次算Epoch迭代的时候,每个节点对于一个batch的梯度都会有一个计算值,一个batch结束后,所有节点把梯度值汇总起来(ps参数服务器的任务就是汇总所有参数更新),从而进行更新。这就会导致每个batch的计算都比非分布式方法精准。相对非分布式,并行方法下,同样的迭代次数,收敛较快。

如何把自己的单机TensorFlow代码变为分布式的代码?

本文将手把手告诉大家3个关键点,重构自己的TensorFlow代码为分布式代码(开始前请大家前用1分钟了解文末的参考文献,了解基本知识):

关键点1: 定义FLAGS全局变量,获得ps参数服务器,worker工作服务器等分布式全局信息。

# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
                            'Steps to validate and print loss')

# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("issync", 0, "issync mode")

以上代码是从命令行获得变量的简单方式。使用TensorFlow自带的FLAGS命令行工具。

ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(
    cluster,
    job_name=FLAGS.job_name,

上述代码教你如何获得命令行变量到python变量。ps_hosts代表所有参数服务器,work_hosts是所有工作服务器。cluster组装一个分布式集群定义。server代表本地为任务分配的服务器。

关键点2: 在流图Graph定义阶段, 加入“参数服务器”和“工作服务器”的判断,重构Graph定义代码。

if FLAGS.job_name == "ps":
    server.join()
elif FLAGS.job_name == "worker":
    with tf.device(tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
        # 这里是各个worker工作服务器下的graph定义。

如果当前服务器是ps参数服务器,当前服务器就要执行join方法汇总更新的参数。

如果当前是工作服务器,构建deVice设备上下文,复制数据到各个设备,并且知道任务号,之后再定义原先的Graph。

关键点3: 最后,重构你原来的graph定义和TensorFlow Session训练的方式细节。

    grads_and_vars = optimizer.compute_gradients(cost)
    correct_prediction = tf.equal(y_pred_cls, y_true_cls)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    rep_op = tf.train.SyncReplicasOptimizer(optimizer,
                                            replicas_to_aggregate=len(
                                            worker_hosts),
                                            #replica_id=FLAGS.task_index,
                                            total_num_replicas=len(
                                            worker_hosts),
                                            use_locking=True)
    train_op = rep_op.apply_gradients(grads_and_vars,
                                global_step=global_step)
    init_token_op = rep_op.get_init_tokens_op()
    chief_queue_runner = rep_op.get_chief_queue_runner()

    init_op = tf.initialize_all_variables()
    saver = tf.train.Saver()
    train_batch_size = batch_size
    tf.summary.scalar('cost', cost)
    tf.summary.scalar('accuracy', accuracy)
    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter('./summary_log/train')
    summary_writer_test = tf.summary.FileWriter('./summary_log/test')

sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                            logdir="./checkpoint/",
                            init_op=init_op,
                            summary_op=None,
                            saver=saver,
                            global_step=global_step,
                            save_model_secs=60)
session = sv.prepare_or_wait_for_session(server.target)
sv.start_queue_runners(session, [chief_queue_runner])
session.run(init_token_op)

训练中稍有不同的是上面这段代码,graph定义完毕后,我们要用optimizer.compute_gradients方法计算梯度得到grads_and_vars对象。通过SyncReplicasOptimizer这个特殊的优化器,进行梯度的计算,即rep_op.apply_gradients(grads_and_vars, global_step=global_step)方法。

计算完毕得到的train_op对象就能在未来想用session.run()的地方使用了:

session.run([train_op, cost, global_step], feed_dict=feed_dict_train)

注意以上三个关键点, 你离TensorFlow并行化已经八九不离十了。

实际重构的例子,请看我github上识别猫狗的基本程序:

分布式版:

https://github.com/yanchao727/tensorflow_kaggle_cat_dog/blob/master/cnn.distributed.py

单机版:

https://github.com/yanchao727/tensorflow_kaggle_cat_dog/blob/master/cnn.py

参考文献

  1. https://github.com/thewintersun/distributeTensorflowExample
  2. https://www.tensorflow.org/deploy/distributed
  3. http://blog.csdn.net/luodongri/article/details/52596780
  4. https://www.slideshare.net/stanleywanguni/distributed-machine-learning
  5. https://www.quora.com/How-is-parallel-computing-used-in-machine-learning

本文转载自:David 9的博客 — 不怕"过拟合"

原文发布于微信公众号 - 进击的Coder(FightingCoder)

原文发表时间:2018-04-20

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏CVer

资源 | GitHub超过2600星的TensorFlow教程,简洁清晰还不太难

最近,弗吉尼亚理工博士Amirsina Torfi在GitHub上贡献了一个新的教程,教程清晰简单,喜提2600颗星~

602
来自专栏机器学习算法全栈工程师

分布式TensorFlow入门教程

深度学习在各个领域实现突破的一部分原因是我们使用了更多的数据(大数据)来训练更复杂的模型(深度神经网络),并且可以利用一些高性能并行计算设备如GPU和FPGA来...

1163
来自专栏AI科技大本营的专栏

重磅消息 | 深度学习框架竞争激烈 TensorFlow也支持动态计算图

今晨 Google 官方发布消息,称 TensorFlow 支持动态计算图。 原文如下: 在大部分的机器学习中,用来训练和分析的数据需要经过一个预处理过程,输入...

2615
来自专栏编程

Python那些事——15分钟用Python破解验证码系统!

让我们一起攻破世界上最流行的WordPress的验证码插件 每个人都讨厌验证码——在你被允许访问一个网站之前,你总被要求输入那些烦人的图像中所包含的文本。 验证...

2479
来自专栏北京马哥教育

Python破解验证码,只要15分钟就够了!

让我们一起攻破世界上最流行的WordPress的验证码插件 每个人都讨厌验证码——在你被允许访问一个网站之前,你总被要求输入那些烦人的图像中所包含的文本。 验...

3666
来自专栏新智元

10 亿图片仅需 17.7微秒:Facebook AI 实验室开源图像搜索工具Faiss

【新智元导读】Facebook的 FAIR 最新开源了一个用于有效的相似性搜索和稠密矢量聚类的库,名为 Faiss,在10亿图像数据集上的一次查询仅需17.7 ...

3395
来自专栏数说工作室

文本相似比较

大家好,我是数说君,这篇文章是想跟大家讨教一下。 如果有两段简单文本,如何比较它们的相似度?这里我们就假设是英文,不存在中文的分词问题,文本就类似于: text...

34714
来自专栏大数据挖掘DT机器学习

Python 自然语言处理(NLP)工具库汇总

最近正在用nltk 对中文网络商品评论进行褒贬情感分类,计算评论的信息熵(entropy)、互信息(point mutual information)和困惑值(...

38112
来自专栏SDNLAB

使用机器学习算法对流量分类的尝试——基于样本分类

导言 机器学习方法目前可以分为5个流派,分别是符号主义,联结主义,进化主义,贝叶斯和Analogzier。具体到实例有联结主义的神经网络,进化主义的遗传算法,贝...

46812
来自专栏FreeBuf

用机器学习玩转恶意URL检测

前段时间漏洞之王Struts2日常新爆了一批漏洞,安全厂商们忙着配合甲方公司做资产扫描,漏洞排查,规则大牛迅速的给出”专杀”规则强化自家产品的规则库。这种基于规...

5859

扫码关注云+社区