专栏首页Python与算法之美使用pytorch-lightning漂亮地进行深度学习研究

使用pytorch-lightning漂亮地进行深度学习研究

pytorch-lightning 是建立在pytorch之上的高层次模型接口。

pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow。

通过使用 pytorch-lightning,用户无需编写自定义训练循环就可以非常简洁地在CPU、单GPU、多GPU、乃至多TPU上训练模型。

无需考虑模型和数据在cpu,cuda之间的移动,并且可以通过回调函数实现CheckPoint参数保存,实现断点续训功能。

一般按照如下方式 安装和 引入 pytorch-lightning 库。

#安装
pip install pytorch-lightning
#引入
import pytorch_lightning as pl 

顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。??

一,pytorch-lightning的设计哲学

pytorch-lightning 的核心设计哲学是将 深度学习项目中的 研究代码(定义模型) 和 工程代码 (训练模型) 相互分离。

用户只需专注于研究代码(pl.LightningModule)的实现,而工程代码借助训练工具类(pl.Trainer)统一实现。

更详细地说,深度学习项目代码可以分成如下4部分:

  • 研究代码 (Research code),用户继承LightningModule实现。
  • 工程代码 (Engineering code),用户无需关注通过调用Trainer实现。
  • 非必要代码 (Non-essential research code,logging, etc...),用户通过调用Callbacks实现。
  • 数据 (Data),用户通过torch.utils.data.DataLoader实现。

二,pytorch-lightning使用范例

下面我们使用minist图片分类问题为例,演示pytorch-lightning的最佳实践。

1,准备数据

import torch 
from torch import nn 

import torchvision 
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor()])

ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
ds_valid = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_valid =  torch.utils.data.DataLoader(ds_valid, batch_size=128, shuffle=False, num_workers=4)

print(len(ds_train))
print(len(ds_valid))
Done!
60000
10000

2,定义模型

import pytorch_lightning as pl 
import datetime

class Model(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
        
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    #定义loss,以及可选的各种metrics
    def training_step(self, batch, batch_idx):
        x, y = batch
        prediction = self(x)
        loss = nn.CrossEntropyLoss()(prediction,y)
        return loss
    
    #定义optimizer,以及可选的lr_scheduler
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return {"optimizer":optimizer}
    
    def validation_step(self, batch, batch_idx):
        loss = self.training_step(batch,batch_idx)
        return {"val_loss":loss}
    
    def test_step(self, batch, batch_idx):
        loss = self.training_step(batch,batch_idx)
        return {"test_loss":loss}

3,训练模型

pl.seed_everything(1234)
model = Model() 


ckpt_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min'
)

# gpus=0 则使用cpu训练,gpus=1则使用1个gpu训练,gpus=2则使用2个gpu训练,gpus=-1则使用所有gpu训练,
# gpus=[0,1]则指定使用0号和1号gpu训练, gpus="0,1,2,3"则使用0,1,2,3号gpu训练
# tpus=1 则使用1个tpu训练

trainer = pl.Trainer(max_epochs=5,gpus=0,callbacks = [ckpt_callback]) 

#断点续训
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

trainer.fit(model,dl_train,dl_valid)
Global seed set to 1234
GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name   | Type       | Params
--------------------------------------
0 | layers | ModuleList | 54.0 K
--------------------------------------
54.0 K    Trainable params
0         Non-trainable params
54.0 K    Total params
Epoch 4: 100% >>>>>>>>>>>>>>>>>>>>>>>>>>>> 158/158 [00:19<00:00, 8.08it/s, loss=0.138, v_num=34]

4,评估模型

result = trainer.test(model, test_dataloaders=dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]

5,使用模型

data,label = next(iter(dl_valid))
model.eval()
prediction = model(data)
print(prediction)
tensor([[ -5.1149,  -6.1142,   2.0591,  ...,   7.0609,  -5.4144,   0.5222],
        [ -2.2989,  -5.6076,   3.7343,  ...,  -1.8391,  -6.4941,  -3.4076],
        [  0.9215,   6.9357,  -1.9887,  ...,  -2.2996,  -0.8034,  -3.2993],
        ...,
        [ -4.5674,  -6.0223,  -0.9309,  ...,  -3.5468,   0.3367,   4.5473],
        [  4.3023,  -4.1629,  -1.2742,  ...,  -4.2527,  -2.3449,  -2.5585],
        [ -3.8913, -10.3790,  -1.7804,  ...,  -4.6757,  -0.7428,   1.0305]],
       grad_fn=<AddmmBackward>)

6,保存模型

最优模型默认保存在 trainer.checkpoint_callback.best_model_path 的目录下,可以直接加载。

print(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_score)
/Users/liangyun/CodeFiles/PythonAiRoad/lightning_logs/version_34/checkpoints/epoch=04-val_loss=0.00.ckpt
tensor(0.0047)
model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer_clone = pl.Trainer(max_epochs=3) 
result = trainer_clone.test(model_clone,dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]

