前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >强化学习异步分布式训练实现

强化学习异步分布式训练实现

作者头像
CreateAMind
发布2019-08-09 13:49:47
1.7K1
发布2019-08-09 13:49:47
举报
文章被收录于专栏:CreateAMind

本文介绍基于Tensorflow的强化学习off policy算法的分布式实现,包括多机共享replay buffer。分布式 TensorFlow 允许我们在多台机器上运行一个模型,所以训练速度或加速效果能显著地提升。

首先定义集群信息,我们将启动一个parameter server (PS),和多个Worker在localhost:2222localhost:2223等,在本机运行这些进程。

代码语言:javascript
复制
# 在配置文件中
...
self.parameter_servers = ["localhost:2222"]
self.workers = []
for i in range(workers_num):    self.workers.append("localhost:"+str(2223+i))

下面代码都是在主程序中。

定义集群信息,并启动server,指定每个server对应为集群定义中的哪个server。立即启动各server,监听集群设置中指定的端口。

代码语言:javascript
复制
cluster = tf.train.ClusterSpec({"ps": opt.parameter_servers, "worker": opt.workers})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

所有的server上运行的同一个计算图,其中的变量都将保存在PS上,在所有server上共享。

代码语言:javascript
复制
        if FLAGS.job_name == "ps":
            server.join()
        elif FLAGS.job_name == "worker":
            # Variable is placed in the parameter server by the replica_device_setter
            with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
                ...
                这部分定义计算图
                # count the number of updates
                global_epoch = tf.get_variable('global_epoch', [], initializer=tf.constant_initializer(0), trainable=False)
                epoch_op = global_epoch.assign(global_epoch+1
                ...
                with tf.Session(server.target) as sess:
                    sess.run(tf.global_variables_initializer())    
                    sess.run(target_init)
                    ...
                    执行程序部分
                    ...

PS (parameter server)保存参数,接收worker发来的梯度,并应用梯度。

Worker从PS获取参数,将批次数据传入模型,计算向前传播和向后传播,计算梯度,最后将梯度发送给PS。

循环过程如下:

  1. Workers并行从PS中获取模型参数。
  2. Workers在本机根据批次数据运行模型计算梯度。
  3. Workers将梯度发送给PS。PS通过优化器用梯度分别更新每个参数。

在worker从ps上读取参数的时候。如果worker在权重更新到一半的时候读取了参数(如:一半参数是更新过的,另一半还没有更新),那这个一半更新一半未更新的参数就被读取和使用了。这样做运行速度会比较快。


分布式共享replay buffer

上面的代码中,我们通过replica_device_setter来共享模型参数。那么在off policy算法中,如何共享replay buffer呢?由于我们的程序要实现分布式的功能,不仅仅可以在单机上多进程训练,而且可以在多机上进行分布式训练。所以我们通过ray的分布式机制来实现多机共享replay buffer类的对象。

导入必要的包

代码语言:javascript
复制
import ray
from ray.utils import hex_to_binary

在某一台机器上启动ray服务。

代码语言:javascript
复制
ray start --head --redis-port=6379

在需要共享的ReplayBuffer类上面加一行@ray.remote,表明该类可以被ray远程操作。

代码语言:javascript
复制
@ray.remote
class ReplayBuffer:
    def __init__(self, obs_dim, act_dim, size):
        ...

初始化ray,并定义共享buffer的ray的object id。我们通过一个唯一的id来实现多个进程共享同一个buffer。ray的object id是一个专门的类对象。我们不方便保存这个对象,所以保存这个对象对应的字符串,可以通过这个字符串生成这个对象。定义一个Variable变量来保存这个字符串,Variable变量是保存在ps上的。

代码语言:javascript
复制
# [ip:port]为启动ray服务的电脑IP和端口 例如:192.168.123.123:6379
# ray的初始化,每个进程都连接到ray服务器上。
ray.init(redis_address="[ip:port]")
buffer_id_str = tf.get_variable('buffer_id_str', [], dtype=tf.string)
...

在某一个进程上创建共享buffer,并将该buffer的object id赋给之前定义的buffer_id_str。创建的共享bufferput到ray中,即将这个对象保存在ray的服务器中,并返回一个object id。通过这个object id我们直接对这个对象进行操作,也就是说每个进程通过这个唯一的object id直接对保存在ray server上的这个对象操作。

代码语言:javascript
复制
...
if is_chief == 0:
    replay_buffer = ReplayBuffer.remote(obs_dim=obs_dim, act_dim=act_dim, size=opt.replay_size)    
    buffer_id = ray.put(replay_buffer)    
    buffer_id_op = buffer_id_str.assign(str(buffer_id)[9:-1])
    sess.run(buffer_id_op)
...

worker需要使用共享buffer时,通过字符串得到共享buffer。

代码语言:javascript
复制
...
buffer_id = ray.ObjectID(hex_to_binary(sess.run(buffer_id_str)))
replay_buffer = ray.get(buffer_id)
代码语言:javascript
复制
# 将经验存入共享buffer中
replay_buffer.store.remote(o, a, r, o2, d)
代码语言:javascript
复制
# 从共享buffer中采样
batch = ray.get(replay_buffer.sample_batch.remote(opt.batch_size))
...

完整代码参考:

https://github.com/LiuShuai26/DRL/blob/master/DSAC1/dsac1.py

参考:

https://github.com/tensorflow/examples/blob/master/community/en/docs/deploy/distributed.md

https://stackoverflow.com/questions/43147435/how-does-asynchronous-training-work-in-distributed-tensorflow

https://stackoverflow.com/questions/41600321/distributed-tensorflow-the-difference-between-in-graph-replication-and-between

https://zhuanlan.zhihu.com/p/60474307

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-08-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 CreateAMind 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 分布式共享replay buffer
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档