# 【干货】用极少量样本有效的训练分类器-对抗自编码器PyTorch手把手实战系列

### 具体代码

Encoder现在既要生成 y (10维的one-hot), 又要生成z(维度为：z_dim):

```# Encoder
class Q_net(nn.Module):
def __init__(self):
super(Q_net, self).__init__()
self.lin1 = nn.Linear(X_dim, N)
self.lin2 = nn.Linear(N, N)
# 隐层变量 z
self.lin3gauss = nn.Linear(N, z_dim)
# 类别label y
self.lin3cat = nn.Linear(N, n_classes)

def forward(self, x):
x = F.dropout(self.lin1(x), p=0.25, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.25, training=self.training)
x = F.relu(x)
xgauss = self.lin3gauss(x)
xcat = F.softmax(self.lin3cat(x))

return xcat, xgauss```

```# Decoder
class P_net(nn.Module):
def __init__(self):
super(P_net, self).__init__()
self.lin1 = nn.Linear(z_dim + n_classes, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, X_dim)

def forward(self, x):
x = self.lin1(x)
x = F.dropout(x, p=0.25, training=self.training)
x = F.relu(x)
x = self.lin2(x)
x = F.dropout(x, p=0.25, training=self.training)
x = self.lin3(x)
return F.sigmoid(x)```

```# 类别y 的判别器
class D_net_cat(nn.Module):
def __init__(self):
super(D_net_cat, self).__init__()
self.lin1 = nn.Linear(n_classes, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, 1)

def forward(self, x):
x = self.lin1(x)
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.lin2(x)
x = F.relu(x)
x = self.lin3(x)
return F.sigmoid(x)

# 隐含向量z 的判别器
class D_net_gauss(nn.Module):
def __init__(self):
super(D_net_gauss, self).__init__()
self.lin1 = nn.Linear(z_dim, N)
self.lin2 = nn.Linear(N, N)
self.lin3 = nn.Linear(N, 1)

def forward(self, x):
x = F.dropout(self.lin1(x), p=0.2, training=self.training)
x = F.relu(x)
x = F.dropout(self.lin2(x), p=0.2, training=self.training)
x = F.relu(x)

return F.sigmoid(self.lin3(x))```

```if not labeled:
z_sample = torch.cat(Q(X), 1)
X_sample = P(z_sample)

#######################
# autoencoder重建
#######################
recon_loss = F.binary_cross_entropy(X_sample + TINY,
X.resize(train_batch_size, X_dim) + TINY)
recon_loss = recon_loss
recon_loss.backward()
P_decoder.step()
Q_encoder.step()

recon_loss = recon_loss
#######################
# y， z的生成
#######################
Q.eval()
z_real_cat = sample_categorical(train_batch_size,
n_classes=n_classes)
z_real_gauss = Variable(torch.randn(train_batch_size, z_dim))
if cuda:
z_real_cat = z_real_cat.cuda()
z_real_gauss = z_real_gauss.cuda()

z_fake_cat, z_fake_gauss = Q(X)

D_real_cat = D_cat(z_real_cat)
D_real_gauss = D_gauss(z_real_gauss)
D_fake_cat = D_cat(z_fake_cat)
D_fake_gauss = D_gauss(z_fake_gauss)

D_loss_cat = -torch.mean(torch.log(D_real_cat + TINY) +
torch.log(1 - D_fake_cat + TINY))
D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) +
torch.log(1 - D_fake_gauss + TINY))

D_loss = D_loss_cat + D_loss_gauss
D_loss = D_loss

D_loss.backward()
D_cat_solver.step()
D_gauss_solver.step()

# Generator
Q.train()
z_fake_cat, z_fake_gauss = Q(X)

D_fake_cat = D_cat(z_fake_cat)
D_fake_gauss = D_gauss(z_fake_gauss)

G_loss = - torch.mean(torch.log(D_fake_cat + TINY)) -
torch.mean(torch.log(D_fake_gauss + TINY))
G_loss = G_loss
G_loss.backward()
Q_generator.step()

#######################
# 半监督部分，分类器的训练
#######################
if labeled:
pred, _ = Q(X)
class_loss = F.cross_entropy(pred, target)
class_loss.backward()
Q_semi_supervised.step()

http://p57mvlyrw.bkt.clouddn.com/code.zip

-END-

0 条评论

• ### 6000星人气深度学习资源！架构模型技巧全都有，图灵奖得主LeCun推荐

有这么一份干货，汇集了机器学习架构和模型的经典知识点，还有各种TensorFlow和PyTorch的Jupyter Notebook笔记资源，地址都在，无需等待...

• ### 6000星人气深度学习资源！架构模型技巧全都有，图灵奖得主LeCun推荐

有这么一份干货，汇集了机器学习架构和模型的经典知识点，还有各种TensorFlow和PyTorch的Jupyter Notebook笔记资源，地址都在，无需等待...

• ### 6000星人气深度学习资源！架构模型技巧全都有，图灵奖得主LeCun推荐

有这么一份干货，汇集了机器学习架构和模型的经典知识点，还有各种TensorFlow和PyTorch的Jupyter Notebook笔记资源，地址都在，无需等待...

• ### 这些资源你肯定需要！超全的GAN PyTorch+Keras实现集合

生成对抗网络及其变体的实现分为基于 Keras 和基于 PyTorch 两个版本。它们都是按照原论文实现的，但模型架构并不一定完全和原论文相同，作者关注于实现这...

• ### 这些资源你肯定需要！超全的GAN PyTorch+Keras实现集合

选自GitHub 作者：eriklindernoren 机器之心编译 参与：刘晓坤、思源、李泽南 生成对抗网络一直是非常美妙且高效的方法，自 14 年 Ian...

• ### 【专知-PyTorch手把手深度学习教程04】GAN快速理解与PyTorch实现: 图文+代码

【导读】主题链路知识是我们专知的核心功能之一，为用户提供AI领域系统性的知识学习服务，一站式学习人工智能的知识，包含人工智能（ 机器学习、自然语言处理、计算机视...

• ### Texar-PyTorch：在PyTorch中集成TensorFlow的最佳特性

Texar-PyTorch 对各类不同的机器学习任务有着广泛的支持，尤其是自然语言处理（NLP）和文本生成任务。