如果对本文内容理解上有需要进一步和作者交流的地方,欢迎在公众号"算法美食屋"下留言。作者时间和精力有限,会酌情予以回复。

本文分享自微信公众号 - Python与算法之美(Python_Ai_Road),作者:梁云1991

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2021-01-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 像 Keras 一样优雅地使用 pytorch-lightning

    你好,我是云哥。本篇文章为大家介绍一个可以帮助大家优雅地进行深度学习研究的工具:pytorch-lightning。

    lyhue1991
  • 从PyTorch到PyTorch Lightning —简要介绍

    PyTorch非常易于使用,可以构建复杂的AI模型。但是一旦研究变得复杂,并且将诸如多GPU训练,16位精度和TPU训练之类的东西混在一起,用户很可能会引入错误...

    代码医生工作室
  • 谁是2020年最强Python库?年度Top10出炉,看你用过几个!

    2020年已经过去了,国外的一家专门提供Python服务的网站Troy Labs,盘点出了2020年发布的Python库Top10。

    统计学家
  • 点亮BERT:3个步骤进行NLP迁移学习

    BERT可能是最流行的NLP迁移学习方法。Huggingface的实现提供了许多不错的功能,并在漂亮的API之后抽象了细节。

    代码医生工作室
  • 千呼万唤始出来!PyTorch Lightning 1.0版发布,终于可以抛弃Keras

    现阶段人工智能的发展速度远远超过任何单一框架所能跟上的速度,深度学习的领域更是在不断发展的,主要体现在在复杂性和规模上。

    新智元
  • PyTorch Lightning团队推出了Flash Lightning,可让用户推断,微调和训练数据模型

    Flash 是基于PyTorch Lightning的快速原型任务,基线和微调可扩展深度学习模型的集合。它提供从基准实验到最新研究的无缝体验。它使用户能够构建模...

    代码医生工作室
  • PyTorch Lightning 1.0 正式发布:从0到1

    PyTorch Lightning是基于PyTorch的高级框架,在云上大规模部署很有优势。

    McGL
  • PyTorch Lightning:专门为机器学习研究者开发的PyTorch轻量 wrapper

    专门为机器学习研究者开发的PyTorch轻量包装器(wrapper)。缩放您的模型。写更少的模板代码。

    磐创AI
  • Pytorch Lightning vs PyTorch Ignite vs Fast.ai

    PyTorch-lightning是最近发布的库,它是PyTorch的一个类似Kera的ML库。它将核心训练和验证逻辑留给您,并自动完成其余的工作。(顺便说一句...

    磐创AI
  • 使用PyTorch Lightning自动训练你的深度神经网络

    对使用PyTorch Lightning的训练代码和原始的PyTorch代码进行了对比,展示了其简单,干净,灵活的优点,相信你会喜欢的。

    Sam Gor
  • 基于PyTorch的「Keras」:除了核心逻辑通通都封装

    PyTorch Lightning 地址:https://github.com/williamFalcon/pytorch-lightning

    机器之心
  • Github年度最强的10个Python库

    链接:https://tryolabs.com/blog/2020/12/21/top-10-python-libraries-of-2020/

    CDA数据分析师
  • 在PyTorch上用"Keras",分布式训练开箱即用,告别没完没了的Debug

    在开始一个新的机器学习项目时,难免要重新编写训练循环,加载模型,分布式训练……然后在Debug的深渊里看着时间哗哗流逝,而自己离项目核心还有十万八千里。

    量子位
  • 使用Pytorch训练解决神经网络的技巧(附代码)

    事实上,你的模型可能还停留在石器时代的水平。估计你还在用32位精度或*GASP(一般活动仿真语言)*训练,甚至可能只在单GPU上训练。如果市面上有99个加速指南...

    数据派THU
  • 9个技巧让你的PyTorch模型训练变得飞快!

    让我们面对现实吧,你的模型可能还停留在石器时代。我敢打赌你仍然使用32位精度或GASP甚至只在一个GPU上训练。

    深度学习技术前沿公众号博主
  • 原创 | 深度学习框架比较,我该选择哪一个?

    近年来,深度学习在很多机器学习领域都有着非常出色的表现,在图像识别、语音识别、自然语言处理、机器人、网络广告投放、医学自动诊断和金融等领域有着广泛应用。面对繁多...

    数据派THU
  • 让PyTorch更轻便,这款深度学习框架你值得拥有!在GitHub上斩获6.6k星

    首先,它把研究代码与工程代码相分离,还将PyTorch代码结构化,更加直观的展现数据操作过程。

    量子位
  • 一行代码安装,TPU也能运行PyTorch,修改少量代码即可快速移植

    对于PyTorch开发者来说,Google Colab上的TPU资源不能用,恐怕是最遗憾的事情了。Google的云计算资源眼看着不能用,转TensorFlow又...

    量子位
  • TorchMetrics:PyTorch的指标度量库

    非常简单实用的PyTorch模型的分布式指标度量库,配合PyTorch Lighting实用更加方便。

    AI算法与图像处理

扫码关注云+社区

领取腾讯云代金券