资源 | NIPS 2017 Spotlight论文Bayesian GAN的TensorFlow实现

选自GitHub

作者:Andrew Gordon Wilson

机器之心编译

参与:路雪、刘晓坤

用生成模型学习高维自然信号(比如图像、视频和音频)长期以来一直是机器学习的重要发展方向之一。来自 Uber AI Lab 的 Yunus Saatchi 等人今年五月提出了 Bayesian GAN——利用一个简单的贝叶斯公式进行端到端无监督/半监督 GAN 学习。该研究的论文已被列入 NIPS 2017 大会 Spotlight。最近,这篇论文的另一作者 Andrew Gordon Wilson 在 GitHub 上发布了 Bayesian GAN 的 TensorFlow 实现。

项目链接:https://github.com/andrewgordonwilson/bayesgan/

论文:Bayesian GAN

论文链接:https://arxiv.org/abs/1705.09558

摘要:生成对抗网络(GAN)可以隐性地学习难以用显性似然(explicit likelihood)建模的图像、音频和数据的丰富分布。我们展示了一种实际的贝叶斯公式,用 GAN 进行无监督和半监督学习。在该框架下,我们使用随机梯度哈密尔顿蒙特卡罗(Hamiltonian Monte Carlo)来边缘化生成器和判别器的权重。得到的方法很直接,且可在没有标准干预(如特征匹配或小批量判别)的情况下达到不错的性能。通过探索生成器参数具有表达性的后验,贝叶斯 GAN 避免了模式崩溃(mode-collapse),输出可解释的多种候选样本,在 SVHN、CelebA 和 CIFAR-10 等多个基准数据集上取得了顶尖的半监督学习量化结果,优于 DCGAN、Wasserstein GAN 和 DCGAN。

介绍

在贝叶斯 GAN 中,我们提出了生成器和判别器权重的条件后验,通过随机梯度哈密尔顿蒙特卡罗边缘化这些后验。贝叶斯 GAN 的主要特性有:(1)在半监督学习问题上的准确预测;(2)对优秀性能的最小干预;(3)响应对抗反馈的推断的概率公式;(4)避免模式崩溃;(5)展示多个互补的生成和判别模型,形成一个概率集成(probabilistic ensemble)。

我们介绍了一个生成器参数的多模态后验。这些参数的每个设置对应数据的不同生成假设。这里我们将展示两种权重向量设置下生成的样本,不同的权重向量设置对应不同的写作风格。贝叶斯 GAN 保留该参数分布。相反,标准 GAN 用点估计(类似最大似然解决方案)来展示整个分布,降低了数据的可解释性。

环境需求

该代码有以下依赖项(版本号很关键)

  • python 2.7
  • tensorflow==1.0.0

在 Linux 上安装 TensorfFow 1.0.0,请按照 https://www.tensorflow.org/versions/r1.0/install/说明进行操作。

  • scikit-learn==0.17.1

你可以使用下列命令安装 scikit-learn 0.17.1:

pip install scikit-learn==0.17.1

或者,使用提供的 environment.yml 文件创建 conda 环境,并进行设置:

conda env create -f environment.yml -n bgan

然后,使用下列命令加载环境:

source activate bgan

训练选项

bayesian_gan_hmc.py 具备以下训练选项。

  • --out_dir:文件夹路径,用于存储输出
  • --n_save: 每 n_save 次迭代存储的样本和权重;默认值 100
  • --z_dim: 生成器 z 向量的维度;默认值 100
  • --data_path:数据路径;具体讨论详见 https://github.com/andrewgordonwilson/bayesgan/#data-preparation;该参数是必需的
  • --dataset:可以是 mnist、cifar、svhn 或 celeb;默认 mnist
  • --gen_observed: 生成器「观察到」的数据;影响噪声变量和先验的缩放;默认值 1000
  • --batch_size:训练的批量大小;默认值 64
  • --prior_std:权重先验分布的 std;默认值 1
  • --numz:和论文中的 J 一样; z 的样本数,实现整合;默认值 1
  • --num_mcmc: 和论文中的 M 一样;每个 z 的 MCMC NN 权重样本数;默认值 1
  • --lr: Adam 优化器使用的学习率;默认值 0.0002
  • --optimizer:使用的优化方法:adam (tf.train.AdamOptimizer) 或 sgd (tf.train.MomentumOptimizer);默认 adam
  • --semi_supervised:进行半监督学习
  • --N:半监督学习所需标注样本数量
  • --train_iter:训练迭代次数;默认值 50000
  • --save_samples:保存训练过程中生成的样本
  • --save_weights:训练过程中,保存权重
  • --random_seed:随机种子;如果使用 GPU,那么注意设置该种子不会引起 100% 的可复现结果

