机器学习中最流行的模型之一,用于分类的支持向量机的完整介绍

支持向量机(SVM)是一个非常强大而多变的机器学习模型,能够执行线性或非线性的分类,回归,甚至异常值检测。它是机器学习中最流行的模型之一,任何对机器学习感兴趣的人都应该学习和使用它。SVM特别适用于复杂的中小型数据集分类。在这篇文章中,我们将探讨用于分类的SVM模型。

线性SVM

假设我们有两类数据,我们要使用SVM进行分类,如图所示:

这两个类别可以用直线(线性可分)轻松分离。左图显示了2个可能的线性分类器的决策边界。所有的SVM模型都是为了生成将数据进行很好分类的正确分割线(称为在更高维度中称为超平面)。在左图中,尽管红线也对数据进行了分类,但在新的数据实例中可能无法很好地执行。我们可以画出许多对这些数据进行分类的线,但是在所有这些线中,蓝线分隔数据最好。同样的蓝线显示在右图。这条线(超平面)不仅分离了两个类,还尽可能远离最接近的训练实例。你可以将SVM分类器视为在类之间拟合尽可能宽的街道(由右侧图上的平行虚线表示)。这被称最大间隔分类器。

这个最好的决策边界是由位于街道边缘的实例决定(或者称为“支持”)。这些实例称为支持向量。街道边缘的间距称为间隔(margin)。

软间隔分类器

如果我们严格要求我们的实例不得在“街道”上,并且修正街道的边缘,这就是所谓的硬间隔分类。而硬边缘分类有2个问题:

1)只有数据线性分离才有效。

2)对异常值太敏感。

在上面的数据分类图中,有一个蓝色的异常值。如果我们对该数据集应用硬间隔分类,我们将得到左边图所示的小间距的决策边界。为了避免这些问题,最好使用更弹性的模型。目的是在保持街道尽可能大和不允许实例过界之间找到一个良好的平衡(就是说会出现实例最终可能在街道中间甚至道路另一面的情况)。这称为软间隔分类。如果我们对该数据集应用软间隔分类,则我们将得到比硬间隔分类更大决策边界。如右图中所示。

非线性SVM

虽然线性SVM分类器能力不错,在许多情况下都工作的很好,但是许多数据集无法进行线性分离。处理非线性数据集的一个简单方法是添加更多的特征,比如多项式特征,有时这可以获得线性可分的数据集。通过生成多项式特征,我们将得到新的特征矩阵,它由具有小于或等于指定度的所有多项式组合的特征。下图是SVM使用多项式特征的示例。

核技巧

核(Kernel)是在一些特征空间中计算两个向量x和y点积的一种方法(特征空间可能有非常高的维度),所以核函数有时被称为“广义点积(generalized dot product)”。

假设我们有一个映射 φ:Rn→Rm向量Rn到特征空间Rm的映射。x和y的内积空间是φ(x)Tφ(y)。核是一个对应于这个点积k的函数,也就是k(x,y)=φ(x)Tφ(y)。核提供了在一些特征空间内计算点积的方法,你甚至可以不知道这个空间和φ是什么。

多项式核函数

添加多项式特征非常简单。但是低度多项式无法处理复杂的数据集,并且具有较高的多项式度它会产生大量的特征,让模型变得特别慢。在这种情况下,我们可以使用多项式核函数解决这个问题。多项式核函数公式如下:

d是多项式的度。

高斯核函数

高斯RBF(径向基函数)是SVM模型中使用的另一种流行的核函数。它的公式如下:

如果我们有如下的数据集,则高斯核函数非常有用。

超参数

在SVM模型中有2个重要的超参数。

C参数

C参数决定SVM分类器的边距宽度。C的值越越大分类器越严格,因此边距宽度小。对于较大的C值,如果该超平面更好地将所有训练点归类正确,则该模型将选择较小边距的超平面。相反,C的非常小的值将导致模型寻找更大的边距分离超平面,即使这个超平面错误的分类了更多的点。如果C值非常小,你会得到错误分类的实例,即使你的训练数据是线性可分的。

γ参数

该γ参数定义了每个训练实例的影响能力。γ参数对scikit-learn中的线性核函数是无效的。

使用scikit-learn实现

在这部分中,我们将使用scikit-learn来实现SVM。

线性核函数

import numpy as np
import pandas as pd
from matplotlibimport style
from sklearn.svmimport SVC
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']= (12,6)
style.use('ggplot')

