
目录
我们在前文介绍过,微软 ZeRO 可以对一个万亿参数模型可以使用 8 路模型并行、64 路管道并行和 8 路数据并行在 4,096 个 NVIDIA A100 GPU 上进行扩展。而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鉴微软ZeRO之后提出的PyTorch DDP升级版本,可以认为是对标微软 ZeRO,其本质是 parameter sharding。Parameter sharding 就是把模型参数等切分到各个GPU之上。我们会以 Google,微软和 Facebook 的论文,博客以及代码来进行学习分析。
之前文章之中我们谈到了FSDP支持混合精度训练,本篇来看看 Activation recomputation。
本系列其他文章如下:
[源码解析] PyTorch 分布式之 ZeroRedundancyOptimizer
[论文翻译] 分布式训练 Parameter sharding 之 ZeRO
[论文翻译] 分布式训练 Parameter Sharding 之 Google Weight Sharding
[源码分析] Facebook如何训练超大模型 --- (2)
[源码分析] Facebook如何训练超大模型 --- (3)
激活重新计算(Activation recomputation),也称“激活检查点(activation checkpointing)”或“梯度检查点(gradient checkpointing)”(Chen et al,2016 https://arvix.org/abs/1604.06174),其思路是用时间换空间,即,牺牲计算时间来换取内存空间。其减少了深度神经网络训练层的内存开销,代价是每个batch会消耗额外的前向传播计算。
比方说,该方法将m层网络平均划分为d个分区,只保存分区边界的激活,并在workers之间交换这些激活。因为后向传播之中依然需要分区内层间激活值(Intermediate activations at intra-partition layers)来计算梯度,所以在后向传播过程中会在分区内部重新计算激活。
下图为论文之中的示意图。

我们在之前文章之中介绍过重计算 [源码解析] 深度学习流水线并行 GPipe(3) ----重计算。本文会看看 FairScale 是如何对其进行进一步封装和改进。
在看思路之前,我们先来讲讲如何更好的分析一个开源框架或者说如何学习源码。个人的意见是按照:论文 --> 文档 --> 用户手册 --> 注释 --> 源码 这个顺序来学习。
为什么按照这个顺序?因为这个顺序是:
对于我们来说,应该首先寻求一种思维的改变,知识框架的更新与整理,然后才是用代码来分析验证(毕竟纸上得来终觉浅)。当然,很多时候我们只有源码,那么就只能从源码之中根据细节来探寻,重建作者的思路,提炼其精华,争取和作者达到一个跨越空间和时间的共鸣,共鸣越多,你就越接近作者了 _。
我们接下来就看看源码文档之中的思路介绍。
激活检查点是一种用于减少训练期间GPU内存使用的技术。具体做法是:
其结果是:以略有增加(约33%)的计算成本来减少了存储大型激活张量的必要,因此允许我们增加batch size,从而增加模型的净吞吐量。
激活检查点是通过重载  torch.autograd.Function 来完成的。
上述功能在torch.utils.checkpoint.checkpoint_wrapper  之中可以看到其具体实现,可以在前向传播之中使用这个API来对模块进行封装。FairScale中的包装器提供的功能比PyTorch API提供的功能更多,比如用户可以使用 fairscale.nn.checkpoint.checkpoint_wrapper 来包装一个 nn.Module,这样就可以在正向传递中处理kwargs,将中间激活卸载(offload)到CPU,并处理从前向函数返回的非张量输出。
我们接下来看看  fairscale.nn.checkpoint.checkpoint_wrapper 的最佳实践。
requires_grad 属性设置为True。通过将输入张量的requires_grad 属性设置为True,我们确保输入可以传播到输出,并触发 backward 函数。checkpoint_wrapper 是具体的wrapper,其内部就是调用了其他函数。但是我们发现其注释可以让我们进一步学习,所以翻译如下:
checkpoint_wrapper 是执行激活检查点的包装器,其比PyTorch版本更加用户友好,具备如下特点:
为了更好的了解checkpointing和"offload_to_cpu"带来的好处,我们将激活分为两种类型:
就GPU内存节约效果而言:
另外,第一层和最后一层不太可能受益于offload_to_cpu标志,因为:
def checkpoint_wrapper(
    module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
    """
    A friendlier wrapper for performing activation checkpointing.
    Compared to the PyTorch version, this version:
        - wraps an nn.Module, so that all subsequent calls will use checkpointing
        - handles keyword arguments in the forward
        - handles non-Tensor outputs from the forward
        - supports offloading activations to CPU
    Usage::
        checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
        a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
    To understand the benefits of checkpointing and the `offload_to_cpu` flag,
    let's divide activations into 2 types: inner activations and outer
    activations w.r.t. the checkpointed modules. The inner ones are saved
    by activation checkpointing, the outer ones are saved by offload_to_cpu.
    In terms of GPU memory savings:
        - When inner ones are large in size and outer ones are small,
          checkpointing helps a lot, offload_to_cpu may help a little.
        - When inner ones are small and outer ones are large,
          checkpointing helps little, offload_to_cpu helps a lot.
        - When both inner and outer are large, both help and the
          benefit is additive.
    ..Note::
        The first and last layers are not likely to benefit from the `offload_to_cpu` flag
        because (1) there are typically other references to the first layer's input, so
        the GPU memory won't be freed; (2) the input to the last layer is immediately
        used by the backward pass and won't result in memory savings.
    Args:
        module (nn.Module):
            The module to be wrapped
        offload_to_cpu (bool):
            Whether to offload activations to CPU.
        maintain_forward_counter (bool):
            If True, maintain a forward counter per inner module. The counter will first
            increases in forward calls of outer forward pass and then decreases in the
            forward calls of outer backward pass. It is used by FullyShardedDataParallel.
    Returns:
        (nn.Module):
            Wrapped module
    """
    # Patch the batchnorm layers in case there are any in this module.
    patch_batchnorm(module)
    if maintain_forward_counter:
        init_counter(module)
    # The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
    # When such cycle exists, gc won't collect the module when the module is freed.
    # That causes GPU memory to be leaked. See the unit test for how we catch that.
    #
    # We prefer this over a class wrapper since the class wrapper would have to
    # proxy a lot of fields and methods.
    module.forward = functools.partial(  # type: ignore
        _checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
    )
    return module # 包装一个nn.Module,以便所有后续调用都将使用checkpointing我们从源码之中找出一些代码,大家可以看看。
self.layers = nn.Sequential(
    nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
    nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
    nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
)
if enable_checkpoint:
    for i, layer in enumerate(self.layers):
        # Only middle layer needs to have offloading
        self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)前面提到对比PyTorch版本,FairScale有几点益处,此处就对应了以下有下划线的两点:
代码逻辑如下:
.forward() 。这样做还可以确保内部fwd counter在前向过程中不会增加,但是这在eval过程中会是一个问题,因为不会有相应的后向过程来减少fwd counter。具体代码如下:
def _checkpointed_forward(
    original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
    module = weak_self()
    # If gradients are disabled, just use original `.forward()` method directly.
    # Doing so also ensures the internal fwd counter is not incremented in the forward pass,
    # which would be an issue during eval since there wouldn't be a corresponding backward pass
    # to decrement the fwd counter.
    # See https://github.com/facebookresearch/fairscale/pull/709.
    if not torch.is_grad_enabled():
        return original_forward(module, *args, **kwargs)
    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module,) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) # 处理输入
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    output = CheckpointFunction.apply(
        torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
    )
    
    # 处理非张量输出
    if not isinstance(output, torch.Tensor):
        # parent_ctx_dict["packed_non_tensor_outputs"] 是 CheckpointFunction 返回的
        packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            # 统一处理成tuple
            output = unpack_non_tensors(output, packed_non_tensor_outputs) # 处理输出
    return output在处理前向过程中的关键字参数(keyword arguments)之中,使用了pack_kwargs,其作用就是把参数的key,value整理成为两个list,具体可以参见示例。
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
    """
    Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
    Usage::
        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
        assert kwarg_keys == ("a", "b")
        assert flat_args == (1, 2, 3, 4)
        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
        assert args == (1, 2)
        assert kwargs == {"a": 3, "b": 4}
    """
    kwarg_keys: List[str] = []
    flat_args: List[Any] = list(args)
    for k, v in kwargs.items():
        kwarg_keys.append(k)
        flat_args.append(v)
    return tuple(kwarg_keys), tuple(flat_args)把一个tuple分割为一个张量列表和后续重建所需要的信息。
