前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pix2pix阅读及代码分析

pix2pix阅读及代码分析

原创
作者头像
薛同学要做好汉
发布2022-07-23 11:08:18
1.2K0
发布2022-07-23 11:08:18
举报
文章被收录于专栏:地图相关

前言

最近在工作中要做一些关于图片简化的事情,也就是将复杂的A图转化成简单的B图,于是便想到了用GAN来做,而用GAN的话,怎么能绕过pix2pix这篇文章 借此机会开始分享一下咯

效果

pix2pix很好地实现了图像到图像之间的映射,效果非常惊艳、逼真

论文效果

image.png
image.png

更多效果

还有很多社区大神用pix2pix做的非常惊艳的demo,可以到 https://phillipi.github.io/pix2pix/ 欣赏

image.png
image.png

背景

pix2pix要解决什么问题

  1. 像素到像素的预测任务,比如mask转街景、mask转门脸、黑白图转彩色图等在pix2pix出现以前都是单一、割裂的解决方案,没有一个通用的框架
  2. CNN利用最小化欧氏距离以减少预测像素与ground truth像素之间的差异,很容易导致生成的图片都是模糊的,即生成的图片倾向于"不犯错"而丢失了真实图片中的边缘信息等

pix2pix怎么解决问题的

  1. 利用条件GAN作为解决像素到像素预测任务的通用框架
  2. GAN能够学习loss,使得在对输出图片进行real还是fake分类的同时还能训练生成模型最小化这个loss,从而使得模糊图片会被判别为fake

方法

模型结构

生成器

U-Net,利用skip-connection能够让低维信息在输入和输出之间共享,使得其图像转译过程中不会丢失

image.png
image.png

判别器

PatchGAN 本质上是个多层全连接层,先将生成器输出的图片切分成NN个patch(N远小于图片大小,采用NN个patch的话,不仅可以使PatchGAN参数更少,运行更快还能应用于任意大的图片),对NN个patch进行real还是fake的二分类判断,最后将NN个patch的分类score进行平均得到整张图片的分类score PatchGAN将图片看作为马尔科夫随机场,假设不同patch之间的像素是相互独立的

损失函数

G^* = arg \ min_G \ max_D \ \mathbb{L}_{cGAN}(G, D) + \lambda \mathbb{L}_{L1}(G)

其中L_{cGAN} = \mathbb{E}_{x, y}[log \ D(x,y)] + \mathbb{E}[log(1 - D(x, G(x,z)))]\mathbb{L}_{L1}为L1距离

评价指标

Frechet Inception Distance(FID)考虑的是生成图像与真实图像之间的联系 用同样数目的生成图像和真实图像各N张,用Inception提取图片的特征(2048维的向量)计算N*2048维的特征向量之间的距离

FID(g,r) = |u_g-u_r||^2_2 + Tr(\sum_g + \sum_r -2(\sum_g\sum_r)^{1/2})

其中gr分别代表生成图像和真实图像,u_gu_r表示其对应图像的特征向量,\sum_g\sum_r表示其对应特征向量的协方差矩阵,Tr表示矩阵的迹(主对角线各元素的和)。特征向量之间的距离越近表示生成模型的效果越好,生成的图像清晰度越高

代码分析

构建模型

代码语言:python
代码运行次数:0
复制
if self.isTrain:
    self.model_names = ['G', 'D']
else:  # during test time, only load G
  self.model_names = ['G']
代码语言:python
代码运行次数:0
复制
# input_nc, num of input image channels: 3 for RGB and 1 for grayscale
# output_nc, num of output image channels: 3 for RGB and 1 for grayscale
# ngf, num of gen filters in the last conv layer
# netG, specify generator architecture [unet_256 | unet_128]
# norm, instance normalization or batch normalization [instance | batch | none]
# no_dropout, no dropout for the generator
# init_type, network initialization [normal | xavier | kaiming | orthogonal]
# init_gain, scaling factor for normal, xavier and orthogonal, default=0.2
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
代码语言:python
代码运行次数:0
复制
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Create a generator

    Parameters:
        input_nc (int) -- the number of channels in input images
        output_nc (int) -- the number of channels in output images
        ngf (int) -- the number of filters in the last conv layer
        netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
        norm (str) -- the name of normalization layers used in the network: batch | instance | none
        use_dropout (bool) -- if use dropout layers.
        init_type (str)    -- the name of our initialization method.
        init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Returns a generator
    """
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_9blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    elif netG == 'unet_128':
        net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    elif netG == 'unet_256':
        net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)
代码语言:python
代码运行次数:0
复制
class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)
代码语言:python
代码运行次数:0
复制
class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet submodule with skip connections.

        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)
代码语言:python
代码运行次数:0
复制
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Create a discriminator

    Parameters:
        input_nc (int)     -- the number of channels in input images
        ndf (int)          -- the number of filters in the first conv layer
        netD (str)         -- the architecture's name: basic | n_layers | pixel
        n_layers_D (int)   -- the number of conv layers in the discriminator; effective when netD=='n_layers'
        norm (str)         -- the type of normalization layers used in the network.
        init_type (str)    -- the name of the initialization method.
        init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Returns a discriminator

    Our current implementation provides three types of discriminators:
        [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
        It can classify whether 70×70 overlapping patches are real or fake.
        Such a patch-level discriminator architecture has fewer parameters
        than a full-image discriminator and can work on arbitrarily-sized images
        in a fully convolutional fashion.

        [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
        with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)

        [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
        It encourages greater color diversity but has no effect on spatial statistics.

    The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
    """
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    elif netD == 'n_layers':  # more options
        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
    elif netD == 'pixel':     # classify if each pixel is real or fake
        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
    return init_net(net, init_type, init_gain, gpu_ids)
