前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知

[源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知

作者头像
罗西的思考
发布2021-07-20 11:30:53
1K0
发布2021-07-20 11:30:53
举报
文章被收录于专栏:罗西的思考罗西的思考

源码解析 深度学习分布式训练框架 horovod (15) --- 广播 & 通知

目录

0x00 摘要

Horovod 是Uber于2017年发布的一个易于使用的高性能的分布式训练框架,在业界得到了广泛应用。

本系列将通过源码分析来带领大家了解 Horovod。本文是系列第十五篇,看看horovod 弹性训练如何广播和发送通知。

本系列其他文章链接如下:

[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么

[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver

[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

[源码解析] 深度学习分布式训练框架 horovod (7) --- DistributedOptimizer

[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

[源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark

[源码解析] 深度学习分布式训练框架 horovod (11) --- on spark --- GLOO 方案

[源码解析] 深度学习分布式训练框架 horovod (12) --- 弹性训练总体架构

[源码解析] 深度学习分布式训练框架 horovod (13) --- 弹性训练之 Driver

[源码解析] 深度学习分布式训练框架 horovod (14) --- 弹性训练发现节点 & State

0x01 问题

首先,我们提出一个问题:为什么弹性训练 需要有广播?

答案就是:因为捕获两种异常之后,需要广播到各个worker。

1.1 HorovodInternalError

关于 HorovodInternalError 异常处理,我们看看具体容错机制,就可以知道缘由:

  • hvd.elastic.run 装饰器捕获异常;
  • 如果是 HorovodInternalError,就恢复到最近一次提交的状态,此时因为是allreduce等异常,所以所有worker都处于停止状态;
  • driver 会根据当前正在运行的节点重新执行一个 rendezvous,以便重新初始化 Horovod context;
  • 当新的通信域构造成功后,rank = 0 的 worker 会将自身的模型广播给其他 worker;
  • 所有worker接着上次停止的迭代步数继续训练;

因为需要从 rank 0 广播变量给其他进程,所以必须有一个广播机制。

1.2 HostsUpdateInterrupt

关于 HostsUpdateInterrupt 异常处理,我们看看具体原因。

  • 当驱动进程通过节点发现脚本发现一个节点被标记为新增或者移除时,它将发送一个通知到所有workers,在下一次 state.commit() 或者更轻量的 state.check_host_updates() 被调用时,会抛出一个 HostsUpdateInterrupt 异常。这个异常类似于 HorovodInternalError 异常,但是参数状态等不会从最近一次commit中恢复,而是从当前实时的参数中恢复。
  • check_host_updates 方法 会从 _host_messages 中读取消息,积累更新,如其方法中注释所述,会在每个 worker 之间同步状态,目的是让这些 worker 同时抛出异常。
  • 具体同步使用 _bcast_object(然后内部调用到了 MPI)。

需要一个广播机制在每个 worker 之间同步状态(因为这些worker目前都是在正常训练,需要有一个东西统一打断他们的训练,从而重新组建一个通信环),目的是让这些 worker 同时抛出 HostsUpdateInterrupt 异常。

其次,我们需要回顾下上文的流程图,本文将对其部分内部流程进行细化。

0x02 广播机制

我们具体剖析广播机制如下,因为广播是和具体框架密切结合,所以我们以tensorflow为例,具体代码在horovod/tensorflow/elastic.py 之中。

2.1 广播实现

在 horovod/tensorflow/elastic.py 之中,就是针对 TF 做的特定实现。其中会依据 TF 的版本做不同处理。

2.1.1 TensorFlowKerasState

以 TensorFlowKerasState 为例,在初始化的时候,因为有广播对象的需要,比如在 TensorFlowKerasState 之中配置了_bcast_model用来广播模型,bcast_object用来广播对象,broadcast_variables用来广播变量。

而且提供了sync函数负责广播,可以看出来调用了_bcast_model

代码语言:javascript
复制
class TensorFlowKerasState(ObjectState):
    def __init__(self, model, optimizer=None, backend=None, **kwargs):

        if not backend or _executing_eagerly():
            # 这里设定了广播函数
            self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
            bcast_object = broadcast_object
        else:
            # For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
            # 这里设定了广播函数
            bcast_op = broadcast_variables(_global_variables(), root_rank=0)
            self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
            bcast_object = broadcast_object_fn(session=self.backend.get_session())
        
    def sync(self):
        self._bcast_model() #广播模型
        self._save_model()
        super(TensorFlowKerasState, self).sync()
2.1.2 广播模型

_broadcast_model 函数会 广播 模型变量,optimizer变量。

代码语言:javascript
复制
def _broadcast_model(model, optimizer, backend):
    if _executing_eagerly():
        # TensorFlow 2.0 or TensorFlow eager
        broadcast_variables(model.variables, root_rank=0) # 广播模型变量
        broadcast_variables(optimizer.variables(), root_rank=0) # 广播优化器变量
    else:
        bcast_op = broadcast_variables(_global_variables(), root_rank=0)
        backend.get_session().run(bcast_op)
2.1.3 广播变量

广播变量的具体实现 在 horovod/tensorflow/functions.py 之中。broadcast_variables 的作用是从 root rank(即 rank 0)广播变量到其他的进程。

具体也根据 TF 版本做了区别。

代码语言:javascript
复制
def _make_subgraph(f):
    return tf.function(f)

@_cache
def _make_broadcast_group_fn():
    if _executing_eagerly():
        # Eager mode will parallelize independent control flow
        def broadcast_group(variables, root_rank): # 在这里定义
            for var in variables:
                var.assign(broadcast(var, root_rank)) # 调用MPI函数,这里都指定了是root_rank

        return _make_subgraph(broadcast_group)
    else:
        # Graph mode requires an Op
        def broadcast_group(variables, root_rank): # 在这里定义
            # tf.group()用于创造一个操作,可以将传入参数的所有操作组合,当这个操作完成后,所有 input 中的所有 ops 都已完成。tf.group()操作没有输出。
            return tf.group(*[var.assign(broadcast(var, root_rank)) # 这里调用MPI函数
                              for var in variables])

        return broadcast_group

def broadcast_variables(variables, root_rank):
    """Broadcasts variables from root rank to all other processes.
    """
    broadcast_group = _make_broadcast_group_fn()
    return broadcast_group(variables, root_rank # 在上面定义
2.1.4 广播对象

广播对象 的作用是从 root rank(即 rank 0)广播对象到其他的进程。 广播对象和广播变量的区别是:对象需要序列化和反序列化。

代码语言:javascript
复制
def broadcast_object(obj, root_rank=0, session=None, name=None):
    """
    Serializes and broadcasts an object from root rank to all other processes.

    Arguments:
        obj: An object capable of being serialized without losing any context.
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
        session: Session for TensorFlow v1 compatibility.
        name: Optional name to use during broadcast, will default to the class
              type.
    Returns:
        The object that was broadcast from the `root_rank`.
    """
    if name is None:
        name = type(obj).__name__

    def to_numpy(v): # 依据tf版本不同做不同处理
        if not _executing_eagerly():
            sess = session or ops.get_default_session()
            return sess.run(v)
        else:
            return v.numpy()

    if rank() == root_rank:
        b = io.BytesIO() # BytesIO实现了在内存中读写bytes
        cloudpickle.dump(obj, b) # 序列化,编码成一个二进制文件
        t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
        sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) # 张量对应维度的数值
        to_numpy(broadcast(sz, root_rank, name + '.sz')) # 广播维度
    else:
        sz = tf.convert_to_tensor([0], dtype=tf.int32)
        sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) # 接受维度
        t = tf.zeros(sz.tolist()[0], dtype=tf.uint8)

    t = to_numpy(broadcast(t, root_rank, name + '.t')) # 广播对象内容

    if rank() != root_rank:
        buf = io.BytesIO(t.tobytes())
        obj = cloudpickle.load(buf) # 反序列化,解码成原本的对象

    return obj
2.1.5 HVD C++

底层会调用到 MPI 函数完成广播功能。

代码语言:javascript
复制
def broadcast(tensor, root_rank, name=None, ignore_name_scope=False):
    """An op which broadcasts the input tensor on root rank to the same input tensor
    on all other Horovod processes.

    The broadcast operation is keyed by the name of the op. The tensor type and
    shape must be the same on all Horovod processes for a given name. The broadcast
    will not start until all processes are ready to send and receive the tensor.

    Returns:
      A tensor of the same shape and type as `tensor`, with the value broadcasted
      from root rank.
    """
    if name is None and not _executing_eagerly():
        name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name)
    return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
                                     ignore_name_scope=ignore_name_scope)
2.1.6 MPI

MPI_BCAST的作用是:从一个序列号为root的进程将一条消息广播发送到组内的所有进程, 包括它本身在内。

因为之前指定了root_rank,所以即使所有worker虽然都调用了同样代码,也只是会把 root_rank 通信消息缓冲区中的消息拷贝到其他所有进程中去。

代码语言:javascript
复制
void MPIController::Bcast(void* buffer, size_t size, int root_rank,
                          Communicator communicator) {
  MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
  int ret_code = MPI_Bcast(buffer, size, MPI_BYTE, root_rank, comm);
  if (ret_code != MPI_SUCCESS) {
    throw std::runtime_error(
        "MPI_Broadcast failed, see MPI output for details.");
  }
}
2.1.7 小结

我们总结一下各个函数:

  • _bcast_model用来广播模型;
  • bcast_object用来广播对象;
  • broadcast_variables用来广播变量;
  • 广播对象和广播变量的区别是:对象需要序列化和反序列化。
  • _broadcast_model 就是调用了broadcast_variables完成对模型参数的广播;
  • broadcast_variables 中调用了broadcast_groupbroadcast_group主要就是利用tf.group()把广播操作组合起来;

2.2 使用

2.2.1 HorovodInternalError

当捕获HorovodInternalError时候,会进行广播同步,目的是当新的通信域构造成功后,rank = 0 的 worker 会将自身的模型广播给其他 worker。

代码语言:javascript
复制
def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False

        try:
            while True:
                if not skip_sync:
                    state.sync() # 这里会进行广播同步,就是TensorFlowKerasState.sync

                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore() # 捕获一场,然后继续while循环
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync

                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper

具体如下:

代码语言:javascript
复制
  Worker rank 0                               Worker rank n
        +                                         +
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 Catch HorovodInternalError                       |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
       sync                                       |
        |                                         |
        |                                         |    
        v                                         |
_broadcast_model(model)                           |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 broadcast_variables(model.variables)             |
                                                  |
 broadcast_variables(optimizer.variables)         |
                                                  |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
  broadcast_group                                 |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 MPI_LIB.horovod_broadcast  +-------------------> |
        +                                         |
        |                                         |
        |                                         |
        v                                         v
2.2.2 HostsUpdateInterrupt

广播对象作用是 在每个 worker 之间同步状态,目的是让这些 worker 同时抛出 HostsUpdateInterrupt 异常。

具体如何使用?

WorkerNotificationService . _handle 方法之中,调用了 self._manager.handle_hosts_updated(req.timestamp, req.res) 进行通知更新。

WorkerNotificationManager. handle_hosts_updated 方法之中,会调用注册的 state,逐一通知更新。

代码语言:javascript
复制
def handle_hosts_updated(self, timestamp, update_res):
    for listener in self._listeners:
        listener.on_hosts_updated(timestamp, update_res)

是在 State 的几个方法中可以看到。

  • on_hosts_updated :当有 host 变化时候调用,即 向 _host_messages 这个 queue 放入一个消息;
  • commit :用户会定期调用此函数,会存储状态,检查 host 更改;
  • check_host_updates : 会从 _host_messages 中读取消息,积累更新,如方法中注释所述,会在每个 worker 之间同步状态,目的是让这些 worker 同时抛出异常。具体同步使用 _bcast_object

check_host_updates代码如下:

代码语言:javascript
复制
def check_host_updates(self):
    """Checks that a notification has been sent indicating that hosts can be added or will be removed.

    Raises a `HostsUpdatedInterrupt` if such a notification has been received.
    """
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    # 这里会广播
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)

具体如下:

代码语言:javascript
复制
+---------------------------+      +--------------+            +-------------+
|Catch HostsUpdatedInterrupt|      | Worker rank 1|            |Worker rank n|
+---------+-----------------+      +-------+------+            +----+--------+
          |                                |                        |
          |                                |                        |
          |                                |                        |
          v                                |                        |
                                           |                        |
 WorkerNotificationService                 |                        |
          +                                |                        |
          |                                |                        |
          |                                |                        |
          |                                |                        |
          v                                |                        |
                                           |                        |
manager.handle_hosts_updated+------------> |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   on_hosts_updated                 |
                                           +                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                   check_host_updates               |
                                           |                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   broadcast_object                 |
                                           +                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   MPI_LIB.horovod_broadcast +----> |
                                           +                        |
                                           |                        |
                                           |                        |
                                           v                        v

0x03 通知机制

上图中用到 manager.handle_hosts_updated,manager就是 WorkerNotificationManager。

所以我们顺着讨论下 WorkerNotificationManager,这是 Hovorod 的通知机制。

3.1 WorkerNotificationManager 生成

每个host 只有一个 WorkerNotificationManager,也只有一个 WorkerNotificationService。

注意:是 ElasticDriver 会作为client,给这些 WorkerNotificationService 发消息,从而引起 WorkerNotificationManager 的对应操作。

horovod/common/elastic.py 有如下代码完成了实例生成。

代码语言:javascript
复制
notification_manager = WorkerNotificationManager()

WorkerNotificationManager 定义如下:

代码语言:javascript
复制
class WorkerNotificationManager(object):
    def __init__(self):
        self._lock = threading.Lock()
        self._service = WorkerNotificationService(secret_key, nic, self)
        self._listeners = set()

3.2 初始化

在用户代码启动之前,会先初始化 WorkerNotificationManager。

代码语言:javascript
复制
 def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        # 初始化 WorkerNotificationManager
        notification_manager.init()
        # 把自己对应的 state 注册到 notification_manager
        notification_manager.register_listener(state)        

WorkerNotificationManager初始化代码如下,其逻辑是:

  • 如果 _service 已经生成,则直接返回,这就保证了每个host之中只有一个 WorkerNotificationService。
  • 从系统变量中得到 rendezvous 的各种信息,比如地址,端口,key 等等;
  • 生成 WorkerNotificationService,赋值给 _service;
  • 使用 put_data_into_kvstore 把 本 worker 的地址 和 给其在逻辑通信环分配的序号 rank 发送给 rendezvous(这个为了后续生成 WorkerNotificationClient 使用)。
  • 备注:这个 rendezvous 会存储每个 worker 的地址和给其在逻辑通信环分配的序号 rank。worker 进程可以通过这个 rendezvous 来构造新的通信域。
代码语言:javascript
复制
def init(self, rendezvous_addr=None, rendezvous_port=None,
         nic=None, hostname=None, local_rank=None):
    with self._lock:
        if self._service:
            return

        # 从系统变量中得到 rendezvous 的各种信息,比如地址,端口,key 等等  
        rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
        rendezvous_port = rendezvous_port if rendezvous_port is not None else \
            int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
        nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
        hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
        local_rank = local_rank if local_rank is not None else \
            int(os.environ.get(HOROVOD_LOCAL_RANK))

        secret_key = secret.make_secret_key()
        self._service = WorkerNotificationService(secret_key, nic, self)

        value = (self._service.addresses(), secret_key)
        # 把本worker的地址 和 给其在逻辑通信环分配的序号 rank 发送给 rendezvous
        put_data_into_kvstore(rendezvous_addr,
                              rendezvous_port,
                              PUT_WORKER_ADDRESSES,
                              self._create_id(hostname, local_rank),
                              value)

具体 put_data_into_kvstore 如下。

代码语言:javascript
复制
def put_data_into_kvstore(addr, port, scope, key, value):
    try:
        url = "http://{addr}:{port}/{scope}/{key}".format(
            addr=addr, port=str(port), scope=scope, key=key
        )
        req = Request(url, data=codec.dumps_base64(value, to_ascii=False))
        req.get_method = lambda: "PUT"  # for urllib2 compatibility
        urlopen(req)
    except (HTTPError, URLError) as e:
        raise RuntimeError("Put data input KVStore server failed.", e)

3.3 注册State

用户代码启动之前,还会把自己对应的 state 注册到 notification_manager。

代码语言:javascript
复制
def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        # 初始化 WorkerNotificationManager
        notification_manager.init()
        # 把自己对应的 state 注册到 notification_manager
        notification_manager.register_listener(state)

具体代码如下:

代码语言:javascript
复制
def register_listener(self, listener):
    self._listeners.add(listener)

def remove_listener(self, listener):
    self._listeners.remove(listener)

3.4 WorkerNotificationService

WorkerNotificationService 在每个host之中也只有一个,用来接受其 client 发来的 HostsUpdatedRequest 消息,进行处理。可以看到,其继承了 network.BasicService,这意味着 WorkerNotificationService 本身是一个http server,可以和其client交互,大家可以想想之前介绍的各种 driver / client,就可以理解其机制了。

代码语言:javascript
复制
class WorkerNotificationService(network.BasicService):
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager):
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)