你还可以用--wasserstein 运行 WGAN,或用--ml_ensemble <num_dcgans> 训练 <num_dcgans> DCGAN 的集成。尤其是,你可以使用--ml_ensemble 1 训练一个 DCGAN。

使用

安装

1. 安装所需依赖项

2. 复制该 repository

合成数据

你可以使用 bgan_synth 脚本运行论文中的合成实验。例如,以下命令用于训练贝叶斯 GAN(D=100,d=10),进行 5000 次迭代,并把结果保存在<results_path>。

./bgan_synth.py --x_dim 100 --z_dim 10 --numz 10 --out <results_path>

运行以下命令使用相同的数据运行 ML GAN:

./bgan_synth.py --x_dim 100 --z_dim 10 --numz 1 --out <results_path>

bgan_synth 的参数有 --save_weights、--out_dir、--z_dim、--numz、--wasserstein、--train_iter 和 --x_dim。x_dim 控制观测数据(论文中的 x)的维度。通过这个链接查看其它参数的说明:https://github.com/andrewgordonwilson/bayesgan/#training-option。

运行了上面的两个命令之后,你可以在<results_path>里查看每 100 次迭代后的输出。例如,第 900 次迭代的贝叶斯 GAN 的输出结果如下:

相对地,标准 GAN(numz=1,强制执行 ML 评估)的输出结果如下:

可以清晰地看到在这个合成数据的例子中,标准 GAN 出现了模式崩溃的趋势,而贝叶斯 GAN 完全没有这样的问题。

你可以查看 synth.iptnb,进一步探索合成实验,并生成詹森-香农差异图。

MNIST、CIFAR10、CELEBA、SVHN

bayesian_gan_hmc 脚本允许在标准和自定义数据集上训练模型。下面,我们将介绍如何使用该脚本。

数据准备

