【干货】对抗自编码器PyTorch手把手实战系列——对抗自编码器学习笔迹风格

即使是非计算机行业, 大家也知道很多有名的神经网络结构, 比如CNN在处理图像上非常厉害, RNN能够建模序列数据. 然而CNN, RNN之类的神经网络结构本身, 并不能用于执行比如图像的内容和风格分离, 生成一个逼真的图片, 用少量的label信息来分类图像, 或者做数据压缩等任务. 因为上述几个任务, 都需要特殊的网络结构和训练算法 .

有没有一个网络结构, 能够把上述任务全搞定呢? 显然是有的, 那就是对抗自编码器Adversarial Autoencoder(AAE) . 在本文中, 我们将构建一个AAE, 从MNIST数据集中学习里面的笔迹, 然后给定任意的内容, 去生成这个字体的图像。

本系列文章, 专知小组成员Huaiwen一共分成四篇讲解,这是第三篇:

每一个人都有自己独特的笔迹风格(或者说字体), 我们写字时的力度, 笔锋, 甚至我们遣词造句的习惯都会反映在字体上. 因此伪造一个人的字体是一个很难的事情.

本文, 我们尝试从MNIST数据集中学习里面的笔迹, 然后给定任意的内容, 去生成这个字体的图像.

首先让我们先搞清楚一张笔迹的字体和内容分别是什么:

上图中的所有的笔迹, 内容都是Autoencoder, 但是字体多变, 比如: Segoe script, Courier New等等.

将字体和内容分割出来, 是表示学习的一个重要内容。

我们用的Adversarial Autoencoder一直都是以无监督的方式训练的. 在本文中, 为了让AAE专注于学习字体的表示, 而减轻内容的学习, 我们将图片中的标签加入进去.

我们设计了如下的AAE架构:

注意, 除了直接用隐层编码之外, 我们将标签信息 也加入了进来, 在这里是一个one-hot向量, 因为是MNIST数据集, 所以是十维的.

训练上述模型, 需要在优化重建误差的时候稍稍做一些修改:

• 将图像输入到Encoder, 得到隐变量

• 将隐变量和标签 组合成一个新变量, 然后输入到Decoder中

• 这样, Encoder专心于学习字体, Decoder会根据学习出的字体表示加上笔迹内容 去生成笔迹

其他部分跟上一篇文章一样. 唯一的不同是Decoder的输入变成下图了:

基于上一篇文章, 我们要改的地方有两个:

Decoder的输入维度扩充n_labels:

# p(x|z)
class P_net(nn.Module):
    def __init__(self,X_dim,N,z_dim,n_labels):
        super(P_net, self).__init__()
        self.lin1 = nn.Linear(z_dim + n_labels, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, X_dim)
    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.25, training=self.training)
        x = self.lin3(x)
        return F.sigmoid(x)

训练过程要把labels加进去

EPS = 1e-15
# 学习率
gen_lr = 0.0001
reg_lr = 0.00005
# 隐变量的维度
z_red_dims = 15
# 类别数
n_labels = 10
# encoder
Q = Q_net(784, 1000, z_red_dims).cuda()
# decoder
P = P_net(784, 1000, z_red_dims, n_labels).cuda()
# discriminator
D_gauss = D_net_gauss(500, z_red_dims).cuda()

# encode/decode 优化器
optim_P = torch.optim.Adam(P.parameters(), lr=gen_lr)
optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=gen_lr)
# GAN部分优化器
optim_Q_gen = torch.optim.Adam(Q.parameters(), lr=reg_lr)
optim_D = torch.optim.Adam(D_gauss.parameters(), lr=reg_lr)

# 数据迭代器
data_iter = iter(data_loader)
iter_per_epoch = len(data_loader)
total_step = 50000

