首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >自定义数据集上训练StyleGAN | 基于Python+OpenCV+colab实现

自定义数据集上训练StyleGAN | 基于Python+OpenCV+colab实现

作者头像
AI算法与图像处理
发布2021-04-21 15:14:36
2.8K0
发布2021-04-21 15:14:36
举报

重磅干货,第一时间送达

概要

  • 分享我的知识,使用带有示例代码片段的迁移学习逐步在Google colab中的自定义数据集上训练StyleGAN
  • 如何使用预训练的权重从自定义数据集中生成图像
  • 使用不同的种子值生成新图像

介绍

生成对抗网络(GAN) 是机器学习中的一项最新创新,由 Ian J. Goodfellow 及其同事于2014年首次提出。

它是一组神经网络,以两人零和博弈的形式相互对抗。博弈论(一个人的胜利就是另一个人的损失)。它是用于无监督学习的生成模型的一种形式。

这里有一个生成器(用于从潜在空间中的某个点在数据上生成新实例)和鉴别器(用于将生成器生成的数据与实际或真实数据值区分开)。

最初,生成器生成虚假或伪造的数据,鉴别器可以将其分类为伪造,但是随着训练的继续,生成器开始学习真实数据的分布并开始生成真实的数据。

这种情况一直持续到鉴别器无法将其分类为不真实的并且生成器输出的所有数据看起来都像真实数据。

因此,此处生成器的输出连接到鉴别器的输入,并根据鉴别器的输出(是实数还是非实数)计算损失,并通过反向传播,为后续训练(epoch)更新生成器的权重。

StyleGAN目前在市场上有多种GAN变体,但在本文中,我将重点介绍Nvidia在2018年12月推出的StyleGAN。StyleGAN的体系结构使用基线渐进式GAN。即,生成图像的大小从非常低的角度逐渐增加分辨率(4×4)到非常高的分辨率(1024×1024),并使用双线性采样代替基线渐进式GAN中使用的最近邻居上/下采样。

该博客的主要目的是解释如何使用迁移学习在自定义数据集上训练StyleGAN,因此,有关GAN架构的更多详细信息,请参见NVlabs / stylegan-官方TensorFlow GitHub链接

  • https://github.com/NVlabs/stylegan

迁移学习在另一个相似的数据集上使用已训练的模型权重并训练自定义数据集。

自定义数据集包含2500个来自时尚的纹理图像。下面几张示例纹理图像可供参考。此处你可以替换成自己的自定义数据集。

重点和前提条件:
  1. 必须使用GPU,StyleGAN无法在CPU环境中进行训练。为了演示,我已经使用google colab环境进行实验和学习。
  2. 确保选择Tensorflow版本1.15.2。StyleGAN仅适用于tf 1.x
  3. StyleGAN训练将花费大量时间(几天之内取决于服务器容量,例如1个GPU,2个GPU等)
  4. 如果你正在从事与GAN相关的任何实时项目,那么由于colab中的使用限制和超时,你可能想在 tesla P-80或 P-100专用服务器上训练GAN 。
  5. 如果你有google-pro(不是强制性的),则可以节省多达40-50%的本文训练时间 ,我对GAN进行了3500次迭代训练,因为训练整个GAN需要很长时间(要获取高分辨率图像),则需要至少运行25000次迭代(推荐)。另外,我的图像分辨率是64×64,但是styleGAN是在1024×1024分辨率图像上训练的。
  6. 我已使用以下预先训练的权重来训练我的自定义数据集(有关更多详细信息,请参见Tensorflow Github官方链接)
    • https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ

使用迁移学习在Google Colab中的自定义数据集上训练style GAN

  1. 打开colab并打开一个新的botebook。确保在Runtime->Change Runtime type->Hardware accelerator下设置为GPU
  2. 验证你的帐户并装载G驱动器
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
  1. 确保选择了Tensorflow版本1.15.2。StyleGAN仅适用于tf1.x。
%tensorflow_version 1.x
import tensorflow
print(tensorflow.__version__)
  1. 从 https://github.com/NVlabs/stylegan 克隆stylegan.git
!git clone https://github.com/NVlabs/stylegan.git
!ls /content/stylegan/
You should see something like this
config.py              LICENSE.txt             run_metrics.py
dataset_tool.py        metrics                 stylegan-teaser.png
dnnlib                 pretrained_example.py   training
generate_figures.py    README.md               train.py

5. 将 stylegan文件夹添加到python,以导入dnnlib模块

