首页
学习
活动
专区
工具
TVP
发布

模型跨界成潮流?OpenAI用GPT-2做图像分类,实现SOTA性能

图像领域的 GPT 模型终于来了!OpenAI 推出了用于图像分类的模型 iGPT,该模型生成的特征在多个分类数据集上实现了当前 SOTA 性能,并且实现了良好的图像补全效果。

无监督和自监督学习,或者无人工标注数据的学习,这些都是机器学习领域长期存在的挑战。近期,BERT、GPT-2、RBERTa、T5 等 Transformer 模型以及其他变体在一系列语言任务上实现了最佳性能。然而,在生成用于图像分类的强特征方面却始终没有出现性能强大的模型。

这是因为,与 GPT-2 和 BERT 这类依赖于词预测的无监督学习算法相比,像素序列无法清楚地包含它们所属图像的标签。

近日,OpenAI 发布了一项新研究,旨在探索在图像上训练 GPT-2 的性能以及无监督准确率表现。研究者表示,BERT 和 GPT-2 等 Transformer 模型是域不可知的,这意味着它们可以直接应用于任何形式的 1D 序列。

OpenAI 研究者在图像上训练 GPT-2(这些图像被分解为长像素序列),他们称该模型称为 iGPT。结果发现这种模型似乎能够理解物体外观和类别等 2D 图像特征。iGPT 生成的各种一致性图像样本可以证明这一点,即使没有人为标签的指导。

论文地址:

https://cdn.openai.com/papers/Generative_Pretraining_from_Pixels_V2.pdf

GitHub 地址:

https://github.com/openai/image-gpt

项目主页:

https://openai.com/blog/image-gpt/

iGPT 缘何能够成功呢?这是因为,在下一像素预测(next pixel prediction)上训练的足够大的 transformer 模型最终可能学会生成具有清晰可识别物体的样本。一旦学会了生成此类样本,那么通过「合成分析」,iGPT 将知道目标类别。实验表明,iGPT 模型的特征在大量的分类数据集上实现了当前 SOTA 性能,以及在 ImageNet 数据集上实现了接近 SOTA 的无监督准确率。

我们先直观地看一下 iGPT 的效果。下图展示了,在 CIFAR-10、CIFAR-100、STL-10 和 ImageNet 数据集上,iGPT 与当前最佳非 iGPT 模型的性能对比情况:

为了突出生成序列建模作为通用无监督学习算法的潜力,该研究特意使用了与 GPT-2 相同的 Transformer 架构,因而该模型需要更多计算才能生成与顶级无监督卷积网络相当的特征。结果表明,当面对一个正确模型先验未知的新领域时,大型 GPT-2 模型可以学习优秀特征,并且不需要领域特定的架构设计选择。

iGPT 处理图像任务的效果如何

iGPT 可实现较好的图像补全效果。不管是动物、建筑物、风景、运动场面、艺术作品,甚至表情包,它都可以实现不错的补全效果。

下图展示了 iGPT 生成的图像样本。研究者在采样过程中使用 temperature 1,且未使用束搜索或核采样(nucleus sampling)等 trick。研究者表示,以下所有样本均未经过挑选。

我们可以看出,几乎所有生成图像均包含清晰可识别的物体。

方法

OpenAI 研究人员提出的方法包含两个阶段:预训练和微调。

在预训练阶段中,研究人员探索自回归目标和 BERT 目标,并使用序列 Transformer 架构来预测像素,而非语言 token。

如下图所示,该方法首先对原始图像进行预处理,将其调整为低分辨率和 1D 序列;然后在自回归下一像素预测或掩码像素预测这两个预训练目标中选择一个;最后,利用 linear probe 或微调,对这些目标学得的表征进行评估。

该研究在 ImageNet 上训练了三个 transformer 模型:iGPT-S、iGPT-M 和 iGPT-L,它们分别包含 76M、455M 和 14 亿参数。此外,研究者还基于 ImageNet 和网络图片的混合数据训练了 iGPT-XL,它包含 68 亿参数。

由于使用密集注意力建模长序列的计算成本较高,因此该研究选择使用较低的分辨率:32x32、48x48 和 64x64。

分辨率继续降低可以进一步减少计算成本,但是之前研究表明在这种情况下,人类的图像分类能力会急剧下降。因此,该研究受早期颜色显示调色板(color display palettes)的启发,创建了 9-bit 颜色调色板来表示像素。使用该调色板可以得到长度仅为标准 (R, G, B) 1/3 的输入序列,同时还能有效编码颜色。

实验结果

该研究使用两种方法来评估模型性能,二者均涉及下游分类任务。