逻辑如下:

代码语言:javascript
复制
 +-------------------------------+                          +---------------------------+
 | WorkerNotificationManager     |                          | rendezvous                |
 |                               +------------------------> |                           |
 |                               |  put_data_into_kvstore   |                           |
 |                               |                          |                           |
 |                               |                          +---------------------------+
 | _listeners                    |
 |      +                        |                          +---------------------------+
 |      |         _service  +-----------------------------> | WorkerNotificationService |
 |      |                        |                          |                           |
 +-----------------------+-------+                          |                           |
        |                ^                                  |                           |
        |                |                                  |                           |
        |                |                                  |                           |
        |                +----------------------------------------+ _manager            |
        |                                                   |                           |
        v                                                   |                           |
                                                            +---------------------------+
[State 1, State 2, ......, State n]

3.5 WorkerNotificationClient

WorkerNotificationClient 就是用来给 WorkerNotificationService 发送消息的接口。

ElasticDriver 中,会针对每个 worker 生成一个对应的 WorkerNotificationClient,用来进行通知。

代码语言:javascript
复制
class WorkerNotificationClient(network.BasicClient):
    def __init__(self, addresses, key, verbose, match_intf=False):
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))

3.6 生成 Client

3.6.1 注册时机

回顾一下,在 WorkerNotificationManager 的初始化函数 init 中,会给 rendezvous 发送put 请求,进行注册。

