前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

作者头像
润森
发布2022-08-18 09:27:57
3070
发布2022-08-18 09:27:57
举报
文章被收录于专栏:毛利学Python毛利学Python

「@Author:Runsen」

GAN 是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。

生成性对抗网络

2014,蒙特利尔大学的Ian Goodfellow和他的朋友发明了生成性对抗网络(GAN)。自它出版以来,有许多它的变体和客观功能来解决它的问题

论文在这里找到.

论文提出了两种模型:生成模型和判别模型。两个模型竞争,以产生真实和假的样本。2016年,Yann LeCun将GANs描述为“过去二十年机器学习中最酷的想法”。

GAN 的大部分研究和应用都集中在计算机视觉领域。

其原因是卷积神经网络 (CNN) 等深度学习模型在过去 5 到 7 年中在计算机视觉领域取得了巨大成功,例如在具有挑战性的任务(如对象检测和人脸识别。

GAN 的典型例子是生成新的逼真的照片,最令人吃惊的是生成照片般逼真的人脸的例子。

在本教程中,我们将实现一个简单的GAN生成假的MNIST样本。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils

import numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)  #cuda

使用MNIST数据集,具有最小大小的数据集。

它由60000个训练图像和10000个测试图像组成,每个图像有28*28的大小和一个彩色通道。

代码语言:javascript
复制
# Define a transform 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])

# batch_size是一个前向和后向传播过程中的图像数。
batch_size = 100

mnist = datasets.MNIST('./data/MNIST', 
                       download = True, 
                       train = True, 
                       transform = transform)

mnist_loader = DataLoader(dataset = mnist, 
                          batch_size = batch_size, 
                          shuffle = True)
# CPU
def imshow(img, title):
    img = utils.make_grid(img.cpu().detach())
    img = (img+1)/2
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()
#GPU
def imshow(img, title):
    npimg = img.detach().numpy()
    fig = plt.figure(figsize = (10, 10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一个GANs模型。一个Generator和Discriminator

GANs由完全连接的层组成。它将从100维高斯分布采样的噪声转换为MNIST图像。鉴别器网络也由完全连接的层组成,用于区分输入数据是真是假。

代码语言:javascript
复制
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        latent_size = 100
        output = 28*28
        
        self.main = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, output),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.main(x)
        out = out.view(-1, 1, 28, 28)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        n_features = 28 * 28
        n_out = 1
        
        self.main = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            
            nn.Linear(64, n_out),
            nn.Sigmoid()        
        )
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        out = self.main(x)
        return out

G = Generator().to(device)
D = Discriminator().to(device)

生成性对抗网络训练过程的损失函数是二进制交叉熵损失,由torch.nn.BCELoss实现。

这两种模型都使用torch.optim.Adam作为优化工具,学习率设置为0.002。

代码语言:javascript
复制
# Objective Function
criterion = nn.BCELoss()

# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)

# Lists
G_losses = []
D_losses = []

# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)

# Train
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(mnist_loader):
        
        # Images #
        images = images.reshape(batch_size, -1).float().to(device)
        
        # Labels #
        ones = torch.ones(batch_size, 1).to(device)
        zeros = torch.zeros(batch_size, 1).to(device)
        
        # Noise #
        noise = torch.randn(batch_size, noise_dim).to(device)
        
        # Initialize Optimizers
        D_optimizer.zero_grad()
        G_optimizer.zero_grad()
        
        #######################
        # Train Discriminator #
        #######################
        
        # Forward Images #
        prob_real = D(images)
        D_real_loss = criterion(prob_real, ones)
        
        # Generate Samples #
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # Forward Fake Samples and Calculate Discriminator Loss #
        D_fake_loss = criterion(prob_fake, zeros)
        D_loss = (D_real_loss + D_fake_loss).mean()
        
        # Back Propagation and Update
        D_loss.backward()
        D_optimizer.step()
        
        ###################
        # Train Generator #
        ###################
        
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # According to the section 3 in paper,
        # early in learning, when G is very poor, D can reject samples from G.
        # In this case, log(1-D(G(z))) saturates. 
        # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
        G_loss = criterion(prob_fake, ones)
        
        # Back Propagation and Update
        G_loss.backward()
        G_optimizer.step()
        
        # Save Losses for Plotting Later
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        
        # Print Statistics #
        if (i + 1) % 100 == 0:
            print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
                  %(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
    
    # Generate Samples #
    if epoch % 1 == 0:
        fake_samples = G(fixed_noise)
        imshow(fake_samples, "Generated MNIST Images")
    
# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')
代码语言:javascript
复制
plt.figure(figsize = (8, 6))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()
代码语言:javascript
复制
sample_size = 64
noise_dim = 100

noise = torch.randn(sample_size, noise_dim).to(device)

G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性对抗网络的运用

  • 将语义图像翻译成城市景观和建筑物的照片。
  • 将卫星照片翻译成地图。
  • 从白天到晚上的照片翻译。
  • 将黑白照片翻译成彩色。

- 论文在这里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

- 上述代码的论文:https://arxiv.org/abs/1511.06434

- 上述代码:https://github.com/yihui-he/GAN-MNIST

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-07-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小刘IT教程 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 生成性对抗网络
相关产品与服务
AI 应用产品
文字识别(Optical Character Recognition,OCR)基于腾讯优图实验室的深度学习技术,将图片上的文字内容,智能识别成为可编辑的文本。OCR 支持身份证、名片等卡证类和票据类的印刷体识别,也支持运单等手写体识别,支持提供定制化服务,可以有效地代替人工录入信息。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档