python SVM 案例,sklearn.svm.SVC 参数说明

sklearn.svm.SVC 参数说明

经常用到sklearn中的SVC函数,这里把文档中的参数翻译了一些,以备不时之需。

本身这个函数也是基于libsvm实现的,所以在参数设置上有很多相似的地方。(PS: libsvm中的二次规划问题的解决算法是SMO)。

sklearn.svm.SVC(C=1.0,kernel='rbf', degree=3, gamma='auto',coef0=0.0,shrinking=True,probability=False,tol=0.001,cache_size=200, class_weight=None,verbose=False,max_iter=-1,decision_function_shape=None,random_state=None)

参数:

l C:C-SVC的惩罚参数C?默认值是1.0

C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。

l kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’

  0 – 线性:u’v

  1 – 多项式:(gamma*u’*v + coef0)^degree

  2 – RBF函数:exp(-gamma|u-v|^2)

  3 –sigmoid:tanh(gamma*u’*v + coef0)

l degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。

l gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features

l coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。

l probability :是否采用概率估计?.默认为False

l shrinking :是否采用shrinking heuristic方法,默认为true

l tol :停止训练的误差值大小,默认为1e-3

l cache_size :核函数cache缓存大小,默认为200

l class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)

l verbose :允许冗余输出?

l max_iter :最大迭代次数。-1为无限制。

l decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3

l random_state :数据洗牌时的种子值,int值

主要调节的参数有:C、kernel、degree、gamma、coef0。

案例代码:

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn import svm
from scipy import stats
from sklearn.metrics import accuracy_score
import matplotlib as mpl
import matplotlib.pyplot as plt


def extend(a, b, r):
    x = a - b
    m = (a + b) / 2
    return m-r*x/2, m+r*x/2


if __name__ == "__main__":
    np.random.seed(0)
    N = 20
    x = np.empty((4*N, 2))
    print("{}\n{}".format(x.shape,x))
    means = [(-1, 1), (1, 1), (1, -1), (-1, -1)]
    print(means)
    sigmas = [np.eye(2), 2*np.eye(2), np.diag((1,2)), np.array(((2,1),(1,2)))]
    print(sigmas)
    for i in range(4):
        mn = stats.multivariate_normal(means[i], sigmas[i]*0.3)
        # print(mn)
        x[i*N:(i+1)*N, :] = mn.rvs(N)
        # print(mn.rvs(N))
    a = np.array((0,1,2,3)).reshape((-1, 1))
    print(a)
    y = np.tile(a, N).flatten()
    print(np.tile(a, N) )
    print(y)
    clf = svm.SVC(C=1, kernel='rbf', gamma=1, decision_function_shape='ovo')
    # clf = svm.SVC(C=1, kernel='linear', decision_function_shape='ovr')
    clf.fit(x, y)
    y_hat = clf.predict(x)
    acc = accuracy_score(y, y_hat)
    np.set_printoptions(suppress=True)
    print (u'预测正确的样本个数:%d,正确率:%.2f%%' % (round(acc*4*N), 100*acc))
    # decision_function
    print (clf.decision_function(x))
    print (y_hat)

    x1_min, x2_min = np.min(x, axis=0)
    x1_max, x2_max = np.max(x, axis=0)
    x1_min, x1_max = extend(x1_min, x1_max, 1.05)
    x2_min, x2_max = extend(x2_min, x2_max, 1.05)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    x_test = np.stack((x1.flat, x2.flat), axis=1)
    y_test = clf.predict(x_test)
    y_test = y_test.reshape(x1.shape)
    cm_light = mpl.colors.ListedColormap(['#FF8080', '#A0FFA0', '#6060FF', '#F080F0'])
    cm_dark = mpl.colors.ListedColormap(['r', 'g', 'b', 'm'])
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.figure(facecolor='w')
    plt.pcolormesh(x1, x2, y_test, cmap=cm_light)
    plt.scatter(x[:, 0], x[:, 1], s=40, c=y, cmap=cm_dark, alpha=0.7)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(b=True)
    plt.tight_layout(pad=2.5)
    plt.title(u'SVM多分类方法:One/One or One/Other', fontsize=18)
    plt.show()

分类结果:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏腾讯Bugly的专栏

手机端运行卷积神经网络实现文档检测功能(二) -- 从 VGG 到 MobileNetV2 知识梳理

1K40
来自专栏文武兼修ing——机器学习与IC设计

关于蘑菇数据集的探索分析数据集描述读取数据集直观分析——颜色鲜艳的蘑菇都有毒?相关性分析——判断各指标与毒性相关性模型训练——使用决策树模型

数据集描述 来源于kaggle的蘑菇数据集,包括毒性,大小,表面,颜色等,所有数据均为字符串类型,分析毒性与其他属性的关系 读取数据集 dataset = pd...

54960
来自专栏数据科学与人工智能

【Python环境】使用 scikit-learn 进行机器学习的简介

概要: 该章节,我们将介绍贯穿scikit-learn使用中的“机器学习(Machine Learning)”这个词汇,并给出一些简单的学习示例。 一、机器学习...

301100
来自专栏杂七杂八

numpy中的乘法(*,dot)

numpy中数据表示有数组和矩阵两种数据类型,他们的乘法计算也是多种形式,下面我们主要来说一下numpy中的乘法计算 numpy.ndarray 运算符 ...

29460
来自专栏利炳根的专栏

学习笔记CB014:TensorFlow seq2seq模型步步进阶

神经网络。《Make Your Own Neural Network》,用非常通俗易懂描述讲解人工神经网络原理用代码实现,试验效果非常好。

60630
来自专栏CreateAMind

keras doc 10终结篇 激活函数 回调函数 正则项 约束项 预训练模型

激活函数可以通过设置单独的激活层实现,也可以在构造层对象时通过传递activation参数实现。

56730
来自专栏漫漫深度学习路

tensorflow学习笔记(四十五):sess.run(tf.global_variables_initializer()) 做了什么?

当我们训练自己的神经网络的时候,无一例外的就是都会加上一句 sess.run(tf.global_variables_initializer()) ,这行代码的...

30360
来自专栏书山有路勤为径

Generative Adversarial Network

这里我们将建立 一个对抗生成网络 (GAN)训练MNIST,并在最后生成新的手写数字。

9920
来自专栏书山有路勤为径

Batch Normalization怎么加入batch normalization

Batch Normalization 会使你的参数搜索问题变得很容易,使神经网络对超参数的选择更加稳定,超参数的范围会更加庞大,工作效果也很好,也会使你的训练...

9820
来自专栏AILearning

sklearn 快速入门 - 0.18 中文翻译

对机器学习问题的简要介绍,以及如何使用scikit-learn来解决这些问题。介绍基本概念和惯例。 原文链接 : http://scikit-learn....

210100

扫码关注云+社区

领取腾讯云代金券