StyleShot: 任意风格的快照
风格迁移是计算机视觉和图像处理领域中的一项重要任务,它涉及到将一张图像(参考图像)的风格应用到另一张图像(内容图像)上。这项技术在艺术创作、设计、娱乐和多个实际应用中都有广泛的应用。随着深度学习技术的发展,特别是生成对抗网络(GANs)和扩散模型的出现,风格迁移的研究得到了显著的推动。这些模型能够生成高质量和多样化的图像,为风格迁移任务提供了新的解决方案。
StyleShot是一种创新的图像风格迁移技术,它允许用户将任意图像的风格应用到另一张图像上。这项技术的背景建立在深度学习和生成模型的快速发展之上,尤其是在文本到图像生成领域。随着扩散模型的引入,图像生成的质量得到了显著提升。
本文通过复现并解读图像风格迁移领域最新的SOTA方法,来解读基于深度学习的图像风格迁移领域的最新研究进展。本文解读的论文是《StyleShot: A Snapshot on Any Style》,作者来自同济大学和上海人工智能实验室。
论文强调了良好的风格表示对于无需测试时调整的风格迁移至关重要且足够,通过构建一个风格感知编码器(style-aware encoder)和有序的风格数据集(StyleGallery),实现了风格迁移。StyleShot方法简单有效,能够模仿各种所需的风格,如3D、平面、抽象甚至细粒度风格。通过大量实验验证,StyleShot在多种风格上的性能优于现有方法。
图像风格迁移的目标是将参考图像的风格应用到内容图像上,使得生成的图像既保持内容的一致性又展现出目标风格。这一部分主要分为两个研究方向:
此外,还有一些方法通过调整模型权重或嵌入来实现风格迁移,这些方法在测试时需要对特定风格进行调整,这导致了计算和存储成本较高。
尽管现有的风格迁移技术取得了一定的进展,但仍面临一些挑战:
针对这些挑战,作者提出了StyleShot方法,旨在通过专门设计的风格感知编码器和内容融合编码器,以及一个风格平衡的数据集,来提高风格迁移的性能和泛化能力,同时避免了测试时调整的需要。
StyleShot的架构和关键组件包括风格感知编码器(Style-Aware Encoder)、内容融合编码器(Content-Fusion Encoder)以及风格平衡数据集StyleGallery的构建和去风格化(De-stylization)策略。
首先,论文简要回顾了Stable Diffusion模型的基本原理。Stable Diffusion由两个过程组成:一个扩散过程(前向过程),它通过马尔可夫链逐步向数据x0x0添加高斯噪声ϵϵ。此外,一个去噪过程从高斯噪声xT∼N(0,1)xT∼N(0,1)生成样本,使用一个可学习的去噪模型ϵθ(xt,t,c)ϵθ(xt,t,c),该模型由参数θθ参数化。这个去噪模型ϵθ(⋅)ϵθ(⋅)是用U-Net实现的,并通过一个简化的变分界限的均方误差损失进行训练:
L=Et,x0,ϵ[∥ϵ−ϵ^θ(xt,t,c)∥2],L=Et,x0,ϵ[∥ϵ−ϵ^θ(xt,t,c)∥2],
其中cc表示一个可选条件。在Stable Diffusion中,cc通常由使用CLIP从文本提示编码的文本嵌入ftft表示,并通过交叉注意力模块整合到Stable Diffusion中,其中潜在嵌入ff被投影到查询QQ上,文本嵌入ftft被映射到键KtKt和值VtVt上。该模块的输出定义如下:
Attention(Q,Kt,Vt)=softmax(QKtTd)⋅Vt,Attention(Q,Kt,Vt)=softmax(dQKtT)⋅Vt,
其中 Q=WQ⋅fQ=WQ⋅f, Kt=WKt⋅ftKt=WKt⋅ft, Vt=WVt⋅ftVt=WVt⋅ft,WQ,WKtWQ,WKt, WVtWVt是投影的可学习权重。在我们的模型中,风格嵌入被引入作为额外的条件,并与文本的注意力值合并。
在训练一个大规模数据集上的风格迁移模型时,每个图像都被视为一种独特的风格。先前的方法通常使用CLIP图像编码器来提取风格特征。
然而,CLIP更擅长于表示与图像的语言相关性,而不是模拟图像风格,这包括了像颜色、素描和布局这样的方面,这些风格特征难以通过语言表达,限制了CLIP编码器捕捉相关风格特征的能力。
因此,作者提出了一个风格感知编码器,专门设计用于提取丰富和富有表现力的风格嵌入。
在实际应用场景中,用户会提供文本提示或图像以及一个风格参考图像,分别用来控制生成的内容和风格。先前的方法通常通过操作内容图像特征来转移风格。
然而,内容特征与风格信息是耦合的,导致生成的图像保留了内容的原始风格。这一限制阻碍了这些方法在复杂风格迁移任务中的性能。与此不同,论文通过在原始图像空间中消除风格信息,预先解耦内容信息,然后引入一个专门设计用于内容和风格整合的内容融合编码器。
这种方法的核心在于,它允许模型在不依赖于内容原始风格特征的情况下,更灵活地应用和融合不同的风格特征。通过这种方式,可以更准确地控制生成图像的风格,同时保持内容的一致性和完整性。这种预解耦和融合策略使得StyleShot能够在各种风格迁移任务中实现更高质量的结果,无论是在文本驱动的风格迁移还是图像驱动的风格迁移中。
GPU 4090D Ubuntu 20.04 PyTorch 2.0.1 Python 3.8 Cuda 11.3
以基于风格图像驱动的图像风格迁移为例,部分关键代码实现如下:
<span style="color:#333333"><span style="background-color:#ffffff"><span style="background-color:#f6f8fa"><code><span style="color:#333333"><strong>import</strong></span> os
<span style="color:#333333"><strong>from</strong></span> types <span style="color:#333333"><strong>import</strong></span> MethodType
<span style="color:#999988"><em># 导入torch库,用于深度学习模型</em></span>
<span style="color:#333333"><strong>import</strong></span> torch
<span style="color:#999988"><em># 导入OpenCV库,用于图像处理</em></span>
<span style="color:#333333"><strong>import</strong></span> cv2
<span style="color:#999988"><em># 从annotator模块导入SOFT_HEDdetector,用于边缘检测</em></span>
<span style="color:#333333"><strong>from</strong></span> annotator.hed <span style="color:#333333"><strong>import</strong></span> SOFT_HEDdetector
<span style="color:#999988"><em># 从annotator.lineart导入LineartDetector,用于线性艺术风格检测</em></span>
<span style="color:#333333"><strong>from</strong></span> annotator.lineart <span style="color:#333333"><strong>import</strong></span> LineartDetector
<span style="color:#999988"><em># 从diffusers导入UNet2DConditionModel,用于条件UNet2D模型</em></span>
<span style="color:#333333"><strong>from</strong></span> diffusers <span style="color:#333333"><strong>import</strong></span> UNet2DConditionModel, ControlNetModel
<span style="color:#999988"><em># 从transformers库导入CLIPVisionModelWithProjection,用于视觉模型</em></span>
<span style="color:#333333"><strong>from</strong></span> transformers <span style="color:#333333"><strong>import</strong></span> CLIPVisionModelWithProjection
<span style="color:#999988"><em># 从PIL库导入Image,用于图像处理</em></span>
<span style="color:#333333"><strong>from</strong></span> PIL <span style="color:#333333"><strong>import</strong></span> Image
<span style="color:#999988"><em># 从huggingface_hub导入snapshot_download,用于下载预训练模型</em></span>
<span style="color:#333333"><strong>from</strong></span> huggingface_hub <span style="color:#333333"><strong>import</strong></span> snapshot_download
<span style="color:#999988"><em># 从ip_adapter导入StyleShot和StyleContentStableDiffusionControlNetPipeline,用于风格迁移</em></span>
<span style="color:#333333"><strong>from</strong></span> ip_adapter <span style="color:#333333"><strong>import</strong></span> StyleShot, StyleContentStableDiffusionControlNetPipeline
<span style="color:#999988"><em># 导入argparse库,用于解析命令行参数</em></span>
<span style="color:#333333"><strong>import</strong></span> argparse
<span style="color:#333333"><strong>def</strong></span> <span style="color:#990000"><strong>main</strong></span>(args):
<span style="color:#999988"><em># 设置基础模型路径和transformer块路径</em></span>
base_model_path = <span style="color:#dd1144">"runwayml/stable-diffusion-v1-5"</span>
transformer_block_path = <span style="color:#dd1144">"laion/CLIP-ViT-H-14-laion2B-s32B-b79K"</span>
<span style="color:#999988"><em># 设置设备为cuda,即GPU</em></span>
device = <span style="color:#dd1144">"cuda"</span>
<span style="color:#999988"><em># 根据命令行参数选择预处理器</em></span>
<span style="color:#333333"><strong>if</strong></span> args.preprocessor == <span style="color:#dd1144">"Lineart"</span>:
detector = LineartDetector()
styleshot_model_path = <span style="color:#dd1144">"Gaojunyao/StyleShot_lineart"</span>
<span style="color:#333333"><strong>elif</strong></span> args.preprocessor == <span style="color:#dd1144">"Contour"</span>:
detector = SOFT_HEDdetector()
styleshot_model_path = <span style="color:#dd1144">"Gaojunyao/StyleShot"</span>
<span style="color:#333333"><strong>else</strong></span>:
<span style="color:#333333"><strong>raise</strong></span> ValueError(<span style="color:#dd1144">"Invalid preprocessor"</span>)
<span style="color:#999988"><em># 如果模型路径不存在,则下载模型</em></span>
<span style="color:#333333"><strong>if</strong></span> <span style="color:#333333"><strong>not</strong></span> os.path.isdir(styleshot_model_path):
styleshot_model_path = snapshot_download(styleshot_model_path, local_dir=styleshot_model_path)
<span style="color:#0086b3">print</span>(<span style="color:#dd1144">f"Downloaded model to <span style="color:#333333">{styleshot_model_path}</span>"</span>)
<span style="color:#999988"><em># 下载基础模型和transformer块</em></span>
<span style="color:#999988"><em># weights for ip-adapter and our content-fusion encoder</em></span>
<span style="color:#333333"><strong>if</strong></span> <span style="color:#333333"><strong>not</strong></span> os.path.isdir(base_model_path):
base_model_path = snapshot_download(base_model_path, local_dir=base_model_path)
<span style="color:#0086b3">print</span>(<span style="color:#dd1144">f"Downloaded model to <span style="color:#333333">{base_model_path}</span>"</span>)
<span style="color:#333333"><strong>if</strong></span> <span style="color:#333333"><strong>not</strong></span> os.path.isdir(transformer_block_path):
transformer_block_path = snapshot_download(transformer_block_path, local_dir=transformer_block_path)
<span style="color:#0086b3">print</span>(<span style="color:#dd1144">f"Downloaded model to <span style="color:#333333">{transformer_block_path}</span>"</span>)
<span style="color:#999988"><em># 设置模型权重路径</em></span>
ip_ckpt = os.path.join(styleshot_model_path, <span style="color:#dd1144">"pretrained_weight/ip.bin"</span>)
style_aware_encoder_path = os.path.join(styleshot_model_path, <span style="color:#dd1144">"pretrained_weight/style_aware_encoder.bin"</span>)
<span style="color:#999988"><em># 初始化UNet2D模型和内容融合编码器</em></span>
unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder=<span style="color:#dd1144">"unet"</span>)
content_fusion_encoder = ControlNetModel.from_unet(unet)
<span style="color:#999988"><em># 从预训练模型创建管道</em></span>
pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=content_fusion_encoder)
styleshot = StyleShot(device, pipe, ip_ckpt, style_aware_encoder_path, transformer_block_path)
<span style="color:#999988"><em># 打开风格图像</em></span>
style_image = Image.<span style="color:#0086b3">open</span>(args.style)
<span style="color:#999988"><em># 处理内容图像</em></span>
content_image = cv2.imread(args.content)
content_image = cv2.cvtColor(content_image, cv2.COLOR_BGR2RGB)
content_image = detector(content_image)
content_image = Image.fromarray(content_image)
<span style="color:#999988"><em># 生成图像</em></span>
generation = styleshot.generate(style_image=style_image, prompt=[[args.prompt]], content_image=content_image)
<span style="color:#999988"><em># 保存生成的图像</em></span>
generation[<span style="color:teal">0</span>][<span style="color:teal">0</span>].save(args.output)
<span style="color:#333333"><strong>if</strong></span> __name__ == <span style="color:#dd1144">"__main__"</span>:
<span style="color:#999988"><em># 解析命令行参数</em></span>
parser = argparse.ArgumentParser()
parser.add_argument(<span style="color:#dd1144">"--style"</span>, <span style="color:#0086b3">type</span>=<span style="color:#0086b3">str</span>, default=<span style="color:#dd1144">"style.png"</span>)
parser.add_argument(<span style="color:#dd1144">"--content"</span>, <span style="color:#0086b3">type</span>=<span style="color:#0086b3">str</span>, default=<span style="color:#dd1144">"content.png"</span>)
parser.add_argument(<span style="color:#dd1144">"--preprocessor"</span>, <span style="color:#0086b3">type</span>=<span style="color:#0086b3">str</span>, default=<span style="color:#dd1144">"Contour"</span>, choices=[<span style="color:#dd1144">"Contour"</span>, <span style="color:#dd1144">"Lineart"</span>])
parser.add_argument(<span style="color:#dd1144">"--prompt"</span>, <span style="color:#0086b3">type</span>=<span style="color:#0086b3">str</span>, default=<span style="color:#dd1144">"text prompt"</span>)
parser.add_argument(<span style="color:#dd1144">"--output"</span>, <span style="color:#0086b3">type</span>=<span style="color:#0086b3">str</span>, default=<span style="color:#dd1144">"output.png"</span>)
args = parser.parse_args()
main(args)
</code></span></span></span>
StyleShot是目前图像风格迁移领域的SOTA方法。
下面我们以基于风格图像驱动的图像风格迁移为例,我们将下面的图片作为内容图像进行实验:
将下面三张图片作为风格图像与内容图像融合:
得到了下面的结果:
可以看到,模型取得了非常好的融合效果,既保留了内容图像的特征,又完美融合了风格图像的特点。