前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch & MMCV Dispatcher 机制解析

PyTorch & MMCV Dispatcher 机制解析

作者头像
OpenMMLab 官方账号
发布2022-01-18 09:56:07
1.1K0
发布2022-01-18 09:56:07
举报
文章被收录于专栏:OpenMMLab

最受欢迎的 MMCV 教程系列如约而至!

话不多说,进入正题~

本文内容

什么是 Dispatcher

为什么需要 Dispatcher

注册 + 分发

PyTorch Dispatcher

MMCV Dispatcher

1. 什么是 Dispatcher

假设一个团队有一个项目经理和三个程序员,甲方正在疯狂地提各种需求,然后项目经理要做的就是根据每位程序员的专长,将不同的需求分配给不同的程序员来做,但是项目经理自己不会去实现需求,此时我们可以说,项目经理就是一个 Dispatcher

Dispatcher,中文译为分派器,它的工作可以简单理解为只负责任务的分发,而不会参与到任务本身中来。但思考一下,如果这个团队的规模扩大,现在有三十位程序员,那这位项目经理可能就不得不加班工作了。这说明上述的 Dispatcher 方案不够优雅,那么有没有更优雅一点的方案呢?我们会在第三节注册 + 分发部分给大家介绍一种思路。

PyTorch 和 MMCV 的 Dispatcher 同样只负责任务的分发,他们将高层 API 分发到合适的底层实现。举个例子,相信很多人都会有疑问,在 PyTorch 中,当我们写下 torch.add(inputs) 时,PyTorch 是怎样让该函数根据 inputs 所在设备、数据布局、数据类型找到其实际应当执行的 kernel 的呢?我们会在第四节和第五节给大家简单解析 PyTorch 和 MMCV 的 Dispatcher 机制,并让大家了解 Dispatcher 在上述过程中发挥的重要作用。

2. 为什么需要 Dispatcher

从上面的描述来看,Dispathcer 只是一个美化的 if 语句:根据 inputs 的一些信息,决定应该调用哪一段代码,那么我们为什么需要 Dispatcher 呢?答案显而易见,是为了维持代码的可维护性和优雅性。

想想如果没有 Dispatcher,对于一个简单的 torch.add,我们就需要针对 CPU、GPU、TPU、FPGA 等不同设备,普通张量、稀疏张量(主要用于解决 3D 视觉、图神经网络等领域数据的稀疏性,以减少内存和计算量浪费 )等不同张量布局,float32、float16、int8 等不同数据类型,写下一大堆 if else,这实在是不可接受的。每一个 operator 我们都这样编写的话,将不可避免地造成代码的混乱、冗余、难以 Debug。因此,我们需要一个 Dispatcher,让它来统一管理分派工作。

3. 注册 + 分发

PyTorch 和 MMCV Dispatcher 都采用的是一种非常常见的架构设计方法——注册 + 分发,或者从设计模式的角度来讲可以叫注册器模式 + 工厂模式。使用这种模式可以实现强大的解耦性和扩展性,实际上 MMCV 的 Register 机制也可以看作这种模式的应用 MMCV 核心组件分析(五): Registry。笔者以为,除了解耦和扩展,这种模式还有一个重要的优点在于责任的下放。

《MMCV 核心组件分析:Registry 》全文可见 OpenMMLab 知乎:

https://zhuanlan.zhihu.com/p/355271993

我们举一个例子来解释什么是注册 + 分发以及什么是责任的下放:

已知复数有两种表示形式,一种是实部和虚部表示(rectangular),一种是模和幅角表示(polar)。那么如果我们要实现一个系统,同时支持这两种表示形式的加法运算,我们应该怎么做呢?我们会先把接口(add)和实现(rectangular-add、polar-add)分离,并给数据打上 tag,最终接口代码会是这样:

代码语言:javascript
复制
# 接口
add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    if(z1.tag == "rectangular"):
        return rectangular-add(z1, z2)
    else if(z1.tag == "polar"):
        return polar-add(z1, z2)
    else:
        raise ValueError

这里的 add 接口就像是一个 manager,他负责任务的分发。此时如果有一个人叫小明,他说我有另一种复数表示形式叫 ming 表示法,那么我们就需要在 add 里再加一个判断语句,那么这个 manager 的责任就会越来越重。

