前言
强化学习算法的并行化可以有效提高算法的效率。并行化可以使单机多cpu的资源得到充分利用,并行化也可以将算法中各个部分独立运行,从而提高运行效率,如将环境交互部分和训练网络部分分开。我们这里介绍如何使用分布式框架Ray以最简单的方式实现算法的并行化。
本文章分为三节:
Ray是一个实现分布式python程序的通用框架。Ray提供了统一的任务并行和actor抽象,并通过共享内存、零拷贝序列化和分布式调度实现了高性能。
Ray里面还有用来调超参数的库Tune和可扩展规模的强化学习库Rllib。
ray的必备知识:
ray.remote
]ray.put
, ray.get
, ray.wait
]ray.remote
]使用Ray,可以使你的代码从单机运行轻松地扩展到大集群上运行。
下面主要介绍ray的基本用法,并行运算为单机并行。
使用该命令安装Ray:pip install -U ray
开始使用ray,导入ray,然后初始化。
import ray
ray.init()
ray.remote
]将python函数转换为远程函数的标准方法是在函数上面添加一个@ray.remote
装饰器。下面看一个例子。
# 一个常规 Python 函数
def regular_function():
return 1
# 一个 Ray 远程函数
@ray.remote
def remote_function():
return 1
下面是调用时的不同。
assert regular_function() == 1
object_id = remote_function.remote()
assert ray.get(object_id) == 1
在调用的时候,普通函数将串行运行。
# These happen serially.
for _ in range(4):
regular_function()
调用远程函数时,程序将并行运行。
# These happen in parallel.
for _ in range(4):
remote_function.remote()
小提示:为了保证ray并行的性能,远程任务应该花费至少几毫秒的时间。
运行ray.init()
后,ray将自动检查可用的GPU和CPU。我们也可以传入参数设置特定的资源请求量。
ray.init(num_cpus=8, num_gpus=4)
远程函数/类也可以设置资源请求量,像这样@ray.remote(num_cpus=2, num_gpus=1)
如果没有设置,默认设置为1个CPU。
ray.put
, ray.get
, ray.wait
]远程函数执行后并不会直接返回结果,而是会立即返回一个object ID。远程函数会在后台并行处理,等执行得到最终结果后,可以通过返回的object ID取得这个结果。
ray.put(*value*)
也会返回object ID
put操作将对象存入object store里,然后返回它的object ID。
y = 1
object_id = ray.put(y)
小提示:当需要重复向不同远程任务传入相同对象时,Ray会每次先将对象put进object store。我们可以先用ray.put()把类存入object store,然后传入它的object id,以提高速度。
ray.get(obj_id)
从object store获取远程对象或者一个列表的远程对象。
需要注意的是,使用get方法时会锁,直到要取得的对象在本地的object store里可用。
调用remote操作是异步的,他们会返回object IDs而不是结果。想要得到真的的结果我们需要使用ray.get()。
我们之前写的这段语句,实际上results是一个由object IDs组成的列表。
results = [do_some_work.remote(x) for x in range(4)]
如果改为下面,ray.get()将通过object ID取得真实的结果。
results = [ray.get(do_some_work.remote(x)) for x in range(4)]
但是,这样写会有一个问题。ray.get()会锁进程,这意味着,ray.get()会一直等到do_some_work这个函数执行完返回结果后才执行结束然后进入下一个循环。这样的话,4次调用do_some_work函数就不再是并行运行的了。
为了可以并行运算,我们需要在调用完所有的任务后再调用ray.get()。像下面这样。
results = ray.get([do_some_work.remote(x) for x in range(4)])
所以,需要小心使用ray.get()。因为它是一个锁进程的操作。如果太频繁调用ray.get(),将会影响并行性能。同时,尽可能的晚些使用ray.get()以防止不必要的等待。
ray.remote
]远程类和远程函数类似。我们在类的定义上面加上修饰器ray.remote。这个类的实例就会是一个Ray的actor。每一个actor运行在自己的python进程上。
@ray.remote
class Counter(object):
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
同样可以给actor设置资源请求量。
@ray.remote(num_cpus=2, num_gpus=0.5)
class Actor(object):
pass
在调用类的方法时加上.remote
,然后使用ray.get
获取实际的值。
obj_id = a1.increment.remote()
ray.get(obj_id) == 1
通过远程类,我们可以实现一个共享的参数服务器。Actor可以作为参数传给别的任务,下面的例子就是实现一个参数服务器。不同的参数就可以共用一个参数服务器了。
先定义一个ParameterServer
的类,上面写上ray的修饰器。
import numpy as np
import ray
@ray.remote
class ParameterServer(object):
def __init__(self, dim):
# Alternatively, params could be a dictionary mapping keys to arrays.
self.params = np.zeros(dim)
def get_params(self):
return self.params
def update_params(self, grad):
self.params += grad
主程序启动ray,创建ParameterServer
实例。
# We need to start Ray first.
ray.init()
# Create a parameter server process.
ps = ParameterServer.remote(10)
在worker中,可以通过传入的ps实例,调用ps的方法。
import time
# Note that the worker function takes a handle to the parameter server as an
# argument, which allows the worker task to invoke methods on the parameter
# server actor.
@ray.remote
def worker(ps):
for _ in range(100):
# Get the latest parameters.
params_id = ps.get_params.remote() # This method call is non-blocking
# and returns a future.
params = ray.get(params_id) # This is a blocking call which waits for
# the task to finish and gets the results.
# Compute a gradient update. Here we just make a fake update, but in
# practice this would use a library like TensorFlow and would also take
# in a batch of data.
grad = np.ones(10)
time.sleep(0.2) # This is a fake placeholder for some computation.
# Update the parameters.
ps.update_params.remote(grad)
创建多个worker并传入ps实例。
# Start 2 workers.
for _ in range(2):
worker.remote(ps)
当你的参数特别大,而且你的集群也很大,一个parameter server可能就不够了。特别是有很多worker的时候,因为很多worker向一个parameter server的数据传输就会成为瓶颈。
简单的解决办法就是把参数分散在多个parameter server上。可以通过创建多个actor来实现。
本节完。
参考资料:
https://ray.readthedocs.io/en/latest/walkthrough.html
https://rise.cs.berkeley.edu/blog/ray-tips-for-first-time-users/
https://ray-project.github.io/2018/07/15/parameter-server-in-fifteen-lines.html