方法 1:linear probe,即使用训练好的模型从下游数据集图像中提取特征,然后将 logistic 回归与标签进行拟合;

方法 2:微调,基于下游数据集微调整个模型。

在没有潜变量的生成模型中,哪种表征效果最好?

下一像素预测任务与图像分类并不明显相关,最后一层的特征可能无法最好地预测物体类别。

该研究的实验结果表明,特征质量先是显著提升,然后逐渐放缓。这表明 transformer 生成模型通过以下两个阶段运行:

第一阶段:每个位置从周围语境中收集信息,以构建语境化图像特征;

第二阶段:使用语境化特征解决下一像素预测任务。

下图表明,特征质量严重依赖于模型层数。并且,与监督模型相反,下图中这些生成模型的最优特征出现在网络中段。

更好的生成模型学到的表征也更好

该研究执行的另一项实验试图在生成性能和特征质量之间建立联系。实验结果表明,模型规模的扩大和训练迭代次数的增加会带来更好的生成性能,而这可以直接转换为更好的特征质量。

在 CIFAR 和 STL-10 上的 Linear Probe 效果

研究者评估了不同模型使用 linear probe 在 CIFAR-10、CIFAR-100 和 STL-10 数据集上的性能,发现该研究提出的方法优于其他监督和无监督迁移算法。甚至在完全微调的设置下,iGPT 的性能仍具备竞争力。

在 ImageNet 上的 Linear Probe 效果

该研究使用 linear probe 在 ImageNet 上进行性能评估。基于 48x48 图像训练的 iGPT-L,使用 1536 个特征得到的最优层 top-1 准确率达到 65.2%,超过了 AlexNet。

下表展示了 iGPT 和当前最优自监督模型的 linear probe 准确率对比情况:

BERT

由于像 BERT 这样的遮蔽语言模型(masked language models)在大多数语言任务上的性能都优于生成模型,因此该研究也在图像模型上评估了 BERT 的性能。

他们没有按照在之前所有像素的基础上预测下一个像素的方式来训练模型,而是遮蔽掉了 15% 的像素,然后训练模型基于未遮蔽的像素进行预测。研究者发现,尽管 BERT 模型的 linear probe 性能明显较差,但在微调方面 BERT 模型仍然表现亮眼:

在低数据 CIFAR-10 分类任务上的性能

如下表所示,研究者在这一子域的竞争性基准上对 iGPT-L 进行了评估,结果发现基于非增强图像特征的简单 linear probe 表现优于 Mean Teacher 和 MixMatch,但弱于 FixMatch。

iGPT 的局限性

尽管该研究表明 iGPT 能够学习强大的图像特征,但是该方法仍存在很大的局限性。

由于该研究采用的是用于语言任务的 GPT-2 的通用序列 Transformer,所以需要大量计算:iGPT-L 在 V100 上大约需要训练 2500 天,而性能类似的 MoCo 模型大约需要训练 70 天。

此外,该研究用 Transformer 对低分辨率输入建模,而大多数自监督结果采用的是基于卷积的编码器,这些编码器可以轻松地处理高分辨率输入。可能需要一种新的架构,例如与域无关的多尺度 Transformer,以实现进一步扩展。

考虑到这些局限性,该研究工作主要是概念证明,证明了基于 Transformer 的大型语言模型在一些新领域中可以学习优秀的无监督表征,而无需硬编码领域的知识。但是,训练这些模型需要大量的资源成本,而基于卷积神经网络的方法又具有更高的准确率,这让这些表征在视觉领域中无法实际应用。

最后,生成模型可能会呈现出偏差,这些偏差是模型训练所用的数据造成的。这些偏差中有许多都是有用的,例如假设棕色和绿色像素的部分代表叶子上覆盖的分支,然后用这种偏差继续生成图像。

但是,从公平性和代表性方面考虑,有些偏差是有害的。例如,如果模型建立了一种偏向男性的科学家视觉观念,那么它很有可能一直用男性而不是混合性别的人来完成科学家图像。研究者希望开发者能够更加关注系统的输入数据,并更好地了解输入数据与训练模型中偏差的关系。

结论

研究者表示,通过 2D 知识与 scale 之间的权衡,并从网络中部选择预测特征,序列 Transformer 的性能可以与顶级无监督图像分类卷积网络相媲美。

此外,值得注意的是,研究者将 GPT-2 语言模型直接用于图像生成,也得到了相应的实验结果。该结果表明,得益于简单性和通用性,基于足够计算量的序列 Transformer 可能是多个领域中学习优秀特征的有效方法。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20200618A0EC9B00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

相关快讯

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券