代码语言:python
代码运行次数:0
复制
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)
代码语言:python
代码运行次数:0
复制
class PixelDiscriminator(nn.Module):
    """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
        """Construct a 1x1 PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        """
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        """Standard forward.""
        return self.net(input)

loss函数

代码语言:python
代码运行次数:0
复制
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1 = torch.nn.L1Loss()


class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

网络参数更新

代码语言:python
代码运行次数:0
复制
def optimize_parameters(self):
	self.forward()                   # compute fake images: G(A)
	# update D
	self.set_requires_grad(self.netD, True)  # enable backprop for D
	self.optimizer_D.zero_grad()     # set D's gradients to zero
	self.backward_D()                # calculate gradients for D
	self.optimizer_D.step()          # update D's weights
	# update G
	self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
	self.optimizer_G.zero_grad()        # set G's gradients to zero
	self.backward_G()                   # calculate graidents for G
	self.optimizer_G.step()             # udpate G's weights


def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

	
def backward_D(self):
	"""Calculate GAN loss for the discriminator"""
	# Fake; stop backprop to the generator by detaching fake_B
	fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
	pred_fake = self.netD(fake_AB.detach())
	self.loss_D_fake = self.criterionGAN(pred_fake, False)
	# Real
	real_AB = torch.cat((self.real_A, self.real_B), 1)
	pred_real = self.netD(real_AB)
	self.loss_D_real = self.criterionGAN(pred_real, True)
	# combine loss and calculate gradients
	self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
	self.loss_D.backward()

	
def backward_G(self):
	"""Calculate GAN and L1 loss for the generator"""
	# First, G(A) should fake the discriminator
	fake_AB = torch.cat((self.real_A, self.fake_B), 1)
	pred_fake = self.netD(fake_AB)
	self.loss_G_GAN = self.criterionGAN(pred_fake, True)
	# Second, G(A) = B
	self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
	# combine loss and calculate gradients
	self.loss_G = self.loss_G_GAN + self.loss_G_L1
	self.loss_G.backward()

FID计算

代码语言:python
代码运行次数:0
复制
# https://github.com/rosinality/stylegan2-pytorch/blob/master/fid.py
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
	"""
	sample_mean = np.mean(features, 0)
	# Estimate a covariance matrix, given data and weights.
	sample_cov = np.cov(features, rowvar=False) 

	with open(args.inception, "rb") as f:
        embeds = pickle.load(f)
        real_mean = embeds["mean"]
        real_cov = embeds["cov"]
	"""
	# @ 中缀运算符,可进行矩阵乘法
    cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)

    if not np.isfinite(cov_sqrt).all():
        print("product of cov matrices is singular")
        offset = np.eye(sample_cov.shape[0]) * eps
        cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))

    if np.iscomplexobj(cov_sqrt):
        if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
            m = np.max(np.abs(cov_sqrt.imag))

            raise ValueError(f"Imaginary component {m}")

        cov_sqrt = cov_sqrt.real

    mean_diff = sample_mean - real_mean
    mean_norm = mean_diff @ mean_diff

    trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)

    fid = mean_norm + trace

    return fid

结论

pix2pix是一种条件GAN,能够很好地胜任image-to-image的转译任务,尤其是高度结构化的图形结构任务

参考

  1. pix2pix论文 https://arxiv.org/abs/1611.07004
  2. pix2pix torch版代码 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
  3. pix2pix2 体验demo https://phillipi.github.io/pix2pix/

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 效果
    • 论文效果
      • 更多效果
      • 背景
        • pix2pix要解决什么问题
          • pix2pix怎么解决问题的
          • 方法
            • 模型结构
              • 生成器
              • 判别器
              • 损失函数
              • 评价指标
            • 代码分析
              • 构建模型
              • loss函数
              • 网络参数更新
              • FID计算
          • 结论
          • 参考
          相关产品与服务
          腾讯云代码分析
          腾讯云代码分析(内部代号CodeDog)是集众多代码分析工具的云原生、分布式、高性能的代码综合分析跟踪管理平台,其主要功能是持续跟踪分析代码,观测项目代码质量,支撑团队传承代码文化。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档