专栏首页相约机器人三四行代码打造元学习核心,PyTorch元学习库L2L现已开源

三四行代码打造元学习核心,PyTorch元学习库L2L现已开源

元学习似乎一直比较「高级」,毕竟学习如何学习这个概念听起来就很难实现。在本文中,我们介绍了这两天新开源的元学习库 learn2learn,它是用 PyTorch 写的,只需要三四行代码就能构建元学习最为核心的部分。

learn2learn 是一个用于实现元学习的 Pytorch 库,我们只需要加几行高层 API,就能为一般的机器学习流程添加元学习能力。例如在元学习 MNIST 案例中,我们可以用 PyTorch 构建整个流程,但只要加上三行 L2L 代码就能打造元学习模型。这三行代码只干三件事:获取元数据集、生成元学习任务、定义元学习模型。

  • 项目地址:https://github.com/learnables/learn2learn

元学习的目标是让智能体学习如何学习,也就是说,我们希望智能体能够在解决更多问题的过程中成为更好的学习器。例如,下图展示的智能体正在学习如何跑步,尽管它只会更新一个参数。

L2L 有什么特性

L2L 是一个元学习库,可以为用户提供 3 个级别的功能。在最高级别上,它有很多使用元学习算法在大量数据集/环境上训练的示例。在中间级别上,它为若干流行的元学习算法提供了功能接口以及便于加载其他数据集的数据加载器。在最低级别上,它为模块提供了可扩展功能。

L2L 的一些特性包括:

  • 模块化 API:使用这个库中的底层工具实现你自己的训练循环;
  • 提供多个元学习算法(如 MAML、FOMAML、MetaSGD、ProtoNets、DiCE);
  • 具有统一 API 的任务生成器,兼容 torchvision、torchtext、torchaudio 和 cherry;
  • 提供标准化的视觉(Omniglot、mini-ImageNet)、强化学习(Particles、Mujoco)甚至文本(新闻分类)元学习任务;
  • 100% 兼容 PyTorch——使用你自己的模块、数据集或库。

最后,整个 L2L 库都是由 PyTorch 写的,因此它的源代码并不难理解,我们可以通过项目的源码学习怎样从底层实现元学习算法。

L2L 实现 MAML 元学习算法的局部源代码,它的源码拥有大量的注释,可以帮助理解实现过程。

示例代码

下面我们来看看 learn2learn 到底该如何学习一个能实现 MNIST 分类任务的模型,它使用非常高层的应用,因此理解起来很容易。

如下代码所示,总体而言,整个过程可以分为导入数据、定义元学习任务、定义元学习模型与最优化方法、在元学习任务内不同的学习器适配不同的数据,最后就是标准的损失计算与模型更新了。

import learn2learn as l2l

mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)

mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
                                        ways=3,
                                        classes=[0, 1, 4, 6, 8, 9],
                                        tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)

for iteration in range(num_iterations):
    learner = maml.clone()  # Creates a clone of model
    adaptation_task = task_generator.sample(shots=1)

    # Fast adapt
    for step in range(adaptation_steps):
        error = compute_loss(adaptation_task)
        learner.adapt(error)

    # Compute evaluation loss
    evaluation_task = task_generator.sample(shots=1,
                                            task=adaptation_task.sampled_task)
    evaluation_error = compute_loss(evaluation_task)

    # Meta-update the model parameters
    opt.zero_grad()
    evaluation_error.backward()
    opt.step()

整个 API 非常高层,只需要很少的代码量就能完成模型。但与此同时,L2L 库还提供了中层和底层方面的 API,它允许我们做更多定制化的修改。更多的例子读者可以在 GitHub 中查阅,其示例模型分为强化学习、文本处理和视觉模型三方面:

如果读者也想要试试这个库,那么直接在命令行中运行 pip install learn2learn 就行了,剩下的再看看文档和教程,就可以快速学会怎样使用元学习。

  • 文档地址:http://learn2learn.net/docs/learn2learn/
  • 教程地址:http://learn2learn.net/tutorials/getting_started/

本文分享自微信公众号 - 相约机器人(xiangyuejiqiren)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-09-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 替代Flume——Kafka Connect简介

    我们知道过去对于Kafka的定义是分布式,分区化的,带备份机制的日志提交服务。也就是一个分布式的消息队列,这也是他最常见的用法。但是Kafka不止于此,打开最新...

    用户6070864
  • JAVA-线程安全与锁机制详解

    JAVA中操作共享数据按照线程安全程度大致分为5类: 不可变,绝对线程安全,相对线程安全,线程兼容和线程对立

    yingzi_code
  • Swagger笔记(二)springboot集成和ApiModel使用不当的一个小问题

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

    yingzi_code
  • Python 工匠: 异常处理的三个好习惯

    “ 如果你用 Python 编程,那么你就无法避开异常,因为异常在这门语言里无处不在。打个比方,当你在脚本执行时按 ctrl+c 退出,解释器就会产生一个 K...

    腾讯NEXT学位
  • TCGA数据挖掘(二):数据下载与整理

    管于TCGA数据库中的数据下载,我们之前有介绍过R语言下载包:R语言TCGA-Assembler包下载TCGA数据,同时在介绍数据库的使用教程中也介绍了在线下载...

    DoubleHelix
  • 聊聊dubbo的ConnectionOrderedDispatcher

    本文主要研究一下dubbo的ConnectionOrderedDispatcher

    codecraft
  • 做「容量预估」可没有true和false

    虽然如此,但是那些体量达到亿级或者是千万级的产品也只是少数公司的专属。对于整个行业里百万+的程序员群体来说,估计也就只有10%人有机会接触到这些“大系统”。

    Zachary_ZF
  • 新手学习FFmpeg - 调用API完成录屏

    如果使用FFmpeg提供的-list_devices 命令可以查询到当前支持的设备,其中分为两类:

    随机来个数
  • SpringDataJPA笔记(1)-基础概念和注解

    JPA是Java Persistence API的简称,中文名Java持久层API,是JDK 5.0注解或XML描述对象-关系表的映射关系,并将运行期的实体对象...

    yingzi_code
  • 聊聊dubbo的AllDispatcher

    dubbo-2.7.3/dubbo-remoting/dubbo-remoting-api/src/main/java/org/apache/dubbo/rem...

    codecraft

扫码关注云+社区

领取腾讯云代金券