Github 项目推荐 | 类 Keras 的 PyTorch 深度学习框架 —— PyToune

PyToune 是一个类 Keras 的 Pytorch 深度学习框架,可用来处理训练神经网络所需的大部分模板代码。

用 PyToune 你可以:

  • 更容易地训练模型
  • 用回调来保存你最好的模型,执行 early stopping 方法等

Pytoune 官方页面:http://pytoune.org/

Pytoune Github 页面:https://github.com/GRAAL-Research/pytoune

Pytoune 兼容 PyTorch >= 0.3.0 版本和 Python >= 3.5 版本。

入门:快速上手 PyToune

PyToune 的核心数据结构是一种 Model,一种训练你的神经网络的方法。创建 PyToune 的方法和平常创建 PyTorch 模块(神经网络)的方式一样,但是你花时间去训练它,将其反馈到 PyToune 模型中,它会处理所有的步骤、统计数据、回调,就像 Keras 那样。

下面是个示例:

# Import the PyToune Model and define a toy dataset
from pytoune.framework import Model

num_train_samples = 800
train_x = torch.rand(num_train_samples, num_features)
train_y = torch.rand(num_train_samples, 1)

num_valid_samples = 200
valid_x = torch.rand(num_valid_samples, num_features)
valid_y = torch.rand(num_valid_samples, 1)

创建你自己的 PyTorch 神经网络,一个损失函数和优化器:

pytorch_module = torch.nn.Linear(num_features, 1)
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(pytorch_module.parameters(), lr=1e-3)

你可以用 PyToune 非常容易地训练神经网络:

model = Model(pytorch_module, optimizer, loss_function)
model.fit(
    train_x, train_y,
    validation_x=valid_x,
    validation_y=valid_y,
    epochs=num_epochs,
    batch_size=batch_size
  )

这与 Keras 中的 model.compile 函数非常相似:

# Keras way to compile and train
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)

你可以使用 PyToune 模型的评估方法评估你的网络的性能:

loss_and_metrics = model.evaluate(x_test, y_test)

或者只预测新数据:

predictions = model.predict(x_test)

正如你所见,PyToune 受到 Keras 很多启发,详细信息,请参阅 PyToune.org 上的 PyToune 文档。

安装

在使用 PyToune 之前,你应该先装上 PyTorch 0.3.0。

安装稳定的 PyToune 版本:

pip install pytoune

安装最新的 PyToune:

pip install -U git+https://github.com/GRAAL-Research/pytoune.git

为什么叫 PyToune

PyToune(或 Québécois 的 pitoune)曾指代的是河流里的原木,用河流运输原木是非常有效的一种运输方式。PyToune 的作者希望 PyToune 能够帮助开发者更加方便地训练神经网络模型,就像「pitoune」那样。

原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2018-03-13

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏老秦求学

数据增强利器--Augmentor

Augmentor是一个Python包,旨在帮助机器学习任务的图像数据人工生成和数据增强。它主要是一种数据增强工具,但也将包含基本的图像预处理功能。

1613
来自专栏CreateAMind

openAi HER 算法运行流程学习

1213
来自专栏素质云笔记

SSD+caffe︱Single Shot MultiBox Detector 目标检测+fine-tuning(二)

承接上一篇SSD介绍:SSD+caffe︱Single Shot MultiBox Detector 目标检测(一) 如果自己要训练SSD模型呢,关键...

1.1K10
来自专栏ATYUN订阅号

浣熊检测器实例, 如何用TensorFlow的Object Detector API来训练你的物体检测器

这篇文章是“用Tensorflow和OpenCV构建实时对象识别应用”的后续文章。具体来说,我在自己收集和标记的数据集上训练了我的浣熊检测器。完整的数据集可以在...

5967
来自专栏ATYUN订阅号

【教程】利用Tensorflow目标检测API确定图像中目标的位置

深度学习提供了另一种解决“Wally在哪儿”(美国漫画)问题的方法。与传统的图像处理计算机视觉方法不同的是,它只使用了少量的标记出Wally位置的示例。 在我的...

7186
来自专栏专知

【干货】手把手教你用苹果Core ML和Swift开发人脸目标识别APP

【导读】CoreML是2017年苹果WWDC发布的最令人兴奋的功能之一。它可用于将机器学习整合到应用程序中,并且全部脱机。CoreML提供的机器学习 API,包...

3256
来自专栏AI研习社

Github 项目推荐 | GAN 非平稳纹理合成

该库是论文「Non-stationary texture synthesis using adversarial expansions.」的官方代码。

1163
来自专栏专知

【下载】PyTorch 实现的YOLO v2目标检测算法

【导读】目标检测是计算机视觉的重要组成部分,其目的是实现图像中目标的检测。YOLO是基于深度学习方法的端到端实时目标检测系统(YOLO:实时快速目标检测)。YO...

5136
来自专栏人人都是极客

5.训练模型之利用训练的模型识别物体

接下来我们开始训练,这里要做三件事: 将训练数据上传到训练服务器,开始训练。 将训练过程可视化。 导出训练结果导出为可用作推导的模型文件。 配置 Pipelin...

4094
来自专栏Petrichor的专栏

tensorflow: 畅玩tensorboard图表(SCALARS)

这篇博客建立在你已经会使用tensorboard的基础上。如果你还不会记录数据并使用tensorboard,请移步我之前的另一篇博客:tensorflow: t...

3393

扫码关注云+社区

领取腾讯云代金券