import sys
sys.path.insert(0, "/content/stylegan")
import dnnlib

6. 将自定义数据集从G驱动器提取到你选择的colab服务器文件夹中

!unrar x "/content/drive/My Drive/CustomDataset.rar" "/content/CData/"

7. Stylegan要求图像必须是正方形,并且为获得很好的分辨率,图像必须为1024×1024。但是在本演示中,我将使用64×64的分辨率,下一步是将所有图像调整为该分辨率。

# resize all the images to same size
import os
from tqdm import tqdm
import cv2
from PIL import Image
from resizeimage import resizeimage
path = '/content/CData/'
for filename in tqdm(os.listdir(path),desc ='reading images ...'):
image = Image.open(path+filename)
image = image.resize((64,64))
image.save(path+filename, image.format)

8.将自定义数据集复制到colab并调整大小后,使用以下命令将自定义图像转换为tfrecords这是StyleGAN的要求,因此此步骤对于训练StyleGAN是必不可少的。

! python /content/stylegan/dataset_tool.py create_from_images /content/stylegan/datasets/custom-dataset /content/texture
replace your custom dataset path (instead of /content/texture)

9.一旦成功创建了tfrecords,你应该查看它们

/content/stylegan/datasets/custom-dataset/custom-dataset-r02.tfrecords - 22
/content/stylegan/datasets/custom-dataset/custom-dataset-r03.tfrecords - 23
/content/stylegan/datasets/custom-dataset/custom-dataset-r04.tfrecords -24
/content/stylegan/datasets/custom-dataset/custom-dataset-r05.tfrecords -25
/content/stylegan/datasets/custom-dataset/custom-dataset-r06.tfrecords -26
These tfrecords correspond to 4x4 , 8x8 ,16x16, 32x32 and 64x64 resolution images (baseline progressive) respectiviely

10.现在转到styleGAN文件夹并打开train.py文件并进行以下更改

Replace line no 37 below # Dataset. from
desc += '-ffhq'; dataset = EasyDict(tfrecord_dir='ffhq'); train.mirror_augment = True    TO
desc += '-PATH of YOUR CUSTOM DATASET'= EasyDict(tfrecord_dir='PATH of YOUR CUSTOM DATASET'); train.mirror_augment = True

uncomment line no 46 below # Number of GPUs. and comment line no 49
line number 52, train.total_kimg = 25000 is recommended for complete GAN training of 1024x1024 resolution image. I have set it to 3500. Training will stop after this much iterations

11.在我们开始GAN训练之前,还需要做几个更改。我们需要知道哪个经过pickle预训练的模型将用于训练我们自己的定制数据集。StyleGAN也使用 inception-v3,所以,我们需要得到 inception_v3_features.pkl

转到链接 https://drive.google.com/drive/folders/1MASQyN5m0voPcx7-9K0r5gObhvvPups7 Google 云端硬盘 https://drive.google.com/drive/folders/1MASQyN5m0voPcx7-9K0r5gObhvvPups7

你会看到一个文件 karras2019stylegan-ffhq1024x1024.pkl。该预训练版本经过训练可生成高分辨率人脸。名人,猫,汽车等还有其他型号。你需要将此文件复制到G驱动器上,并从G驱动器中的文件获取URL链接。URL链接看起来是这样

https://drive.google.com/uc?id=1FtjSVZawl-e_LDmIH3lbB0h_8q2g51Xq

同样,我们需要将 inception_v3_features.pkl 复制到G盘并获取URL链接。现在转到styleGAN / metrics下的路径,然后打开python文件 frechet_inception_distance.py。 我们 需要在第29行做一些小的更改,如下所示

将以下代码替换为

inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
inception = misc.load_pkl(''YOUR G-Drive inception-v3_features.pkl LINK url') # inception_v3_features.pkl

我们现在都准备去训练自己的styleGAN

  1. 运行以下命令开始训练
! python /content/stylegan/train.py (! nohup python /content/stylegan/train.py if you want it to run in the background and you do not wish to see the progress in your terminal directly. Do note this will take a lot of time depending on the configurations mentioned above) you should observe something like below
Training...
tick 1 kimg 140.3 lod 3.00 minibatch 128 time 4m 34s sec/tick 239.7 sec/kimg 1.71 maintenance 34.5 gpumem 3.6
network-snapshot-000140 time 6m 33s fid50k 331.8988
WARNING:tensorflow:From /content/stylegan/dnnlib/tflib/autosummary.py:137: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.
WARNING:tensorflow:From /content/stylegan/dnnlib/tflib/autosummary.py:182: The name tf.summary.merge_all is deprecated. Please use tf.compat.v1.summary.merge_all instead.

tick 2 kimg 280.6 lod 3.00 minibatch 128 time 15m 18s sec/tick 237.1 sec/kimg 1.69 maintenance 407.2 gpumem 3.6
tick 3 kimg 420.9 lod 3.00 minibatch 128 time 19m 16s sec/tick 237.3 sec/kimg 1.69 maintenance 0.7 gpumem 3.6
tick 4 kimg 561.2 lod 3.00 minibatch 128 time 23m 15s sec/tick 238.1 sec/kimg 1.70 maintenance 0.7 gpumem 3.6
tick 5 kimg 681.5 lod 2.87 minibatch 128 time 31m 54s sec/tick 518.6 sec/kimg 4.31 maintenance 0.7 gpumem 4.7
tick 6 kimg 801.8 lod 2.66 minibatch 128 time 42m 53s sec/tick 658.0 sec/kimg 5.47 maintenance 0.8 gpumem 4.7
tick 7 kimg 922.1 lod 2.46 minibatch 128 time 53m 52s sec/tick 657.7 sec/kimg 5.47 maintenance 0.9 gpumem 4.7
tick 8 kimg 1042.4 lod 2.26 minibatch 128 time 1h 04m 49s sec/tick 656.6 sec/kimg 5.46 maintenance 0.8 gpumem 4.7
tick 9 kimg 1162.8 lod 2.06 minibatch 128 time 1h 15m 49s sec/tick 658.5 sec/kimg 5.47 maintenance 0.8 gpumem 4.7
tick 10 kimg 1283.1 lod 2.00 minibatch 128 time 1h 26m 40s sec/tick 650.0 sec/kimg 5.40 maintenance 0.8 gpumem 4.7
network-snapshot-001283 time 6m 10s fid50k 238.2729
tick 11 kimg 1403.4 lod 2.00 minibatch 128 time 1h 43m 39s sec/tick 647.7 sec/kimg 5.38 maintenance 371.7 gpumem 4.7
tick 12 kimg 1523.7 lod 2.00 minibatch 128 time 1h 54m 27s sec/tick 647.5 sec/kimg 5.38 maintenance 0.8 gpumem 4.7
tick 13 kimg 1644.0 lod 2.00 minibatch 128 time 2h 05m 15s sec/tick 647.4 sec/kimg 5.38 maintenance 0.9 gpumem 4.7
tick 14 kimg 1764.4 lod 2.00 minibatch 128 time 2h 16m 04s sec/tick 647.3 sec/kimg 5.38 maintenance 0.8 gpumem 4.7
tick 15 kimg 1864.4 lod 1.89 minibatch 64 time 2h 41m 25s sec/tick 1520.8 sec/kimg 15.19 maintenance 0.8 gpumem 4.7
tick 16 kimg 1964.5 lod 1.73 minibatch 64 time 3h 15m 48s sec/tick 2060.2 sec/kimg 20.58 maintenance 2.9 gpumem 4.7
tick 17 kimg 2064.6 lod 1.56 minibatch 64 time 3h 50m 11s sec/tick 2060.1 sec/kimg 20.58 maintenance 3.1 gpumem 4.7
tick 18 kimg 2164.7 lod 1.39 minibatch 64 time 4h 24m 36s sec/tick 2061.2 sec/kimg 20.59 maintenance 3.1 gpumem 4.7
tick 19 kimg 2264.8 lod 1.23 minibatch 64 time 4h 59m 00s sec/tick 2061.1 sec/kimg 20.59 maintenance 3.0 gpumem 4.7
tick 20 kimg 2364.9 lod 1.06 minibatch 64 time 5h 33m 24s sec/tick 2061.1 sec/kimg 20.59 maintenance 2.9 gpumem 4.7
network-snapshot-002364 time 7m 46s fid50k 164.6632
tick 21 kimg 2465.0 lod 1.00 minibatch 64 time 6h 15m 16s sec/tick 2042.9 sec/kimg 20.41 maintenance 469.6 gpumem 4.7
tick 22 kimg 2565.1 lod 1.00 minibatch 64 time 6h 49m 11s sec/tick 2032.3 sec/kimg 20.30 maintenance 2.9 gpumem 4.7
tick 23 kimg 2665.2 lod 1.00 minibatch 64 time 7h 23m 07s sec/tick 2032.5 sec/kimg 20.31 maintenance 2.9 gpumem 4.7
tick 24 kimg 2765.3 lod 1.00 minibatch 64 time 7h 57m 03s sec/tick 2033.5 sec/kimg 20.32 maintenance 2.9 gpumem 4.7
tick 25 kimg 2865.4 lod 1.00 minibatch 64 time 8h 31m 00s sec/tick 2034.1 sec/kimg 20.32 maintenance 2.9 gpumem 4.7

一旦达到train.py文件中指定的train.total_kimg值,训练便会结束。

现在让我们看看由styleGAN在自定义数据上生成的图像

真实(原始)图像64 x 64分辨率

初始迭代后-S-GAN生成的伪造

经过1000次以上的训练

经过> 3500次训练后

我们可以看到,随着训练迭代的进行,模型已经开始生成真实图像。经过将近4000次迭代,我已经终止了训练,因为这只是一个实验和演示。

但是,随着我们对模型进行较长时间的训练,图像将越来越精细,经过9000或10000轮训练后,GAN将开始生成原始图片的死角。太神奇了!

现在让我们看看如何使用预训练的自定义权重来生成类似于我们的自定义数据集的图像

如何使用预训练的权重从自定义数据集中生成图像

训练结束后,将创建一个如下所示的目录

/ content / results / 00006-sgan- / content / stylegan / datasets / custom-dataset-1gpu /

在此之下,你可以看到创建了许多网络快照的pickle文件。我们需要获取最新的.pkl文件,并将该文件的权重用于预训练模型,如下面的代码片段所示

# 4.0 International License. To view a copy of this license, visit# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

使用预训练的StyleGAN生成器生成图像的最小脚本。

导入操作系统

import pickle
import numpy as np

导入PIL.Image

import dnnlib
import dnnlib.tflib as tflib
import config
def main():

初始化TensorFlow。

tflib.init_tf()
url = '/content/network-snapshot-003685 .pkl'
with open(url,'rb') as f :

_G,_D,Gs = pickle.load(f)

# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

# Print network details.
Gs.print_layers()

# Pick latent vector.
rnd = np.random.RandomState()
latents = rnd.randn(1, Gs.input_shape[1])

# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

保存图片。

os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir, f’/content/example1.png’)
PIL.Image.fromarray(images[0], ‘RGB’).save(png_filename)
#if __name__ == "__main__":
main()

