matlab svmtrain和svmclassify函数使用示例

监督式学习(Supervised Learning)常用算法包括:线性回归(Linear Regression)、逻辑回归(Logistic Regression)、神经网络(Neural Network)以及支持向量机(Support Vector Machine,SVM)等。支持向量机与逻辑回归算法类似,都是解决二分类或多分类问题,但是SVM在非线性回归预测方面具有更优秀的分类效果,所以SVM又被称为最大间距分类器。

本文不对支持向量机的原理进行详细解释,直接运用matlab自带的工具箱函数svmtrain、svmclassify解决实际的二分类问题。

导入数据:

clear; close all; clc;
%% ================ load fisheriris.mat ================
load fisheriris.mat

1、对于线性分类问题,我们选取线性核函数,原始数据包括训练数据和测试数据两部分。

data = meas(51:end,3:4);                % column 3,column 4作为特征值
group = species(51:end);                % 类别
idx = randperm(size(data,1));
N = length(idx);

% SVM train
T = floor(N*0.9);                       % 90组数据作为训练数据
xdata = data(idx(1:T),:);
xgroup = group(idx(1:T));
svmStr = svmtrain(xdata,xgroup,'Showplot',true);

训练过程得到结构体svmStr,对测试数据进行预测

% SVM predict
P = floor(N*0.1);                       % 10组预测数据
ydata = data(idx(T+1:end),:);
ygroup = group(idx(T+1:end));
pgroup = svmclassify(svmStr,ydata,'Showplot',true);     % svm预测
hold on;
plot(ydata(:,1),ydata(:,2),'bs','Markersize',12);
accuracy1 = sum(strcmp(pgroup,ygroup))/P*100;            % 预测准确性
hold off;

程序运行结果如下:

图中,方块*号表示测试数据的预测结果,accuracy1结果为90%(上下浮动)。

2、对于非线性分类问题,我们选取高斯核函数RBF,原始数据包括训练数据和测试数据两部分。

训练过程前,导入原始数据:

data = meas(51:end,1:2);                % column 1,column 2作为特征值
group = species(51:end);                % 类别
idx = randperm(size(data,1));
N = length(idx);

% SVM train
T = floor(N*0.9);                       % 90组数据作为训练数据
xdata = data(idx(1:T),:);
xgroup = group(idx(1:T));

对于高斯核函数,有两个参数对SVM的分类效果有着重要的影响:一个是sigma;另一个是C。

首先讨论sigma的影响,sigma反映了RBF函数从最大值点向周围函数值下降的速度,sigma越大,下降速度越慢,对应RBF函数越平缓;sigma越小,下降速度越快,对应RBF函数越陡峭。对于不同的sigma,程序代码:

% different sigma
figure;
sigma = 0.5;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...
    sigma,'showplot',true);
title('sigma = 0.5');
figure;
sigma = 1;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...
    sigma,'showplot',true);
title('sigma = 1');
figure;
sigma = 3;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...
    sigma,'showplot',true);
title('sigma = 3');

分类平面分别如下:

从图中可以看出,sigma越小,分类曲线越复杂,事实也确实如此。因为sigma越小,RBF函数越陡峭,下降速度越大,预测过程容易发生过拟合问题,使分类模型对训练数据过分拟合,而对测试数据预测效果不佳。

然后讨论C的影响,程序代码如下:

% different C
figure;
C = 1;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...
    C,'showplot',true);
title('C = 0.1');
figure;
C = 8;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...
    C,'showplot',true);
title('C = 1');
figure;
C = 64;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...
    C,'showplot',true);
title('C = 10');

分类平面如下:

从图中可以发现,C越大,分类曲线越复杂,也就是说越容易发生过拟合,因为C对应逻辑回归的lambda的倒数。

若令sigma=1,C=1,则对测试数据的预测程序:

% SVM predict
P = floor(N*0.1);                       % 10组预测数据
ydata = data(idx(T+1:end),:);
ygroup = group(idx(T+1:end));
% sigma = 1,C = 1,default
figure;
svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','showplot',true);
pgroup = svmclassify(svmStr,ydata,'Showplot',true);     % svm预测
hold on;
plot(ydata(:,1),ydata(:,2),'bs','Markersize',12);
accuracy2 = sum(strcmp(pgroup,ygroup))/P*100;            % 预测准确性
hold off;

程序运行结果如下:

图中,方块*号表示测试数据的预测结果,accuracy2结果为70%(上下浮动)。

分类效果不佳因为两个特征量的选择,可以选择更合适的特征量。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Brian

机器学习性能度量

---- 概述 我们在进行机器学习时需要衡量机器学习的优劣和本身模型的准确程度,比如简单的衡量数据的准确率和错误率,但是我们更关心的是模型的泛化能力的指标,即基...

33811
来自专栏机器学习算法全栈工程师

如果Boosting 你懂、那 Adaboost你懂么?

作者:崔家华 编辑:王抒伟 转载请注明作者和出处: https://zhuanlan.zhihu.com/ml-jack 机器学习知乎专栏: h...

2765
来自专栏CSDN技术头条

深度卷积对抗生成网络(DCGAN)

卷积神经网络在有监督学习中的各项任务上都有很好的表现,但在无监督学习领域,却比较少。本文介绍的算法将有监督学习中的CNN和无监督学习中的GAN结合到了一起。 ...

2078
来自专栏机器学习算法全栈工程师

全面直观认识深度神经网络

作者:石文华 编辑:赵一帆 01 深度学习的精准定义 一类通过多层非线性变换对高复杂性数据建模算法的集合。它的两个非常重要的特征是多层性和非线性。俗称多层非线...

3378
来自专栏机器学习算法与Python学习

PCA在图像降维的应用

(自动编码器优化之主成分分析)从实例和数学背景的引导下详细的介绍了PCA的原理以及定义,并以旋转数据的角度论述其数据降维的实质,如何从降维之后的数据还原近似原始...

3539
来自专栏机器之心

学界 | 将未来信息作为正则项,Twin Networks加强RNN对长期依赖的建模能力

2868
来自专栏https://www.cnblogs.com/L

【神经网络篇】--RNN递归神经网络初始与详解

由图可知,比如第二个节点的输入不仅依赖于本身的输入U1,而且依赖上一个节点的输入W0,U0,同样第三个节点依赖于前两个节点的输入,

622
来自专栏marsggbo

Andrew Ng机器学习课程笔记--week1(机器学习介绍及线性回归)

之前看过一遍,但是总是模模糊糊的感觉,也刚入门,虽然现在也是入门,但是对于一些概念已经有了比较深的认识(相对于最开始学习机器学习的时候)。所以为了打好基础,决定...

2027
来自专栏机器之心

入门 | 神经网络训练中,Epoch、Batch Size和迭代傻傻分不清?

选自Medium 机器之心编译 参与:刘晓坤 你肯定经历过这样的时刻,看着电脑屏幕抓着头,困惑着:「为什么我会在代码中使用这三个术语,它们有什么区别吗?」因为它...

34511
来自专栏专知

春节充电系列:李宏毅2017机器学习课程学习笔记03之梯度下降

【导读】我们在上一节的内容中已经为大家介绍了台大李宏毅老师的机器学习课程的regression问题,其中简要提及了梯度下降(gradient descent),...

34514

扫码关注云+社区