风格迁移三部曲(一)之普通风格迁移

前言

风格迁移(Style Transfer)是一个很有意思的任务,通过风格迁移可以使一张图片保持本身内容大致不变的情况下呈现出另外一张图片的风格。风格迁移三步曲将绍以下三种风格迁移方式以及对应的代码实现

  1. 固定风格固定内容的普通风格迁移(A Neural Algorithm of Artistic Style)
  2. 固定风格任意内容的快速风格迁移(Perceptual Losses for Real-Time Style Transfer and Super-Resolution)
  3. 任意风格任意内容的极速风格迁移(Meta Networks for Neural Style Transfer)

本文会介绍固定风格固定内容的普通风格迁移(A Neural Algorithm of Artistic Style)方式以及对应的代码实现。本文所使用的环境是 pytorch 0.4.0,如果你使用了其他的版本,稍作修改即可正确运行。

1 固定风格固定内容的普通风格迁移

最早的风格迁移就是在固定风格、固定内容的情况下做的风格迁移,这是最慢的方法,也是最经典的方法。

最原始的风格迁移的思路很简单,把图片当做可以训练的变量,通过优化图片来降低与内容图片的内容差异以及降低与风格图片的风格差异,迭代训练多次以后,生成的图片就会与内容图片的内容一致,同时也会与风格图片的风格一致。

VGG16

VGG16 是一个很经典的模型,它通过堆叠 3x3 的卷积层和池化层,在 ImageNet 上获得了不错的成绩。我们使用在 ImageNet 上经过预训练的 VGG16 模型可以对图像提取出有用的特征,这些特征可以帮助我们去衡量两个图像的内容差异和风格差异。

在进行风格迁移任务时,我们只需要提取其中几个比较重要的层,所以我们对 pytorch 自带的预训练 VGG16 模型稍作了一些修改:

 1class VGG(nn.Module):
 2
 3    def __init__(self, features):
 4        super(VGG, self).__init__()
 5        self.features = features
 6        self.layer_name_mapping = {
 7            '3': "relu1_2",
 8            '8': "relu2_2",
 9            '15': "relu3_3",
10            '22': "relu4_3"
11        }
12        for p in self.parameters():
13            p.requires_grad = False
14
15    def forward(self, x):
16        outs = []
17        for name, module in self.features._modules.items():
18            x = module(x)
19            if name in self.layer_name_mapping:
20                outs.append(x)
21        return outs
22
23vgg16 = models.vgg16(pretrained=True)
24vgg16 = VGG(vgg16.features[:23]).to(device).eval()

经过修改的 VGG16 可以输出 relu1_2、relu2_2、relu3_3、relu4_3 这几个特定层的特征图。下面这两句代码就是它的用法:

1features = vgg16(input_img)
2content_features = vgg16(content_img)

举个例子,当我们使用 vgg16 对 input_img 计算特征时,它会返回四个矩阵给 features,假设 input_img 的尺寸是 [1, 3, 512, 512](四个维度分别代表 batch, channels, height, width),那么它返回的四个矩阵的尺寸就是这样的:

  • relu1_2 [1, 64, 512, 512]
  • relu2_2 [1, 128, 256, 256]
  • relu3_3 [1, 256, 128, 128]
  • relu4_3 [1, 512, 64, 64]

内容

我们进行风格迁移的时候,必须保证生成的图像与内容图像的内容一致性,不然风格迁移就变成艺术创作了。那么如何衡量两张图片的内容差异呢?很简单,通过 VGG16 输出的特征图来衡量图片的内容差异。

提示:在本方法中没有 Image Transform Net,为了表述方便,我们使用了第二篇论文中的图。

这里使用的损失函数是:

其中:

根据生成图像和内容图像在 relu3_3 输出的特征图的均方误差(MeanSquaredError)来优化生成的图像与内容图像之间的内容一致性。

那么写成代码就是这样的:

1content_loss = F.mse_loss(features[2], content_features[2]) * content_weight

因为我们这里使用的是经过在 ImageNet 预训练过的 VGG16 提取的特征图,所以它能提取出图像的高级特征,通过优化生成图像和内容图像特征图的 mse,可以迫使生成图像的内容与内容图像在 VGG16 的 relu3_3 上输出相似的结果,因此生成图像和内容图像在内容上是一致的。

风格

Gram 矩阵

那么如何衡量输入图像与风格图像之间的内容差异呢?这里就需要提出一个新的公式,Gram 矩阵:

其中:

具体到代码,我们可以写出下面的函数:

1def gram_matrix(y):
2    (b, ch, h, w) = y.size()
3    features = y.view(b, ch, w * h)
4    features_t = features.transpose(1, 2)
5    gram = features.bmm(features_t) / (ch * h * w)
6    return gram

参考链接:

https://github.com/pytorch/examples/blob/0.4/fast_neural_style/neural_style/utils.py#L21-L26

假设我们输入了一个 [1, 3, 512, 512] 的图像,下面就是各个矩阵的尺寸:

  • relu1_2 [1, 64, 512, 512],gram [1, 64, 64]
  • relu2_2 [1, 128, 256, 256],gram [1, 128, 128]
  • relu3_3 [1, 256, 128, 128],gram [1, 256, 256]
  • relu4_3 [1, 512, 64, 64],gram [1, 512, 512]

