前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >​Jax 生态再添新库:DeepMind 开源 Haiku、RLax

​Jax 生态再添新库:DeepMind 开源 Haiku、RLax

作者头像
机器之心
发布2020-02-25 10:55:25
1K0
发布2020-02-25 10:55:25
举报
文章被收录于专栏:机器之心机器之心

机器之心报道

参与:一鸣

Jax 是一个优秀的代码库,在进行科学计算的同时能够自动微分,还有 GPU、TPU 的性能加速加持。但是 Jax 的生态还不够完善,使用者相比 TF、PyTorch 少得多。近日,DeepMind 开源了两个基于 Jax 的新库,给这个生态注入了新的活力。

Jax 是谷歌开源的一个科学计算库,能对 Python 程序与 NumPy 运算执行自动微分,而且能够在 GPU 和 TPU 上运行,具有很高的性能。基于 Jax 已有很多优秀的开源项目,如 Trax 等。近日,DeepMind 开源了两个基于 Jax 的新机器学习库,分别是 Haiku 和 RLax,它们都有着各自的特色,对于丰富深度学习社区框架、提升研究者和开发者的使用体验有着不小的意义。

Haiku:https://github.com/deepmind/haiku

RLax:https://github.com/deepmind/rlax

Haiku:在 Jax 上进行面向对象开发

首先值得注意的是 Haiku,这是一个面向 Jax 的深度学习代码库,它是由 Sonnet 作者——一个谷歌的神经网络库团队开发的。

为什么要使用 Haiku?这是因为其支持的是 Jax,Jax 在灵活性和性能上具有相当的优势。但是另一方面,Jax 本身是函数式的,和面向对象的用户习惯有差别。因此,通过 Haiku,用户可以在 Jax 上进行面向对象开发了。

此外,Haiku 的 API 和编程模型都是基于 Sonnet,因此使用过 Sonnet 的用户可以快速上手。项目作者也表示,Sonnet 之于 TensorFlow 的提升就如同 Haiku 之于 Jax。

目前,Haiku 已公开了 Alpha 版本,已完全开源。项目作者欢迎使用者提出建议。

Haiku 怎么和 Jax 交互

Haiku 主要分为两个模块:hk.Modules和 hk.transform。下文将会分别介绍。

hk.Modules 是 Python 对象,保存着到参数、其他模块和方法的参照(references)。

hk.transform 则负责将面向对象的模块转换为纯粹的函数式代码,然后让 jax 中的 jax.jit, jax.grad, jax.pmap 等进行处理,从而实现和 Jax 组件的兼容。

Haiku 的功能

Haiku 能够做到很多机器学习需要完成的任务,相关功能和代码如下:

自定义你的模块

在 Haiku 中,类似于 TF2.0 和 PyTorch,你可以自定义模块,作为 hk.Module 的子类。例如,自定义一个线性层:

代码语言:javascript
复制
class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super(MyLinear, self).__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

可以看出,Haiku 的代码和 TensorFlow 等非常相似,但是你可以看到包括 numpy 等的方法还可以定义在模块中。Haiku 的优势就在于,它不是一个封闭的框架,而是代码库,因此可以在定义模块的过程中调用其他的库和方法。

当定义好线性层后,我们想要试试自动微分的方法了:

代码语言:javascript
复制
def forward_fn(x):
  model = MyLinear(10)
  return model(x)
# Turn `forward_fn` into an object with `init` and `apply` methods.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])
# When we run `forward.init`, Haiku will run `forward(x)` and collect initial# parameter values. Haiku requires you pass a RNG key to `init`, since parameters# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)
# When we run `forward.apply`, Haiku will run `forward(x)` and inject parameter# values from the `params` that are passed as the first argument. We do not require# an RNG key by default since models are deterministic. You can (of course!) change# this using `hk.transform(f, apply_rng=True)` if you prefer:
y = forward.apply(params, x)

这里可以看到,定义好模块和前向传播的函数后,使用 hk.transform(forward_fn) 可以将这种面向对象的方法转换成 Jax 底层的函数式代码进行处理,因此你不需要担心底层的计算问题。另外,这里的代码相比 TensorFlow 还要简洁。

非训练状态

有时候,我们想要在训练的过程中保持某些内部参数的状态,在 Haiku 上这也是非常容易实现的。

代码语言:javascript
复制
def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)
# The `init` function now returns parameters **and** state. State contains# anything that was created using `hk.set_state`. The structure is the same as# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)
# The apply function now takes both params **and** state. Additionally it will# return updated values for state. In the resnet example this will be the# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)

如上所示,只需要两行代码进行设置。

和 jax.pmap 联合进行分布式训练

由于所有的代码都会被转换成 Jax 的函数,因此它们和 jax.pmap. 是完全兼容的。这说明,我们可以利用 jax.pmap 来进行分布式计算。

如下为进行数据分割的分布式加速代码,首先,我们先定义模型和训练步骤:

代码语言:javascript
复制
def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_obj = hk.transform(loss_fn)
# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_obj.init(rng, sample_image, sample_label)

然后设定将参数拷贝到所有的设备上:

代码语言:javascript
复制
# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

定义数据分批的方法,以及参数更新的方法:

代码语言:javascript
复制
def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_obj.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

最后开始分布式计算即可:

代码语言:javascript
复制
# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

RLax:Jax 上也有强化学习库了

除了令人印象深刻的 Haiku 外,DeepMind 还开源了 RLax——这是一个基于 Jax 的强化学习库。

相比 Haiku,RLax 专门针对强化学习。项目作者认为,尽管强化学习中的算子和函数并不是完全的算法,但是,如果需要构建完全基于函数式的智能体,就需要特定的数学算子。

因此,函数式的 Jax 就成为了一个不错的选择。在 Jax 上进行一定的开发后,就可以有专用的强化学习库了。RLax 目前的资料还较少,但项目已提供了一个示例代码:使用 RLax 进行 Q-learning 模型的搭建和训练。

代码如下,首先,使用 Haiku 构建基本的强化学习模型:

代码语言:javascript
复制
def build_network(num_actions: int) -> hk.Transformed:

  def q(obs):
    flatten = lambda x: jnp.reshape(x, (-1,))
    network = hk.Sequential(
        [flatten, nets.MLP([FLAGS.hidden_units, num_actions])])
    return network(obs)

  return hk.transform(q)

设定训练的方法:

代码语言:javascript
复制
def main_loop(unused_arg):
  env = catch.Catch(seed=FLAGS.seed)
  rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))

  # Build and initialize Q-network.
  num_actions = env.action_spec().num_values
  network = build_network(num_actions)
  sample_input = env.observation_spec().generate_value()
  net_params = network.init(next(rng), sample_input)

  # Build and initialize optimizer.
  optimizer = optix.adam(FLAGS.learning_rate)
  opt_state = optimizer.init(net_params)

以下和 Jax 结合,定义策略、奖励等:

代码语言:javascript
复制
@jax.jitdef policy(net_params, key, obs):

可以看到,RLax 基于 jax.jit 的方法,在性能方面有不错的提升。更有趣的是,构建模型的过程中使用了前文提到的 Haiku,可见基于 Jax 生态的代码库之间都是可以兼容的。

从 DeepMind 近日开源的两个代码库可以看到,虽然现在深度学习框架依然在稳步发展,但是针对高性能的科学计算也渐渐变得更为重要了。而 Jax 这样的优秀开源项目,无疑也需要更多的生态支持。这次开源的 Haiku 和 RLax,无疑能够巩固 Jax 的地位,使其优秀的特性进一步得到发挥。

本文为机器之心报道,转载请联系本公众号获得授权。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-02-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器之心 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档