前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Matlab-RBF对iris鸢尾花数据集进行分类

Matlab-RBF对iris鸢尾花数据集进行分类

作者头像
用户9925864
发布2022-07-27 09:34:30
1.9K0
发布2022-07-27 09:34:30
举报
文章被收录于专栏:算法工程师的学习日志

接着前面2期rbf相关的应用分享一下rbf在分类场景的应用,数据集采用iris

前期参考

Matlab-RBF神经网络拟合数据

Matlab RBF神经网络及其实例

一、数据集

iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类(setosa,versicolor, virginica),每类50个数据,每个数据包含4个属性。每一个数据包含4个独立的属性,这些属性变量测量植物的花朵(比如萼片和花瓣的长度等)信息。要求以iris数据为对象,来进行不可测信息(样本类别)的估计。数据随机打乱,然后训练集:测试集=7:3进行训练,并和实际结果作比较

二、编程步骤、思路

(1)读取训练数据通过load函数读取训练数据,并对数据进行打乱,提取对应的数据分为训练和验证数据,训练集和验证集7:3

代码语言:javascript
复制
iris = load('iris.txt');
inputData = iris(:,1:4);
outputData = iris(:,5);

flag = length(outputData);
orderTrain = randperm(flag);
nbertrain = round(0.7*flag);% 提取训练和验证数据 70% 训练,30% 验证

XTrain = inputData(orderTrain(1:nbertrain),1:4)';
YTrain = outputData(orderTrain(1:nbertrain))';
XValidation = inputData(orderTrain(nbertrain+1:flag),1:4)';
YValidation = outputData(orderTrain(nbertrain+1:flag))';

(2)建立一个RBF网络使用matlab的newrb函数,设定误差均方根值目标-0.02;径向基层的分布常数-1;最大的神经元个数-25

代码语言:javascript
复制
eg = 0.02; % 误差均方根值目标
sc = 1;    % 径向基层的分布常数
mn = 25; % 最大的神经元个数

训练模型

代码语言:javascript
复制
net = newrb(XTrain,YTrain,eg,sc);
代码语言:javascript
复制
NEWRB, neurons = 0, MSE = 0.656327
预测准确率: 97.7778 % 

(3)使用新的数据集测试这个网络将待识别的样本数据(XValidation)放在net变量,然后运行即可,

代码语言:javascript
复制
Y = net(XValidation);

最后的结果进行归一化计算,得到对应的预测类别 输出仿真结果

代码语言:javascript
复制
output = zeros(1,length(Y));
for i = 1:length(Y)
    [m,n] = min(abs(Y(i)-[1 2 3]));
    output(i) = n ;
end

绘制结果成图

分析:从实验运行结果可以看出,本程序的识别率准确率为97.7778 % ,

思考:本次使用了RBF神经网络,RBF是一种前馈型的神经网络,它的激励函数一般是高斯函数,高斯函数是通过计算输入与函数中心点的距离来算权重的。BP神经网络学习速率是固定的,因此网络的收敛速度慢,需要较长的训练时间。对于一些复杂问题,BP算法需要的训练时间可能非常长,这主要是由于学习速率太小造成的。而RBF神经网络是种高效的前馈式网络,它具有其他前向网络所不具有的最佳逼近性能和全局最优特性,并且结构简单,训练速度快,所以它也比BP网络更优。

完整代码

代码语言:javascript
复制
clc
close all
clear

iris = load('iris.txt');
inputData = iris(:,1:4);
outputData = iris(:,5);

flag = length(outputData);
orderTrain = randperm(flag);
nbertrain = round(0.7*flag);% 提取训练和验证数据 70% 训练,30% 验证

XTrain = inputData(orderTrain(1:nbertrain),1:4)';
YTrain = outputData(orderTrain(1:nbertrain))';
XValidation = inputData(orderTrain(nbertrain+1:flag),1:4)';
YValidation = outputData(orderTrain(nbertrain+1:flag))';


% net = newrbe(XTrain,YTrain);
eg = 0.02; % 误差均方根值目标
sc = 1;    % 径向基层的分布常数
mn = 25; % 最大的神经元个数
net = newrb(XTrain,YTrain,eg,sc);
Y = net(XValidation);
output = zeros(1,length(Y));
for i = 1:length(Y)
    [m,n] = min(abs(Y(i)-[1 2 3]));
    output(i) = n ;
end
figure
plot(YValidation,'r+');
ylabel('label');
hold on;
plot(output,'b*');
hold off;
legend({'Target','Output'})
ylim([0 4])
figure
plot(YValidation,'r');
ylabel('label');
hold on;
plot(output,'b');
hold off;
legend({'Target','Output'})
ylim([0 4])

error = YValidation-output;
figure
bar(error)
title('error')
ylim([-1.2 1.2])
figure
plot(error)
title('error')
ylim([-1.2 1.2])
correcr_rate = length(error(error==0))/length(YValidation)*100;
disp(['预测准确率: ', num2str(correcr_rate), ' % '])
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-01-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法工程师的学习日志 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档