专栏首页机器之心JAXnet:一行代码定义计算图,兼容三大主流框架,可GPU加速

JAXnet:一行代码定义计算图,兼容三大主流框架,可GPU加速

一行代码定义计算图,So Easy,妈妈再也不用担心我的机器学习。

项目地址:https://github.com/JuliusKunze/jaxnet

JAXnet 是一个基于 JAX 的深度学习库,它的 API 提供了便利的模型搭建体验。相比 TensorFlow 2.0 或 PyTorch 等主流框架,JAXnet 拥有独特的优势。举个栗子,不论是 Keras 还是 PyTorch,它们建模就像搭积木一样。

然而,还有一种比搭积木更简单的方法,这就是 JAXnet 的模块化:

from jaxnet import *

net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)

creates a neural net model from predefined modules.

创建一个全连接网络可以直接用预定义的模块,可以说 JAXnet 定义计算图,只需一行代码就可以了。写一个神经网络,原来 So easy。

总体来说,JAXnet 主要关注的是模块化、可扩展性和易用性等几个方面:

  • 采用了不可变权重,而不是全局计算图,从而获得更强的稳健性;
  • 用于构建神经网络、训练循环、预处理、后处理等过程的 NumPy 代码经过 GPU 编译;
  • 任意模块或整个网络的正则化、重参数化都只需要一行代码;
  • 取消了全局随机状态,采用了更便捷的 Key 控制。

可扩展性

你可以使用 @parametrized 定义自己的模块,并复用其它的模块:

from jax import numpy as np

@parametrizeddef 
loss(inputs, targets):
   return -np.mean(net(inputs) * targets)

所有的模块都是用这样的方法组合在一起的。jax.numpy (https://github.com/google/jax#whats-supported) 是 numpy 的镜像。只要你知道怎么使用 numpy,那么你就可以知道 JAXnet 大部分的用法了。

以下是 TensorFlow2/Keras 的代码,JAXnet 相比之下更为简洁:

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Lambda

net = Sequential([Dense(1024, 'relu'), Dense(1024, 'relu'), Dense(4), Lambda(tf.nn.log_softmax)])

def loss(inputs, targets):
    return -tf.reduce_mean(net(inputs) * targets)

需要注意的是,Lambda 函数在 JAXnet 中不是必要的。而 relu 和 logsoftmax 函数都是 Python 写的函数。

非可变权重

和 TensorFlow 或者 Keras 不同,JAXnet 没有全局计算图。net 和 loss 这样的模块不保存可变权重。权重则是保存在分开的不可变类中。这些权重由 init_parameters 函数初始化,用于提供随机的键和样本输入:

from jax.random import PRNGKey
def next_batch(): 
    return np.zeros((3, 784)), np.zeros((3, 4))

params = loss.init_parameters(PRNGKey(0), *next_batch())
print(params.sequential.dense2.bias) # [0.00376661 0.01038619 0.00920947 0.00792002]

目标函数不会在线变更权重,而是不断更新权重的下一个版本。它们会以新的优化状态返回,并由 get_parameters 取回。

opt = optimizers.Adam()
state = opt.init_state(params)for _ in range(10):
    state = opt.optimize(loss.apply, state, *next_batch()) # accelerate with jit=True
trained_params = opt.get_parameters(state)

当需要对网络进行评价时:

test_loss = loss.apply(trained_params, *test_batch) # accelerate with jit=True

JAXnet 的正则化也十分简单:

loss = L2egularized(loss, scale = .1)

其他特性

除了简洁的代码,JAXnet 还支持在 GPU 上进行计算。而且还可以用 jit 进行编译,摆脱 Python 运行缓慢的问题。同时,JAXnet 是单步调试的,和 Python 代码一样。

安装也十分简单,使用 pip 安装即可。如果需要使用 GPU,则需要先安装 jaxlib。

其他具体的 API 可参考:https://github.com/JuliusKunze/jaxnet/blob/master/API.md

本文分享自微信公众号 - 机器之心(almosthuman2014)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-09-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 这是一份你们需要的Windows版深度学习软件安装指南

    选自Github 机器之心编译 参与:蒋思源、刘晓坤 本文从最基本的依赖项开始,依次配置了 VS 2015、Anaconda 4.4.0、CUDA 8.0.61...

    机器之心
  • 教程 | 在Keras上实现GAN:构建消除图片模糊的应用

    选自Sicara Blog 作者:Raphaël Meudec 机器之心编译 参与:陈韵竹、李泽南 2014 年,Ian Goodfellow 提出了生成对抗网...

    机器之心
  • 学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

    选自Github 机器之心编译 参与:蒋思源 近来 GAN 证明是十分强大的。因为当真实数据的概率分布不可算时,传统生成模型无法直接应用,而 GAN 能以对抗...

    机器之心
  • pytorch进行CIFAR-10分类(4)训练

    经过前面的数据加载和网络定义后,就可以开始训练了,这里会看到前面遇到的一些东西究竟在后面会有什么用,所以这一步希望各位也能仔细研究一下

    TeeyoHuang
  • 深度学习中常用的损失函数loss有哪些?

    这是专栏《AI初识境》的第11篇文章。所谓初识,就是对相关技术有基本了解,掌握了基本的使用方法。

    小草AI
  • 【AI初识境】深度学习中常用的损失函数有哪些?

    这是专栏《AI初识境》的第11篇文章。所谓初识,就是对相关技术有基本了解,掌握了基本的使用方法。

    用户1508658
  • GAN对抗网络入门教程

    译:A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.a...

    致Great
  • CenterNet之loss计算代码解析

    本文主要讲解CenterNet的loss,由偏置部分(reg loss)、热图部分(heatmap loss)、宽高(wh loss)部分三部分loss组成,附...

    BBuf
  • 二分类语义分割损失函数

    1 - softmax 交叉熵损失函数(softmax loss,softmax with cross entroy loss)

    AIHGF
  • 一种Dynamic ReLU:自适应参数化ReLU激活函数(调参记录1)

    自适应参数化ReLU是一种动态ReLU(Dynamic ReLU)激活函数,在2019年5月3日投稿至IEEE Transactions on Industri...

    用户7368967

扫码关注云+社区

领取腾讯云代金券