前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >更简单实用的pytorch——pytorch_lighting

更简单实用的pytorch——pytorch_lighting

原创
作者头像
勤劳小王
发布2024-07-12 09:07:05
2310
发布2024-07-12 09:07:05

更简单实用的pytorch——pytorch_lighting

介绍

PyTorch Lightning的优点

PyTorch Lightning 是一个“batteries included”的深度学习框架,适合需要最大灵活性同时大规模增强性能的专业人工智能研究人员和机器学习工程师。

Lightning 组织 PyTorch 代码以删除样板文件并释放可扩展性。

具体而言,Lightning 把深度学习中网络定义、前向传播、优化器、训练方式、训练输出都进行了高级封装,可以使得代码更加简洁易写,同时也能根据用户需求进行灵活调整。

如何安装?

pip install lightning

conda install lightning -c conda-forge

PyTorch LightningPyTorch 本身不会直接冲突,因为 PyTorch Lightning 是建立在 PyTorch 基础之上的高级封装,旨在简化深度学习模型的训练过程。然而,如果两者的版本不兼容,或者在同一个环境中安装了相互冲突的依赖包,可能会出现问题。为了避免这些问题,应该确保按照官方文档推荐的版本兼容性矩阵来安装相应版本的 PyTorchPyTorch Lightning

PyTorch LightningPyTorch版本对应关系

使用教程

定义LightningModule

LightningModule 使您的 PyTorch nn.Module 能够在训练步骤(还有可选的验证步骤和测试步骤)内以复杂的方式一起运行。

代码语言:python
代码运行次数:0
复制
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)```

定义数据集

Lightning 支持任何可迭代( DataLoader 、 numpy 等)用于训练/验证/测试/预测分割。

代码语言:python
代码运行次数:0
复制
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)```

训练模型

Lightning Trainer 将任何 LightningModule 与任何数据集“混合”,并抽象出扩展所需的所有工程复杂性。

代码语言:python
代码运行次数:0
复制
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)```

The Lightning Trainer automates 40+ tricks including:

Lightning Trainer 可自动执行 40 多个技巧,包括:

使用模型

训练完模型后,您可以导出到 onnx、torchscript 并将其投入生产,或者只是加载权重并运行预测。

代码语言:python
代码运行次数:0
复制
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)```

可视化训练

如果您安装了tensorboard ,则可以使用它来可视化实验。

在命令行上运行此命令并打开浏览器访问 http://localhost:6006/

代码语言:python
代码运行次数:0
复制
tensorboard --logdir .

Supercharge training

使用 Trainer 参数启用高级训练功能。这些是最先进的技术,可以自动集成到您的训练循环中,而无需更改您的代码。

代码语言:python
代码运行次数:0
复制
# train on 4 GPUs
trainer = L.Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
trainer = L.Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
trainer = L.Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])

最大限度地提高灵活性

Lightning 的核心指导原则是始终提供最大的灵活性,而不隐藏任何 PyTorch。

根据项目的复杂性,Lightning 提供 5 种额外的灵活性。

自定义训练循环

使用 LightningModule 中提供的 20 多种方法(Hook)中的任何一个,在训练循环中的任何位置注入自定义代码。

代码语言:python
代码运行次数:0
复制
class LitAutoEncoder(L.LightningModule):
    def backward(self, loss):
        loss.backward()
扩展训练器

讲解视频

如果您有多行具有类似功能的代码,则可以使用回调将它们轻松分组在一起,并同时打开或关闭所有这些行。

代码语言:python
代码运行次数:0
复制
trainer = Trainer(callbacks=[AWSCheckpoints()])
使用raw PyTorch loop

对于某些类型的前沿研究工作,Lightning 为专家提供了以各种方式完全控制优化或训练循环的能力。

拥有你的循环(高级) — PyTorch Lightning 2.3.3 文档 --- Own your loop (advanced) — PyTorch Lightning 2.3.3 documentation

拓展阅读

[第 2 级:添加验证和测试集 — PyTorch Lightning 2.3.3 文档 --- Level 2: Add a validation and test set — PyTorch Lightning 2.3.3 documentation](https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html)

PyTorch Lightning 教程 — PyTorch Lightning 2.3.3 文档 --- PyTorch Lightning Tutorials — PyTorch Lightning 2.3.3 documentation

将模型部署到生产中 — PyTorch Lightning 2.3.3 文档 --- Deploy models into production — PyTorch Lightning 2.3.3 documentation

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 更简单实用的pytorch——pytorch_lighting
    • 介绍
      • PyTorch Lightning的优点
      • 如何安装?
    • 使用教程
      • 定义LightningModule
      • 定义数据集
      • 训练模型
      • 使用模型
      • 可视化训练
      • Supercharge training
      • 最大限度地提高灵活性
    • 拓展阅读
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档