前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用pytorch-lightning漂亮地进行深度学习研究

使用pytorch-lightning漂亮地进行深度学习研究

作者头像
lyhue1991
发布2021-01-26 22:05:10
2.9K0
发布2021-01-26 22:05:10
举报

pytorch-lightning 是建立在pytorch之上的高层次模型接口。

pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow。

通过使用 pytorch-lightning,用户无需编写自定义训练循环就可以非常简洁地在CPU、单GPU、多GPU、乃至多TPU上训练模型。

无需考虑模型和数据在cpu,cuda之间的移动,并且可以通过回调函数实现CheckPoint参数保存,实现断点续训功能。

一般按照如下方式 安装和 引入 pytorch-lightning 库。

代码语言:javascript
复制
#安装
pip install pytorch-lightning
代码语言:javascript
复制
#引入
import pytorch_lightning as pl 

顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。??

一,pytorch-lightning的设计哲学

pytorch-lightning 的核心设计哲学是将 深度学习项目中的 研究代码(定义模型) 和 工程代码 (训练模型) 相互分离。

用户只需专注于研究代码(pl.LightningModule)的实现,而工程代码借助训练工具类(pl.Trainer)统一实现。

更详细地说,深度学习项目代码可以分成如下4部分:

  • 研究代码 (Research code),用户继承LightningModule实现。
  • 工程代码 (Engineering code),用户无需关注通过调用Trainer实现。
  • 非必要代码 (Non-essential research code,logging, etc...),用户通过调用Callbacks实现。
  • 数据 (Data),用户通过torch.utils.data.DataLoader实现。

二,pytorch-lightning使用范例

下面我们使用minist图片分类问题为例,演示pytorch-lightning的最佳实践。

1,准备数据

代码语言:javascript
复制
import torch 
from torch import nn 

import torchvision 
from torchvision import transforms
代码语言:javascript
复制
transform = transforms.Compose([transforms.ToTensor()])

ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
ds_valid = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_valid =  torch.utils.data.DataLoader(ds_valid, batch_size=128, shuffle=False, num_workers=4)

print(len(ds_train))
print(len(ds_valid))
代码语言:javascript
复制
Done!
60000
10000

2,定义模型

代码语言:javascript
复制
import pytorch_lightning as pl 
import datetime

class Model(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
        
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    #定义loss,以及可选的各种metrics
    def training_step(self, batch, batch_idx):
        x, y = batch
        prediction = self(x)
        loss = nn.CrossEntropyLoss()(prediction,y)
        return loss
    
    #定义optimizer,以及可选的lr_scheduler
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return {"optimizer":optimizer}
    
    def validation_step(self, batch, batch_idx):
        loss = self.training_step(batch,batch_idx)
        return {"val_loss":loss}
    
    def test_step(self, batch, batch_idx):
        loss = self.training_step(batch,batch_idx)
        return {"test_loss":loss}

3,训练模型

代码语言:javascript
复制
pl.seed_everything(1234)
model = Model() 


ckpt_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min'
)

# gpus=0 则使用cpu训练,gpus=1则使用1个gpu训练,gpus=2则使用2个gpu训练,gpus=-1则使用所有gpu训练,
# gpus=[0,1]则指定使用0号和1号gpu训练, gpus="0,1,2,3"则使用0,1,2,3号gpu训练
# tpus=1 则使用1个tpu训练

trainer = pl.Trainer(max_epochs=5,gpus=0,callbacks = [ckpt_callback]) 

#断点续训
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

trainer.fit(model,dl_train,dl_valid)
代码语言:javascript
复制
Global seed set to 1234
GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name   | Type       | Params
--------------------------------------
0 | layers | ModuleList | 54.0 K
--------------------------------------
54.0 K    Trainable params
0         Non-trainable params
54.0 K    Total params
Epoch 4: 100% >>>>>>>>>>>>>>>>>>>>>>>>>>>> 158/158 [00:19<00:00, 8.08it/s, loss=0.138, v_num=34]

4,评估模型

代码语言:javascript
复制
result = trainer.test(model, test_dataloaders=dl_valid)
print(result)
代码语言:javascript
复制
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]

5,使用模型

代码语言:javascript
复制
data,label = next(iter(dl_valid))
model.eval()
prediction = model(data)
print(prediction)
代码语言:javascript
复制
tensor([[ -5.1149,  -6.1142,   2.0591,  ...,   7.0609,  -5.4144,   0.5222],
        [ -2.2989,  -5.6076,   3.7343,  ...,  -1.8391,  -6.4941,  -3.4076],
        [  0.9215,   6.9357,  -1.9887,  ...,  -2.2996,  -0.8034,  -3.2993],
        ...,
        [ -4.5674,  -6.0223,  -0.9309,  ...,  -3.5468,   0.3367,   4.5473],
        [  4.3023,  -4.1629,  -1.2742,  ...,  -4.2527,  -2.3449,  -2.5585],
        [ -3.8913, -10.3790,  -1.7804,  ...,  -4.6757,  -0.7428,   1.0305]],
       grad_fn=<AddmmBackward>)

6,保存模型

最优模型默认保存在 trainer.checkpoint_callback.best_model_path 的目录下,可以直接加载。

代码语言:javascript
复制
print(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_score)
代码语言:javascript
复制
/Users/liangyun/CodeFiles/PythonAiRoad/lightning_logs/version_34/checkpoints/epoch=04-val_loss=0.00.ckpt
tensor(0.0047)
代码语言:javascript
复制
model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer_clone = pl.Trainer(max_epochs=3) 
result = trainer_clone.test(model_clone,dl_valid)
print(result)
代码语言:javascript
复制
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]

如果对本文内容理解上有需要进一步和作者交流的地方,欢迎在公众号"算法美食屋"下留言。作者时间和精力有限,会酌情予以回复。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-01-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法美食屋 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一,pytorch-lightning的设计哲学
  • 二,pytorch-lightning使用范例
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档