# 【干货】对抗自编码器PyTorch手把手实战系列——对抗自编码器学习笔迹风格

• 将图像输入到Encoder, 得到隐变量

• 将隐变量和标签 组合成一个新变量, 然后输入到Decoder中

• 这样, Encoder专心于学习字体, Decoder会根据学习出的字体表示加上笔迹内容 去生成笔迹

Decoder的输入维度扩充n_labels:

```# p(x|z)
class P_net(nn.Module):
def __init__(self,X_dim,N,z_dim,n_labels):
super(P_net, self).__init__()
self.lin1 = nn.Linear(z_dim + n_labels, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, X_dim)
def forward(self, x):
x = F.dropout(self.lin1(x), p=0.25, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.25, training=self.training)
x = self.lin3(x)
return F.sigmoid(x)```

```EPS = 1e-15
# 学习率
gen_lr = 0.0001
reg_lr = 0.00005
# 隐变量的维度
z_red_dims = 15
# 类别数
n_labels = 10
# encoder
Q = Q_net(784, 1000, z_red_dims).cuda()
# decoder
P = P_net(784, 1000, z_red_dims, n_labels).cuda()
# discriminator
D_gauss = D_net_gauss(500, z_red_dims).cuda()

# encode/decode 优化器
# GAN部分优化器

# 数据迭代器
total_step = 50000

for step in range(total_step):

if (step + 1) % iter_per_epoch == 0:

# 从MNSIT数据集中拿样本
images, labels = next(data_iter)
images = to_var(images.view(images.size(0), -1))
y = torch.zeros(batch_size, n_labels).scatter_(1, labels.unsqueeze(1), 1)
y = to_var(y)

# 把这三个模型的累积梯度清空
################ Autoencoder部分 ######################
# encoder 编码x, 生成z
z_sample = Q(images)
# decoder 解码z和y, 生成x'
X_sample = P(torch.cat((z_sample, y), dim=1))
# 这里计算下autoencoder 的重建误差|x' - x|
recon_loss = F.binary_cross_entropy(X_sample + EPS, images + EPS)

# 优化autoencoder
recon_loss.backward()
optim_P.step()
optim_Q_enc.step()

################ GAN 部分 #############################

# 从正太分布中, 采样real gauss(真-高斯分布样本点)
z_real_gauss = V(randn(images.size()[0], z_red_dims) * 5.).cuda()
# 判别器判别一下真的样本, 得到loss
D_real_gauss = D_gauss(z_real_gauss)

# 用encoder 生成假样本
Q.eval()  # 切到测试形态, 这时候, Q(即encoder)不参与优化
z_fake_gauss = Q(images)
# 用判别器判别假样本, 得到loss
D_fake_gauss = D_gauss(z_fake_gauss)

# 判别器总误差
D_loss = -mean(log(D_real_gauss + EPS) + log(1 - D_fake_gauss + EPS))

# 优化判别器
D_loss.backward()
optim_D.step()

# encoder充当生成器
Q.train()  # 切换训练形态, Q(即encoder)参与优化
z_fake_gauss = Q(images)
D_fake_gauss = D_gauss(z_fake_gauss)

G_loss = -mean(log(D_fake_gauss + EPS))

G_loss.backward()
# 仅优化Q
optim_Q_gen.step()

# 训练结束后, 存一下encoder的参数
torch.save(Q.state_dict(), 'Q_encoder_weights.pt')```

-END-

0 条评论

## 相关文章

30930

20420

### 使用Python进行描述性统计

【目录】 1 描述性统计是什么？ 2 使用NumPy和SciPy进行数值分析 　　2.1 基本概念 　　2.2 中心位置（均值、中位数、众数） 　　2.3 发散...

60570

### 【SIGGRAPH Asia 2017 论文选读】基于图片风格特征的画家代表作选取

【导读】第十届ACM SIGGRAPH Asia亚洲电脑图形及互动技术展览会将于今年11月27日至30日，在泰国的首都－曼谷隆重举行。本篇选取文章来自我们课题组...

30940

41560

8220

45430

### 深入浅出LSTM神经网络

【编者按】使用前馈卷积神经网络（convnets）来解决计算机视觉问题，是深度学习最广为人知的成果，但少数公众的注意力已经投入到使用递归神经网络来对时间关系进行...

24090

23320

7.7K40