首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

使用Cleanlab、PCA和Procrustes可视化ViT微调

与传统的卷积神经网络不同,vit使用最初设计用于自然语言处理任务的Transformers 架构来处理图像。微调这些模型以获得最佳性能可能是一个复杂的过程。

下面是使用动画演示了在微调过程中嵌入的变化。这是通过对嵌入执行主成分分析(PCA)来实现的。这些嵌入是从处于不同微调阶段的模型及其相应的检查点生成的。

在本文中,我们将介绍如何创建这样一个动画,主要包括:微调、创建嵌入、异常值检测、PCA、Procrustes、创建动画。

微调

第一步是对预训练好的ViT模型进行微调。为了简单起见我们使用了CIFAR-10数据集,其中包含6万张图像,分为10个不同的类别

微调代码很简单,我们这里主要就是在微调时增加日志记录

from transformers import TrainerCallback

class PrinterCallback(TrainerCallback):

  def on_log(self, args, state, control, logs=None, **kwargs):

      _ = logs.pop("total_flos", None)

      if state.is_local_process_zero:

          if len(logs) == 3: # skip last row

              with open("log.csv", "a") as f:

                  f.write(",".join(map(str, logs.values())) + "\n")

通过在TrainingArguments中设置save_strategy="step"和一个较低的save_step值来增加检查点的保存间隔是很重要的,这样可以确保动画有足够的检查点。动画中的每一帧对应一个检查点。在训练期间为每个检查点和CSV文件创建一个文件夹

创建嵌入

我们使用Transformers库中的AutoFeatureExtractor和autommodel来使用不同的模型检查点中生成嵌入。

每个嵌入是一个768维向量,测试图像总计有10,000个。生成的这些嵌入与检查点存储在同一个文件夹中

提取离群值

我们可以使用Cleanlab库提供的OutOfDistribution类,根据每个检查点的嵌入来识别离群值,可以识别出动画的前10个离群值。这些值也就是我们所说的分类错误的特征,对我们研究模型是非常有用的

from cleanlab.outlier import OutOfDistribution

def get_ood(sorted_checkpoint_folder, df):

...

ood = OutOfDistribution()

ood_train_feature_scores = ood.fit_score(features=embedding_np)

df["scores"] = ood_train_feature_scoresPCA和Procrustes

使用scikit-learn包的主成分分析(PCA),我们通过将768维向量减少到2维来可视化二维空间中的嵌入。当为每个时间步重新计算PCA时,由于轴翻转或旋转,可能会出现动画中的大的条约,这样显示效果很不好。所以为了解决这个问题,我们还从SciPy包中应用了一个额外的Procrustes Analysis,以几何方式将每一帧转换为最后一帧,这只涉及平移、旋转和均匀缩放。这使得动画中的过渡更加平滑。

from sklearn.decomposition import PCA

from scipy.spatial import procrustes

def make_pca(sorted_checkpoint_folder, pca_np):

...

embedding_np_flat = embedding_np.reshape(-1, 768)

pca = PCA(n_components=2)

pca_np_new = pca.fit_transform(embedding_np_flat)

_, pca_np_new, disparity = procrustes(pca_np, pca_np_new)使用Spotlight进行检查

在完成整个动画之前,可以在Spotlight中进行最后的检查。我们用第一个和最后一个检查点来执行嵌入生成、PCA和异常值检测。在Spotlight中加载结果DataFrame如下:

创建动画

通过使用make_pca(…)和get_ood(…)函数对每个模型的检查点创建一个图表,它们分别生成代表嵌入的2D点并提取前8个异常值。2D点用对应于它们各自类别的颜色绘制。异常值是根据他们的分数排序的,最后的训练损失从CSV文件加载并绘制的线形图。

最后,图像使用imageio或类似的库编译成GIF。

总结

本文介绍了如何创建视ViT模型的微调过程可视化。我们通过生成和分析嵌入、可视化结果以及创建将这些元素结合在一起的动画的步骤。

创建这样的动画不仅有助于理解微调ViT模型的复杂过程,而且还可以作为向他人传达这些概念的强大工具。

本文的源代码:

https://github.com/Renumics/spotlight/blob/main/playbook/stories/making_of_embeddings_animation.ipynb

作者:Markus Stoll

  • 发表于:
  • 原文链接https://page.om.qq.com/page/ODRlzQ7HxhkhmxlImCzjG6Sw0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券