首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >PyTorch闪电与亚马逊SageMaker

PyTorch闪电与亚马逊SageMaker
EN

Stack Overflow用户
提问于 2022-09-10 14:09:35
回答 3查看 199关注 0票数 0

目前,我们正在使用毕火炬闪电进行SageMaker以外的培训。希望利用SageMaker来利用分布式训练、检查点、模型训练优化(训练编译器)等来加速训练过程,节省成本。将他们的PyTorch闪电脚本迁移到SageMaker上的推荐方法是什么?

EN

回答 3

Stack Overflow用户

发布于 2022-09-14 23:58:29

在SageMaker上运行Py火炬闪电的最简单方法是使用SageMaker Pytorch估值器(示例)开始。理想情况下,您可以在源代码的同时添加一个requirement.txt,用于安装Pytors闪电。

关于分布式培训,亚马逊SageMaker最近推出了本机支持运行基于火炬闪电的分布式培训。请按照下面的链接设置您的培训代码

https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-modify-sdp-pt-lightning.html

https://aws.amazon.com/blogs/machine-learning/run-pytorch-lightning-and-native-pytorch-ddp-on-amazon-sagemaker-training-featuring-amazon-search/

票数 0
EN

Stack Overflow用户

发布于 2022-09-15 03:12:24

由于您的问题是特定于将已经工作的代码迁移到Sagemaker,因此,以链接在这里为参考,我可以尝试将该过程分为3部分:

  1. 创建一个火炬估计器- estimator
代码语言:javascript
运行
复制
import sagemaker
sagemaker_session = sagemaker.Session()

pytorch_estimator = PyTorch(
     entry_point='my_model.py',
    instance_type='ml.g4dn.16xlarge',
    instance_count=1,
    framework_version='1.7',
    py_version='py3',
    output_path: << s3 bucket >>,
    source_dir = <<  path for my_model.py >> ,
    sagemaker_session=sagemaker_session)
  1. entry_point = "my_model.py" -这部分应该是你现有的毕火炬闪电脚本。在方法中,您可以得到如下内容:
代码语言:javascript
运行
复制
if __name__ ==  '__main__':
     import pytorch_lightning as pl
     trainer = pl.Trainer(
                         devices=-1, ## in order to utilize all GPUs
                         accelerator="gpu", 
                         strategy="ddp", 
                         enable_checkpointing=True, 
                         default_root_dir="/opt/ml/checkpoints",
                         )
  1. model=estimator.fit()

此外,这里的链接很好地解释了编码过程。giu2021-Introduction-PyTorch-Lightning.pdf

票数 0
EN

Stack Overflow用户

发布于 2022-09-17 07:40:59

在使用PyTorch闪电和普通PyTorch脚本运行SageMaker方面没有太大的区别。

但是,在使用DDPPlugin运行分布式培训作业时,要注意的一点是在脚本开头正确设置NODE_RANK环境变量,因为PyTorch闪电对SageMaker环境变量一无所知,并且依赖于通用集群变量:

代码语言:javascript
运行
复制
os.environ["NODE_RANK"] = str(int(os.environ.get("CURRENT_HOST", "algo-1")[5:]) - 1)

或(更有力):

代码语言:javascript
运行
复制
rc = json.loads(os.environ.get("SM_RESOURCE_CONFIG", "{}"))
os.environ["NODE_RANK"] = str(rc["hosts"].index(rc["current_host"]))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73672457

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档