前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用 Ray 用 15 行 Python 代码实现一个参数服务器

使用 Ray 用 15 行 Python 代码实现一个参数服务器

作者头像
用户1107453
发布2018-09-29 15:51:57
1.7K0
发布2018-09-29 15:51:57
举报
文章被收录于专栏:UAI人工智能

使用 Ray 用 15 行 Python 代码实现一个参数服务器

参数服务器是很多机器学习应用的核心部分。其核心作用是存放机器学习模型的参数(如,神经网络的权重)和提供服务将参数传给客户端(客户端通常是处理数据和计算参数更新的 workers)

参数服务器(如同数据库)是正常构建并 shipped 像一个单一系统。这个文章讲解如何使用 Ray 来用几行代码实现参数服务器。

通过将参数服务器从一个“系统”调整为一个“应用”,这个方法将量级的 orders 变得更加简单来部署一个参数服务器应用。类似地,通过让应用和库实现自身的参数服务器,这个方法让参数服务器的行为更加可配置和灵活(因为这个应用可以轻松地修改实现)

什么是 Ray? Ray 是一个用于并行和分布式的通用框架。Ray 提供了一个统一的任务并行和actor抽象,并且通过共享内存、零复制序列化和分布式调度达到了高的性能。Ray 也包含了针对人工智能应用(如超参数调优和强化学习)的高性能库。

什么是一个参数服务器?

一个参数服务器是一个用来在集群上训练机器学习模型的键值对。其值(values)是机器学习模型的参数(如一个神经网络)。其键(keys)索引了模型参数。

例如,在一个电影的推荐系统中,可能会针对每个用户、每个电影都有相应的键。对每个用户和电影,有对应的以用户特属和以电影特属的参数。在语言建模的应用中,词可能会作为键而其嵌入则可能为值。在最简单的形式中,参数服务器可能会隐式地有一个单个键,允许你所有的参数被获取并一次性更新。我们展示了如何作为一个 Ray 的 actor 实现一个参数服务器。

代码语言:javascript
复制
import numpy as npimport ray@ray.remoteclass ParameterServer(object):
  def __init__(self, dim):
    # params 可以是一个将键映射到数组的字典
    self.params = np.zeros(dim)  def get_params(self):
    return self.params    
  def update_params(self, grad):
    self.params += grad

@ray.remote 装饰器定义了一个服务。以类 ParameterServer 为‘输入’并使之作为一个远程服务或者 actor 被实例化。

这里,我们假设更新是一个梯度,这个被加到参数的向量上。这仅仅是最简单可能例子,可以有很多不同的选择。

参数服务器一般作为远程进程或者服务存在 并通过远程过程调用来和客户端交互。为了实例化参数服务器为一个远程 actor,我们可以这样:

代码语言:javascript
复制
ray.init()

ps = ParameterServer.remote(10)

Actor 方法调用返回 futures。如果我们想要检索实际值,我们可以使用一个 blocking 的 ray.get 调用,如:

代码语言:javascript
复制
>>> ps = ParameterServer.remote(10)>>> params_id = ps.get_params.remote()>>> params_id
ObjectID(4e9c8ac9a6d3dbf20c625f8d36c93beb07ca45d0)>>> ray.get(params_id)
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

现在,假设我们想要启动某些 worker 任务连续地计算梯度和更新模型的参数。每个 worker 将会循环地执行下面任务:

  1. 获取最新的参数
  2. 计算对参数的一个更新
  3. 更新参数

作为一个 Ray 远程函数(尽管 worker 也可以被看做一个 actor),如下:

代码语言:javascript
复制
import time# 注意 worker 函数获取参数服务器作为参数,使得 worker 任务激活参数服务器 actor 的方法@ray.remotedef worker(ps):
  for _ in range(100):
    params_id = ps.get_params.remote() # 这个方法调用是非阻塞的,返回一个 future
    params = ray.get(params_id) # 这是一个阻塞的调用,等待任务完成,并获取结果
    
    # 计算梯度更新。这里我们仅做一个假的更新,但在实际环境中,这里会使用一个库,如 tensorflow,也会获取一个批量的数据为输入
    grad = np.ones(10)
    time.sleep(0.2) # 这个是一个伪造的作为计算的占位符
    
    # 更新参数
    ps.update_params.remote(grad)

然后我们可以启动几个 worker 任务:

代码语言:javascript
复制
for _ in range(2):
  worker.remote(ps)

接着我们可以从驱动进程中检索到参数,并看到他们由 workers 进行更新

代码语言:javascript
复制
>>> ray.get(ps.get_params.remote())
array([164., 164., 164., 164., 164., 164., 164., 164., 164., 164.])>>> ray.get(ps.get_params.remote())
array([198., 198., 198., 198., 198., 198., 198., 198., 198., 198.])>>> >>> ray.get(ps.get_params.remote())
array([200., 200., 200., 200., 200., 200., 200., 200., 200., 200.])>>> ray.get(ps.get_params.remote())
array([200., 200., 200., 200., 200., 200., 200., 200., 200., 200.])>>> ray.get(ps.get_params.remote())
array([200., 200., 200., 200., 200., 200., 200., 200., 200., 200.])

Ray 这里加上的值一部分原因是 Ray 让其变得简单来启动一个远程服务或者 actor 因为这是定义了一个 Python 类。actor 的 Handles 可以被传递给其他的 actors 和任务,来保证可以进行任意和直觉的消息传递和通信模式。目前的替代物更多。例如,考虑等价运行时刻服务创建和用 GRPC 来进行 handle 的传递。

扩展

这里我们给出一些设计上的重要变化。我们描述了额外的自然扩展。