为了重现在 MNIST、CIFAR10、CelebA 和 SVHN 数据集上的实验,你需要准备这些数据,并使用一个正确的——data_path。

  • 对于 MNIST,你不需要准备数据,并可以提供任意的——data_path;
  • 对于 CIFAR10,请从该地址(https://www.cs.toronto.edu/~kriz/cifar.html)下载和获取数据的 Python 版本;然后使用包含 cifar-10-batchs-py 的目录的路径作为——data_path;
  • 对于 SVHN,请从该地址(http://ufldl.stanford.edu/housenumbers/)下载 train_32x32.mat 和 test_32x32.mat 文件,并使用包含这些文件的目录的路径作为——data_path;
  • 对于 CelebA,你需要安装 OpenCV。数据下载地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html。你需要创建 celebA 文件夹,该文件夹包含 Anno 和 img_align_celeba 子文件夹。其中 Anno 必须包含 list_attr_celeba.txt,img_align_celeba 必须包含.jpg 文件。你还需要通过在——data_path <path>(其中<path>是包含了 celebA 的文件夹的路径)中运行 datasets/crop_faces.py 脚本对图像进行剪裁。训练模型的时候,你需要在——data_path 中使用相同的<path>。

无监督学习

你可以在没有 -- semi 参数的情况下通过运行 bayesian_gan_hmc 脚本对模型进行无监督训练。例如,使用以下命令:

./bayesian_gan_hmc.py --data_path <data_path> --dataset svhn --numz 1 --num_mcmc 10 --out_dir 
<results_path> --train_iter 75000 --save_samples --n_save 100

在 SVHN 数据集上训练模型。该命令会使用这个方法执行 75000 次迭代,每 100 次迭代保存样本。这里<results_path>必须是保存结果的目录。可查看数据准备部分,了解如何设置<data_path>。可查看训练选项部分,了解其它训练选项。

半监督训练

你可以使用--semi 选项来运行 bayesian_gan_hmc,对模型进行半监督训练。使用-N 参数设置训练所用标注样本的数量。例如,使用

./bayesian_gan_hmc.py --data_path <data_path> --dataset cifar --numz 1 --num_mcmc 10
--out_dir <results_path> --train_iter 75000 --N 4000 --semi --lr 0.00005

在 CIFAR10 数据集上用 4000 个标注样本训练该模型。该命令使训练经历 75000 次迭代,输出结果储存在<results_path> 文件夹中。

要想在 MNIST 数据集上使用 200 个标注样本训练该模型,你需要使用以下命令:

./bayesian_gan_hmc.py --data_path <data_path>/ --dataset mnist --numz 5 --num_mcmc 5
--out_dir <results_path> --train_iter 30000 -N 200 --semi --lr 0.001

自定义数据

要想在自定义数据集上训练该模型,你需要用特定的接口定义类别。假设我们想在 digits 数据集上训练模型。该数据集包含 8x8 数字图像。假设数据的储存格式为 x_tr.npy、y_tr.npy、x_te.npy 和 y_te.npy。我们假设 x_tr.npy 和 x_te.npy 的形态为 (?, 8, 8, 1)。然后在 bgan_util.py 中定义该数据集对应的类别,如下所示:

class Digits:

    def __init__(self):
        self.imgs = np.load('x_tr.npy') 
        self.test_imgs = np.load('x_te.npy')
        self.labels = np.load('y_tr.npy')
        self.test_labels = np.load('y_te.npy')
        self.labels = one_hot_encoded(self.labels, 10)
        self.test_labels = one_hot_encoded(self.test_labels, 10) 
        self.x_dim = [8, 8, 1]
        self.num_classes = 10

    @staticmethod
    def get_batch(batch_size, x, y): 
        """Returns a batch from the given arrays.
        """
        idx = np.random.choice(range(x.shape[0]), size=(batch_size,), replace=False)
        return x[idx], y[idx]

    def next_batch(self, batch_size, class_id=None):
        return self.get_batch(batch_size, self.imgs, self.labels)

    def test_batch(self, batch_size):
        return self.get_batch(batch_size, self.test_imgs, self.test_labels)

该类别必须具备 next_batch 和 test_batch,以及 imgs、labels、test_imgs、test_labels、x_dim 和 num_classes。

现在,我们可以把 Digits 类输入到 bayesian_gan_hmc.py:

from bgan_util import Digits

将下列行添加至处理--dataset 参数的代码中:

if args.dataset == "digits":
    dataset = Digits()

准备过程完成后,我们可以使用以下命令训练模型:

./bayesian_gan_hmc.py --data_path <any_path> --dataset digits --numz 1 --num_mcmc 10 
--out_dir <results path> --train_iter 5000 --save_samples

本文为机器之心编译,转载请联系本公众号获得授权。

原文发布于微信公众号 - 机器之心(almosthuman2014)

原文发表时间:2017-11-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

Tensorflow 中 learning rate decay 的奇技淫巧

深度学习中参数更新的方法想必大家都十分清楚了——sgd,adam 等等,孰优孰劣相关的讨论也十分广泛。可是,learning rate 的衰减策略大家有特别关注...

5074
来自专栏利炳根的专栏

学习笔记CB010:递归神经网络、LSTM、自动抓取字幕

递归神经网络(RNN),时间递归神经网络(recurrent neural network),结构递归神经网络(recursive neural network...

5864
来自专栏和蔼的张星的图像处理专栏

SAMF

论文:paper 结合了CN和KCF的多尺度扩展,看文章之前就听说很暴力,看了以后才发现原来这么暴力。 论文的前一半讲KCF,后一半讲做的实验,中间一点点大...

1542
来自专栏用户2442861的专栏

计算图像相似度——《Python也可以》之一

声明:本文最初发表于赖勇浩(恋花蝶)的博客http://blog.csdn.NET/lanphaday,如蒙转载,敬请确保全文完整,未经同意,不得用于商业用途。

9312
来自专栏企鹅号快讯

深入机器学习系列7-Random Forest

1 Bagging   采用自助采样法()采样数据。给定包含个样本的数据集,我们先随机取出一个样本放入采样集中,再把该样本放回初始数据集,使得下次采样时,样本仍...

4776
来自专栏Python中文社区

支持向量机原理推导(二)

專 欄 ❈ exploit,Python中文社区专栏作者。希望与作者交流或者对文章有任何疑问的可以与作者联系: Email: 15735640998@163....

1965

用Python的长短期记忆神经网络进行时间序列预测

长短期记忆递归神经网络具有学习长的观察序列的潜力。

2.7K8
来自专栏AILearning

卷积神经网络

注意:本教程面向TensorFlow 的高级用户,并承担机器学习方面的专业知识和经验。 概观 CIFAR-10分类是机器学习中常见的基准问题。问题是将R...

20810
来自专栏大数据挖掘DT机器学习

如何用TensorFlow和TF-Slim实现图像标注、分类与分割

本文github源码地址: 在公众号 datadw 里 回复 图像 即可获取。 笔者将和大家分享一个结合了TensorFlow和slim库的小应用,来实现...

5854
来自专栏AI研习社

YOLO 升级到 v3 版,速度相比 RetinaNet 快 3.8 倍

雷锋网 AI 研习社按,YOLO 是一种非常流行的目标检测算法,速度快且结构简单。日前,YOLO 作者推出 YOLOv3 版,在 Titan X 上训练时,在 ...

1313

扫码关注云+社区