手把手教你做交叉验证

是不是常在文献中看到“交叉验证(Cross Validation, CV)”这个词?却常常感叹为啥自己不会做?看看本文能否帮到你!

方法简介

交叉验证法可以用来估计一种指定的统计学习方法的测试误差,从而来评价这种方法的表现(模型评价),目前已成为业界评估模型性能的标准;或者为这种方法选择合适的光滑度(模型选择)。

K折交叉验证法(k-fold CV)是将观测集随机地分为k个大小基本一致的组,或者说折(fold),第一折作为验证集,然后在剩下的k-1个折上拟合模型,均方误差MSE1(响应变量Y为定性变量时则为错误率)由保留折的观测计算得出。重复这个步骤k次,每一次把不同折作为验证集,整个过程会得到k个测试误差的估计MSE1,MSE2,…, MSEk。K折CV估计由这些值求平均计算得到,如下图所示的10折CV:

最常见的是k=5或k=10的情形,当k=n时便是我们说的留一交叉验证法(leave-one-out cross validation, LOOCV),也即LOOCV是k折CV的特例。

R语言实现

啰嗦完毕,下面上干货!R语言如何实现k折交叉验证呢?很多完备的添加包自身即包含交叉验证函数,查阅PDF帮助文档即可发现。如果你使用的添加包里确实没有这个功能而你又想做交叉验证怎么办呢?

caret添加包提供了createFolds( )函数来创建交叉验证的数据集,如果响应变量Y是定性变量,该函数会尝试在每一折中维持与原始数据类似的各类别的比例。下面我们以UCI机器学习数据仓库的“威斯康星乳腺癌诊断”数据集(数据集包括569例细胞活检案例,第一列为病人ID编号,第二列为癌症诊断结果,其他30个特征是数值型的实验室测量结果。癌症诊断结果用M表示恶性,B表示良性。)为例,演示10折CV评价KNN的分类准确度:

#设置工作目录

setwd("D:\\学海拾贝之统计\\数据源")

#读取数据

data=read.delim("wisc_bc_data.txt",sep=",",header=F)

#查看数据条目

str(data)

#重命名数据集

names(data)=c("id","diagnosis",paste0("v",seq(1:30)))

#查看良恶性案例的大致分布

table(data$diagnosis)

#将诊断结果转化为因子并为其赋予标签

data$diagnosis=factor(data$diagnosis,levels=c("B","M"),labels=c("Benign","Malignant"))

#将30个特征数据标准化

library(chipPCR)

#可将“minm”换成“zscore”

standardize=function(x)

data_n=as.data.frame(lapply(data[,c(3:32)],standardize))

#将正确的分类标签另存为labels

data_n$labels=data[,2]

library(class)

library(caret)

#创建10折的列表,设置种子值以使其可重复

set.seed(20180717)

folds=createFolds(data_n$labels,k=10)

#查看每一折中选取的观测

str(folds)

#10折CV

cv_results=lapply(folds,function(x){

train=data_n[-x,]

test=data_n[x,]

pred=knn(train=train[,1:30],test=test[,1:30],cl=train$labels,k=24)

actual=test$labels

accuracy =length(pred[which(pred==actual)])/length(actual)

return(accuracy)

})

#查看每一折作为验证集求得的准确率

str(cv_results)

#求10折CV的平均准确率

mean(unlist(cv_results))

学会了么?学不会不是你的问题,一定是小编没讲清楚

欢迎关注公众号!每周三给你好看!

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180718G0FYV800?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券