注册信息就是为了 生成client。

代码语言:javascript
复制
put_data_into_kvstore(rendezvous_addr,
                      rendezvous_port,
                      PUT_WORKER_ADDRESSES,
                      self._create_id(hostname, local_rank),
                      value)
3.6.2 注册 worker

在 ElasticRendezvousHandler 中有 _put_value,用来处理 PUT_WORKER_ADDRESSES。调用 driver 处理。

代码语言:javascript
复制
# 注意,这里在 Rendezvous Server 之内
def _put_value(self, scope, key, value):
    if scope == PUT_WORKER_ADDRESSES:
        host, local_rank = key.split(':')
        addresses, secret_key = codec.loads_base64(value)
        self._put_worker_addresses(host, int(local_rank), addresses, secret_key)

    super(RendezvousHandler, self)._put_value(scope, key, value)

def _put_worker_addresses(self, host, local_rank, addresses, secret_key):
    # 这里调用driver进行处理
    driver.register_worker_server(host, local_rank, addresses, secret_key)
3.6.3 生成 WorkerNotificationClient

ElasticDriver 中,会针对每个 worker 生成一个对应的 WorkerNotificationClient,用来进行通知。

这里需要注意:ElasticDriver 就是 WorkerNotificationClient 的使用者,需要通知各个worker时候,就调用这些WorkerNotificationClient ,给对应host上的 WorkerNotificationService发消息,从而引起WorkerNotificationManager做相应处理。