# Import Dataset
data= pd.read_csv('data.csv', header=None)
X= data.values[:, :2]
y= data.values[:,2]

# A function to draw hyperplane and the margin of SVM classifier
def draw_svm(X, y, C=1.0):
    # Plotting the Points
    plt.scatter(X[:,0], X[:,1], c=y)

    # The SVM Model with given C parameter
    clf= SVC(kernel='linear', C=C)
    clf_fit= clf.fit(X, y)

    # Limit of the axes
    ax= plt.gca()
    xlim= ax.get_xlim()
    ylim= ax.get_ylim()

    # Creating the meshgrid
    xx= np.linspace(xlim[0], xlim[1],200)
    yy= np.linspace(ylim[0], ylim[1],200)
    YY, XX= np.meshgrid(yy, xx)
    xy= np.vstack([XX.ravel(), YY.ravel()]).T
    Z= clf.decision_function(xy).reshape(XX.shape)

    # Plotting the boundary
    ax.contour(XX, YY, Z, colors='k', levels=[-1,0,1],
                        alpha=0.5, linestyles=['--','-','--'])
    ax.scatter(clf.support_vectors_[:,0],
                clf.support_vectors_[:,1],
                s=100, linewidth=1, facecolors='none')
    plt.show()
    # Returns the classifier
    return clf_fit

clf_arr= []
clf_arr.append(draw_svm(X, y,0.0001))
clf_arr.append(draw_svm(X, y,0.001))
clf_arr.append(draw_svm(X, y,1))
clf_arr.append(draw_svm(X, y,10))

for i, clfin enumerate(clf_arr):
    # Accuracy Score
    print(clf.score(X, y))
    pred= clf.predict([(12,32), (-250,32), (120,43)])
    print(pred)
0.992907801418
[1 0 1]
0.992907801418
[1 0 1]
1.0
[1 0 1]
1.0
[1 0 1]

你能看到相同超平面具有不同边距宽度。这取决于C超参数。

多项式核函数

import numpy as np
import pandas as pd
from matplotlibimport style
from sklearn.svmimport SVC
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']= (12,6)
style.use('ggplot')

data= pd.read_csv('polydata2.csv', header=None)
X= data.values[:, :2]
y= data.values[:,2]

def draw_svm(X, y, C=1.0):
    plt.scatter(X[:,0], X[:,1], c=y)
    clf= SVC(kernel='poly', C=C)
    clf_fit= clf.fit(X, y)

    ax= plt.gca()
    xlim= ax.get_xlim()
    ylim= ax.get_ylim()

    xx= np.linspace(xlim[0], xlim[1],200)
    yy= np.linspace(ylim[0], ylim[1],200)
    YY, XX= np.meshgrid(yy, xx)
    xy= np.vstack([XX.ravel(), YY.ravel()]).T
    Z= clf.decision_function(xy).reshape(XX.shape)

    ax.contour(XX, YY, Z, colors='k', levels=[-1,0,1],
                        alpha=0.5, linestyles=['--','-','--'])
    ax.scatter(clf.support_vectors_[:,0],
                clf.support_vectors_[:,1],
                s=100, linewidth=1, facecolors='none')
    plt.show()
    return clf_fit

clf= draw_svm(X, y)
score= clf.score(X, y)
pred= clf.predict([(-130,110), (-170,-160), (80,90), (-280,20)])
print(score)
print(pred)
1.0
[0 1 0 1]

高斯核函数

import numpy as np
import pandas as pd
from matplotlibimport style
from sklearn.svmimport SVC
from sklearn.datasetsimport make_classification, make_blobs, make_moons
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']= (12,6)
style.use('ggplot')

X, y= make_moons(n_samples=200)

# Auto gamma equals 1/n_features
def draw_svm(X, y, C=1.0, gamma='auto'):
    plt.scatter(X[:,0], X[:,1], c=y)
    clf= SVC(kernel='rbf', C=C, gamma=gamma)
    clf_fit= clf.fit(X, y)

    ax= plt.gca()
    xlim= ax.get_xlim()
    ylim= ax.get_ylim()

    xx= np.linspace(xlim[0], xlim[1],200)
    yy= np.linspace(ylim[0], ylim[1],200)
    YY, XX= np.meshgrid(yy, xx)
    xy= np.vstack([XX.ravel(), YY.ravel()]).T
    Z= clf.decision_function(xy).reshape(XX.shape)

    ax.contour(XX, YY, Z, colors='k', levels=[-1,0,1],
                        alpha=0.5, linestyles=['--','-','--'])
    ax.scatter(clf.support_vectors_[:,0],
                clf.support_vectors_[:,1],
                s=100, linewidth=1, facecolors='none')
    plt.show()
    return clf_fit

