【专知-PyTorch手把手深度学习教程04】GAN快速理解与PyTorch实现: 图文+代码

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视觉等)、大数据、编程语言、系统架构。使用请访问专知 进行主题搜索查看 - 桌面电脑访问www.zhuanzhi.ai, 手机端访问www.zhuanzhi.ai 或关注微信公众号后台回复" 专知"进入专知,搜索主题查看。值国庆佳节,专知特别推出独家特刊-来自中科院自动化所专知小组博士生huaiwen和Jin创作的-PyTorch教程学习系列, 今日带来第四篇-< 快速理解系列(三): 图文+代码, 让你快速理解GAN >

  1. < 一文带你入门优雅的Pytorch >
  2. < 快速理解系列(一): 图文+代码, 让你快速理解CNN>
  3. < 快速理解系列(二): 图文+代码, 让你快速理解LSTM>
  4. < 快速理解系列(三): 图文+代码, 让你快速理解GAN >
  5. < 快速理解系列(四): 图文+代码, 让你快速理解Dropout >
  6. < NLP系列(一) 用Pytorch 实现 Word Embedding >
  7. < NLP系列(二) 基于字符级RNN的姓名分类 >
  8. < NLP系列(三) 基于字符级RNN的姓名生成 >

生成对抗网络 GAN

生成模型通过训练大量数据, 学习自身模型, 最后通过自身模型产生逼近真实分布的模拟分布. 用这个宝贵的”分布”生成新的数据. 因此, 判别模型的目标是得到关于 y 的分布 P(y|X), 而生成模型的侧重是得到关于X分布 P(y, X) 或 P(x|y)P(y). 即, 判别模型的目标是给定一张图片, 请告诉我这是”长颈鹿”还是”斑马”, 而, 生成模型的目标是告诉你词语: “长颈鹿”, 请生成一张画有”长颈鹿”的图片吧~ 下面这张图片来自slideshare 可以说明问题:

来自: http://www.slideshare.net/shaochuan/spatially-coherent-latent-topic-model-for-concurrent-object

所以, 生成模型可以从大量数据中生成你从未见过的, 但是符合条件的样本.

难怪, 我们可以调教神经网络, 让他的画风和梵高一样. 最后输入一张图片, 它会输出模拟梵高画风的这张图片的油画.

言归正传, 为啥对抗网络在生成模型中受到追捧 ? 生成对抗网络最近为啥这么火 , 到底好在哪里?

那就必须谈到生成对抗网络和一般生成模型的区别了.

一般的生成模型, 必须先初始化一个“假设分布”,即后验分布, 通过各种抽样方法抽样这个后验分布,就能知道这个分布与真实分布之间究竟有多大差异。这里的差异就要通过构造损失函数(loss function)来估算。知道了这个差异后,就能不断调优一开始的“假设分布”,不断逼近真实分布。限制玻尔兹曼机(RBM)就是这种生成模型的一种.

正如”对抗样本与生成式对抗网络“一文所说的: 传统神经网络需要一个人类科学家精心打造的损失函数。但是,对于生成模型这样复杂的过程来说,构建一个好的损失函数绝非易事。这就是对抗网络的闪光之处。对抗网络可以学习自己的损失函数——自己那套复杂的对错规则——无须精心设计和建构一个损失函数:

来自:http://www.slideshare.net/xavigiro/deep-learning-for-computer-vision-generative-models-and-adversarial-training-upc-2016

生成对抗网络同时训练两个模型, 叫做生成器(Generator 图中蓝色框)和判断器(Discriminator 图中红色框). 生成器竭尽全力模仿真实分布生成数据; 判断器竭尽全力区分出真实样本和生成器生成的模仿样本. 直到判断器无法区分出真实样本和模仿样本为止.

通过这种方式, 损失函数被蕴含在判断器中了. 我们不再需要思考损失函数应该如何设定, 只要关注判断器输出损失就可以了.

论文”Generative Adversarial Nets”中的训练过程, 生成器和判别器的各自表现

上图是生成对抗网络的训练过程, 可以看到生成器和判别器的各自表现. 其中, 黑色虚线的分布是真实分布, 绿色线的是生成器的分布, 蓝色虚线是判别器的判定分布. 两条水平线代表了两个分布的样本空间的映射.

(a)图中真实分布和生成器的分布比较接近, 但是判定器很容易区分出二者生成的样本. (b)图中判定器又经过训练加强判断, 注意判定分布. (c)图是生成器调整分布, 更好地欺骗判定器. (d)图是不断优化, 直到生成器非常逼近真实分布, 而且判定器无法区分.

下图是Ian J. Goodfellow等人论文中在MNIST和TFD数据上训练出的对抗模型生成的样本:

最右边一列是真实数据集中最接近的邻居样本, 证明生成模型的有效性. 生成右边导数第二列和真实样本非常接近, 但是确是对抗网络随机生成的图片. 可见, 对抗网络对于随机生成一些图片干扰很在行, 这些干扰并不影响人造样本和真实样本的相似性.

下面我们看看如何用Pytorch实现GAN生成MNIST:

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable

def get_variable(x):
    x = Variable(x)
    return x.cuda() if torch.cuda.is_available() else x

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))])
mnist = datasets.MNIST(root='./mnist/',
                       train=True,
                       transform=transform,
                       download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=100,
                                          shuffle=True)
# 判别器
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid())

# 生成器
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 784),
    nn.Tanh())

if torch.cuda.is_available():
    D.cuda()
    G.cuda()

loss_func = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

for epoch in range(200):
    for i, (images, _) in enumerate(data_loader):

        batch_size = images.size(0)
        # reshape 成 (batch_size, 28*28)
        images = get_variable(images.view(batch_size, -1))
        real_labels = get_variable(torch.ones(batch_size))  # 真实数据 label 为1
        fake_labels = get_variable(torch.zeros(batch_size))  # 假数据 label 为0

        # ============= Train the discriminator =============#
        # 判别真实数据,计算损失
        outputs = D(images)
        d_loss_real = loss_func(outputs, real_labels)
        real_score = outputs

        # 生成假数据
        z = get_variable(torch.randn(batch_size, 64))
        fake_images = G(z)

        # 判别生成的数据,计算损失
        outputs = D(fake_images)
        d_loss_fake = loss_func(outputs, fake_labels)
        fake_score = outputs

        # 优化判别器
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # =============== Train the generator ===============#
        # 生成假数据
        z = get_variable(torch.randn(batch_size, 64))
        fake_images = G(z)
        # 用判别器计算损失
        outputs = D(fake_images)
        g_loss = loss_func(outputs, real_labels)

        # 优化生成器
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                  'g_loss: %.4f, 真实数据平均得分: %.2f, 假数据平均得分: %.2f'
                  % (epoch, 200, i + 1, 600, d_loss.data[0], g_loss.data[0],
                     real_score.data.mean(), fake_score.data.mean()))

    # 保存一下真实数据
    if (epoch + 1) == 1:
        images = images.view(images.size(0), 1, 28, 28)
        save_image(denorm(images.data), './mnist/real_images.png')

    # 保存生成数据
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), './mnist/fake_images-%d.png' % (epoch + 1))

# 保存模型参数
torch.save(G.state_dict(), './generator.pkl')
torch.save(D.state_dict(), './discriminator.pkl')

Reference:

#9-生成对抗网络101-终极入门-通俗解析 http://nooverfit.com/wp/9-生成对抗网络101-终极入门-通俗解析 作者: david 9

明天继续推出:专知PyTorch深度学习教程系列-< 快速理解系列(四): 图文+代码, 让你快速理解Dropout >,敬请关注。

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2017-10-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Echo is learning

machine learning 之 Neural Network 3

1385
来自专栏AI研习社

深度神经网络中的对抗样本与学习

概述 最近大半年,人工智能领域成为科技领域提到的最多的名词之一。在kdnuggets此前发布的文章 (Deep Learning’s Deep Flaws)...

38512
来自专栏数据派THU

卷积神经网络失陷,CoordConv来填坑(附代码&视频)

本文讲述了卷积神经网络在涉及坐标建模等方面的缺陷,但是提出了CoordConv 作为解决方案。

862
来自专栏机器学习算法工程师

不想累死就来看看 : 特征工程之特征选择

地址:https://www.cnblogs.com/pinard/p/9032759.html

1152
来自专栏专知

【干货】计算机视觉实战系列07——用Python做图像处理

这一次继续为大家详细讲解SciPy库的使用以及图像导数实战。

6069
来自专栏机器学习算法与Python学习

深度卷积生成对抗网络的无监督学习,补全人脸合成图像匹敌真实照

关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第二 【Python】:排名第三 【算法】:排名第四 我们会再接再厉 成为全网优质的技术类...

4606
来自专栏人工智能LeadAI

二、机器学习面试之有必要手推SVM吗?

上篇文章中,我们介绍了SVM的基本思想,并将其推导成了一个数学问题,今天的任务,就是解决如何求解这个数学问题,同时,回答上篇文章中提出的第二个问题: 如果将正负...

4066
来自专栏新智元

【重磅】计算机视觉和 CNN 发展十一座里程碑(附论文下载)

【新智元导读】从AlexNet到ResNet,计算机视觉领域和卷积神经网络(CNN)每一次发展,都伴随着代表性架构取得历史性的成绩。作者回顾计算机视觉和CNN过...

4474
来自专栏智能算法

长文干货!走近人脸检测:从VJ到深度学习(下)

深度学习给目标检测带来的变革 人脸检测作为一种特定类型目标的检测任务,一方面具有其自己鲜明的特点,需要考虑人脸这一目标的特殊性,另一方面其也和其它类型...

3968
来自专栏机器之心

卷积神经网络「失陷」,CoordConv来填坑

Uber 利用卷积神经网络解决许多方面的问题,其中就包括坐标变换。从自动驾驶汽车的设计到路标的自动检测、再到地图的绘制,Uber 都需要使用卷积网络完成坐标变换...

1393

扫码关注云+社区

领取腾讯云代金券