前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch使用DistributedDataParallel进行多卡加速训练

pytorch使用DistributedDataParallel进行多卡加速训练

原创
作者头像
languageX
发布2021-11-02 10:21:06
2.8K0
发布2021-11-02 10:21:06
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

上文我们介绍了如何使用多线程在数据模块中进行模型训练加速,本文我们主要介绍在pytorch中如何使用DistributedDataParallel,torch.multiprocessing等模块来进行多卡并行处理提升模块训练速度。

下面依次介绍下pytorch的数据并行处理和多卡多进程并行处理,以及代码上如何调整代码进行多卡并行计算。

DataParallel(DP)

DataParallel是将数据进行并行,使用比较简单:

代码语言:javascript
复制
model = nn.DataParallel(model,device_ids=gpu_ids)

但是在使用过程中会发现加速并不明显,而且会有严重的负载不均衡。这里主要原因是虽然模型在数据上进行了多卡并行处理,但是在计算loss时确是统一到第一块卡再计算处理的,所以第一块卡的负载要远大于其他卡。

”DataParallel是数据并行,但是梯度计算是汇总在第一块GPU相加计算,这就造成了第一块GPU的负载远远大于剩余其他的显卡。

在前向过程中,你的输入数据会被划分成多个子部分(以下称为副本)送到不同的device中进行计算,而你的模型module是在每个device上进行复制一份,也就是说,输入的batch是会被平均分到每个device中去,但是你的模型module是要拷贝到每个devide中去的,每个模型module只需要处理每个副本即可,当然你要保证你的batch size大于你的gpu个数。然后在反向传播过程中,每个副本的梯度被累加到原始模块中。概括来说就是:DataParallel 会自动帮我们将数据切分 load 到相应 GPU,将模型复制到相应 GPU,进行正向传播计算梯度并汇总。”

具体分析可以参考: https://zhuanlan.zhihu.com/p/102697821

DistributedDataParallel(DDP)

DP这种方式实际gpu负载不均衡,不能很好的利用多卡。官方目前也更推荐使用"torch.nn.parallel.DistributedDataParallel" DDP的并行方式。

不同于DP是单进程多线程方式,DDP是通过多进程实现的,在每个GPU上创建一个进程。参数更新方式上DDP也是各进程独立进行梯度计算后进行汇总平均,然后再传播到所有进程。而DP是梯度都汇总到GPU0,反向传播更新参数再广播参数到其他的GPU。所以在速度上DDP更快,而且避免了多卡负载不均衡问题。

DP和DDP的区别可参考:https://zhuanlan.zhihu.com/p/206467852

下面直接从代码角度分析如何从单卡训练调整为使用DDP的多卡训练。

单卡进行模型训练逻辑:

代码语言:javascript
复制
def train(args, gpu_id, is_dist=False):
    # 创建模型
    model_builder = ModelBuilder()
    models, optimizers= model_builder.build_net(args, is_dist)
    # 创建loss
    model_builder.build_loss()
    # 创建数据
    train_loader, test_loader = build_data(args, is_dist)
    
   for epoch in range(start_epoch, max_epoch):
        for x_input, x_gt in enumerate(train_loader):
            # forward
            model_builder.forward(x_input, x_gt)
            # build loss
            model_builder.get_loss()
            # compute loss
            model_builder.criterion(args)
            # backward
            model_builder.backward()
            steps += 1

多卡进行模型训练逻辑:

代码语言:javascript
复制
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp

def train_worker(gpu_id, nprocs, cfg, is_dist):
    '''多卡分布式训练,独立进程运行
    '''
    os.environ['NCCL_BLOCKING_WAIT']="1"
    os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
    cudnn.deterministic = True
    # 提升速度,主要对input shape是固定时有效,如果是动态的,耗时反而慢
    torch.backends.cudnn.benchmark = True
    dist.init_process_group(backend='nccl',
    init_method='tcp://127.0.0.1:'+str(cfg['port']),
                            world_size=len(cfg['gpu_ids']),
                            rank=gpu_id)
    torch.cuda.set_device(gpu_id)
    # 按batch分割给各个GPU
    cfg['batch_size'] = int(cfg['batch_size'] / nprocs)
    train(cfg, gpu_id, is_dist)
    
def main():
    mp.spawn(train_worker, nprocs=gpu_nums, args=(gpu_nums, args, True))

其中build_net接口中,如果传入is_dist为True,需要设置DistributedDataParallel

代码语言:javascript
复制
if is_dist:
    d_net = DistributedDataParallel(
        net, device_ids=[gpu_id], find_unused_parameters=find_unused_parameters)

其中build_data接口中,如果is_dist为True,需要设置sampler

代码语言:javascript
复制
sampler = None
if dist:
    # sampler自动分配数据到各个gpu上
    sampler = DistributedSampler(dataset)
# pin_memory = True: 锁页内存,加快数据在内存上的传递速度。
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    sampler=sampler,
    drop_last=True,  
)

总结主要需要修改逻辑:

  1. 使用 mp.spawn创建多进程
代码语言:javascript
复制
mp.spawn(train_worker)

2. 初始化进程配置

train_worker中进行GPU_ids以及进程配置dist.init_process_group

3. 修改模型

在模型创建时使用DistributedDataParallel

4. 修改数据

在dataloader构建中使用DistributedSampler

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • DataParallel(DP)
  • DistributedDataParallel(DDP)
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档