多参数服务器的分片 sharding 当你的参数很大和集群很大时,单个参数武器可能不能满足要求,因为应用会被网络带宽限制,进入和流出参数服务器所在的机器(特别是有很多的 workers 时候)

一个自然的解决方法是对多参数服务器上的参数进行分片。这个可以被简单地开启多个参数服务器 actors 达成。例如我们底下给出的代码那样。

控制 actor 放置 特定 actors 和任务在不同机器上的放置可以使用 Ray 对任意的资源需求支持指定。例如,如果 worker 需要一个 GPU,那么它的远程装饰器可以被声明为 @ray.remote(num_gpus=1)。任意定制资源可以同样定义。

统一任务和 actors

Ray 支持参数服务器应用高效大部分原因是其统一的任务并行和 actor 抽象。

流行的数据处理系统如 Apache Spark,可以有无状态的任务(没有 side effects 的函数)在不可变动的数据上操作。这个假设简化了整体系统的设计,让验证正确性变得简单。

但是,可变状态在很多的任务中存在,机器学习领域中反复出现。状态可能是一个神经网络的权重,第三方模拟器的状态,或者物理世界的交互的封装。

为了支持这些类型的应用,Ray 引入了 actor 抽象。一个 actor 会序列化地执行方法(使得没有并发的问题),每个任务可以任意地改变 actor 的内部状态。方法可以有其他的 actors 和任务激活(甚至由在同样的集群上的其他应用)

让 Ray 变得很强大的一点是它统一了 actor 抽象和任务并行抽像,继承了两者的优点。Ray 使用了底层的动态任务图在同样的框架中来实现 actors 和无状态任务。所以,这两个抽象其实完全整合在一起。任务和 actors 可以从其他任务和 actors 中进行创建。两者返回的future可以被传递给其他的任务或者 actor 方法来引入调度和数据依赖。所以,Ray 应用进程了这两个的好的特性。

底层基础

动态任务图 在底层,远程函数激活和 actor 方法激活创建了任务被加入到一个动态增长的任务图上。Ray 的后端管理调度和在集群上执行这些任务(或者在一个单机多核机器上)。任务可以被 driver 应用或者其他任务创建。

数据 Ray 使用 Apache Arrow data layout 来高效地序列化数据。对象在 workers 和 actors 之间通过共享内存在同样的机器上进行共享,这就避免了复制和去序列化的需要。这样的优化绝对是达到好的性能的关键。

调度 Ray 使用了一个分布式调度方法。每个机器有其自身的调度器,这个东西管理这台机器上的 workers 和 actors。任务被应用和 workers 提交给同一机器上的调度器。这让 Ray 达成比一个中心化的调度器达到的更高的任务吞吐量,这对机器学习应用非常重要。

总结

参数服务器通常是做一个单一系统实现和 shipped。让这个方法很强大的是我们能够用少量代码实现参数服务器为一个应用。这个方法让部署使用参数服务器的应用和修改参数服务器的行为更加简单。例如,如果我们希望对参数服务器进行分片,改变更新规则,在同步和异步更新之间切换,或略 straggler workers,或者任何其他的定制,我们可以用少量的代码达成。

这个文章描述了如何使用 Ray 的 actors 来实现参数服务器。然而,actors 是更加通用的概念,可以用来进行很多包含状态计算的应用。logging,streaming,simulation,model serving, graph processing,和其他应用。

运行代码

为了运行完整的应用,首先安装 Ray pip install ray。然后能运行下面的代码,这段代码实现了一个共享的参数服务器。

代码语言:javascript
复制
import numpy as npimport rayimport time# Start Ray.ray.init()@ray.remoteclass ParameterServer(object):
    def __init__(self, dim):
        # Alternatively, params could be a dictionary mapping keys to arrays.
        self.params = np.zeros(dim)    def get_params(self):
        return self.params    def update_params(self, grad):
        self.params += grad@ray.remotedef worker(*parameter_servers):
    for _ in range(100):        # Get the latest parameters.
        parameter_shards = ray.get(
          [ps.get_params.remote() for ps in parameter_servers])
        params = np.concatenate(parameter_shards)        # Compute a gradient update. Here we just make a fake
        # update, but in practice this would use a library like
        # TensorFlow and would also take in a batch of data.
        grad = np.ones(10)
        time.sleep(0.2)  # This is a fake placeholder for some computation.
        grad_shards = np.split(grad, len(parameter_servers))        # Send the gradient updates to the parameter servers.
        for ps, grad in zip(parameter_servers, grad_shards):
            ps.update_params.remote(grad)# Start two parameter servers, each with half of the parameters.parameter_servers = [ParameterServer.remote(5) for _ in range(2)]# Start 2 workers.workers = [worker.remote(*parameter_servers) for _ in range(2)]# Inspect the parameters at regular intervals.for _ in range(5):
    time.sleep(1)
    print(ray.get([ps.get_params.remote() for ps in parameter_servers]))
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-08-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 UAI人工智能 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用 Ray 用 15 行 Python 代码实现一个参数服务器
    • 什么是一个参数服务器?
      • 扩展
        • 统一任务和 actors
          • 底层基础
            • 总结
              • 运行代码
              相关产品与服务
              数据库
              云数据库为企业提供了完善的关系型数据库、非关系型数据库、分析型数据库和数据库生态工具。您可以通过产品选择和组合搭建,轻松实现高可靠、高可用性、高性能等数据库需求。云数据库服务也可大幅减少您的运维工作量,更专注于业务发展,让企业一站式享受数据上云及分布式架构的技术红利!
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档