Github 项目推荐 | 兼容 Scikit-Learn 的 PyTorch 神经网络库 —— skorch

Skorch 是一个兼容 Scikit-Learn 的 PyTorch 神经网络库。

资源

文档:

https://skorch.readthedocs.io/en/latest/?badge=latest

源代码

https://github.com/dnouri/skorch/

示例

更详细的例子,请查看此链接:

https://github.com/dnouri/skorch/tree/master/notebooks/README.md

import numpy as np
from sklearn.datasets import make_classification
import torch
from torch import nn
import torch.nn.functional as F

from skorch.net import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X


net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search

from sklearn.model_selection import GridSearchCV


params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

安装

pip 安装

pip install -U skorch

建议使用虚拟环境。

源代码安装

如果你想使用 skorch 最新的案例或者开发帮助,请使用源代码安装

用 conda

如果你需要一个工作conda安装, 从这里为的的系统获取正确的 miniconda:

https://conda.io/miniconda.html

如果你只是使用 skorch:

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
python setup.py install

如果你只想帮助开发,运行:

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
conda install --file requirements-dev.txt
python setup.py develop

py.test  # unit tests
pylint skorch  # static code checks

用 pip

如果你只是使用 skorch:

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
python setup.py install

如果你想使用帮助开发:

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
python setup.py develop

py.test  # unit tests
pylint skorch  # static code checks

原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2018-05-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Kubernetes

如何在Kubernetes集群中利用GPU进行AI训练

Author: xidianwangtao@gmail.com 注意事项 截止Kubernetes 1.8版本: 对GPU的支持还只是实验阶段,仍停留在A...

7347
来自专栏王亚军的专栏

谷歌开源图片压缩算法 Guetzli 实测体验报告

谷歌大神又出开源新技术啦,这次是对JPEG格式的图片采用全新算法重新编码,输出的图片还是JPEG但是图片大小明显缩小,而质量不但没有损失,甚至还更加优化,速速来...

9.3K1
来自专栏公有云大数据平台弹性 MapReduce

ResourceManager中的Resource Estimator框架介绍与算法剖析

本文首先介绍了Hadoop中的ResourceManager中的estimator service的框架与运行流程,然后对其中用到的资源估算算法进行了原理剖析。

2.5K16
来自专栏张善友的专栏

在Expression Blend中使用XAML建立3D应用程序

参考微软<Creating 3D Content with WPF>文档翻译。 源文件下载http://www.wangpangzi.net/uploads/2...

2069
来自专栏Kubernetes

TensorFlow Serving在Kubernetes中的实践

xidianwangtao@gmail.com 关于TensorFlow Serving 下面是TensorFlow Serving的架构图: ? 关于T...

91412
来自专栏AI科技大本营的专栏

教程 | 如何在手机上使用TensorFlow

? 翻译 | AI科技大本营 参与 | zzq 审校 | reason_W 我们知道,TensorFlow是一个深度学习框架,它通常用来在服务器上训练需要大量...

7267
来自专栏程序员同行者

python3模块: uuid

1762
来自专栏用户2442861的专栏

Tesseract文字训练,以及样本生成

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/details...

1881
来自专栏CreateAMind

InfoGAN修改训练人脸数据集celebA的过程记录

2051
来自专栏黑白安全

如果银行卡只能使用六位数的密码到底有多安全?

我们使用的银行卡密码为 6 位数字,在 ATM 机上使用时如果连续输错 3 次密码就会被吞卡。那么如果有人捡到一张银行卡,拿到 ATM 机上去试密码,他在 3 ...

3335

扫码关注云+社区

领取腾讯云代金券