前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >sklearn 实现 logistic 算法

sklearn 实现 logistic 算法

作者头像
用户3147702
发布2022-06-27 13:07:13
6550
发布2022-06-27 13:07:13
举报
文章被收录于专栏:小脑斧科技博客

1. 引言

前面两篇文章中,我们介绍了 logistic 回归的原理和实现: Logistic 回归数学公式推导 梯度上升算法与随机梯度上升算法 本文,我们来看看如何使用 sklearn 来进行 logistic 回归呢。

2. Sklearn 的 logistic 回归

sklearn 通过 sklearn.linear_model.LogisticRegression 实现了逻辑斯蒂回归算法。 官方文档: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html。 下面的列表中,我们将样本数称为 ns,将分类数称为 nc,将特征数称为 nf。

3. LogisticRegression 的构造参数

sklearn.linear_model.LogisticRegression 具有以下构造参数:

DecisionTreeClassifier 类构造参数

3.1. penalty

penalty 参数是规范化方法,也称为正则化方法,主要是为了防止出现过拟合,具体的我们后面专门用一篇文章详细进行总结。 newton-cg、sag 和 lbfgs 算法只能使用 l2 正则化。 ‘elasticnet’ 只适用于 saga 算法。 具体算法由 solver 参数指定。

3.2. class_weight

class_weight 参数决定了样本的各分类类型权重,可以取值:

  • dict — 通过 dict 定义分类权重:{class_label: weight}
  • ‘balance’ — 使用 ns/(nc * np.bincount(y)) 作为分类权重
  • None — 默认值,不指定权重

3.3. solver

优化算法,有五个可选的参数:’newton-cg’, ’lbfgs’, ’liblinear’, ’sag’, ’saga’

  • liblinear — 开源的liblinear库实现,内部使用了坐标轴下降法来迭代优化损失函数,适用于小数据集
  • lbfgs — 拟牛顿法,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数,因此他不能用于没有连续导数的L1正则化,只能用于L2正则化
  • newton-cg — 牛顿迭代,也是利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数,因此也只能使用 L2 正则化
  • sag — 随机平均梯度下降法,适用于样本量非常大的数据集,否则结果准确率会受到影响,不能用于L1正则化
  • saga — 快速梯度下降法,线性收敛的随机优化算法的的变种,适用于样本量非常大的数据集

3.4. multi_class

多分类问题处理方法,有三个参数可选:’ovr’, ’multinomial’, ’auto’ 既然是“多分类问题处理方法”,所以对于二分类问题,选择哪个的处理方法都是一样的。

  • ‘ovr’ — ’OvR’, 将多分类问题看成是二分类问题,每次只将一类样本与其他类样本组成的集合进行训练,进行 nc 次训练以后就可以完成多分类问题的处理了
  • ‘multinomial’ — ’MvM’,liblinear 不能选择该项,以单循环的方式进行分类,每次处理两个分类,保证样本中所有分类两两组合进行过一次训练,共需 nc*(nc-1)/2 次训练,分类速度慢,但分类结果更准确
  • ‘auto’ — 如果 resolver 是 liblinear 则选择 OvR,否则选择 MvM

4. LogisticRegression 的类属性

LogisticRegression 的类属性

5. 类方法

  • decision_function(X) — 预测样本的置信度分数
  • densify() — 密集化处理,将系数矩阵转换为密集阵列格式,默认格式
  • sparsify() — 稀疏化处理,将系数矩阵转换为稀疏格式
  • fit(X, y[, sample_weight]) — 训练逻辑斯蒂模型
  • get_params([deep]) — 获取参数
  • predict(X) — 使用已训练模型进行预测
  • predict_log_proba(X) — 获取预测样本 log 概率估计
  • predict_proba(X) — 获取预测样本概率估计
  • score(X, y[, sample_weight]) — 模型打分
  • set_params(**params) — 设置参数

6. 示例

代码语言:javascript
复制
print(__doc__)

# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn import datasets

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
Y = iris.target

logreg = LogisticRegression(C=1e5, solver='lbfgs', multi_class='multinomial')

# Create an instance of Logistic Regression Classifier and fit the data.
logreg.fit(X, Y)

# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
h = .02  # step size in the mesh
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure(1, figsize=(4, 3))
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired)

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors='k', cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')

plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())

plt.show()

7. 参考资料

https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html。 https://scikit-learn.org/stable/auto_examples/linear_model/plot_iris_logistic.html#sphx-glr-auto-examples-linear-model-plot-iris-logistic-py。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-12-03,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小脑斧科技博客 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 引言
  • 2. Sklearn 的 logistic 回归
  • 3. LogisticRegression 的构造参数
    • 3.1. penalty
      • 3.2. class_weight
        • 3.3. solver
          • 3.4. multi_class
          • 4. LogisticRegression 的类属性
          • 5. 类方法
          • 6. 示例
          • 7. 参考资料
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档