前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[Hands On ML] 5. 支持向量机

[Hands On ML] 5. 支持向量机

作者头像
Michael阿明
发布2021-02-19 09:53:28
3240
发布2021-02-19 09:53:28
举报
文章被收录于专栏:Michael阿明学习之路

文章目录

      1. 线性支持向量机分类
      1. 非线性支持向量机分类
      • 2.1 多项式核
      • 2.2 高斯 RBF 核
      1. 支持向量机回归
      1. 原理

本文为《机器学习实战:基于Scikit-Learn和TensorFlow》的读书笔记。

中文翻译参考

SVM 特别适合应用于复杂但中小规模数据集的分类问题。

可参考:《统计学习方法》支持向量机(Support Vector Machines,SVM) 笔记

1. 线性支持向量机分类

  • 硬间隔最大化:数据必须线性可分,间隔内无数据点
  • 软件间隔最大化:允许部分点在间隔内,甚至越过分类线,使用超参数 c 控制 较小的 c:惩罚小,间隔更大,较大的 c:惩罚大,间隔小

如果 SVM 模型过拟合,可以尝试通过减小超参数C去调整 SVM 对特征缩放比较敏感

2. 非线性支持向量机分类

很多时候,数据是线性不可分的,我们可以增加特征,下图左侧数据线性不可分,增加 x2 项以后就可分了

代码语言:javascript
复制
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
import numpy as np

X, y = make_moons(n_samples=100, noise=0.15, random_state=1)
def plot_dataset(X, y, axes):
    plt.plot(X[:, 0][y==0], X[:, 1][y==0], "bs")
    plt.plot(X[:, 0][y==1], X[:, 1][y==1], "g^")
    plt.axis(axes)
    plt.grid(True, which='both')
    plt.xlabel(r"$x_1$", fontsize=20)
    plt.ylabel(r"$x_2$", fontsize=20, rotation=0)

plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()
代码语言:javascript
复制
ploynomial_svm_clf = Pipeline((
    ("ploy_features",PolynomialFeatures(degree=3)),
    ("scaler",StandardScaler()),
    ("svm_clf",LinearSVC(C=10, loss="hinge"))
))

ploynomial_svm_clf.fit(X,y)
代码语言:javascript
复制
def plot_predictions(clf, axes):
    x0s = np.linspace(axes[0], axes[1], 100)
    x1s = np.linspace(axes[2], axes[3], 100)
    x0, x1 = np.meshgrid(x0s, x1s)
    X = np.c_[x0.ravel(), x1.ravel()]
    y_pred = clf.predict(X).reshape(x0.shape)
    y_decision = clf.decision_function(X).reshape(x0.shape) # 样本点到分割超平面的函数距离
    plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)
    plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)

plot_predictions(ploynomial_svm_clf, [-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])

# save_fig("moons_polynomial_svc_plot")
plt.show()

2.1 多项式核

  • 添加多项式特征,产生了大量的特征,使模型变慢
  • 使用核技巧,可以取得同等的效果,同时没有特征组合爆炸
代码语言:javascript
复制
from sklearn.svm import SVC

poly_kernel_svm_clf = Pipeline([
        ("scaler", StandardScaler()),
        ("svm_clf", SVC(kernel="poly", degree=3, coef0=1, C=5))
    ])
poly_kernel_svm_clf.fit(X, y)


poly100_kernel_svm_clf = Pipeline([
        ("scaler", StandardScaler()),
        ("svm_clf", SVC(kernel="poly", degree=10, coef0=100, C=5))
    ])
poly100_kernel_svm_clf.fit(X, y)

plt.figure(figsize=(11, 4))
plt.subplot(121)
plot_predictions(poly_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.title(r"$d=3, r=1, C=5$", fontsize=18)
plt.subplot(122)
plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.title(r"$d=10, r=100, C=5$", fontsize=18)

# save_fig("moons_kernelized_polynomial_svc_plot")
plt.show()
  • 如果模型过拟合,可以减小多项式核的阶数,欠拟合则增大阶数
  • 超参数 coef0 控制高阶多项式与低阶多项式对模型的影响

2.2 高斯 RBF 核

代码语言:javascript
复制
gamma1, gamma2 = 0.1, 5
C1, C2 = 0.001, 1000
hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)

svm_clfs = []
for gamma, C in hyperparams:
    rbf_kernel_svm_clf = Pipeline([
            ("scaler", StandardScaler()),
            ("svm_clf", SVC(kernel="rbf", gamma=gamma, C=C))
        ])
    rbf_kernel_svm_clf.fit(X, y)
    svm_clfs.append(rbf_kernel_svm_clf)

plt.figure(figsize=(11, 7))

for i, svm_clf in enumerate(svm_clfs):
    plt.subplot(221 + i)
    plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])
    plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
    gamma, C = hyperparams[i]
    plt.title(r"$\gamma = {}, C = {}$".format(gamma, C), fontsize=16)

plt.show()
  • 增大 γ 使钟型曲线更窄,在单个样本周围环绕
  • 较小 γ 使钟型曲线更宽,样本有更大的影响范围,判定边界最终则更加平滑
  • 如果模型过拟合,减小γ值,若欠拟合,增大γ

计算复杂度

3. 支持向量机回归

回归:在间隔内放置尽可能多的样本点

代码语言:javascript
复制
from sklearn.svm import LinearSVR
svm_reg = LinearSVR(epsilon=1.5, random_state=1)

间隔大小由 ϵ \epsilon ϵ 控制

代码语言:javascript
复制
from sklearn.svm import SVR
svm_poly_reg1 = SVR(kernel="poly", degree=2, C=100, epsilon=0.1, gamma="auto")
svm_poly_reg2 = SVR(kernel="poly", degree=2, C=0.01, epsilon=0.1, gamma="auto")

多项式核化的非线性SVM

4. 原理

代码语言:txt
复制
                                               y                               ^                                      =                                       {                                                                                                     0                                               if                                                                w                                                 T                                                              x                                              +                                              b                                              <                                              0                                                                                                                                                  1                                               if                                                                w                                                 T                                                              x                                              +                                              b                                              ≥                                              0                                                                                                             \hat{y}=\left\{\begin{array}{l}0 \text { if } \mathbf{w}^{T} \mathbf{x}+b<0 \\ 1 \text { if } \mathbf{w}^{T} \mathbf{x}+b \geq 0\end{array}\right.                     y^​={0 if wTx+b<01 if wTx+b≥0​
  • 两个平面的交界线,决策边界
  • 虚线表示:决策函数等于 1 或 -1 的点:它们平行,且到决策边界的距离相等,形成一个间隔
  • 训练线性 SVM 分类器意味着找到w值和b值使得这一个间隔尽可能大,同时避免间隔违规(硬间隔)或限制它们(软间隔)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/07/16 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 1. 线性支持向量机分类
  • 2. 非线性支持向量机分类
    • 2.1 多项式核
      • 2.2 高斯 RBF 核
      • 3. 支持向量机回归
      • 4. 原理
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档