风格损失

根据生成图像和风格图像在relu1_2、relu2_2、relu3_3、relu4_3 输出的特征图的 Gram 矩阵之间的均方误差(MeanSquaredError)来优化生成的图像与风格图像之间的风格差异:

其中:

那么写成代码就是下面这样:

1style_grams = [gram_matrix(x) for x in style_features]
2
3style_loss = 0
4grams = [gram_matrix(x) for x in features]
5for a, b in zip(grams, style_grams):
6    style_loss += F.mse_loss(a, b) * style_weight

训练

那么风格迁移的目标就很简单了,直接将两个 loss 按权值加起来,然后对图片优化 loss,即可优化出既有内容图像的内容,也有风格图像的风格的图片。代码如下:

 1input_img = content_img.clone()
 2optimizer = optim.LBFGS([input_img.requires_grad_()])
 3style_weight = 1e6
 4content_weight = 1
 5
 6run = [0]
 7while run[0] <= 300:
 8    def f():
 9        optimizer.zero_grad()
10        features = vgg16(input_img)
11
12        content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
13        style_loss = 0
14        grams = [gram_matrix(x) for x in features]
15        for a, b in zip(grams, style_grams):
16            style_loss += F.mse_loss(a, b) * style_weight
17
18        loss = style_loss + content_loss
19
20        if run[0] % 50 == 0:
21            print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format(
22                run[0], style_loss.item(), content_loss.item()))
23        run[0] += 1
24
25        loss.backward()
26        return loss
27
28    optimizer.step(f)

此处使用了 LBFGS,所以 loss 需要包装在一个函数里,代码参考了:

https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

效果

最终效果如图所示:

可以看到生成的图像既有风格图像的风格,也有内容图像的内容,很完美。不过生成一幅256x256 的图像在 1080ti 上需要18.6s,这个时间挺长的,谈不上实时性。

link: https://zhuanlan.zhihu.com/p/40322927

github: https://github.com/CortexFoundation/StyleTransferTrilogy

预告

下一篇风格迁移三部曲(二)会介绍固定风格任意内容的快速风格迁移(Perceptual Losses for Real-Time Style Transfer and Super-Resolution)方式以及对应的代码实现。敬请期待

原文发布于微信公众号 - CVer(CVerNews)

原文发表时间:2018-07-30

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI科技大本营的专栏

别磨叽,学完这篇你也是图像识别专家了

图像识别是当今深度学习的主流应用,而Keras是入门最容易、使用最便捷的深度学习框架,所以搞图像识别,你也得强调速度,不能磨叽。本文让你在最短时间内突破五个流行...

77270
来自专栏深度学习自然语言处理

深度学习之激活函数详解

激活函数是什么 激活函数,即Activation Function,有时候也称作激励函数。它是为了解决线性不可分的问题引出的。但是也不是说线性可分就不能用激活函...

30670
来自专栏智能算法

SVM 的“核”武器

一、上一次我们讲到关于SVM通过拉格朗日乘子法去求解的部分,引入乘子 得到下面的式子: ? 我们令 ? 当所有的约束条件满足时,我们得到的 ? ,而之前的优...

34560
来自专栏marsggbo

DeepLearning.ai学习笔记(四)卷积神经网络 -- week3 目标检测

一、目标定位 这一小节视频主要介绍了我们在实现目标定位时标签该如何定义。 ? 上图左下角给出了损失函数的计算公式(这里使用的是平方差) 如图示,加入我们需...

23460
来自专栏ATYUN订阅号

一文带你认识深度学习中不同类型的卷积

卷积(convolution)现在可能是深度学习中最重要的概念。靠着卷积和卷积神经网络(CNN),深度学习超越了几乎其它所有的机器学习手段。 ? 这篇文章将简要...

47290
来自专栏书山有路勤为径

Recurrent Neural Networks (RNNs)

许多应用涉及时间依赖,或基于时间依赖。这表示我们当前输出不仅仅取决于当前输入,还依赖于过去的输入。 RNN存在一个关键缺陷,因为几乎不可能捕获超过8或10步的...

18230
来自专栏ml

深度学习之图像的数据增强

   在图像的深度学习中,为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合),一般都会对数据图像进行数据增强, 数据增强,常用的方式,就是旋转图...

94970
来自专栏人工智能LeadAI

黑猿大叔-译文 | TensorFlow实现Batch Normalization

原文:Implementing Batch Normalization in Tensorflow(https://r2rt.com/implementing-...

57680
来自专栏Ldpe2G的个人博客

Mxnet 实现图片快速风格化

论文链接:Perceptual Losses for Real-Time Style Transfer and Super-Resolution

18570
来自专栏人工智能

深度学习之激活函数详解

激活函数是什么 激活函数,即Activation Function,有时候也称作激励函数。它是为了解决线性不可分的问题引出的。但是也不是说线性可分就不能用激活函...

26280

扫码关注云+社区

领取腾讯云代金券