本文介绍基于Tensorflow的强化学习off policy算法的分布式实现,包括多机共享replay buffer。分布式 TensorFlow 允许我们在多台机器上运行一个模型,所以训练速度或加速效果能显著地提升。
首先定义集群信息,我们将启动一个parameter server (PS),和多个Worker在localhost:2222
和localhost:2223
等,在本机运行这些进程。
# 在配置文件中
...
self.parameter_servers = ["localhost:2222"]
self.workers = []
for i in range(workers_num): self.workers.append("localhost:"+str(2223+i))
下面代码都是在主程序中。
定义集群信息,并启动server,指定每个server对应为集群定义中的哪个server。立即启动各server,监听集群设置中指定的端口。
cluster = tf.train.ClusterSpec({"ps": opt.parameter_servers, "worker": opt.workers})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
所有的server上运行的同一个计算图,其中的变量都将保存在PS上,在所有server上共享。
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# Variable is placed in the parameter server by the replica_device_setter
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
...
这部分定义计算图
# count the number of updates
global_epoch = tf.get_variable('global_epoch', [], initializer=tf.constant_initializer(0), trainable=False)
epoch_op = global_epoch.assign(global_epoch+1
...
with tf.Session(server.target) as sess:
sess.run(tf.global_variables_initializer())
sess.run(target_init)
...
执行程序部分
...
PS (parameter server)保存参数,接收worker发来的梯度,并应用梯度。
Worker从PS获取参数,将批次数据传入模型,计算向前传播和向后传播,计算梯度,最后将梯度发送给PS。
循环过程如下:
在worker从ps上读取参数的时候。如果worker在权重更新到一半的时候读取了参数(如:一半参数是更新过的,另一半还没有更新),那这个一半更新一半未更新的参数就被读取和使用了。这样做运行速度会比较快。
上面的代码中,我们通过replica_device_setter
来共享模型参数。那么在off policy算法中,如何共享replay buffer呢?由于我们的程序要实现分布式的功能,不仅仅可以在单机上多进程训练,而且可以在多机上进行分布式训练。所以我们通过ray的分布式机制来实现多机共享replay buffer类的对象。
导入必要的包
import ray
from ray.utils import hex_to_binary
在某一台机器上启动ray服务。
ray start --head --redis-port=6379
在需要共享的ReplayBuffer类上面加一行@ray.remote
,表明该类可以被ray远程操作。
@ray.remote
class ReplayBuffer:
def __init__(self, obs_dim, act_dim, size):
...
初始化ray,并定义共享buffer的ray的object id。我们通过一个唯一的id来实现多个进程共享同一个buffer。ray的object id是一个专门的类对象。我们不方便保存这个对象,所以保存这个对象对应的字符串,可以通过这个字符串生成这个对象。定义一个Variable
变量来保存这个字符串,Variable
变量是保存在ps上的。
# [ip:port]为启动ray服务的电脑IP和端口 例如:192.168.123.123:6379
# ray的初始化,每个进程都连接到ray服务器上。
ray.init(redis_address="[ip:port]")
buffer_id_str = tf.get_variable('buffer_id_str', [], dtype=tf.string)
...
在某一个进程上创建共享buffer,并将该buffer的object id赋给之前定义的buffer_id_str。创建的共享bufferput
到ray中,即将这个对象保存在ray的服务器中,并返回一个object id。通过这个object id我们直接对这个对象进行操作,也就是说每个进程通过这个唯一的object id直接对保存在ray server上的这个对象操作。
...
if is_chief == 0:
replay_buffer = ReplayBuffer.remote(obs_dim=obs_dim, act_dim=act_dim, size=opt.replay_size)
buffer_id = ray.put(replay_buffer)
buffer_id_op = buffer_id_str.assign(str(buffer_id)[9:-1])
sess.run(buffer_id_op)
...
worker需要使用共享buffer时,通过字符串得到共享buffer。
...
buffer_id = ray.ObjectID(hex_to_binary(sess.run(buffer_id_str)))
replay_buffer = ray.get(buffer_id)
# 将经验存入共享buffer中
replay_buffer.store.remote(o, a, r, o2, d)
# 从共享buffer中采样
batch = ray.get(replay_buffer.sample_batch.remote(opt.batch_size))
...
完整代码参考:
https://github.com/LiuShuai26/DRL/blob/master/DSAC1/dsac1.py
参考:
https://github.com/tensorflow/examples/blob/master/community/en/docs/deploy/distributed.md
https://stackoverflow.com/questions/43147435/how-does-asynchronous-training-work-in-distributed-tensorflow
https://stackoverflow.com/questions/41600321/distributed-tensorflow-the-difference-between-in-graph-replication-and-between
https://zhuanlan.zhihu.com/p/60474307