代码语言:javascript
复制
# 这里是 ElasticDriver 之中
def register_worker_server(self, host, slot, addresses, secret_key):
    self._worker_clients[(host, slot)] = WorkerNotificationClient(
        addresses, secret_key, self._verbose)

逻辑如下:

代码语言:javascript
复制
 +-------------------------------+
 | WorkerNotificationManager     |                          +---------------------------+      +----------------------------+
 |                               |                          | rendezvous                |      | ElasticRendezvousHandler   |
 |                 init  +--------------------------------> |                        +-------> |                            |
 |                               |  1 put_data_into_kvstore |                           |      |                            |
 |                               |                          |                           |      |                            |
 |                               |                          +---------------------------+      +------------------+---------+
 | _listeners                    |                                                                                |
 |      +                        |                          +---------------------------+                         |
 |      |         _service  +-----------------------------> | WorkerNotificationService |                         |
 |      |                        |                          |                           |                         |
 +-----------------------+-------+                          |                           |                         |
        |                ^                                  |                           |                         |
        |                |                                  |                           |                         |
        |                |                                  |                           |                         |
        |                +----------------------------------------+ _manager            |                         |
        |                                                   |                           |                         |
        v                                                   |                           |                         |
                                                            +---------------------------+                         |
[State 1, State 2, ......, State n]                                                                               |
                                                                                                                  |
                      +-------------------------------------------------------------------------------------------+
                      |                             2 register_worker_server
                      |
                      |
                      v
                                                        3 new instance
 +-------------------------------+
 |ElasticDriver                  |             +----------------------------+     +---------------------------+
 |                               |             | WorkerNotificationClient 1 |     |WorkerNotificationClient n |
 |                               |             |                            |     |                           |
 |                               |             |                            |     |                           |
 |         _worker_clients  +--------------->  |     (host 1, slot 1)       | ... |     (host n, slot n)      |
 |                               |             |      For worker 1          |     |        For worker n       |
 |                               |             |                            |     |                           |
 +-------------------------------+             +----------------------------+     +---------------------------+

