前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch.nn.DataParallel类

torch.nn.DataParallel类

作者头像
狼啸风云
修改2022-09-03 22:16:05
1.4K0
修改2022-09-03 22:16:05
举报

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]

Implements data parallelism at the module level.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

The batch size should be larger than the number of GPUs used.

See also: Use nn.DataParallel instead of multiprocessing

Arbitrary positional and keyword inputs are allowed to be passed into DataParallel but some types are specially handled. tensors will be scattered on dim specified (default 0). tuple, list and dict types will be shallow copied. The other types will be shared among different threads and can be corrupted if written to in the model’s forward pass.

The parallelized module must have its parameters and buffers on device_ids[0] before running this DataParallel module.

Warning

In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded. E.g., BatchNorm2d and spectral_norm() rely on this behavior to update the buffers.

Warning

Forward and backward hooks defined on module and its submodules will be invoked len(device_ids) times, each with inputs located on a particular device. Particularly, the hooks are only guaranteed to be executed in correct order with respect to operations on corresponding devices. For example, it is not guaranteed that hooks set via register_forward_pre_hook() be executed before all len(device_ids) forward() calls, but that each such hook be executed before the corresponding forward() call of that device.

Warning

When module returns a scalar (i.e., 0-dimensional tensor) in forward(), this wrapper will return a vector of length equal to number of devices used in data parallelism, containing the result from each device.

Note

There is a subtlety in using the pack sequence -> recurrent network -> unpack sequence pattern in a Module wrapped in DataParallel. See My recurrent network doesn’t work with data parallelism section in FAQ for details.

Parameters

  • module (Module) – module to be parallelized
  • device_ids (list of python:int or torch.device) – CUDA devices (default: all devices)
  • output_device (int or torch.device) – device location of output (default: device_ids[0])

Variables

~DataParallel.module (Module) – the module to be parallelized

Example:

代码语言:javascript
复制
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-09-30 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档