连接图像和文本,更多的多模态文章可以看博主整理过的系列(跨界出圈 | 谈谈BERT跨模态预训练),本篇文章主要整理一下OpenAI发表的2篇文章。其中CLIP 能够完成图像与文本类别的匹配,DALL·E 则可以直接基于文本描述生成图像,且性能十分优异。
首先是CLIP,直接看模型吧,分为三步:Contrastive Pretraning,Create dataset classifier from label text和use for zero-shot prediction。
第一部分的整体架构如上图,是图文匹配的双流分支,一边是图像编码器(如resnet50或者ViT等),另一边是文本编码器(如Transformer)得到特征,然后对一个batch的文图pair对数据算内积得到匹配矩阵,这个矩阵行方向就是对图像的分类器,而从文本角度看列方向也是类似的分类器。最后使对角线的蓝色部分概率最大化就行(使匹配的pair内积相似度最大化),对比学习[1]博主已经整理过 不做赘述。
这一步是主要是利用大量的训练数据(直接从网上得到的句子-图像对)得到特征的表示。接下来的两步是测试过程,流程如下图:
和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后把目标任务数据集的每一个标签转为一个对应的文本(因为CLIP的Pretraning数据是句子,对于分类任务的单词label是不适用的),如上图中的 dog 这一label会改造成 "A photo of a dog",并且dog这个词被mask,尝试通过模型算内积相似度来预测出这个词,也就能做好分类了,由于是生成句子的感觉,所以其实它十分适合做zero-shot 的分类。
同时,基于 CLIP 还可以自由定义自己的分类器!也就是说可以很方便的利用CLIP和很多工作结合,比如等会要整理的 DALL-E 中就用到了 CLIP来提特征。
简单看看CLIP里面的逻辑流程
def forward(self, image, text):
image_features = self.encode_image(image) #编码image
text_features = self.encode_text(text) #编码text
# norm一下特征
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算内积相似度logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
然后是DALL-E模型,CLIP主要可以做分类检索等任务,而它则可以直接根据文本生成效果非常好的图像。motivation是目标是训练一个transformer进行自动建模,即将文本以及图片的tokens转换为单一的数据流,所以主要是需要考虑如何对2D的图片也转为单数据流。
也是直接看模型,如上图可以分为三个阶段:dVAE,Transformer和CLIP。
值得注意的一些trick:
还有大佬复现code:https://github.com/lucidrains/DALLE-pytorch
这个复现的库可直接调用训练,似乎非常好用,如果你有足够的卡那么pip一下即可:
pip install dalle-pytorch
import torch
from dalle_pytorch import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
) #设置CLIP的参数
text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()
loss = clip(text, images, text_mask = mask, return_loss = True) #直接训练CLIP
loss.backward()
[1]
对比学习: https://nakaizura.blog.csdn.net/article/details/108941999
[2]
Vision Transformer: https://nakaizura.blog.csdn.net/article/details/113095927
- END -