首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

双编码器的自然语言图像搜索

正文字数:5798  阅读时长:10 分钟

如何构建一个双编码器(也称为双塔)神经网络模型,以使用自然语言搜索图像。

作者 / Khalid Salama

1

介绍

该示例演示了如何构建一个双编码器(也称为双塔)神经网络模型,以使用自然语言搜索图像。该模型的灵感来自于Alec Radford等人提出的CLIP方法,其思想是联合训练一个视觉编码器和一个文本编码器,将图像及其标题的表示投射到同一个嵌入空间,从而使标题嵌入位于其描述的图像的嵌入附近。

这个例子需要TensorFlow 2.4或更高版本。此外,BERT模型需要TensorFlow Hub和TensorFlow Text,AdamW优化器需要TensorFlow Addons。这些库可以使用以下命令进行安装。

2

安装

3

准备数据

我们使用MS-COCO数据集来训练我们的双编码器模型。MS-COCO包含超过82,000张图片,每张图片至少有5个不同的标题注释。该数据集通常用image captioning任务,但我们可以重新利用图像标题对来训练双编码器模型进行图像搜索。

下载提取数据

首先,下载数据集,它由两个压缩文件夹组成:一个是图像,另一个是相关的图像标题。值得注意的是压缩后的图像文件夹大小为13GB。

处理并将数据保存到TFRecord文件中

你可以改变sample_size参数去控制将用于训练双编码器模型的多对图像-标题。在这个例子中,我们将training_size设置为30000张图像,约占数据集的35%。我们为每张图像使用2个标题,从而产生60000个图像-标题对。训练集的大小会影响生成编码器的质量,样本越多,训练时间越长。

创建用于训练和评估的 tf.data.Dataset

4

实时投影头

投影头用于将图像和文字嵌入到具有相同的维度的同一嵌入空间。

5

实现视觉编码器

在本例中,我们使用Keras Applications的Xception作为视觉编码器的基础。

6

实现文本编码器

我们使用TensorFlow Hub的BERT作为文本编码器

7

实现双编码器

为了计算loss,我们计算每个 caption_i和 images_j之间的对偶点积相似度作为预测值。caption_i和image_j之间的目标相似度计算为(caption_i和caption_j之间的点积相似度)和(image_i和image_j之间的点积相似度)的平均值。然后,我们使用交叉熵来计算目标和预测之间的损失。

8

训练双编码模型

在这个实验中,我们冻结了文字和图像的基础编码器,只让投影头进行训练。

值得注意的是使用 V100 GPU 加速器训练 60000 个图像标题对的模型,批量大小为 256 个,每个 epoch 需要 12 分钟左右。如果有2个GPU,则每个epoch需要8分钟左右。

训练损失的绘制:

9

使用自然语言查询搜索图像

我们可以通过以下步骤来检索对应自然语言查询的图像:

1. 将图像输入vision_encoder,生成图像的嵌入。

2. 将自然语言查询反馈给text_encoder,生成查询嵌入。

3. 计算查询嵌入与索引中的图像嵌入之间的相似度,以检索出最匹配的索引。

4. 查阅顶部匹配图片的路径,将其显示出来。

值得注意的是在训练完双编码器后,将只使用微调后的visual_encoder和text_encoder模型,而dual_encoder模型将被丢弃。

生成图像的嵌入

我们加载图像,并将其输入到vision_encoder中,以生成它们的嵌入。在大规模系统中,这一步是使用并行数据处理框架来执行的,比如Apache Spark或Apache Beam。生成图像嵌入可能需要几分钟时间。

检索相关图像

该例子中,我们通过计算输入的查询嵌入和图像嵌入之间的点积相似度来使用精确匹配,并检索前k个匹配。然而,在实时用例中,使用ScaNN、Annoy或Faiss等框架进行近似匹配是首选,以扩展大量图像。

将查询变量设置为你要搜索的图片类型。试试像 "一盘健康的食物", "一个戴着帽子的女人走在人行道上", "一只鸟坐在水边", 或 "野生动物站在田野里"。

评估检索质量

为了评估双编码器模型,我们使用标题作为查询。使用训练外样本图像和标题来评估检索质量,使用top k精度。如果对于一个给定的标题,其相关的图像在前k个匹配范围内被检索到,则算作一个真正的预测。

结束语

你可以通过增加训练样本的大小,训练更多的时期,探索其他图像和文本的基础编码器,设置基础编码器的可训练性,以及调整超参数,特别是softmax的temperature loss计算,获得更好的结果。

LiveVideoStackCon 2021 ShangHai

我们准备好全新的内容

在上海欢迎您的到来

LiveVideoStackCon 2021 上海站

北京时间:2021年4月16日-4月17日

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20210226A01HIS00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券