作者 | sharmistha chatterjee
来源 | Medium
编辑 | 代码医生团队
介绍
元学习研究和开放源代码库提供了一种通过标准化基准和各种可用数据集对不同算法进行详细比较的方法,从而可以完全控制此评估的复杂性。但是,大多数在线可用的代码都有以下限制:
为了解决这个限制,Google AI引入了Torchmeta,这是一个基于PyTorch深度学习框架构建的库,可以对多个数据集的元学习算法进行无缝且一致的评估。为了解释Torchmeta,使用了一些初步的概念,例如DataLoader和BatchLoader,可以解释为:
DataLoader是一种通用实用程序,可用作应用程序数据获取层的一部分,以通过批处理和缓存在各种远程数据源(例如数据库或Web服务)上提供简化且一致的API。
批处理是DataLoader的主要功能。批处理加载函数接受键列表,并返回一个Promise,该Promise解析为值列表DataLoader合并在单个执行框架内发生的所有单个加载(一旦解决了包装承诺,即执行),然后是具有全部功能的批处理函数要求的钥匙。
数次学习的数据加载器
快速学习很少能具有使用先验知识快速推广具有有限监督经验的新任务的能力。快速学习分为三类:
Torchmeta在其库中具有以下内容。
为了平衡几次学习中固有的数据缺乏,元学习算法从称为元训练集的数据集D-meta = {D1,…,Dn}中获取一些先验知识。在几次学习中,每个元素Di仅包含几个输入/输出对(x,y),其中y取决于问题的性质。由于这些数据集可以包含过去执行的不同任务的示例。Torchmeta提供了一种解决方案,可以使用最少的问题特定组件来自动创建每个数据集Di。
极少回归
少有的回归问题中的大多数是通过不同功能的输入和输出之间的简单回归问题,其中每个功能对应一个任务。这些功能被参数化以允许任务之间的可变性,同时在各个任务之间保持不变的“主题”。例如,这些函数可以是形式为fi(x)= ai sin(x + bi)的正弦波,其中a和b在某些范围内变化。
在Torchmeta中,元训练集继承自名为MetaDataset的对象,每个数据集Di(i = 1,...,n,用户定义n)对应于该函数的特定参数选择,所有在元训练集创建时采样一次的参数。一旦知道了函数的参数,我们就可以通过在给定范围内对输入进行采样并将其提供给函数来创建数据集。
少拍分类
对于少有的分类问题,数据集Di的创建通常遵循两个步骤:
下图展示了元学习器的作用,在元测试中,另一个不相交的任务集Tt〜p(T)(p(T)->任务T的分布)用于测试元学习者。每个Tt都作用于N个数据集,其中数据集= {D train Tt,D test Tt}。学习者从训练集D train Tt和测试集D test Tt上学习。Tt的平均损耗被视为元学习测试误差。
训练和测试数据集拆分
为了实例化基于Mini Imagenet的5向1发分类问题的元训练集,使用:
数据集= torchmeta.datasets.MiniImagenet(“数据”,num_classes_per_task = 5,meta_train = True,下载= True)
数据集= torchmeta.transforms.ClassSplitter(数据集,num_train_per_class = 1,num_test_per_class = 15,shuffle = True)
除了元训练集之外,大多数基准测试还提供了元测试集,用于对元学习算法的总体评估(以及可能的元验证集)。创建MetaDataset对象时,可以使用meta_test = True(或meta_val = True)而不是meta_train = True来选择这些不同的元数据集。
元数据加载器
可以迭代一些镜头分类和回归问题中的元训练集对象,以生成PyTorch数据集对象,该对象包含在任何标准数据管道(与DataLoader组合)中。
元学习算法在批次任务上运行效果更好。与在PyTorch中将示例与DataLoader一起批处理的方式类似,Torchmeta公开了一个MetaDataLoader,该对象可以在迭代时产生大量任务。这样的元数据加载器能够输出一个大张量,其中包含批处理中来自不同任务的所有示例,如下所示:
数据集= torchmeta.datasets.helpers.miniimagenet(“数据”,镜头= 1,方式= 5,meta_train = True,下载= True)
数据加载器= torchmeta.utils.data.BatchMetaDataLoader(数据集,batch_size = 16)
元学习模块
下图显示了使用学习者的损失和错误信号进行元学习的顺序步骤。
元学习者的学习步骤:来源:
https : //arxiv.org/pdf/1904.05046.pdf
在元学习中,PyTorch中的模型是由称为模块的基本组件创建的,该基本组件等效于神经网络中包含该层的计算图及其参数的一层。这些模块将其参数视为其计算图的组成部分,足以训练带有反向传播的模型。
但是,一些元学习算法需要通过参数更新(例如梯度更新)进行反向传播,以进行元优化(或“外环”),因此涉及高阶微分。
因此,适应PyTorch中的现有模块至关重要,以便它们可以处理任意计算图来替代这些参数。因此,Torchmeta扩展了现有模块,并保留了提供新参数作为附加输入的选项。这些新对象称为MetaModule,它们的默认行为(即,未指定任何其他参数)等同于它们的PyTorch对应对象。否则,如果指定了额外的参数(例如,梯度下降的一步的结果),则MetaModule会将它们视为计算图的一部分,并且反向传播将按预期进行。
下面的代码演示了如何从Torchmeta的现有数据集中生成训练,验证和测试元数据集。
from torchmeta.datasets import Omniglot, MiniImagenet, CIFARFS, FC100, TieredImagenet, TCGA
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader
dataset = Omniglot("data",
# Number of ways
num_classes_per_task=5,
# Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
transform=Compose([Resize(28), ToTensor()]),
# Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
target_transform=Categorical(num_classes=5),
# Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
class_augmentations=[Rotation([90, 180, 270])],
meta_train=True,
download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)
for batch in dataloader:
train_inputs, train_targets = batch["train"]
print('Train inputs shape: {0}'.format(train_inputs.shape)) # (16, 25, 1, 28, 28)
print('Train targets shape: {0}'.format(train_targets.shape)) # (16, 25)
test_inputs, test_targets = batch["test"]
print('Test inputs shape: {0}'.format(test_inputs.shape)) # (16, 75, 1, 28, 28)
print('Test targets shape: {0}'.format(test_targets.shape)) # (16, 75)
下图显示了下载后从Omnichlot和MiniImagenet从Torchmeta的数据集中生成的元学习数据集。
此处Omniglot数据集包含50个字母。将其分为30个字母的背景集和20个字母的评估集。在将背景大小调整为28x28张量后,应该使用背景集学习有关字符的一般知识(例如,特征学习,元学习)。此外,将标签传送到整数Glagolitic / character01”,“ Sanskrit / character14”,……)到(0,1,..,n)。
MiniImageNet包含60,000个84x84 RGB图像,每个类别600个图像。使用Torchmeta,可以生成HDF5格式的元学习数据集。
Torchmeta具有以HDF5格式下载数据集的功能,该功能允许:
用于定义Torchmeta数据集(例如Omniglot)的元学习参数的TieredImagenetClassDataset包含来自34个类别的图像。元训练/验证/测试拆分超过20/6/8个类别。每个类别包含10到30个类别。按类别划分(而不是按类别划分)可确保所有训练课程与测试课程完全不同(不同于Mini-Imagenet)。它带有以下一组参数,这些参数定义了训练,验证和测试数据集的划分以及应用于它们的转换和增强技术
num_classes_per_task(int):每个任务的类数,对应于“ N向”分类中的“ N”。
meta_train:bool(`False`):使用数据集的元火车拆分。如果设置为True,则必须将参数meta_val和meta_test设置为False。这三个参数中的一个必须正确设置为“ True”。
meta_val:bool(`False`):使用数据集的元验证拆分。如果设置为True,则参数meta_train和metatest必须设置为False。这三个参数中只有一个必须设置为“ True”。
meta_test:bool(`False`):使用数据集的元测试拆分。如果设置为True,则参数meta_train和meta_val必须设置为False。这三个参数中只有一个必须设置为“ True”。
meta_split:{'train','val','test'}中的字符串,可选要使用的拆分名称,如果所有三个都设置为False,则覆盖参数meta_train,metaval和metatest。
transform:可调用的,可选的:获取“ PIL”图像并返回转换后版本的函数/转换。
target_transform:可调用,可选:接受目标并返回转换版本的函数/转换。
dataset_transform:可调用,可选:函数/转换,它接受数据集(即任务),并返回其转换后的版本。-> torchmeta.transforms.ClassSplitter()。
class_augmentations:可调用的,可选的列表:使用新类扩展数据集的函数列表。这些类是现有类的转换。
download:bool(默认值:False)如果为True,则下载pickle文件并处理根目录(位于tieredimagenet文件夹下)中的数据集。如果数据集已经可用,则不会再次下载/处理数据集。
结论
在此博客中,了解了Google AI最新发布的库Torchmeta,它提供了哪些功能以及可以解决什么样的元学习问题。可以浏览其他PyTorch元学习库,例如元Agonistic机器学习,以学习可以快速适应新任务的网络初始化。
https://github.com/dragen1860/MAML-Pytorch
如下图所示,在Torchmeta中很少有镜头学习可用于图像分类。
参考
https://github.com/markdtw/meta-learning-lstm-pytorch
https://arxiv.org/abs/1909.06576
https://docs.graphene-python.org/en/latest/execution/dataloader/