def split_non_tensors(
    mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
    """
    Split a tuple into a list of tensors and the rest with information
    for later reconstruction.
    Usage::
        x = torch.Tensor([1])
        y = torch.Tensor([2])
        tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
        assert tensors == (x, y)
        assert packed_non_tensors == {
            "is_tensor": [True, True, False, False],
            "objects": [None, 3],
        }
        recon = unpack_non_tensors(tensors, packed_non_tensors)
        assert recon == (x, y, None, 3)
    """
    if isinstance(mixed, torch.Tensor):
        return (mixed,), None
    tensors: List[torch.Tensor] = []
    packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
    for o in mixed:
        if isinstance(o, torch.Tensor):
            packed_non_tensors["is_tensor"].append(True)
            tensors.append(o)
        else:
            packed_non_tensors["is_tensor"].append(False)
            packed_non_tensors["objects"].append(o)
    return tuple(tensors), packed_non_tensorsunpack_non_tensors 用来把非张量列表恢复成tuple。
def unpack_non_tensors(
    tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
    """See split_non_tensors."""
    if packed_non_tensors is None:
        return tensors
    assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
    mixed: List[Any] = []
    is_tensor_list = packed_non_tensors["is_tensor"]
    objects = packed_non_tensors["objects"]
    obj_i = tnsr_i = 0
    for is_tensor in is_tensor_list:
        if is_tensor:
            mixed.append(tensors[tnsr_i])
            tnsr_i += 1
        else:
            mixed.append(objects[obj_i])
            obj_i += 1
    return tuple(mixed)我们接下来分析 CheckpointFunction,就是具体 activation checkpointing 的业务函数。关于 PyTorch 的 CheckpointFunction 版本,可以参见 [源码解析] 深度学习流水线并行 GPipe(3) ----重计算。
这里对应了优点之中的:支持将激活卸载到CPU。
其前向传播的逻辑如下:
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.
    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """
    @staticmethod
    def forward(  # type: ignore
        ctx: Any,
        dummy_tensor_requires_grad: torch.Tensor,
        run_function: Any,
        parent_ctx_dict: Dict[str, Any],
        kwarg_keys: Tuple[str, ...],
        *args: Any,
        **kwargs: Any
    ) -> Any:
        torch_checkpoint.check_backward_validity(args)
        ctx.run_function = run_function # 在上下文之中存储前向传播函数
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state() # 在上下文之中存储前向传播状态
        ctx.had_autocast_in_fwd = is_autocast_enabled()
        # 分割非张量参数列表,得到张量输入和非张量输入
        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) 
        if parent_ctx_dict["offload"]:
            # 在上下文记录设备,梯度需求情况,并且把输入张量放到cpu上
            ctx.fwd_device = tuple(x.device for x in tensor_inputs) # 在上下文存储前向传播设备
            ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
            tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None
        # 为后向传播保存输入
        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
        with torch.no_grad(), enable_checkpointing(): # 如果设置了activation checkpointing
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) # 处理参数
            outputs = run_function(*unpacked_args, **unpacked_kwargs) # 前向计算
            the_module = unpacked_args[0]
            inc_counter(the_module)
        if not isinstance(outputs, torch.Tensor): # 如果输出不是张量
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            # Autograd Functions不喜欢非张量输出。我们可以拆分为非张量和张量输出,
            # 通过parent_ctx_dict引用返回前者,然后直接返回后者。
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
        return outputs后向传播逻辑如下:
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.
    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """
    @staticmethod
    def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        # 拿到存储在上下文的张量输入
        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None: # 如果设置了在设备上计算
            # 把 offload 的张量再移到 GPU之上
            tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs))
            for i, need_grad in enumerate(ctx.grad_requirements): # 找到需要计算的梯度
                tensor_inputs[i].requires_grad = need_grad
        # 处理非张量输入,最终和张量输入组合在一起        
        inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
        # Store the current states.
        bwd_rng_state = get_rng_state() # 拿到之前保存的当前状态
        # Set the states to what it used to be before the forward pass.
        set_rng_state(ctx.fwd_rng_state) # 从上下文加载前向传播时候的状态
        with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
            unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) # 重新做前向传播
            tensor_outputs, _ = split_non_tensors(outputs) # 处理前向传播输出
            the_module = unpacked_args[0]
            dec_counter(the_module)
        # Set the states back to what it was at the start of this function.
        set_rng_state(bwd_rng_state) # 恢复后向传播的状态
        # Run backward() with only Tensors that require grad
        outputs_with_grad = [] 
        args_with_grad = []
        # 从前向传播输出找到需要梯度的张量
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i]) # 在后向传播的输入之中找到对应的张量
        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")
        # 进行后向传播     
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        # 从inputs里面得到梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
        return (None, None, None, None) + grads # 返回梯度我们整理逻辑如下:

前文在 OffloadModel 的 forward 方法之中,如果设置了 _checkpoint_activation,则调用 OffloadFunction 把激活检查点卸载到CPU之上,直接返回。我们接下来看看 OffloadFunction 如何实现与activation相关的操作。
此函数通过覆盖nn.Module的向前和向后传播,在分片边界启用中间激活的检查点。这样只保存分区边界的激活,并在workers之间交换这些激活。
本节与上节的主要区别是:
在FW过程中,它遍历每一个分区,针对每一个分区,删除前一个分片中的参数,并加载下一个分片的参数,然后进行这个分区的前向计算。FW过程中未构造任何计算图。这使我们能够卸载分片边界上的中间激活。
这里有几点说明:
具体代码如下:
class OffloadFunction(torch.autograd.Function):
    """
     This Function enables checkpointing of intermediate activations at
     shard boundaries by overriding the forward and backward pass of the nn.Module.
     - In the FW pass, it drops parameters in the previous shard and
     loads parameters for the next shard. No graph is constructed in the FW pass.
     This enables us to offload intermediate activations present at the shard
     boundaries.
     - In the BW pass, it does the reverse. We run the forward pass using the
     saved intermediate activations and calculate gradients as needed.
     The trade-off is latency vs memory when using activation checkpointing.
     - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
     NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
     """
    @staticmethod
    @_conditional_amp_fwd_decorator  # type: ignore
    def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
        inputs = inputs if isinstance(inputs, tuple) else (inputs,)
        # 把后向传播所需要的信息存储在上下文。
        ctx.inputs = inputs
        ctx.model_instance = model_instance
        # TODO(anj-s): We might need to store this for each boundary activation.
        # Currently we assume all boundary activation inputs require
        ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
        ctx.fwd_rng_state = torch.get_rng_state()
        # List of input activations starting with the given input.
        model_instance._activations = [inputs]
        # Enumerate through layer shards and apply activations from the previous shard.
        for index, layer_shard in enumerate(model_instance.model_slices): # 遍历模型的分区
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
                # Bring in the current activations onto the device.
                # 把当前激活拷贝到设备之上
                model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
                # Bring in the current layer shard onto the device.
                # 把当前层加载到设备之上
                layer_shard.forward_load()
            # Apply the FP and store the activations on the CPU.
            inputs = model_instance._activations[index]
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
                with torch.no_grad(): # 不会跟踪下面的梯度,只是计算激活
                    output_list: List[Any] = []
                    for given_input in inputs:
                        given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
                        given_output_list = []
                        for inputs in given_input_list:
                            output = layer_shard(inputs) # 前向操作
                            given_output_list.append(output)
                        given_output = torch.cat(given_output_list).squeeze(-1)
                        output_list.append(given_output)
                    output = tuple(output_list) # 得到输出
            output = output if isinstance(output, tuple) else (output,)
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
                # Move the activation used back for the curent shard back to the CPU.
                # 把激活移动到CPU
                model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
                # The newly computed activations remain on the GPU ready for the next shard computation.
                model_instance._activations.append(output)
                # Move the layer shard back to the CPU.
                layer_shard.forward_drop() # 把层移动到CPU
        # The last instance will lose the gradient function if we move it to the CPU.
        # This is because all grad function are present on the device that ran the FW pass.
        # The last activation remains on the GPU and is the return value of this function.
        # Note that this assumes that the target is also on the GPU which is required for calculating
        # the loss.
        
        result = model_instance._activations[-1] # 最后一层的激活
        result = [r.cuda() for r in result] # 把最后一层的激活移动到设备上,其余的已经移动到CPU之上
        for r in result:
            r.requires_grad = True
        return result[0] if len(result) == 1 else result在BW过程中,它执行相反的操作。我们使用保存的中间激活运行前向传播,并根据需要计算梯度。在使用激活检查点时,需要权衡延迟和内存。因为这里会用到几个PyTorch的内置方法,所以我们需要首先来看看其用法和原理。
