首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

在 PyTorch 中使用梯度检查点在GPU 上训练更大的模型

作为机器学习从业者,我们经常会遇到这样的情况,想要训练一个比较大的模型,而 GPU 却因为内存不足而无法训练它。当我们在出于安全原因不允许在云计算的环境中工作时,这个问题经常会出现。在这样的环境中,我们无法足够快地扩展或切换到功能强大的硬件并训练模型。并且由于梯度下降算法的性质,通常较大的批次在大多数模型中会产生更好的结果,但在大多数情况下,由于内存限制,我们必须使用适应GPU显存的批次大小。

本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

梯度检查点

在反向传播算法中,梯度计算从损失函数开始,计算后更新模型权重。图中每一步计算的所有导数或梯度都会被存储,直到计算出最终的更新梯度。这样做会消耗大量 GPU 内存。梯度检查点通过在需要时重新计算这些值和丢弃在进一步计算中不需要的先前值来节省内存。

让我们用下面的虚拟图来解释。

上面是一个计算图,每个叶节点上的数字相加得到最终输出。假设这个图表示反向传播期间发生的计算,那么每个节点的值都会被存储,这使得执行求和所需的总内存为7,因为有7个节点。但是我们可以用更少的内存。假设我们将1和2相加,并在下一个节点中将它们的值存储为3,然后删除这两个值。我们可以对4和5做同样的操作,将9作为加法的结果存储。3和9也可以用同样的方式操作,存储结果后删除它们。通过执行这些操作,在计算过程中所需的内存从7减少到3。

在没有梯度检查点的情况下,使用PyTorch训练分类模型

我们将使用PyTorch构建一个分类模型,并在不使用梯度检查点的情况下训练它。记录模型的不同指标,如训练所用的时间、内存消耗、准确性等。

由于我们主要关注GPU的内存消耗,所以在训练时需要检测每批的内存消耗。这里使用nvidia-ml-py3库,该库使用nvidia-smi命令来获取内存信息。

pip install nvidia-ml-py3

为了简单起见,我们使用简单的狗和猫分类数据集的子集。

git clone https://github.com/laxmimerit/dog-cat-full-dataset.git

执行上述命令后会在dog-cat-full-dataset的文件夹中得到完整的数据集。

导入所需的包并初始化nvdia-smi

import torch

import torch.nn as nn

import torch.optim as optim

import numpy as np

from torchvision import datasets, models, transforms

import matplotlib.pyplot as plt

import time

import os

import cv2

import nvidia_smi

import copy

from PIL import Image

from torch.utils.data import Dataset,DataLoader

import torch.utils.checkpoint as checkpoint

from tqdm import tqdm

import shutil

from torch.utils.checkpoint import checkpoint_sequential

device="cuda" if torch.cuda.is_available() else "cpu"

%matplotlib inline

import random

nvidia_smi.nvmlInit()

导入训练和测试模型所需的所有包。我们还初始化nvidia-smi。

定义数据集和数据加载器

#Define the dataset and the dataloader.

train_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/train",

                           transform=transforms.Compose([

                               transforms.RandomRotation(30),

                               transforms.RandomHorizontalFlip(),

                               transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),

                               transforms.ToTensor(),

                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

                          ]))

val_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/test",

                           transform=transforms.Compose([

                               transforms.Resize([224, 224]),

                               transforms.ToTensor(),

                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),

                          ]))

train_dataloader=DataLoader(train_dataset,

                           batch_size=64,

                           shuffle=True,

                           num_workers=2)

val_dataloader=DataLoader(val_dataset,

                           batch_size=64,

                           shuffle=True,

                           num_workers=2)

这里我们用torchvision数据集的ImageFolder类定义数据集。还在数据集上定义了某些转换,如RandomRotation, RandomHorizontalFlip等。最后对图片进行归一化,并且设置batch_size=64

定义训练和测试函数

def train_model(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):

   model.train()

   #Training loop.

   for epoch in range(epochs):

     model.train()

     for images, target in tqdm(train_dataloader):

         images, target = images.to(device), target.to(device)

         images.requires_grad=True

         optimizer.zero_grad()

         output = model(images)

         loss = loss_func(output, target)

         loss.backward()

         optimizer.step()

     if os.path.exists('grad_checkpoints/') is False:

       os.mkdir('grad_checkpoints')

     torch.save(model.state_dict(), 'grad_checkpoints/epoch_'+str(epoch)+'.pt')

     #Test the model on validation data.

     train_acc,train_loss=test_model(model,train_dataloader)

     val_acc,val_loss=test_model(model,val_dataloader)

     #Check memory usage.

     handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

     info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)

     memory_used=info.used

     memory_used=(memory_used/1024)/1024

     print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")

def test_model(model,val_dataloader):

 model.eval()

 test_loss = 0

 correct = 0

 with torch.no_grad():

     for images, target in val_dataloader:

         images, target = images.to(device), target.to(device)

         output = model(images)

         test_loss += loss_func(output, target).data.item()

         _, predicted = torch.max(output, 1)

         correct += (predicted == target).sum().item()

 test_loss /= len(val_dataloader.dataset)

 return int(correct / len(val_dataloader.dataset) * 100),test_loss

上面创建了一个简单的训练和测试循环来训练模型。最后还通过调用nvidia-smi计算内存使用。

训练

torch.manual_seed(0)

#Learning rate.

