# 轻松学Pytorch – 构建生成对抗网络

01

Ian J. Goodfellow在2014年提出生成对抗网络，从此打开了深度学习中另外一个重要分支，让生成对抗网络（GAN）成为与卷积神经网络（CNN）、循环神经网络（RNN/LSTM）可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一，每年还有大量相关的论文产生，GAN已经被用在视觉任务的很多方面，主要包括：

• 图像合成与数据增广
• 图像翻译与变换
• 缺陷检测
• 图像去噪与重建
• 图像分割

```根据输入噪声Z生成输出样本G(z)

```根据输入样本数据来分辨真实样本概率

02

GAN代码实现

HARR特征级联分类器人脸检测来自VJ的2004论文中提出，其主要思想可以通过下面一张图像解释：

```transform = tv.transforms.Compose([tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5,), (0.5,))])
train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False)
test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False)

class Generator(t.nn.Module):
def __init__(self, g_input_dim, g_output_dim):
super(Generator, self).__init__()
self.fc1 = t.nn.Linear(g_input_dim, 256)
self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim)

# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
return t.tanh(self.fc4(x))

class Discriminator(t.nn.Module):
def __init__(self, d_input_dim):
super(Discriminator, self).__init__()
self.fc1 = t.nn.Linear(d_input_dim, 1024)
self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
self.fc4 = t.nn.Linear(self.fc3.out_features, 1)

# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.dropout(x, 0.3)
return t.sigmoid(self.fc4(x))```

```# 生成者与判别者
bs = 128
z_dim = 100
mnist_dim = 784
# loss
criterion = t.nn.BCELoss()

# optimizer
device = "cuda"
gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
dnet = Discriminator(mnist_dim).to(device)
lr = 0.0002

def D_train(x):
# =======================Train the discriminator=======================#

# train discriminator on real
x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1)
x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

D_output = dnet(x_real)
D_real_loss = criterion(D_output, y_real)

# train discriminator on facke
z = Variable(t.randn(bs, z_dim).to(device))
x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device))

D_output = dnet(x_fake)
D_fake_loss = criterion(D_output, y_fake)

# gradient backprop & optimize ONLY D's parameters
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()

return D_loss.data.item()

def G_train(x):
# =======================Train the generator=======================#

z = Variable(t.randn(bs, z_dim).to(device))
y = Variable(t.ones(bs, 1).to(device))

G_output = gnet(z)
D_output = dnet(G_output)
G_loss = criterion(D_output, y)

# gradient backprop & optimize ONLY G's parameters
G_loss.backward()
G_optimizer.step()

return G_loss.data.item()

n_epoch = 100
for epoch in range(1, n_epoch+1):
D_losses, G_losses = [], []
for batch_idx, (x, _) in enumerate(train_dl):
bs_, _,_,_ = x.size()
bs = bs_
D_losses.append(D_train(x))
G_losses.append(G_train(x))

print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
(epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))```

0 条评论

• ### 【项目实践】中英文文字检测与识别项目（CTPN+CRNN+CTC Loss原理讲解）

文字识别也是图像领域一个常见问题。然而，对于自然场景图像，首先要定位图像中的文字位置，然后才能进行文字的识别。

• ### 人脸识别损失函数的汇总 | Pytorch版本实现

这篇文章的重点不在于讲解FR的各种Loss，因为知乎上已经有很多，搜一下就好，本文主要提供了各种Loss的Pytorch实现以及Mnist的可视化实验，一方面让...

• ### 【项目实践】YOLO V4万字原理详细讲解并训练自己的数据集（pytorch完整项目打包下载）

YOLOV4是YOLOV3的改进版，在YOLOV3的基础上结合了非常多的小Tricks。尽管没有目标检测上革命性的改变，但是YOLOV4依然很好...

• ### python爬虫之验证码识别

版权声明：本文为博主原创文章，遵循 CC 4.0 BY-SA 版权协议，转载请附上原文出处链接和本声明。 ...

• ### Python导出Excel图表

Python自动化办公的过程，部分涉及到导出Excel图表；本篇主要讲下使用python代码将excel中的图表导出为图片的开发过程；

• ### 验证码

版权声明：本文为博主原创文章，遵循 CC 4.0 BY-SA 版权协议，转载请附上原文出处链接和本声明。 ...

• ### 使用PyQt5实现图片查看器的示例代码

在学习 PyQt5 的过程中我会不断地做一些小的 Demo，用于让自己能够更好地理解和学习，这次要做的就是一个图片查看器，主要功能包括打开图片、拖动图片、放大和...

### OpenCV学堂

OpenCV开发专家计算机视觉开发