而采用注册 + 分发的机制,实际上会存在一张支持 register 和 get 操作的虚拟表格,像是这样:

每次有新的方法时,通过 register 注册到表里,然后我们的接口通过 get 取到对应的实现函数,此时接口就不再是 manager 了,它只需要从表中取出对应的函数并执行,而向表中注册的事情,交由编写具体的新方法的人来负责,最终接口代码会是这样:

代码语言:javascript
复制
add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    method = get(z1.tag)
    return method(z1, z2)

上述的例子来自讲座 Lec4b:通用运算符,讲座的前半段重点讲述类似注册 + 分发的机制,后半段则更进一步,将这种机制与递归结合,展示了在复杂系统中的强大威力,推荐大家观看。

《Lec4b:通用运算符》b 站视频链接: https://www.bilibili.com/video/BV1Xx41117tr?p=8

4. PyTorch Dispatcher

设备、布局和类型

PyTorch 的核心数据结构是 Tensor,Tensor 可以简单看做由一些元数据描述的 n 维数据结构。其中最核心的三个元数据是 device、layout 和 dtype。

- device:表明 Tensor 的物理内存实际存储,CPU 或 GPU 或 TPU 或其他

- layout:决定如何在逻辑上解释这块 Tensor 的数据所在的这块连续的物理内存,普通张量或稀疏张量或其他

- dtype:表明数据的类型,float32 或者 float16 或者 int8 或者其他

Tensor 的数据本身和这三个元数据一起决定了这个张量在用户眼中到底是什么,而这三个元数据的笛卡尔积的大小也就决定了一个 torch operator 对应的 kernel 数量。

算子注册

PyTorch 的算子注册是通过一个二维表,竖轴是 PyTorch 中支持的每个运算符。横轴则是 PyTorch 支持的每个分派键(除了 opHandle),可以理解为不同硬件平台 + 不同张量布局 + 其他,一个分派键对应着一类 Tensor。算子注册就是填充单元格。

当我们在单个分派键上为单个运算符注册 kernel 时,我们填充单个单元格,比如:

代码语言:javascript
复制
TORCH_LIBRARY_IMPL(aten, CPU, m) {
  m.impl("mul", cpu_mul);
}

当我们在单个分派键上为所有运算符注册 kernel 时,我们填充该分派键的列,比如:

代码语言:javascript
复制
TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
  m.fallback(makeFallthrough());
}

其中,在单个分派键上为单个运算符注册的 kernel 具有更高的优先级。

PyTorch Dispatcher 实现

以下代码部分均基于 PyTorch 1.5,不过有些代码是编译时生成的,之后的 PyTorch 版本可能修改了许多源码,但核心思想应当是不变的。

当调用 torch.add 时,会发生两次分派。第一次分派针对 Tensor 的设备类型和布局,例如,它是 CPU Tensor 还是 CUDA Tensor,它是 Strided Tensor 还是 Sparse Tensor;第二次分派则是针对 Tensor 的类型。其中第二次分派只是用类似 AT_DISPATCH_ALL_TYPES 这样的宏包装了一个 switch 语句来完成分派,MMCV 中同样借用了 PyTorch 的这个轮子,比较简单,在此不再赘述。我们重点关注第一次分派。

当执行 torch.add() 时,通过 pybind11 (连接 Python 和 C++ 的桥梁) 来到 THPVariable_add 函数;然后经过多次跳转来到 at::add。

代码语言:javascript
复制
static inline Tensor add(const Tensor & self, const Tensor & other, Scalar alpha) {
    ...
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::add", "Tensor");
    return op.callUnboxed<Tensor, const Tensor &, const Tensor &, Scalar>(self, other, alpha);
}

在这里,我们通过 findSchemaOrThrow 得到了和 at::add 相对应的 OperatorHandle (即 3.3 节表中的 opHandle)。

每个 operator 和 OperatorHandle 是一一对应的,在每个 OperatorHandle 中都包含一张 "KF" 表,这张表相当于 3.3 节表中的一行,存储着分派键和对应的算子。然后我们会调用 callUnboxed,并跳转到 Dispatcher::callUnboxed。

代码语言:javascript
复制
inline Return Dispatcher::callUnboxedWithDispatchKey(const OperatorHandle& op, DispatchKey dispatchKey, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
    return kernel.template callUnboxed<Return, Args...>(op, std::forward<Args>(args)...);
}
inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, args...);
    return callUnboxedWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}

