前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Github 项目推荐 | 类 Keras 的 PyTorch 深度学习框架 —— PyToune

Github 项目推荐 | 类 Keras 的 PyTorch 深度学习框架 —— PyToune

作者头像
AI研习社
发布2018-03-28 10:04:28
9750
发布2018-03-28 10:04:28
举报
文章被收录于专栏:AI研习社AI研习社

PyToune 是一个类 Keras 的 Pytorch 深度学习框架,可用来处理训练神经网络所需的大部分模板代码。

用 PyToune 你可以:

  • 更容易地训练模型
  • 用回调来保存你最好的模型,执行 early stopping 方法等

Pytoune 官方页面:http://pytoune.org/

Pytoune Github 页面:https://github.com/GRAAL-Research/pytoune

Pytoune 兼容 PyTorch >= 0.3.0 版本和 Python >= 3.5 版本。

入门:快速上手 PyToune

PyToune 的核心数据结构是一种 Model,一种训练你的神经网络的方法。创建 PyToune 的方法和平常创建 PyTorch 模块(神经网络)的方式一样,但是你花时间去训练它,将其反馈到 PyToune 模型中,它会处理所有的步骤、统计数据、回调,就像 Keras 那样。

下面是个示例:

代码语言:javascript
复制
# Import the PyToune Model and define a toy dataset
from pytoune.framework import Model

num_train_samples = 800
train_x = torch.rand(num_train_samples, num_features)
train_y = torch.rand(num_train_samples, 1)

num_valid_samples = 200
valid_x = torch.rand(num_valid_samples, num_features)
valid_y = torch.rand(num_valid_samples, 1)

创建你自己的 PyTorch 神经网络,一个损失函数和优化器:

代码语言:javascript
复制
pytorch_module = torch.nn.Linear(num_features, 1)
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(pytorch_module.parameters(), lr=1e-3)

你可以用 PyToune 非常容易地训练神经网络:

代码语言:javascript
复制
model = Model(pytorch_module, optimizer, loss_function)
model.fit(
    train_x, train_y,
    validation_x=valid_x,
    validation_y=valid_y,
    epochs=num_epochs,
    batch_size=batch_size
  )

这与 Keras 中的 model.compile 函数非常相似:

代码语言:javascript
复制
# Keras way to compile and train
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)

你可以使用 PyToune 模型的评估方法评估你的网络的性能:

代码语言:javascript
复制
loss_and_metrics = model.evaluate(x_test, y_test)

或者只预测新数据:

代码语言:javascript
复制
predictions = model.predict(x_test)

正如你所见,PyToune 受到 Keras 很多启发,详细信息,请参阅 PyToune.org 上的 PyToune 文档。

安装

在使用 PyToune 之前,你应该先装上 PyTorch 0.3.0。

安装稳定的 PyToune 版本:

代码语言:javascript
复制
pip install pytoune

安装最新的 PyToune:

代码语言:javascript
复制
pip install -U git+https://github.com/GRAAL-Research/pytoune.git

为什么叫 PyToune

PyToune(或 Québécois 的 pitoune)曾指代的是河流里的原木,用河流运输原木是非常有效的一种运输方式。PyToune 的作者希望 PyToune 能够帮助开发者更加方便地训练神经网络模型,就像「pitoune」那样。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-03-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 入门:快速上手 PyToune
  • 安装
  • 为什么叫 PyToune
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档