首页
学习
活动
专区
圈层
工具
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

PyTorch Lightning,一个轻量级的PyTorch框架!

说起PyTorch,相信不少小伙伴都用过。它确实是个不错的深度学习框架,但有时候用起来还是觉得麻烦。好在有了PyTorch Lightning这个“神器”,它就像是给PyTorch装上了涡轮增压器,让咱们的开发工作事半功倍!

1.

啥是PyTorch Lightning?

PyTorch Lightning 是基于PyTorch的一个轻量级框架,它的目标就是让你的深度学习代码更整洁、更规范、更高效。说白了,就是帮你省事儿。

它主要解决了啥问题呢?比如说,你在用PyTorch时可能要写一大堆琐碎的代码来处理训练循环、验证、测试这些流程。用了Lightning,这些烦人的活儿它都帮你包圆了。你只需要专注于模型的核心逻辑就行,剩下的交给Lightning去操心。

2.

Lightning的主要特性

1. 代码更简洁了

用Lightning写代码,感觉就像是在写诗一样优雅。它把模型定义、训练逻辑啥的都整合到一个LightningModule类里了。看看这个例子:

import pytorch_lightning as pl

class MyAwesomeModel(pl.LightningModule):

  def __init__(self):

      super().__init__()

      self.layer = nn.Linear(28 * 28, 10)

  def forward(self, x):

      return self.layer(x.view(x.size(0), -1))

  def training_step(self, batch, batch_idx):

      x, y = batch

      y_hat = self(x)

      loss = F.cross_entropy(y_hat, y)

      return loss

  def configure_optimizers(self):

      return torch.optim.Adam(self.parameters(), lr=0.02)

瞧瞧,是不是感觉清爽多了?

2. 自动化训练流程

Lightning帮你搞定了训练、验证、测试的整个流程。你只要定义好training_step、validation_step这些方法,剩下的它都会自动处理。省心!

trainer = pl.Trainer(max_epochs=10, gpus=1)

model = MyAwesomeModel()

trainer.fit(model, train_dataloader, val_dataloader)

就这么几行代码,训练就开始了,美滋滋~

3. 多GPU训练?小菜一碟!

在PyTorch中搞多GPU训练有点麻烦,但Lightning让这事儿变得超简单。只要改一下Trainer的参数就行:

trainer = pl.Trainer(max_epochs=10, gpus=4, strategy=“ddp”)

这样就能在4个GPU上愉快地分布式训练了,不用自己写繁琐的并行代码。

3.

Lightning的实际应用

说了这么多,来点实际的。假设咱们要训练一个简单的图像分类模型,用Lightning怎么搞?

import torch

import torch.nn as nn

import torch.nn.functional as F

import pytorch_lightning as pl

from torchvision import transforms

from torchvision.datasets import MNIST

class MNISTClassifier(pl.LightningModule):

  def __init__(self):

      super().__init__()

      self.conv1 = nn.Conv2d(1, 32, 3, 1)

      self.conv2 = nn.Conv2d(32, 64, 3, 1)

      self.fc1 = nn.Linear(9216, 128)

      self.fc2 = nn.Linear(128, 10)

  def forward(self, x):

      x = F.relu(self.conv1(x))

      x = F.relu(self.conv2(x))

      x = F.max_pool2d(x, 2)

      x = torch.flatten(x, 1)

      x = F.relu(self.fc1(x))

      x = self.fc2(x)

      return x

  def training_step(self, batch, batch_idx):

      x, y = batch

      y_hat = self(x)

      loss = F.cross_entropy(y_hat, y)

      self.log('train_loss', loss)

      return loss

  def configure_optimizers(self):

      return torch.optim.Adam(self.parameters(), lr=0.001)

# 数据加载

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = MNIST('data', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=32)

# 训练

model = MNISTClassifier()

trainer = pl.Trainer(max_epochs=5, gpus=1)

trainer.fit(model, train_loader)

就这么简单,一个MNIST分类器就训练好了。Lightning帮我们处理了数据加载、训练循环、损失计算等各种琐事,我们只需要关注模型结构和核心逻辑。

温馨提示:在实际项目中,别忘了加上验证集和测试集哦!这里为了简化例子,就只用了训练集。

4.

Lightning的一些小技巧

自动日志记录:Lightning默认会记录你在self.log()中传入的所有指标。想看训练过程中的各种指标变化,用TensorBoard就行了。

断点续训:训练到一半想停下来喝杯咖啡?没问题!Lightning支持断点续训练,只要在Trainer中设置resume_from_checkpoint参数就行。

超参数优化:Lightning还能和一些超参数优化库(比如Optuna)无缝集合,帮你自动找最佳超参数。

模型checkpointing:Lightning会自动保存训练过程中的最佳模型,你都不用操心。

5.

小结

PyTorch Lightning就像是给你的PyTorch代码装上了一对翅膀,让你的深度学习之旅飞得更高、更远。它简化了代码结构,自动化了训练流程,还提供了很多高级功能。用上它,你写代码的速度能提高不少,debugging也会轻松很多。

不过话说回来,Lightning再好用,也别忘了理解底层的PyTorch原理。毕竟工具只是辅助,真正的功力还是得靠平时的积累和思考。

好了,今天的分享就到这里。希望这篇文章能给你的学习带来一些启发。记住,不断练习才是提高编程水平的王道。加油!

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OJcO2rt0mui536lSYk9YL-ww0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券