在 callUnboxed 中,我们通过 "KF" 表获取到对应的分派键,随后在 callUnboxedWithDispatchKey 中通过分派键索引到了相应的 KernelFunction ,之后 KernelFunction 内部会调用算子函数。

5. MMCV Dispatcher

算子扩展技术选择

MMCV 作为基于 PyTorch 和 Parrots(商汤自研框架)的框架,可以利用 PyTorch 提供的扩展方式,实现许多高质量的算子。而 PyTorch 主要提供了三种底层算子扩展的方式,native_functions 方式、Extension 方式和 OP Register 方式,那么为什么 MMCV 选择了 Extension 方式呢?下面我们分别分析一下 Pytorch 的三种扩展方式就能得到答案了。

native_functions

我们只需要按照 PyTorch 官方方式组织新算子的 kernel 实现,然后在 native_functions.yaml 和 derivatives.yaml 添加配置信息,就可以自动生成:torch.xxx()、torch.nn.functional.xxx()、tensor.xxx() 方法。

看起来很方便,但是该方法与 PyTorch 的耦合度过高,实际上是修改 PyTorch 源码,所以每次增加或者修改算子都需要重新编译 PyTorch,成本较高。而且 MMCV 的定位是基于 PyTorch,而不是 PyTorch 的魔改版,所以不采用该方案。

OP register

该方式与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。我们只需要先编写新算子的 kernel 实现,然后通过 PyTorch 底层的注册接口torch::RegisterOperators 将该算子注册即可。如下所示:

代码语言:javascript
复制
namespace {
Tensor my_kernel_cpu(const Tensor& a, const Tensor& b) {...}
Tensor my_kernel_cuda(const Tensor& a, const Tensor& b) {...}
}

static auto registry = torch::RegisterOperators()
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(CPU()))
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cuda), &my_kernel_cuda>(CUDA()));

但是该方法也有一个问题,就是如果要增加新硬件平台对应的算子,那么需要首先在 PyTorch 源码中增加对新硬件的支持,之后才能借助torch::RegisterOperators 实现算子注册。MMCV 正积极与各硬件厂商合作,已支持曙光 DCU,也正在与寒武纪 MLU 开展合作,被动地等待 Pytorch 官方加入这些新硬件显然不是一个好的策略,因此不采用该方案。

Extension

该方案同样与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。它的原理是通过 pybind11,将 C++(CUDA) 编译为 PyTorch 的一个模块,更多内容可见 揭秘 C++/CUDA 算子实现和调用全流程。同时硬件厂商也可以通过和 MMCV 合作,提供自家的 Extension(如寒武纪的 MLUExtension)模块,较为简单地实现在 MMCV 中支持新硬件。

《揭秘 C++/CUDA 算子实现和调用全流程》全文可见:

https://zhuanlan.zhihu.com/p/348555597

MMCV Dispatcher 实现

为什么要讲 MMCV 算子扩展技术选择呢?主要还是为了让大家理解 MMCV 的特点——基于 PyTorch 和 Parrots、提供算子扩展、支持多后端设备。MMCV 的特点决定了我们会尽量利用 PyTorch 的设计思想和"轮子",但不会使用 PyTorch 那样复杂的 Dispatcher 机制,力求简单、灵活。MMCV 和 PyTorch 一样会发生两次分派,我们也是主要关注第一次分派。

MMCV Dispathcer 的核心就是 pytorch_device_registry.hpp 这个文件,核心内容有三部分:

第一部分是 DeviceRegistry 类,主要包括一个单例模式的方法 instance()、实现设备和算子函数绑定的 Register()、根据设备查找对应算子的Find()。

代码语言:javascript
复制
// Registry
template <typename F, F f>
class DeviceRegistry;

template <typename Ret, typename... Args, Ret (*f)(Args...)>
class DeviceRegistry<Ret (*)(Args...), f> {
 public:
  using FunctionType = Ret (*)(Args...);
  static const int MAX_DEVICE_TYPES =
      int8_t(at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);

  void Register(at::DeviceType device, FunctionType function) {
    funcs_[int8_t(device)] = function;
  }

