专栏首页贾志刚-OpenCV学堂轻松学Pytorch – 构建生成对抗网络

轻松学Pytorch – 构建生成对抗网络

引言

又好久没有继续写了,这个是我写的第21篇文章,我还在继续坚持写下去,虽然经常各种拖延症,但是我还记得,一直没有敢忘记!今天给大家分享一下Pytorch生成对抗网络代码实现。

01

什么是生成对抗网络

Ian J. Goodfellow在2014年提出生成对抗网络,从此打开了深度学习中另外一个重要分支,让生成对抗网络(GAN)成为与卷积神经网络(CNN)、循环神经网络(RNN/LSTM)可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一,每年还有大量相关的论文产生,GAN已经被用在视觉任务的很多方面,主要包括:

  • 图像合成与数据增广
  • 图像翻译与变换
  • 缺陷检测
  • 图像去噪与重建
  • 图像分割

但是GAN最基本的核心思想还是2014年Ian J. Goodfellow在论文中提到的两个基本的模型分别是:生成器与判别器

生成器(G):

根据输入噪声Z生成输出样本G(z)
目标:通过生成样本与目标样本分布一致,成功欺骗鉴别器

判别器(D):

根据输入样本数据来分辨真实样本概率
从数据中学习样本数据的差异性

从a到d,可以看到输入噪声的生成分布越来越接近真实分布X,最终达到一种平衡状态,这种稳定的平衡状态叫纳什均衡,还有一部电影跟这个有关系叫《美丽心灵》。

02

GAN代码实现

下面的代码实现了基于Mnist数据集实现判别器与生成器,最终通过生成器可以自动生成手写数字识别的图像,输入的z=100是随机噪声,输出的是784个数据表示28x28大小的手写数字样本,损失主要来自两个部分,生成器生成损失,判别器分别判别真实与虚构样本概率,基于反向传播训练两个网络,设置epoch=100,得到最终的生成器生成结果如下:

生成器与判别器代码实现如下

HARR特征级联分类器人脸检测来自VJ的2004论文中提出,其主要思想可以通过下面一张图像解释:

transform = tv.transforms.Compose([tv.transforms.ToTensor(),
                                   tv.transforms.Normalize((0.5,), (0.5,))])
train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False)
test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False)


class Generator(t.nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = t.nn.Linear(g_input_dim, 256)
        self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
        self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
        self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return t.tanh(self.fc4(x))


class Discriminator(t.nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = t.nn.Linear(d_input_dim, 1024)
        self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = t.nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return t.sigmoid(self.fc4(x))

损失与训练代码如下

分别定义生成网络训练与鉴别网络的训练方法,然后开始训练即可,代码实现如下:

# 生成者与判别者
bs = 128
z_dim = 100
mnist_dim = 784
# loss
criterion = t.nn.BCELoss()

# optimizer
device = "cuda"
gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
dnet = Discriminator(mnist_dim).to(device)
lr = 0.0002
G_optimizer = t.optim.Adam(gnet.parameters(), lr=lr)
D_optimizer = t.optim.Adam(dnet.parameters(), lr=lr)


def D_train(x):
    # =======================Train the discriminator=======================#
    dnet.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = dnet(x_real)
    D_real_loss = criterion(D_output, y_real)

    # train discriminator on facke
    z = Variable(t.randn(bs, z_dim).to(device))
    x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device))

    D_output = dnet(x_fake)
    D_fake_loss = criterion(D_output, y_fake)

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return D_loss.data.item()


def G_train(x):
    # =======================Train the generator=======================#
    gnet.zero_grad()

    z = Variable(t.randn(bs, z_dim).to(device))
    y = Variable(t.ones(bs, 1).to(device))

    G_output = gnet(z)
    D_output = dnet(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()


n_epoch = 100
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_dl):
        bs_, _,_,_ = x.size()
        bs = bs_
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL),作者:gloomyfish

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-08-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【项目实践】中英文文字检测与识别项目(CTPN+CRNN+CTC Loss原理讲解)

    文字识别也是图像领域一个常见问题。然而,对于自然场景图像,首先要定位图像中的文字位置,然后才能进行文字的识别。

    OpenCV学堂
  • 人脸识别损失函数的汇总 | Pytorch版本实现

    这篇文章的重点不在于讲解FR的各种Loss,因为知乎上已经有很多,搜一下就好,本文主要提供了各种Loss的Pytorch实现以及Mnist的可视化实验,一方面让...

    OpenCV学堂
  • 【项目实践】YOLO V4万字原理详细讲解并训练自己的数据集(pytorch完整项目打包下载)

    YOLOV4是YOLOV3的改进版,在YOLOV3的基础上结合了非常多的小Tricks。尽管没有目标检测上革命性的改变,但是YOLOV4依然很好...

    OpenCV学堂
  • python爬虫之验证码识别

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    云雀叫了一整天
  • Python解析XML字符串

    py3study
  • Python导出Excel图表

    Python自动化办公的过程,部分涉及到导出Excel图表;本篇主要讲下使用python代码将excel中的图表导出为图片的开发过程;

    py3study
  • 验证码

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    云雀叫了一整天
  • 教你制作可移动的导航栏

    Dwyane
  • 使用PyQt5实现图片查看器的示例代码

    在学习 PyQt5 的过程中我会不断地做一些小的 Demo,用于让自己能够更好地理解和学习,这次要做的就是一个图片查看器,主要功能包括打开图片、拖动图片、放大和...

    砸漏
  • 使用ReactiveCocoa开发RSS阅读器

    目前已经完成的功能有对RSS的解析和Atom解析,RSS内容本地数据库存储和读取,抓取中状态进度展示,标记阅读状态,标记全部已读等。这些功能里我对一些异步操作产...

    用户7451029

扫码关注云+社区

领取腾讯云代金券