ViT

最近更新时间:2024-08-29 10:19:52

我的收藏
本指导适用于在 TencentOS Server 3上使用 DataParallel(DP)训练框架运行 ViT (Vision Transformer) 模型,以 Docker 方式启动。

前置环境条件

请确保已经按照 ResNet 文档内进行操作,运行模型之前的所有步骤已经完成,并已经准备好了运行模型的所有必要环境。

运行模型

DP 运行的文件为 dataparallel.py,其中有以下几处需要修改。
代码第118行,可以根据自己的 GPU 数量更改训练时使用的 GPU 数。例如我们这里拥有8块L40 GPU,我们将所有 GPU 用于训练模型。
#原有代码
gpus = [0, 1, 2, 3]

#更改为
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
代码第375行会报错,将 view() 方法更改为 reshape() 方法即可。
#原有代码
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)

#更改为
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
以上修改完之后,即可运行模型。这里我们测试的是 vit_base_16 模型通过 ImageNet-1K 数据集训练。
python dataparallel.py --data datasets/imagenet --arch vit_b_16 --epochs 300 --batch-size 2048 --learning-rate 0.001
由于当前情况下 L40 batch_size 为原论文中4096时会显存溢出,所以这里采用2048的 batch_size 大小。
此时会使用 vit_base_16 模型架构,从0开始训练300个 epochs,数据集在 datasets/imagenet 文件夹下,batchsize 为2048,学习率lr为0.001,优化方法为SGD 使用 momentum=0.9,权重衰退 weight decay=1e-4。部分参数设置来源于ViT原论文:ViT,从而尽可能的复现出论文里的结果。
说明:
如需要更改 momentum 参数和权重衰减参数,只需加上对应参数--momentum--weight-decay并设置想要的值即可。
我们发现 batchsize 为2048时训练结果不一定特别理想,原文 batchsize 为4096,可以根据需要对相应参数进行修改以获得更好结果。
此时会开始训练模型,由于我们设置 DP 汇总梯度的 GPU 为 gpus[0],(代码第138行 model = nn.DataParallel(model, device_ids=gpus, output_device=gpus[0])),所以 GPU0 的显存使用量会明显高于其他7块 GPU。
运行一个 epoch 的时间大约为半个小时,300个 epoch 大约需要六天训练完成。训练花费的时间可以在 dataparallel.csv 文件里找到。运行模型过程中会保存当前 epoch 训练结束之后的模型权重以及到当前 epoch 为止在验证集上最好性能的 epoch 的模型权重。
训练时命令行会出现训练时每一个 iteration 和测试时每一个 iteration 的性能以及每一个 epoch 测试完成之后的总的性能(例如第30个 epoch 的参考如下):
...
Epoch: [30][480/626] Time 5.109 ( 3.039) Data 4.707 ( 2.446) Loss 4.4270e+00 (4.3804e+00) Acc@1 16.99 ( 17.58) Acc@5 34.42 ( 36.74)
Epoch: [30][490/626] Time 5.261 ( 3.041) Data 4.859 ( 2.446) Loss 4.3596e+00 (4.3801e+00) Acc@1 18.26 ( 17.58) Acc@5 38.23 ( 36.74)
Epoch: [30][500/626] Time 5.156 ( 3.041) Data 4.753 ( 2.446) Loss 4.4136e+00 (4.3801e+00) Acc@1 17.77 ( 17.58) Acc@5 36.23 ( 36.73)
Epoch: [30][510/626] Time 4.843 ( 3.040) Data 4.441 ( 2.445) Loss 4.3692e+00 (4.3798e+00) Acc@1 19.19 ( 17.59) Acc@5 36.91 ( 36.74)
Epoch: [30][520/626] Time 5.095 ( 3.040) Data 4.664 ( 2.444) Loss 4.3212e+00 (4.3796e+00) Acc@1 18.16 ( 17.59) Acc@5 38.09 ( 36.74)
Epoch: [30][530/626] Time 5.023 ( 3.040) Data 4.621 ( 2.443) Loss 4.3655e+00 (4.3794e+00) Acc@1 18.07 ( 17.60) Acc@5 37.30 ( 36.76)
Epoch: [30][540/626] Time 5.103 ( 3.040) Data 4.700 ( 2.442) Loss 4.4637e+00 (4.3803e+00) Acc@1 17.77 ( 17.59) Acc@5 35.60 ( 36.74)
Epoch: [30][550/626] Time 5.359 ( 3.041) Data 4.955 ( 2.443) Loss 4.3498e+00 (4.3800e+00) Acc@1 18.02 ( 17.59) Acc@5 36.28 ( 36.75)
Epoch: [30][560/626] Time 5.073 ( 3.041) Data 4.670 ( 2.442) Loss 4.4005e+00 (4.3800e+00) Acc@1 17.58 ( 17.59) Acc@5 36.77 ( 36.75)
Epoch: [30][570/626] Time 5.122 ( 3.042) Data 4.717 ( 2.442) Loss 4.4096e+00 (4.3798e+00) Acc@1 17.82 ( 17.59) Acc@5 35.94 ( 36.75)
Epoch: [30][580/626] Time 5.111 ( 3.042) Data 4.707 ( 2.442) Loss 4.3924e+00 (4.3795e+00) Acc@1 17.43 ( 17.59) Acc@5 35.55 ( 36.75)
Epoch: [30][590/626] Time 5.182 ( 3.043) Data 4.780 ( 2.442) Loss 4.3718e+00 (4.3795e+00) Acc@1 17.77 ( 17.59) Acc@5 36.62 ( 36.75)
Epoch: [30][600/626] Time 5.250 ( 3.043) Data 4.845 ( 2.442) Loss 4.3822e+00 (4.3792e+00) Acc@1 16.65 ( 17.60) Acc@5 37.16 ( 36.75)
Epoch: [30][610/626] Time 4.912 ( 3.043) Data 4.508 ( 2.441) Loss 4.2889e+00 (4.3789e+00) Acc@1 17.87 ( 17.60) Acc@5 38.92 ( 36.77)
Epoch: [30][620/626] Time 4.951 ( 3.043) Data 4.546 ( 2.441) Loss 4.3697e+00 (4.3789e+00) Acc@1 17.63 ( 17.60) Acc@5 36.33 ( 36.76)
Test: [ 0/25] Time 9.065 ( 9.065) Loss 3.7750e+00 (3.7750e+00) Acc@1 24.66 ( 24.66) Acc@5 49.32 ( 49.32)
Test: [10/25] Time 6.520 ( 4.251) Loss 4.6754e+00 (4.0579e+00) Acc@1 13.82 ( 20.16) Acc@5 32.23 ( 42.12)
Test: [20/25] Time 3.998 ( 3.994) Loss 4.8062e+00 (4.3301e+00) Acc@1 10.45 ( 17.44) Acc@5 25.49 ( 37.14)
* Acc@1 18.318 Acc@5 38.298
同时 dataparallel.csv 可以看到训练的时间(参考):
...
2024-08-15 14:03:14,1982.8778982162476
2024-08-15 14:36:18,2001.0861496925354
2024-08-15 15:09:41,1995.106449842453
2024-08-15 15:42:57,2005.7094027996063
2024-08-15 16:16:24,1995.0537593364716
2024-08-15 16:49:41,2051.8484938144684
2024-08-15 17:23:54,2018.2189569473267
2024-08-15 17:57:34,1993.9278218746185
2024-08-15 18:30:49,2039.668051958084
2024-08-15 19:04:51,2022.5196571350098
2024-08-15 19:38:35,2007.2742729187012
2024-08-15 20:12:04,1991.4569637775421
2024-08-15 20:45:17,1996.3973467350006
2024-08-15 21:18:35,2011.3201851844788
同时目录下会出现 model_best.pth.tar和checkpoint.pth.tar,记录模型权重。

参考文档

ViT