前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch的元学习库:Torchmeta

PyTorch的元学习库:Torchmeta

作者头像
代码医生工作室
发布2020-03-04 11:16:40
1.6K0
发布2020-03-04 11:16:40
举报
文章被收录于专栏:相约机器人

Torchmeta是扩展和数据加载器的集合,用于在PyTorch中进行少量学习和元学习。Torchmeta在2019年全球PyTorch夏季黑客马拉松上获得了最佳表演奖。该库是开源的,可以尝试使用pip install torchmeta。

https://github.com/tristandeleu/pytorch-meta

什么是元学习?

当无法访问大量数据时会发生什么?毕竟,与当前形式的深度学习不同,人类只有几次演示就能非常快速,高效地学习执行新任务。仅通过少数几个训练示例就可以获得性能良好的模型尤其具有挑战性,因此需要一定水平的先验知识才能解决任务,从而在学习新知识时有效地“领先”。例如,领域专家可以例如通过正则化或体系结构选择将这种先验知识明确引入模型中。

或者,可以从过去的经验中获得此先验知识;这是元学习中采用的方法。在元学习中,利用从许多不同的元训练任务中获得的经验,目的是提高在新的下游任务上的表现和学习效率(即,必要的训练示例的数量)。

创建元学习的“健身房”

创建Torchmeta的动机是为了促进对不同数据集上的元学习算法进行评估,并尽可能减少更改。它的设计灵感来自OpenAI Gym,它通过提供适用于多种环境的通用界面,使强化学习变得更加容易。Gym作为标准工具的采用,使大多数开源项目都可以不受环境选择的影响,并且可以无缝测试多个环境。

同样,Torchmeta在统一的界面下引入了数据加载器,以处理各种标准的几次镜头分类和回归问题。从1.3版开始,Torchmeta中提供以下数据集:

  • 快速回归

-正弦波(Finn等,2017)

-谐波函数(Lacoste等,2018)

-正弦和直线(Finn等,2018)

  • 为数不多的镜头分类(图片分类)

- Omniglot(湖等人,2015年,2019)

- (迷你ImageNet 。Vinyals等人在2016年,。拉维等人,2017年)

-分层-ImageNet(Ren等人,2018。)

-CIFAR-FS(Bertinetto等,2018)

-Fewshot-CIFAR100(Oreshkin等,2018)

-加州理工学院-UCSD鸟类(Hilliard等,2018,Chen等,2019)

-双重和三重MNIST (2019年,星期日)

在Omniglot(左),Mini-ImageNet(中)和Caltech-UCSD Birds(右)上进行5次5拍学习问题的任务示例。

所有这些数据加载器都与PyTorch生态系统完全兼容,包括PyTorch DataLoader和torchvision软件包。根据相应的数据集随机生成一批任务,每个任务包含一个培训和一个测试数据集-这是元学习中的常见做法。尽管可以完全控制数据加载器的定义方式,但Torchmeta还包括适用于最受欢迎基准的帮助程序功能,以及文献中有用的默认值。

https://pytorch.org/docs/master/torchvision/

代码语言:javascript
复制
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

dataset = omniglot("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

for batch in dataloader:
    train_inputs, train_targets = batch["train"]
    print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)

    test_inputs, test_targets = batch["test"]
    print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
    print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

一个为5次5通道Omniglot数据集创建数据加载器的最小示例。数据加载器加载一批随机生成的任务,并将所有样本串联到一个张量中。

元学习模块

除了数据加载器之外,Torchmeta还提供了PyTorch的扩展nn.Module,称为MetaModule,以简化某些元学习算法的实现。这些元模块使可以选择使用完整的计算图手动指定模块的参数。例如,这允许通过更新参数进行反向传播,这些参数是基于梯度的元学习方法的关键成分(Finn等人,2017 ; Finn,2018 ; Grant等人,2018 ; Lee等人等人,2019年 ; Raghu等人,2019年)和各种混合方法(Rusu等人,2018年 ; Zintgraf等人,2019年)。

MetaLinear元模块的插图,nn.Linear的扩展。左:MetaLinear元模块的实例化。中:默认行为,等效于nn.Linear。右:具有额外参数的行为(此处为一步式渐变更新,Finn等人,2017年)。渐变表示为虚线箭头。

默认情况下(即没有额外的参数),元模块的行为与其在PyTorch中对应的模块相同。因此,创建与这些元学习方法兼容的模型对于Torchmeta来说非常自然,并且只需对现有PyTorch模型进行最小的更改即可。还可以将元模块与标准nn.Module实例进行交错,以仅对模型的某些部分进行快速适应(Raghu等人,2019)。

代码语言:javascript
复制
import torch.nn as nn
from torchmeta.modules import (MetaModule, MetaSequential,
                               MetaConv2d, MetaLinear)
from torchmeta.modules.utils import get_subdict

class Model(MetaModule):
    def __init__(self, in_channels, num_classes):
        super(Model, self).__init__()
        self.features = MetaSequential(MetaConv2d(in_channels, 64, 3),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2))
        self.classifier = MetaLinear(64, num_classes)

    def forward(self, inputs, params=None):
        features = self.features(inputs,
                                 params=get_subdict(params, 'features'))
        logits = self.classifier(features.view((inputs.size(0), -1)),
                                 params=get_subdict(params, 'classifier'))
        return logits

提高元学习研究的可重复性

由于缺乏文献中所使用的某些数据集的标准,因此元学习的可重复性可能非常具有挑战性,尤其是在数据加载方面。例如,虽然Vinyals等人介绍了Mini-ImageNet数据集。(2016),Ravi&Larochelle(2017)使用的分割现在已被社区广泛接受为官方数据集。到目前为止,对于某些数据集(例如CUB),这种情况仍然存在。很难跟踪应评估元学习算法的“正确”版本。

借助众多速记学习和元学习数据集以及标准拆分,Torchmeta的目标是提供所有必要的工具,以使元学习算法的开发和可重复性尽可能容易且可访问。

结论

PyTorch中元学习的未来是光明的,并且最近发布了许多伟大的开源项目。特别要提到两个,learn2learn提供了一些标准元学习算法的实现,而更高版本的则是一个库,可以对现有PyTorch模型进行高阶优化。Torchmeta通过为各种数据集提供统一的接口以及一组简化元学习算法开发的工具,很好地补充了这些其他库。这使得在不同基准上对这些方法的评估变得无缝,因此是在元学习中更好地再现研究的关键一步。

要了解有关Torchmeta的更多信息,请查看项目存储库中可用的示例以及MAML的此实现,以更详细地展示Torchmeta的所有功能。

https://github.com/tristandeleu/pytorch-meta/tree/master/examples

https://github.com/tristandeleu/pytorch-maml

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-02-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档