训练GAN的16个trick

本文转载自:https://mp.weixin.qq.com/s/d_W0O7LNqlBuZV87Ou9uqw 新智元公众号 本文来自ICCV 2017的Talk:如何训练GAN,FAIR的研究员Soumith Chintala总结了训练GAN的16个技巧,例如输入的规范化,修改损失函数,生成器用Adam优化,使用Sofy和Noisy标签,等等。这是NIPS 2016的Soumith Chintala作的邀请演讲的修改版本,而2016年的这些tricks在github已经有2.4k星。

ICCV 2017 slides:https://github.com/soumith/talks/blob/master/2017-ICCV_Venice/How_To_Train_a_GAN.pdf

NIPS2016:https://github.com/soumith/ganhacks

训练GAN的16个trick

# 1:规范化输入

  • 将输入图像规范化为-1到1之间
  • 生成器最后一层的输出使用tanh函数(或其他bounds normalization)

#2:修改损失函数(经典GAN)

  • 在GAN论文里人们通常用 min (log 1-D) 这个损失函数来优化G,但在实际训练的时候可以用max log D

       -因为第一个公式早期有梯度消失的问题 

       - Goodfellow et. al (2014)

  • 在实践中:训练G时使用反转标签能工作得很好,即:real = fake, fake = real

一些GAN变体

【TensorFlow】https://github.com/hwalsuklee/tensorflow-generative-model-collections

【Pytorch】https://github.com/znxlwm/pytorch-generative-model-collections

#3:使用一个具有球形结构的噪声z

  • 在做插值(interpolation)时,在大圆(great circle)上进行
  • Tom White的论文“Sampling Generative Networks”

- https://arxiv.org/abs/1609.04468

#4: BatchNorm

  • 一个mini-batch里面必须保证只有Real样本或者Fake样本,不要把它们混起来训练
  • 如果不能用batchnorm,可以用instance norm

#5:避免稀疏梯度:ReLU, MaxPool

  • GAN的稳定性会因为引入了稀疏梯度受到影响
  • LeakyReLU很好(对于G和D)
  • 对于下采样,使用:Average Pooling,Conv2d + stride
  • 对于上采样,使用:PixelShuffle, ConvTranspose2d + stride

       -PixelShuffle 论文:https://arxiv.org/abs/1609.05158

#6:使用Soft和Noisy标签

  • Label平滑,也就是说,如果有两个目标label:Real=1 和 Fake=0,那么对于每个新样本,如果是real,那么把label替换为0.7~1.2之间的随机值;如果样本是fake,那么把label替换为0.0~0.3之间的随机值。
  • 训练D时,有时候可以使这些label是噪声:偶尔翻转label

       - Salimans et. al. 2016

#7:架构:DCGANs / Hybrids

  • 能用DCGAN就用DCGAN,
  • 如果用不了DCGAN而且没有稳定的模型,可以使用混合模型:KL + GAN 或 VAE + GAN
  • WGAN-gp的ResNet也很好(但非常慢)

       - https://github.com/igul222/improved_wgan_training

  • width比depth更重要

#8:借用RL的训练技巧

  • Experience replay
  • 对于deep deterministic policy gradients(DDPG)有效的技巧
  • 参考Pfau & Vinyals (2016)的论文

#9:优化器:ADAM

  • 优化器用Adam(Radford et. al. 2015)
  • 或者对D用SGD,G用Adam

#10:使用 Gradient Penalty

  • 使梯度的norm规范化
  • 对于为什么这一点有效,有多个理论(WGAN-GP, DRAGAN, 通过规范化使GAN稳定)

#11:不要通过loss statistics去balance G与D的训练过程(经典GAN)

#12:如果你有类别标签,请使用它们

  • 如果还有可用的类别标签,在训练D判别真伪的同时对样本进行分类

#13:给输入增加噪声,随时间衰减

  • 给D的输入增加一些人工噪声(Arjovsky et. al., Huszar, 2016)
  • 给G的每一层增加一些高斯噪声(Zhao et. al. EBGAN)

#14:多训练判别器D

  • 特别是在加噪声的时候

#15:避开离散空间

  • 将生成结果作为一个连续预测

#16:离散变量

  • 使用一个嵌入层
  • 给图像增加额外通道
  • 保持嵌入的维度低和上采样以匹配图像通道的大小

总结:

  • GAN模型的稳定性在提升
  • 理论研究有所进展
  • 技巧只是权宜之计

时间线——GAN模型的稳定性

PPT下载:https://github.com/soumith/talks/blob/master/2017-ICCV_Venice/How_To_Train_a_GAN.pdf

参考:https://github.com/soumith/ganhacks


本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

看完立刻理解 GAN!初学者也没关系

前言 GAN 从 2014 年诞生以来发展的是相当火热,比较著名的 GAN 的应用有 Pix2Pix、CycleGAN 等。本篇文章主要是让初学者通过代码了...

34050
来自专栏编程

FPGA图像处理之rgbtogray算法的实现

FPGA图像处理之rgbtogray算法的实现 作者:lee神 1.背景知识 在正是入题之前先给大家讲解一下gray图像,YUV图像以及Ycbcr图像。 Gra...

43640
来自专栏贾志刚-OpenCV学堂

详解LBP特征与应用(人脸识别)

之前我已经写过一篇关于局部二值模式(LBP)文章,当时主要是介绍了一下局部二值模式的概念与其简单的尺度空间扩展,本文是上一篇文章基础上对局部二值模式的深化,涉及...

40180
来自专栏WOLFRAM

随机三维图像中可以找到多少动物和阿尔普物形?

34460
来自专栏新智元

不可错过的 GAN 资源:教程、视频、代码实现、89 篇论文下载

【新智元导读】这是一份生成对抗(神经)网络的重要论文以及其他资源的列表,由 Holger Caesar 整理,包括重要的 workshops,教程和博客,按主题...

921100
来自专栏机器之心

教程 | 详解如何使用Keras实现Wassertein GAN

选自Deeply Random 机器之心编译 参与:晏奇、李泽南 在阅读论文 Wassertein GAN 时,作者发现理解它最好的办法就是用代码来实现其内容。...

508100
来自专栏小小挖掘机

推荐系统遇上深度学习(二)--FFM模型理论和实践

推荐系统遇上深度学习系列: 推荐系统遇上深度学习(一)--FM模型理论和实践 1、FFM理论 在CTR预估中,经常会遇到one-hot类型的变量,one-ho...

1K40
来自专栏数据科学学习手札

(数据科学学习手札18)二次判别分析的原理简介&Python与R实现

上一篇我们介绍了Fisher线性判别分析的原理及实现,而在判别分析中还有一个很重要的分支叫做二次判别,本文就对二次判别进行介绍: 二次判别属于距离判别法中的内容...

39790
来自专栏AI科技评论

开发 | 看完立刻理解GAN!初学者也没关系

AI 科技评论按:本文原作者天雨粟,原文载于作者的知乎专栏——机器不学习,经授权发布。 前言 GAN 从 2014 年诞生以来发展的是相当火热,比较著名的 GA...

420130
来自专栏机器之心

教程 | 在Keras上实现GAN:构建消除图片模糊的应用

选自Sicara Blog 作者:Raphaël Meudec 机器之心编译 参与:陈韵竹、李泽南 2014 年,Ian Goodfellow 提出了生成对抗网...

46230

扫码关注云+社区

领取腾讯云代金券