for step in range(total_step):

    if (step + 1) % iter_per_epoch == 0:
        data_iter = iter(data_loader)

    # 从MNSIT数据集中拿样本
    images, labels = next(data_iter)
    images = to_var(images.view(images.size(0), -1))
    y = torch.zeros(batch_size, n_labels).scatter_(1, labels.unsqueeze(1), 1)
    y = to_var(y)

    # 把这三个模型的累积梯度清空
    P.zero_grad()
    Q.zero_grad()
    D_gauss.zero_grad()
    ################ Autoencoder部分 ######################
    # encoder 编码x, 生成z
    z_sample = Q(images)
    # decoder 解码z和y, 生成x'
    X_sample = P(torch.cat((z_sample, y), dim=1))
    # 这里计算下autoencoder 的重建误差|x' - x|
    recon_loss = F.binary_cross_entropy(X_sample + EPS, images + EPS)

    # 优化autoencoder
    recon_loss.backward()
    optim_P.step()
    optim_Q_enc.step()

    ################ GAN 部分 #############################

    # 从正太分布中, 采样real gauss(真-高斯分布样本点)
    z_real_gauss = V(randn(images.size()[0], z_red_dims) * 5.).cuda()
    # 判别器判别一下真的样本, 得到loss
    D_real_gauss = D_gauss(z_real_gauss)

    # 用encoder 生成假样本
    Q.eval()  # 切到测试形态, 这时候, Q(即encoder)不参与优化
    z_fake_gauss = Q(images)
    # 用判别器判别假样本, 得到loss
    D_fake_gauss = D_gauss(z_fake_gauss)

    # 判别器总误差
    D_loss = -mean(log(D_real_gauss + EPS) + log(1 - D_fake_gauss + EPS))

    # 优化判别器
    D_loss.backward()
    optim_D.step()

    # encoder充当生成器
    Q.train()  # 切换训练形态, Q(即encoder)参与优化
    z_fake_gauss = Q(images)
    D_fake_gauss = D_gauss(z_fake_gauss)

    G_loss = -mean(log(D_fake_gauss + EPS))

    G_loss.backward()
    # 仅优化Q
    optim_Q_gen.step()

# 训练结束后, 存一下encoder的参数
torch.save(Q.state_dict(), 'Q_encoder_weights.pt')

训练完之后, 我们可以随机一个, 让后让 从0到9变化, 可以生成如下的图:

横轴是 从0到9, 纵轴是一个内容的不同字体, 是不是很有意思, 自己动手试一下吧.

-END-

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2018-04-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能头条

只知道GAN你就OUT了——VAE背后的哲学思想及数学原理

30930
来自专栏木东居士的专栏

【Andrew Ng 机器学习】Week 1(一):基本概念和线代基础

20420
来自专栏PPV课数据科学社区

使用Python进行描述性统计

【目录】 1 描述性统计是什么? 2 使用NumPy和SciPy进行数值分析   2.1 基本概念   2.2 中心位置(均值、中位数、众数)   2.3 发散...

60570
来自专栏专知

【SIGGRAPH Asia 2017 论文选读】基于图片风格特征的画家代表作选取

【导读】第十届ACM SIGGRAPH Asia亚洲电脑图形及互动技术展览会将于今年11月27日至30日,在泰国的首都-曼谷隆重举行。本篇选取文章来自我们课题组...

30940
来自专栏Vamei实验室

概率论12 矩与矩生成函数

作者:Vamei 出处:http://www.cnblogs.com/vamei 欢迎转载,也请保留这段声明。谢谢! 我们重新回到对单随机变量分布的研究。描述...

41560
来自专栏自然语言处理

程序员眼中的统计学1

在大学时期我打了三年的暑假工,我的工作就是检测数码产品的屏幕的油墨厚度,每天都有好几万的数据报表,面对着这大量的待处理数据,我们就利用了图表来进行统计。

8220
来自专栏CVer

【CVPR 2018 行人重识别】相关论文介绍(下)

上次Amusi很荣幸邀请linolzhang大佬分享关于深度学习的行人重识别方向相关论文整理的文章。由于篇幅有限,上次只介绍了部分论文内容和文献,本文将继续介绍...

45430
来自专栏CSDN技术头条

深入浅出LSTM神经网络

【编者按】使用前馈卷积神经网络(convnets)来解决计算机视觉问题,是深度学习最广为人知的成果,但少数公众的注意力已经投入到使用递归神经网络来对时间关系进行...

24090
来自专栏Java与Android技术栈

基于边缘保留滤波实现人脸磨皮的算法

快速边缘保留滤波是通过积分图像实现局部均方差的边缘保留模糊算法,计算简单而且可以做到计算量跟半径无关。 首先局部均方差滤波中计算局部均值的公式如下:

23320
来自专栏腾讯音视频实验室

带宽节省利器——帧率上采样

目前大多数人都关注点都在超分辨率技术上,为何不Pick一下帧率上采样呢?

7.7K40

扫码关注云+社区

领取腾讯云代金券