前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >从零使用Python 实现对抗神经网络GAN

从零使用Python 实现对抗神经网络GAN

作者头像
double
发布2024-03-18 16:32:13
880
发布2024-03-18 16:32:13
举报
文章被收录于专栏:算法channel算法channel

你好,我是郭震

这篇从零使用Python,实现生成对抗网络(GAN)的基本版本。

GAN使用两套网络,分别是判别器(D)网络和生成器(G)网络,最重要的是弄清楚每套网络的输入和输出分别是什么,两套网络如何结合在一起,及优化的目标即cost function如何定义。

通俗来讲,两套网络结合的方法,就是G会从D的判分中不断提升生成能力,要知道G最开始的输入全部是噪点,这个思想也是文生图,文生视频的基石

下面这段代码展示了使用PyTorch框架进行生成对抗网络(GAN)训练的基本流程。

下面这些解释非常重要:

对于判别器网络而言,它的目标是最大化表达式 log(D(x)) + log(1 - D(G(z))),其中:

  • D(x) 是判别器网络对真实图像 x 的输出,这个值代表判别器认为图像是真实的概率。
  • D(G(z)) 是判别器网络对生成图像 G(z) 的输出,这个值代表判别器认为通过生成器从噪声 z 生成的图像是真实的概率。
  • log(D(x)) 的目标是使得判别器能够尽可能地将真实图像分类为真实(即,使 D(x) 接近于1)。
  • log(1 - D(G(z))) 的目标是使得判别器能够将生成的图像分类为假(即,使 D(G(z)) 接近于0)。
代码语言:javascript
复制
# GAN 训练的基本代码
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # 更新判别器网络:maximize log(D(x)) + log(1 - D(G(z)))
        
        # 在真实图像上训练
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # 在假图像上训练
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(0)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        # 更新生成器网络:maximize log(D(G(z)))
        netG.zero_grad()
        label.fill_(1)  # 假图像的标签对于生成器来说是真的
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

在训练过程中,这一目标通过以下步骤来实现:

  1. 对于真实图像
    • 判别器 D 接收一批真实图像 x
    • 计算 D(x),即这些真实图像被识别为真实的概率。
    • 使用 log(D(x)) 计算损失。这个损失会根据真实图像被正确识别的程度(即,D(x) 应该接近于1)来调整。
  2. 对于生成图像
    • 生成器 G 从随机噪声 z 生成一批假图像。
    • 判别器 D 接收这些生成的图像,并计算 D(G(z)),即这些假图像被识别为真实的概率。
    • 使用 log(1 - D(G(z))) 计算损失。这个损失会根据假图像被正确识别的程度(即,D(G(z)) 应该接近于0)来调整。

实现细节

  • 在PyTorch中,损失函数通常是要最小化的。因此,虽然理论目标是最大化 log(D(x)) + log(1 - D(G(z))),实际上我们通过最小化 -log(D(x)) - log(1 - D(G(z))) 来实现这一目标。
  • 使用二元交叉熵损失(Binary Cross-Entropy, BCE)来实现这一目标,因为它直接提供了所需的 -log(x)-log(1-x) 形式的损失。

这个基本训练循环是理解和实现GANs的关键,而且也是后续进行各种变体和改进的基础。

弄清楚GAN的训练过程后,其他代码就比较容易理解。

导入所需的库

代码语言:javascript
复制

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils

定义生成器(Generator)

代码语言:javascript
复制

class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是 Z, 对此进行全连接
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出形状: (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

定义判别器(Discriminator)

代码语言:javascript
复制

class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入形状: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出形状: (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出形状: (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出形状: (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出形状: (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

初始化模型、优化器和损失函数

代码语言:javascript
复制

# 初始化
nz = 100  # 隐藏向量的维度
ngf = 64  # 与生成器的特征图深度相关
ndf = 64  # 与判别器的特征图深度相关
nc = 1    # 输出图像的通道数

# 创建生成器和判别器
netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)

# 初始化权重
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

# 设置优化器
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 设置损失函数
criterion = nn.BCELoss()

通过这种方式,判别器学习区分真实和生成的图像,同时生成器试图生成越来越难以被判别器区分的图像,从而实现了GAN的训练过程。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序员郭震zhenguo 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档