手机如图:

3.7 使用

3.7.1 发现更新

ElasticDriver._discovery_thread 之中 如果发现有 host 变化,则调用 self._notify_workers_host_changes 来通知。

代码语言:javascript
复制
def _notify_workers_host_changes(self, current_hosts, update_res):
    next_host_assignments = {}
    if current_hosts.count_available_slots() >= self._min_np:
        # Assignments are required to be stable via contract
        next_host_assignments, _ = self._get_host_assignments(current_hosts)

    if next_host_assignments == self.host_assignments:
        # Skip notifying workers when host changes would not result in changes of host assignments
        return

    coordinator_slot_info = self.get_coordinator_info()
    coordinator_client = self.get_worker_client(coordinator_slot_info)

    timestamp = _epoch_time_s()
    coordinator_client.notify_hosts_updated(timestamp, update_res)
3.7.2 获取 client

get_worker_client 函数就是获取 WorkerNotificationClient。就是依据 host,slot 信息来找到某一个 worker 对应的 client。

代码语言:javascript
复制
def get_worker_client(self, slot_info):
    return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
3.7.3 发送HostsUpdatedRequest

notify_hosts_updated 的作用是发送HostsUpdatedRequest

代码语言:javascript
复制
class WorkerNotificationClient(network.BasicClient):
    def __init__(self, addresses, key, verbose, match_intf=False):
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))
3.7.4 处理 HostsUpdatedRequest

