元学习似乎一直比较「高级」,毕竟学习如何学习这个概念听起来就很难实现。在本文中,我们介绍了这两天新开源的元学习库 learn2learn,它是用 PyTorch 写的,只需要三四行代码就能构建元学习最为核心的部分。
learn2learn 是一个用于实现元学习的 Pytorch 库,我们只需要加几行高层 API,就能为一般的机器学习流程添加元学习能力。例如在元学习 MNIST 案例中,我们可以用 PyTorch 构建整个流程,但只要加上三行 L2L 代码就能打造元学习模型。这三行代码只干三件事:获取元数据集、生成元学习任务、定义元学习模型。
元学习的目标是让智能体学习如何学习,也就是说,我们希望智能体能够在解决更多问题的过程中成为更好的学习器。例如,下图展示的智能体正在学习如何跑步,尽管它只会更新一个参数。
L2L 有什么特性
L2L 是一个元学习库,可以为用户提供 3 个级别的功能。在最高级别上,它有很多使用元学习算法在大量数据集/环境上训练的示例。在中间级别上,它为若干流行的元学习算法提供了功能接口以及便于加载其他数据集的数据加载器。在最低级别上,它为模块提供了可扩展功能。
L2L 的一些特性包括:
最后,整个 L2L 库都是由 PyTorch 写的,因此它的源代码并不难理解,我们可以通过项目的源码学习怎样从底层实现元学习算法。
L2L 实现 MAML 元学习算法的局部源代码,它的源码拥有大量的注释,可以帮助理解实现过程。
示例代码
下面我们来看看 learn2learn 到底该如何学习一个能实现 MNIST 分类任务的模型,它使用非常高层的应用,因此理解起来很容易。
如下代码所示,总体而言,整个过程可以分为导入数据、定义元学习任务、定义元学习模型与最优化方法、在元学习任务内不同的学习器适配不同的数据,最后就是标准的损失计算与模型更新了。
import learn2learn as l2l
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
ways=3,
classes=[0, 1, 4, 6, 8, 9],
tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)
for iteration in range(num_iterations):
learner = maml.clone() # Creates a clone of model
adaptation_task = task_generator.sample(shots=1)
# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(adaptation_task)
learner.adapt(error)
# Compute evaluation loss
evaluation_task = task_generator.sample(shots=1,
task=adaptation_task.sampled_task)
evaluation_error = compute_loss(evaluation_task)
# Meta-update the model parameters
opt.zero_grad()
evaluation_error.backward()
opt.step()
整个 API 非常高层,只需要很少的代码量就能完成模型。但与此同时,L2L 库还提供了中层和底层方面的 API,它允许我们做更多定制化的修改。更多的例子读者可以在 GitHub 中查阅,其示例模型分为强化学习、文本处理和视觉模型三方面:
如果读者也想要试试这个库,那么直接在命令行中运行 pip install learn2learn 就行了,剩下的再看看文档和教程,就可以快速学会怎样使用元学习。