lr = 0.003

#Defining the VGG16 sequential model.

vgg16=models.vgg16()

vgg_layers_list=list(vgg16.children())[:-1]

vgg_layers_list.append(nn.Flatten())

vgg_layers_list.append(nn.Linear(25088,4096))

vgg_layers_list.append(nn.ReLU())

vgg_layers_list.append(nn.Dropout(0.5,inplace=False))

vgg_layers_list.append(nn.Linear(4096,4096))

vgg_layers_list.append(nn.ReLU())

vgg_layers_list.append(nn.Dropout(0.5,inplace=False))

vgg_layers_list.append(nn.Linear(4096,2))

model = nn.Sequential(*vgg_layers_list)

model=model.to(device)

#Num of epochs to train

num_epochs=10

#Loss

loss_func = nn.CrossEntropyLoss()

# Optimizer

# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

#Training the model.

model = train_model(model, loss_func, optimizer,

                      train_dataloader,val_dataloader,num_epochs)

我们使用VGG16模型进行分类。下面是模型的训练日志。

可以从上面的日志中看到,在没有检查点的情况下,训练64个批大小的模型大约需要5分钟,占用内存为14222.125 mb。

使用带有梯度检查点的PyTorch训练分类模型

为了用梯度检查点训练模型,只需要编辑train_model函数。

def train_with_grad_checkpointing(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):

   #Training loop.

   for epoch in range(epochs):

     model.train()

     for images, target in tqdm(train_dataloader):

         images, target = images.to(device), target.to(device)

         images.requires_grad=True

         optimizer.zero_grad()

         #Applying gradient checkpointing

         segments = 2

         # get the modules in the model. These modules should be in the order

         # the model should be executed

         modules = [module for k, module in model._modules.items()]

         # now call the checkpoint API and get the output

         output = checkpoint_sequential(modules, segments, images)

         loss = loss_func(output, target)

         loss.backward()

         optimizer.step()

     if os.path.exists('checkpoints/') is False:

       os.mkdir('checkpoints')

     torch.save(model.state_dict(), 'checkpoints/epoch_'+str(epoch)+'.pt')

     #Test the model on validation data.

     train_acc,train_loss=test_model(model,train_dataloader)

     val_acc,val_loss=test_model(model,val_dataloader)

     #Check memory.

     handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

     info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)

     memory_used=info.used

     memory_used=(memory_used/1024)/1024

     print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")

def test_model(model,val_dataloader):

 model.eval()

 test_loss = 0

 correct = 0

 with torch.no_grad():

     for images, target in val_dataloader:

         images, target = images.to(device), target.to(device)

         output = model(images)

         test_loss += loss_func(output, target).data.item()

         _, predicted = torch.max(output, 1)

         correct += (predicted == target).sum().item()

 test_loss /= len(val_dataloader.dataset)

 return int(correct / len(val_dataloader.dataset) * 100),test_lossdef test_model(model,val_dataloader)

我们将函数名修改为train_with_grad_checkpointing。也就是不通过模型(图)运行训练,而是使用checkpoint_sequential函数进行训练,该函数有三个输入:modules, segments, input。modules是神经网络层的列表,按它们执行的顺序排列。segments是在序列中创建的段的个数,使用梯度检查点进行训练以段为单位将输出用于重新计算反向传播期间的梯度。本文设置segments=2。input是模型的输入,在我们的例子中是图像。这里的checkpoint_sequential仅用于顺序模型,对于其他一些模型将产生错误。

使用梯度检查点进行训练,如果你在notebook上执行所有的代码。建议重新启动,因为nvidia-smi可能会获得以前代码中的内存消耗。

torch.manual_seed(0)

lr = 0.003

# model = models.resnet50()

# model=model.to(device)

vgg16=models.vgg16()

vgg_layers_list=list(vgg16.children())[:-1]

vgg_layers_list.append(nn.Flatten())

vgg_layers_list.append(nn.Linear(25088,4096))

vgg_layers_list.append(nn.ReLU())

vgg_layers_list.append(nn.Dropout(0.5,inplace=False))

vgg_layers_list.append(nn.Linear(4096,4096))

vgg_layers_list.append(nn.ReLU())

vgg_layers_list.append(nn.Dropout(0.5,inplace=False))

vgg_layers_list.append(nn.Linear(4096,2))

model = nn.Sequential(*vgg_layers_list)

model=model.to(device)

num_epochs=10

#Loss

loss_func = nn.CrossEntropyLoss()

# Optimizer

# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

#Fitting the model.

model = train_with_grad_checkpointing(model, loss_func, optimizer,

                      train_dataloader,val_dataloader,num_epochs)

输出如下:

从上面的输出可以看到,每个epoch的训练大约需要6分45秒。但只需要10550.125 mb的内存,也就是说我们用时间换取了空间,并且这两种情况下的精度都是79,因为在梯度检查点的情况下模型的精度没有损失。

总结

梯度检查点是一个非常好的技术,它可以帮助在小显存的情况下完整模型的训练。经过我们的测试,一般情况下梯度检查点会将训练时间延长20%左右,但是时间长点总比不能用要好,对吧。

本文的源代码:

https://medium.com/geekculture/training-larger-models-over-your-average-gpu-with-gradient-checkpointing-in-pytorch-571b4b5c2068

作者:Vikas Kumar Ojha

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20230201A01XVP00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券