  FunctionType Find(at::DeviceType device) const {
    return funcs_[int8_t(device)];
  }

  static DeviceRegistry& instance() {
    static DeviceRegistry inst;
    return inst;
  }

 private:
  DeviceRegistry() {
    for (size_t i = 0; i < MAX_DEVICE_TYPES; ++i) {
      funcs_[i] = nullptr;
    }
  };
  FunctionType funcs_[MAX_DEVICE_TYPES];
};

第二部分是 Dispatch 函数,主要逻辑为首先获取第一个 Tensor 的设备,然后检查全部 Tensor 的设备一致性,之后根据设备找到对应的函数(指针),最后执行函数,中间会通过 TORCH_CHECK 做检查工作:

代码语言:javascript
复制
// dispatch
template <typename R, typename... Args>
auto Dispatch(const R& registry, const char* name, Args&&... args) {
  auto device = GetFirstTensorDevice(std::forward<Args>(args)...);
  auto inconsist =
      CheckDeviceConsistency(device, 0, std::forward<Args>(args)...);
  TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ",
              inconsist.first,
              ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(),
              " vs ", GetDeviceStr(device).c_str(), "\n")
  auto f_ptr = registry.Find(device.type());
  TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ",
              GetDeviceStr(device).c_str(), " not found.\n")
  return f_ptr(std::forward<Args>(args)...);
}

第三部分主要是一些宏,根据前两部分的代码就能轻松理解这一部分,在此不再赘述。

代码语言:javascript
复制
// helper macro
#define DEVICE_REGISTRY(key) DeviceRegistry<decltype(&(key)), key>::instance()

#define REGISTER_DEVICE_IMPL(key, device, value)           \
  struct key##_##device##_registerer {                     \
    key##_##device##_registerer() {                        \
      DEVICE_REGISTRY(key).Register(at::k##device, value); \
    }                                                      \
  };                                                       \
  static key##_##device##_registerer _##key##_##device##_registerer;

#define DISPATCH_DEVICE_IMPL(key, ...) \
  Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__)

通过使用 Dispatcher,MMCV 重构了 CSRC 目录的代码,详情见 Refactor csrc with device dispatcher

Refactor csrc with device dispatcher 链接:

https://github.com/open-mmlab/mmcv/pull/1463

重构前:

代码语言:javascript
复制
void roi_align_forward(Tensor input, Tensor rois, Tensor output,
                       Tensor argmax_y, Tensor argmax_x, int aligned_height,
                       int aligned_width, float spatial_scale,
                       int sampling_ratio, int pool_mode, bool aligned) {
  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(rois);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(argmax_y);
    CHECK_CUDA_INPUT(argmax_x);

    roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
                           aligned_height, aligned_width, spatial_scale,
                           sampling_ratio, pool_mode, aligned);
#else
    AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
  } else {
    CHECK_CPU_INPUT(input);
    CHECK_CPU_INPUT(rois);
    CHECK_CPU_INPUT(output);
    CHECK_CPU_INPUT(argmax_y);
    CHECK_CPU_INPUT(argmax_x);
    roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
                          aligned_height, aligned_width, spatial_scale,
                          sampling_ratio, pool_mode, aligned);
  }
}

重构后:

代码语言:javascript
复制
// 注册算子的cuda实现
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);

// roi_align.cpp
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned) {
  DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
                       argmax_x, aligned_height, aligned_width, spatial_scale,
                       sampling_ratio, pool_mode, aligned);
}

参考资料

[1] REGISTERING A DISPATCHED OPERATOR IN C++

https://pytorch.org/tutorials/advanced/dispatcher.html

[2] PyTorch internals

http://blog.ezyang.com/2019/05/pytorch-internals/

[3] Let’s talk about the PyTorch dispatcher

http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

[4] Lec4b:通用运算符

https://www.bilibili.com/video/BV1Xx41117tr?p=8

[5] MMCV 核心组件分析(五): Registry

https://zhuanlan.zhihu.com/p/355271993

[6] 揭秘 C++/CUDA 算子实现和调用全流程

https://zhuanlan.zhihu.com/p/348555597

[7] MMCV Compatibility Docs

https://github.com/open-mmlab/mmcv/blob/master/docs/zh_cn/compatibility.md

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-12-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
GPU 云服务器
GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于生成式AI,自动驾驶,深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档