WorkerNotificationService 之中会处理HostsUpdatedRequest,调用 WorkerNotificationManager处理。

代码语言:javascript
复制
class WorkerNotificationService(network.BasicService):
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager):
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)
3.7.5 WorkerNotificationManager

所以,当有host 更新时候,WorkerNotificationManager 中的 handle_hosts_updated 如下,最终调用到 state 的 on_hosts_updated。

代码语言:javascript
复制
def handle_hosts_updated(self, timestamp, update_res):
    for listener in self._listeners: # 遍历state
        listener.on_hosts_updated(timestamp, update_res)

State 的实现如下:

代码语言:javascript
复制
def on_hosts_updated(self, timestamp, update_res):
    self._host_messages.put((timestamp, update_res))

逻辑如下图:

代码语言:javascript
复制
                                                         +-----------------------------v
                                                         ^        thread loop          |
                                                         |                             |
                                        +----------------+----------------------+      |
                                        |  ElasticDriver._discovery_thread      |      |
       1 _notify_workers_host_changes   |                                       |      |
                                        |                                       |      |
                     +------------------+                                       |      |
                     |                  |                                       |      |
                     |                  |   HostManager.update_available_hosts  |      |
                     |                  |                                       |      |
                     |                  +-----------------+---------------------+      |
                     |                                    ^                            |
                     |                                    |                            |
                     |                                    |                            |
                     |                                    +----------<---------------+ v
                     v

