sklearn中Logistics Regression的coef_和intercept_的具体意义

使用sklearn库可以很方便的实现各种基本的机器学习算法,例如今天说的逻辑斯谛回归(Logistic Regression),我在实现完之后,可能陷入代码太久,忘记基本的算法原理了,突然想不到 coef_intercept_ 具体是代表什么意思了,就是具体到公式中的哪个字母,虽然总体知道代表的是模型参数。

好尴尬,折腾了一会,终于弄明白了,记录下来,以说明自己too young。

正文

我们使用sklearn官方的一个例子来作为说明,源码可以从这里下载,下面我截取其中一小段并做了一些修改:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression

# 构造一些数据点
centers = [[-5, 0], [0, 1.5], [5, -1]]
X, y = make_blobs(n_samples=1000, centers=centers, random_state=40)
transformation = [[0.4, 0.2], [-0.4, 1.2]]
X = np.dot(X, transformation)

clf = LogisticRegression(solver='sag', max_iter=100, random_state=42, multi_class=multi_class).fit(X, y)

print clf.coef_ 
print clf.intercept_

输出如图:

可以看到 clf.coef_ 是一个3×2(n_class, n_features)的矩阵,clf.intercept_是一个1×3的矩阵(向量),那么这些到底是什么意思呢?

我们来回顾一下Logistic回归的模型:

hθ(x)=11+e(−θTx)

h_\theta(x) = \frac{1}{1+e^{(-\theta^Tx)}} 其中 θ\theta 是模型参数,其实 θTx\theta^Tx 就是一个线性表达式,将这个表达式的结果再一次利用Logistic函数映射到0~1之间。

知道了这个,也就可以搞清楚那个 clf.coef_clf.intercept_ 了: clf.coef_clf.intercept_ 就是 θ\theta ,下面我们来验证一下:

i = 100
print 1 / (1 + np.exp(-(np.dot(X[i].reshape(1, -1), cc.T) + clf.intercept_)))
# 正确的类别
print y[i]
print clf.predict_proba(X[i].reshape(1, -1))
print clf.predict_log_proba(X[i].reshape(1, -1))

输出结果:

可以看到结果是吻合的,说明我们的猜想是正确的。

END

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏企鹅号快讯

从零开始用Python构造决策树

来源:Python中文社区 作者:weapon 本文长度为700字,建议阅读5分钟 本文介绍如何不利用第三方库,仅用python自带的标准库来构造一个决策树。 ...

1937
来自专栏落影的专栏

Metal入门教程(二)三维变换

上一篇的教程介绍了如何绘制一张图片,这次的目标是把图片显示到3D物体上,并进行三维变换。

1285
来自专栏PaddlePaddle

【图像分类】如何转化模型文件

场景文字识别 图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计...

2975
来自专栏十月梦想

css3动画变换transform用法

刚才说到transition动画执行,接下来看下动画变换(transform),transform属性的取值4个

564
来自专栏mathor

“达观杯”文本智能处理挑战赛

 由于提供的数据集较大,一般运行时间再10到15分钟之间,基础电脑配置在4核8G的样子(越消耗内存在6.2G),因此,一般可能会遇到内存溢出的错误

502
来自专栏Pulsar-V

Arduino 基于陀螺仪的定位

姿态解算代码 #include "Wire.h" #include "I2Cdev.h" unsigned long now, lastTime = 0; ...

2816
来自专栏月色的自留地

从锅炉工到AI专家(8)

18213
来自专栏PPV课数据科学社区

数据挖掘知识脉络与资源整理(七)–饼图

? ? 简介 饼图英文学名为Sector Graph, 有名Pie Graph。常用于统计学模块。2D饼图为圆形,手画时,常用圆规作图。 仅排列在工作表的一...

2667
来自专栏LIN_ZONE

PHP计算两个经纬度地点之间的距离

function getdistance($lng1, $lat1, $lng2, $lat2) {

443
来自专栏大数据挖掘DT机器学习

写一只具有识别能力的图片爬虫

在网上看到python做图像识别的相关文章后,真心感觉python的功能实在太强大,因此将这些文章总结一下,建立一下自己的知识体系。 当然了,图像识别这个话题...

3435

扫码关注云+社区