torch.no_grad() 是一个上下文管理器,被 no_grad 包括起来的代码不会跟踪其梯度。我们用一个例子来看看。
import torch
x = torch.tensor([2.2], requires_grad=True)
y = x * 3
print(y)
y.add_(2)
print(y)
with torch.no_grad():
    y.div_(3)
    print(y)输出为:
tensor([6.6000], grad_fn=<MulBackward0>) # 这里记录了梯度操作
tensor([8.6000], grad_fn=<AddBackward0>) # add操作被跟踪
tensor([2.8667], grad_fn=<AddBackward0>) # 用了no_grad,所以div没有被跟踪torch.chunk(tensor, chunk_num, dim) 将张量按dimension(行或列)分割得到 chunk_num 个张量块,此函数将返回一个元组,比如下面例子。
x = torch.Tensor([[1,2,3]])
y = torch.Tensor([[4,5,6], [7,8,9], [10,11,12]])
z = torch.cat((x,y), dim=0)
print(z)
print(z.size())
c = torch.chunk(z,4,dim=0)
print(c)
print(len(c))输出为:
# cat之后的输出
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
torch.Size([4, 3])
# chunk之后的输出
(tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]), tensor([[7., 8., 9.]]), tensor([[10., 11., 12.]]))
4OffloadFunction 的反向传播如下,这里有个reverse操作需要注意。
具体代码如下:
class OffloadFunction(torch.autograd.Function):
    # Ignore the following function for code coverage since the backward pass
    # is triggered by C++ code and cannot be calculated when overriding
    # autograd.Function
    @staticmethod
    @_conditional_amp_bwd_decorator
    def backward(ctx, *grad_outputs):  # type: ignore # pragma: no cover
        inputs = ctx.inputs
        model_instance = ctx.model_instance
        # 遍历上下文存储的信息,给输入设定是否需要梯度
        for i, need_grad in enumerate(ctx.grad_requirements):
            inputs[i].requires_grad = need_grad
        # 得到反向传播的输入
        all_grads = [grad_outputs]
        # 把模型分片和激活进行reverse(注意,没有把原始分配和激活进行reverse,这里是reverse之后的结果返回,不影响原始数据),因为计算梯度是从后向前,所以把-1放到第一个位置,依次类推,这样可以方便使用backward_load和backward_drop。
        
        # 然后遍历模型分片,针对每一个分片进行处理
        for model_shard, activation in zip(
            reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
        ):
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
                # Move the activation to the GPU.
                # 把当前分片的激活移动到GPU
                activation = tuple([a.cuda() for a in list(activation)])
                # 把当前分片的模型移动到GPU
                # Move the model shard to the GPU.
                model_shard.backward_load()
            # Store the BW pass state.
            # 暂存反向传播状态
            bwd_rng_state = torch.get_rng_state()
            # TODO(anj-s): Why detach inputs?
            activation = torch.utils.checkpoint.detach_variable(activation)
            # Get the last gradient calculation.
            final_grads = all_grads[-1] # 这将会是最终生成的梯度
            if isinstance(activation, torch.Tensor):
                activation = (activation,)
            if isinstance(final_grads, torch.Tensor):
                final_grads = (final_grads,)
            # Iterate through all the inputs/outputs of a shard (there could be multiple).
            chunked_grad_list: List[Any] = []
            # Chunk the activation and grad based on the number of microbatches that are set.
            # 因为可能有多个微批次,所以需要把梯度和激活分别做chunk操作
            for chunked_activation, chunked_grad in zip(
                torch.chunk(*activation, model_instance._num_microbatches),  # type: ignore
                torch.chunk(*final_grads, model_instance._num_microbatches),  # type: ignore
            ):
                # Set the states to what it used to be before the forward pass.
                torch.set_rng_state(ctx.fwd_rng_state) # 暂时使用前向传播状态
                # 构建为list
                if isinstance(chunked_activation, torch.Tensor):
                    chunked_activation = (chunked_activation,)  # type: ignore
                if isinstance(chunked_grad, torch.Tensor):
                    chunked_grad = (chunked_grad,)  # type: ignore
                # Since we need a grad value of a non leaf element we need to set these properties.
                for a in chunked_activation:
                    if a.dtype == torch.long:
                        continue
                    a.requires_grad = True # 因为需要计算非叶子结点,所以将其设置为需要梯度
                    a.retain_grad()
                with torch.autograd.profiler.record_function(
                    "fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
                ):
                    with torch.enable_grad():
                        # calculate the output of the last shard wrt to the stored activation at the slice boundary.
                        outputs = model_shard(*chunked_activation) # 前向传播
                # Set the states back to what it was at the start of this function.
                torch.set_rng_state(bwd_rng_state) # 恢复状态
                
                with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
                    torch.autograd.backward(outputs, chunked_grad) # 反向传播
                    
                intermediate_grads = []
                for a in chunked_activation:
                    if a.grad is not None:
                        intermediate_grads.append(a.grad)
                if None not in intermediate_grads:
                    chunked_grad_list += intermediate_grads
             
            # 把梯度列表添加到all_grads之上
            if chunked_grad_list:
                # Append the list of grads to the all_grads list and this should be on the GPU.
                all_grads.append(torch.cat(chunked_grad_list).squeeze(-1))  # type: ignore
                
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
                # Move the shard back to the CPU. This should move all the grad tensors to CPU as well.
                # We don't need to move activations since we are using a copy of the tensors on the GPU.
                model_shard.backward_drop() # 分区移动到CPU
           
        # 之前的reverse没有对 model_instance._activations 做修改
        detached_inputs = model_instance._activations[0]
        # 从输入之中拿到其梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None, None) + grads # 返回梯度逻辑拓展如下:

至此,FSDP 分析完毕,我们下一个系列将会通过 NVIDIA Megatron 来介绍模型并行,敬请期待。
https://arxiv.org/pdf/2101.06840.pdf
https://www.deepspeed.ai/tutorials/zero-offload/
DeepSpeed: Extreme-scale model training for everyone
1 Li et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” VLDB 2020.
2 Cui et al. “GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server” EuroSys 2016
3 Shoeybi et al. “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv preprint arXiv:1909.08053 (2019).
4 Narayanan et al. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.” arXiv preprint arXiv:2104.04473 (2021).
5 Huang et al. “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.” arXiv preprint arXiv:1811.06965 (2018).
6 Narayanan et al. “PipeDream: Generalized Pipeline Parallelism for DNN Training.” SOSP 2019.
7 Narayanan et al. “Memory-Efficient Pipeline-Parallel DNN Training.” ICML 2021.
8 Shazeer et al. “The Sparsely-Gated Mixture-of-Experts Layer Noam.” arXiv preprint arXiv:1701.06538 (2017).
9 Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” arXiv preprint arXiv:2006.16668 (2020).
10 Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” arXiv preprint arXiv:2101.03961 (2021).
11 Narang & Micikevicius, et al. “Mixed precision training.” ICLR 2018.
12 Chen et al. 2016 “Training Deep Nets with Sublinear Memory Cost.” arXiv preprint arXiv:1604.06174 (2016).
13 Jain et al. “Gist: Efficient data encoding for deep neural network training.” ISCA 2018.
14 Shazeer & Stern. “Adafactor: Adaptive learning rates with sublinear memory cost.” arXiv preprint arXiv:1804.04235 (2018).
15 Anil et al. “Memory-Efficient Adaptive Optimization.” arXiv preprint arXiv:1901.11150 (2019).
16 Rajbhandari et al. “ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.” arXiv preprint arXiv:1910.02054 (2019).