+---------------------------+ 2 HostsUpdatedRequest  +----------------------------+ handle_hosts_updated +----------------------------+
|                           |                        |                            |                      |                            |
| WorkerNotificationClient  +----------------------> |  WorkerNotificationService | +------------------> |  WorkerNotificationManager |
|                           |                        |                            |                      |                            |
+---------------------------+                        +----------------------------+                      +------+---------------------+
                                                                                                                |
                                                                                                                |
                                                                                                                | on_hosts_updated
                                                                                                                |
                                                                                                                v
                                                                                                  +-----------------------+
                                                                                                  |  State      |         |
                                                                                                  |             | put     |
                                                                                                  |             v         |
                                                                                                  |     _host_messages    |
                                                                                                  +-----------------------+

手机如下:

3.7.6 处理更新

在用户调用 commit 的时候,才会调用 check_host_updates 检查更新。

代码语言:javascript
复制
def commit(self):
    self.save()
    self.check_host_updates()

检查更新就是看看 _host_messages 有没有新的消息,如果发现 host 有变化,就会产生一个 HostsUpdatedInterrupt 异常。

代码语言:javascript
复制
def check_host_updates(self):
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)

在worker进程出现 HorvodInternalError 错误或者 HostsUpdatedInterrupt 节点增删时,会捕获这两个错误,调用 reset 来进行容错处理。于是就把流程前后串联了起来。

具体如下:

至此,广播通知机制我们整理完毕,下一篇介绍 worker 如何运作。

0xEE 个人信息

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-07-16 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 源码解析 深度学习分布式训练框架 horovod (15) --- 广播 & 通知
    • 0x00 摘要
      • 0x01 问题
        • 1.1 HorovodInternalError
        • 1.2 HostsUpdateInterrupt
      • 0x02 广播机制
        • 2.1 广播实现
        • 2.2 使用
      • 0x03 通知机制
        • 3.1 WorkerNotificationManager 生成
        • 3.2 初始化
        • 3.3 注册State
        • 3.4 WorkerNotificationService
        • 3.5 WorkerNotificationClient
        • 3.6 生成 Client
        • 3.7 使用
      • 0xEE 个人信息
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档