on running this code , output image example1.png will be created under /content 

The output quality will be based on the network_snapshot.pkl we use

使用不同的种子值生成新图像-潜在空间中的不同点

从不同种子值(潜伏空间中的不同点)生成的图像

代码片段

!python /content/stylegan2/run_generator.py generate-latent-walk --network=/content/results/00000-sgan-/content/stylegan/datasets/custom-dataset-1gpu/network-snapshot-003685.pkl --seeds=200,1000,2500,4000,200 --frames 10 --truncation-psi=0.8

上面的代码将生成10张图像。在这里,我使用了使用自定义模型的styleGAN训练的预训练权重,并使用run_generator.py(在styleGAN2中可用)生成了不同的图像。

我们可以从逻辑上选择看起来相似的种子(你需要尝试一些实验才能达到此目的),然后对它们进行插值以获取原始数据集中不存在的全新图像。

同样,输出的质量将取决于我们的模型完成训练的哪个阶段。就我而言,它大约在4000个epoch终止。

结论

在此博客中,我分享了我在Google colab服务器中进行 stylegan / stylegan2 实验时获得的知识。以下是一些混合的python程序示例,你可以参考

  • stylegan – pretrained_example.py
  • stylegan – generate_figure.py
  • stylegan2 – run_generator.py

官方StyleGAN github链接

  • https://github.com/NVlabs/stylegan
  • https://github.com/NVlabs/stylegan2

☆ END ☆

个人微信(如果没有备注不拉群!)请注明:地区+学校/企业+研究方向+昵称


下载1:何恺明顶会分享
在「AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析
下载2:终身受益的编程指南:Google编程风格指南
在「AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!
下载3 CVPR2021
在「AI算法与图像处理」公众号后台回复:CVPR,即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文

点亮

,告诉大家你也在看

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-04-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI算法与图像处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 介绍
    • 重点和前提条件:
    • 使用迁移学习在Google Colab中的自定义数据集上训练style GAN
      • 真实(原始)图像64 x 64分辨率
      • 如何使用预训练的权重从自定义数据集中生成图像
      • 使用不同的种子值生成新图像-潜在空间中的不同点
      • 结论
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档