机器之心报道
机器之心编辑部
作为「史上最强 GAN 图像生成器」,BigGAN 自去年 9 月推出以来就成为了 AI 领域最热词。其生成图像的目标和背景都高度逼真、边界自然,简直可以说是在「创造新物种」。然而 BigGAN 训练时需要的超高算力(128-512 个谷歌 TPU v3 核心)却让很多想要参与制图狂欢的开发者望而却步。 今日,BigGAN 论文的第一作者、来自英国 Heriot-Watt 大学的 Andrew Brock 发布了 BigGAN 的 PyTorch 版实现。最令人高兴的是:这一次训练模型的算力要求被降低到 4 到 8 块 GPU 了!
项目链接:https://github.com/ajbrock/BigGAN-PyTorch 该项目一出即引发了人们的广泛关注,有的人表示不敢相信,也有人哭晕在 Colab。
Brock 本次放出的 BigGAN 实现包含训练、测试、采样脚本以及完整的预训练检查点(生成器、判别器和优化器),以便你可以在自己的数据上进行微调或者从零开始训练模型。 作者表示,这些代码制作时间很长,从一开始就被设计成可操控、可扩展的基础,以方便未来的研究。作者花了很多心思考虑在什么地方具体使用什么抽象,以确保它们有效但又易于理解或改变。 这一工作是 Andrew Brock 与 MIT 的 Alex Andonian 一起完成的。
BigGAN 的 PyTorch 实现
这是由论文原作者正式发布的「非官方」BigGAN PyTorch 实现。
该 repo 包含用 4-8 个 GPU 训练 BigGAN 的代码。 如何使用 你需要用到:
首先,你可以准备目标数据集的预处理 HDF5 版本,以便更快地输入/输出(可选)。在此之后(不管是否如此做了),你需要计算 FID 所需的 Inception moment。这些都可以通过修改并运行以下代码来完成:
sh scripts/utils/prepare_data.sh
默认情况下,假设你的 ImageNet 训练集已经下载至此目录的根文件夹 data
中,然后以 128x128 的像素分辨率准备缓存的 HDF5。
脚本文件夹中有多个 bash 脚本,此类脚本可以用不同的批量大小训练 BigGAN。这段代码假设你无法访问完整的 TPU pod,然后通过梯度累积(将多个小批量上的梯度平均化,然后仅在 N 次累积后采取优化步骤)表示相应的 mega-batches。默认情况下,launch_BigGAN_bs256x8.sh
脚本训练批量大小为 256 且具备 8 次梯度累积的完整 BigGAN 模型,其总的批量大小为 2048。在 8xV100 上进行全精度训练(无张量核),这个脚本需要 15天训练到 15 万次迭代。
你需要先确定你的设置能够支持的最大批量。这里提供的预训练模型是在 8xV100(每个有 16GB VRAM)上训练的,8xV100 能支持比默认使用的 BS256 略大的批量大小。一旦确定了这一点,你应该修改脚本,使批大小乘以梯度累积的数量等同于你期望的总批量大小(BigGAN 默认的总批量大小是 2048)。
注意,这个脚本使用参数 --load_in_mem
,该参数会将整个 I128.hdf5(约 64GB)文件加载至 RAM 中,以便更快地加载数据。如果你没有足够的 RAM 来支持这个(可能需要 96GB 以上),删除这个参数。
度量和采样
在训练过程中,该脚本将输出包含训练度量和测试度量的日志,并保存模型权重/优化器参数的多个副本(2 个最新的和 5 个得分最高的),还会在每次保存权重时产生样本和插值。日志文件夹包含处理这些日志及使用 MATLAB 绘制结果的脚本。
训练结束后,你可以使用 sample.py
生成额外的样本和插值,用不同的截断值、批大小、standing stat 累积次数等进行测试。示例参考 sample_BigGAN_bs256x8.sh
脚本。
默认情况下,所有内容都会保存至 weights/samples/logs/data 文件夹中,这些文件夹应与该 repo 在同一文件夹中。你可以使用 --base_root
参数将这些文件夹指向不同的根目录,或者使用对应的参数(如 --logs_root)为
每个文件夹选择特定的位置。
该 repo 还包含运行 BigGAN-deep 的脚本,但作者尚未使用它们来完整地训练模型,所以可将其视为未经测试。另外,该 repo 包含在 CIFAR 上运行模型的脚本,以及在 ImageNet 上运行 SA-GAN(带有EMA)和 SN-GAN 的脚本。SA-GAN 代码假设你有 4xTitanX(或具备同等 RAM 的 GPU),并使用 128 的批量大小和 2 个梯度累积来训练。
关于 Inception 度量的重要提示
该 repo 使用 PyTorch 内置 inception 网络来计算 IS 和 FID 分数。这些分数与使用官方 TF inception 代码得到的不同,且仅用于监控目的。使用 --sample_npz
参数在模型上运行 sample.py,然后运行 inception_tf13 来计算真实的 TensorFlow IS。注意:你需要安装 TensorFlow 1.3 或更早版本,因为 TF1.4+ 会破坏原始 IS 代码。
预训练模型
该 repo 包含两个预训练模型检查点(具备 G、D、G 的 EMA copy、优化器和 state dict):
使用 Places-365 数据集预训练模型也将很快开源。 该 repo 还包含将原始 TFHub BigGAN Generator 权重迁移到 PyTorch 的脚本。详见 TFHub 文件夹。
如果你想继续被中断的训练或者微调预训练模型,运行同样的启动脚本,不过这次需要添加 —resume 参数。实验名称是从配置中自动生成的,但是你可以使用 —experiment_name 参数对其进行重写(例如你想使用修改后的优化器设置来微调模型)。
要想使用自己的数据集,你需要将其添加到 datasets.py,并修改 utils.py 中的 convenience dicts (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict),从而为自己的数据集准备适合的元数据。在 prepare_data.sh 中重复该过程(可选择性地生成 HDF5 preprocessed copy,然后计算 FID 所需的 Inception moment。
默认情况下,该训练脚本将以 Inception Score 为衡量标准选出 top 5 最优检查点并保存。对于 ImageNet 以外的数据集,模型的 Inception Score 可能不是很好的质量度量标准,因此你可以使用 which_best FID 来代替 Inception Score。
要想使用自己的训练函数(如训练 BigVAE),你可以修改 train_fns.GAN_training_function,或者将新的训练函数添加到 if config['which_train_fn'] == 'GAN' 之后(train.py 中的行)。
亮点
这段代码与原始 BigGAN 的关键区别