clf_arr= []
clf_arr.append(draw_svm(X, y,0.01))
clf_arr.append(draw_svm(X, y,0.1))
clf_arr.append(draw_svm(X, y,1))
clf_arr.append(draw_svm(X, y,10))

for i, clfin enumerate(clf_arr):
    print(clf.score(X, y))
0.83
0.9
1.0
1.0
import numpy as np
import pandas as pd
from matplotlibimport style
from sklearn.svmimport SVC
from sklearn.datasetsimport make_gaussian_quantiles
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']= (12,6)
style.use('ggplot')

X, y= make_gaussian_quantiles(n_samples=200, n_features=2, n_classes=2, cov=3)

# Auto gamma equals 1/n_features
def draw_svm(X, y, C=1.0, gamma='auto'):
    plt.scatter(X[:,0], X[:,1], c=y)
    clf= SVC(kernel='rbf', C=C, gamma=gamma)
    clf_fit= clf.fit(X, y)

    ax= plt.gca()
    xlim= ax.get_xlim()
    ylim= ax.get_ylim()

    xx= np.linspace(xlim[0], xlim[1],200)
    yy= np.linspace(ylim[0], ylim[1],200)
    YY, XX= np.meshgrid(yy, xx)
    xy= np.vstack([XX.ravel(), YY.ravel()]).T
    Z= clf.decision_function(xy).reshape(XX.shape)

    ax.contour(XX, YY, Z, colors='k', levels=[-1,0,1],
                        alpha=0.5, linestyles=['--','-','--'])
    ax.scatter(clf.support_vectors_[:,0],
                clf.support_vectors_[:,1],
                s=100, linewidth=1, facecolors='none')
    plt.show()
    return clf_fit

clf_arr= []
clf_arr.append(draw_svm(X, y,0.1))
clf_arr.append(draw_svm(X, y,1))
clf_arr.append(draw_svm(X, y,10))
clf_arr.append(draw_svm(X, y,100))

for i, clfin enumerate(clf_arr):
    print(clf.score(X, y))
0.965
0.97
0.985
0.995

γ参数对RBF SVM模型非常重要。在第一个例子中,低的γ值导致了出现接近线性分类。

你可以在下面的链接中查看代码示例和数据集。

https://github.com/mubaris/studious-eureka

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-10-21

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Phoenix的Android之旅

Java 集合 Vector

List有三种实现,ArrayList, LinkedList, Vector, 它们的区别在于, ArrayList是非线程安全的, Vector则是线程安全...

672
来自专栏开发与安全

算法:AOV网(Activity on Vextex Network)与拓扑排序

在一个表示工程的有向图中,用顶点表示活动,用弧表示活动之间的优先关系,这样的有向图为顶点表示活动的网,我们称之为AOV网(Activity on Vextex ...

2607
来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1958
来自专栏刘君君

JDK8的HashMap源码学习笔记

3068
来自专栏拭心的安卓进阶之路

Java 集合深入理解(6):AbstractList

今天心情比天蓝,来学学 AbstractList 吧! ? 什么是 AbstractList ? AbstractList 继承自 AbstractCollec...

19210
来自专栏desperate633

LeetCode Invert Binary Tree题目分析

Invert a binary tree. 4 / \ 2 7 / \ / \1 3 6 9 to4 / \ 7 2 / \ / \9 6 3 1 Tri...

871
来自专栏alexqdjay

HashMap 多线程下死循环分析及JDK8修复

1K4
来自专栏MelonTeam专栏

ArrayList源码完全分析

导语: 这里分析的ArrayList是使用的JDK1.8里面的类,AndroidSDK里面的ArrayList基本和这个一样。 分析的方式是逐个API进行解析 ...

4519
来自专栏后端之路

LinkedList源码解读

List中除了ArrayList我们最常用的就是LinkedList了。 LInkedList与ArrayList的最大区别在于元素的插入效率和随机访问效率 ...

19710
来自专栏拭心的安卓进阶之路

Java 集合深入理解(12):古老的 Vector

今天刮台风,躲屋里看看 Vector ! 都说 Vector 是线程安全的 ArrayList,今天来根据源码看看是不是这么相...

2447

扫码关注云+社区