HAWQ + MADlib 玩转数据挖掘之(十二)——模型评估之交叉验证

一、交叉验证概述

        机器学习技术在应用之前使用“训练+检验”的模式,通常被称作“交叉验证”,如图1所示。

图1

1. 预测模型的稳定性

        让我们通过以下几幅图来理解这个问题:

图2

        此处我们试图找到尺寸(size)和价格(price)的关系。三个模型各自做了如下工作:

  1. 第一个模型使用了线性等式。对于训练用的数据点,此模型有很大误差。这样的模型在初期排行榜和最终排行榜都会表现不好。这是“拟合不足”(“Under fitting”)的一个例子。此模型不足以发掘数据背后的趋势。
  2. 第二个模型发现了价格和尺寸的正确关系,此模型误差低/概括程度高。
  3. 第三个模型对于训练数据几乎是零误差。这是因为此关系模型把每个数据点的偏差(包括噪声)都纳入了考虑范围,也就是说,这个模型太过敏感,甚至会捕捉到只在当前数据训练集出现的一些随机模式。这是“过度拟合”(Over fitting)的一个例子。这个关系模型可能在初榜和终榜成绩变化很大。

        在应用中,一个常见的做法是对多个模型进行迭代,从中选择表现更好的。然而,最终的分数是否会有改善依然未知,因为我们不知道这个模型是更好的发掘潜在关系了,还是过度拟合了。为了解答这个难题,我们应该使用交叉验证(cross validation)技术。它能帮我们得到更有概括性的关系模型。

        实际上,机器学习关注的是通过训练集训练过后的模型对测试样本的分类效果,我们称之为泛化能力。左右两图的泛化能力就不好。在机器学习中,对偏差和方差的权衡是机器学习理论着重解决的问题。

2. 什么是交叉验证

        交叉验证意味着需要保留一个样本数据集,不用来训练模型。在最终完成模型前,用这个数据集验证模型。交叉验证包含以下步骤:

  1. 保留一个样本数据集,即测试集。
  2. 用剩余部分(训练集)训练模型。
  3. 用保留的数据集(测试集)验证模型。

        这样做有助于了解模型的有效性。如果当前的模型在此数据集也表现良好,说明模型的泛化能力较好,可以用来预测未知数据。 

3. 交叉验证的常用方法

        交叉验证有很多方法,下面介绍其中三种。

(1)“验证集”法

        保留 50% 的数据集用作验证,剩下 50% 训练模型。之后用验证集测试模型表现。这个方法的主要缺陷是,由于只使用了 50% 数据训练模型,原数据中一些重要的信息可能被忽略,也就是说,会有较大偏误。

(2)留一法交叉验证 ( LOOCV )

        这种方法只保留一个数据点用作验证,用剩余的数据集训练模型。然后对每个数据点重复这个过程。这个方法有利有弊:

  • 由于使用了所有数据点,所以偏差较低。
  • 验证过程重复了 n 次( n 为数据点个数),导致执行时间很长。
  • 由于只使用一个数据点验证,这个方法导致模型有效性的差异更大。得到的估计结果深受此点的影响。如果这是个离群点,会引起较大偏差。

(3)K折交叉验证 (K-fold cross validation)

        从以上两个验证方法中,我们知道:

  1. 应该使用较大比例的数据集来训练模型,否则会导致失败,最终得到偏误很大的模型。
  2. 验证用的数据点,其比例应该恰到好处。如果太少,会导致验证模型有效性时,得到的结果波动较大。
  3. 训练和验证过程应该重复多次(迭代)。训练集和验证集不能一成不变,这样有助于验证模型有效性。

        是否有一种方法可以兼顾这三个方面?答案是肯定的!这种方法就是“ K折交叉验证”。该方法的简要步骤如下:

  1.  把整个数据集随机分成 K“层”。
  2.  对于每一份数据来说: 1) 以该份作为测试集,其余作为训练集,也就是说用其中 K-1 层训练模型,然后用第K层验证。2) 在训练集上得到模型。3) 在测试集上得到生成误差。
  3. 重复这个过程,直到每“层”数据都作过验证集。这样对每一份数据都有一个预测结果,记录从每个预测结果获得的误差。
  4. 记录下的 k 个误差的平均值,被称为交叉验证误差(cross-validation error)。可以被用做衡量模型表现的标准。
  5. 取误差最小的那一个模型。

        此算法的缺点是计算量较大,当 k=10 时,k 层交叉验证示意图如下:

图3

        一个常见的问题是:如何确定合适的k值?K 值越小,偏误越大,所以越不推荐。另一方面,K 值太大,所得结果会变化多端。K 值小,则会变得像“验证集法”;K 值大,则会变得像“留一法”(LOOCV)。所以通常建议的值是 k=10 。

4. 如何衡量模型的偏误/变化程度

        K 层交叉检验之后,我们得到 K 个不同的模型误差估算值(e1, e2 …..ek)。理想的情况是,这些误差值相加得 0 。要计算模型的偏误,我们把所有这些误差值相加再取平均值,平均值越低,模型越好。

        模型表现变化程度的计算与之类似。取所有误差值的标准差,标准差越小说明模型随训练数据的变化越小。

        应该试图在偏误和变化程度间找到一种平衡。降低变化程度、控制偏误可以达到这个目的。这样会得到更好的预测模型。进行这个取舍,通常会得出复杂程度较低的预测模型。

二、Madlib的交叉验证

        在决策树的例子中,我们已经用到了交叉验证,只不过那是内嵌在决策树训练函数中的交叉验证。Madlib还提供了独立的交叉验证函数,可用于大部分Madlib的预测模型。

        如前所述,交叉验证可以估计一个预测模型在实践中的执行的精度,还可用于设置预测目标。Madlib提供的交叉验证函数非常灵活,不但可以选择交已经支持的叉验证算法,用户还能编写自己的验证算法。从交叉验证函数输入需要验证的训练、预测和误差估计函数规范。这些规范包括三部分:函数名称、传递给函数的参数数组、参数对应的数据类型数组。

  • 训练函数使用给定的自变量和因变量数据集产生模型,模型存储于输出表中。
  • 预测函数使用训练函数生成的模型,并接收不同于训练数据的自变量数据集,产生基于模型的对因变量的预测,并将预测结果存储在输出表中。预测函数的输入中应该包含一个表示唯一ID的列名,便于预测结果与验证值作比较。注意,有些Madlib的预测函数不将预测结果存储在输出表中,这种函数不适用于交叉验证。
  • 误差度量函数比较数据集中已知的因变量和预测结果,用特定的算法计算误差度量,并将结果存入一个表中。

其它输入包括输出表名,k折交叉验证的k值等。

三、交叉验证函数

1. 语法

cross_validation_general( modelling_func,
                          modelling_params,
                          modelling_params_type,
                          param_explored,
                          explore_values,
                          predict_func,
                          predict_params,
                          predict_params_type,
                          metric_func,
                          metric_params,
                          metric_params_type,
                          data_tbl,
                          data_id,
                          id_is_random,
                          validation_result,
                          data_cols,
                          fold_num
                        )

2. 参数

modelling_func:VARCHAR类型,模型训练函数名称。

modelling_params:VARCHAR[]类型,训练函数参数数组。

modelling_params_type:VARCHAR[]类型,训练函数参数对应的数据类型名称数组。

param_explored:VARCHAR类型,被寻找最佳值的参数名称,必须是modelling_params数组中的元素。

explore_values:VARCHAR类型,候选的参数值。

predict_func:VARCHAR类型,预测函数名称。

predict_params:VARCHAR[]类型,提供给预测函数的参数数组。

predict_params_type:VARCHAR[]类型,预测函数参数对应的数据类型名称数组。

metric_func:VARCHAR类型,误差度量函数名称。

metric_params:VARCHAR[]类型,提供给误差度量函数的参数数组。

metric_params_type:VARCHAR[]类型,误差度量函数参数对应的数据类型名称数组。

data_tbl:VARCHAR类型,包含原始输入数据表名,这些数据将被分成训练集和测试集。

data_id:VARCHAR类型,表示每一行唯一ID的列名,但可以为空。理想情况下,数据集中的每行数据都包含一个唯一ID,这样便于将数据集分成训练部分与验证部分。id_is_random参数值告诉交叉验证函数ID值是否是随机赋值。如果不是随机赋的ID值,验证函数为每行生成一个随机ID。

id_is_random:BOOLEAN类型,为TRUE时表示提供的ID是随机分配的。

validation_result:VARCHAR类型,存储交叉验证函数输出结果的表名,具有以下列:

                param_explored被寻找最佳值的参数名称。与cross_validation_general()函数的param_explored入参相同。

                average error误差度量函数计算出的平均误差。

                standard deviation of error标准差。

data_cols:逗号分隔的用于计算的数据列名。为NULL时,函数自动计算数据表中的所有列。只有当data_id参数为NULL时才会用到此参数,否则忽略。如果数据集没有唯一ID,交叉验证函数为每行生成一个随机ID,并将带有随机ID的数据集复制到一个临时表。设置此参数为自变量和因变量列表,通过只复制计算需要的数据,最小化复制工作量。计算完成后临时表被自动删除。

fold_num:INTEGER类型,k值,缺省值为10,指定验证轮数,每轮验证使用1/fold_num数据做验证。

        训练、预测和误差度量函数的参数数组中可以包含以下特殊关键字:

  • %data% – 代表训练/验证数据。
  • %model% – 代表训练函数的输出,即预测函数的输入。
  • %id% – 代表唯一ID列(用户提供的或函数生成的)。
  • %prediction% – 代表预测函数的输出,即误差度量函数的输入。
  • %error% – 代表误差度量函数的输出。

        注意,如果explore_values参数值为NULL,那么只运行一轮交叉验证。

四、示例

        我们将调用交叉验证函数,量化弹性网络正则化回归模型的准确性,并找出最佳的正则化参数。关于弹性网络正则化的说明可以参考Elastic net regularization

1. 准备输入数据

drop table if exists houses;
-- 房屋价格表
create table houses (
    id serial not null,   -- 自增序列
    tax integer,          -- 税金
    bedroom real,         -- 卧室数
    bath real,            -- 卫生间数
    price integer,        -- 价格
    size integer,         -- 使用面积
    lot integer           -- 占地面积
);

insert into houses(tax, bedroom, bath, price, size, lot) values
( 590, 2, 1,    50000,  770, 22100),
(1050, 3, 2,    85000, 1410, 12000),
(  20, 3, 1,    22500, 1060, 3500 ),
( 870, 2, 2,    90000, 1300, 17500),
(1320, 3, 2,   133000, 1500, 30000),
(1350, 2, 1,    90500,  820, 25700),
(2790, 3, 2.5, 260000, 2130, 25000),
( 680, 2, 1,   142500, 1170, 22000),
(1840, 3, 2,   160000, 1500, 19000),
(3680, 4, 2,   240000, 2790, 20000),
(1660, 3, 1,    87000, 1030, 17500),
(1620, 3, 2,   118600, 1250, 20000),
(3100, 3, 2,   140000, 1760, 38000),
(2070, 2, 3,   148000, 1550, 14000),
( 650, 3, 1.5,  65000, 1450, 12000);

2. 创建函数执行交叉验证

create or replace function check_cv()
returns void as $$
begin
    execute 'drop table if exists valid_rst_houses';
    perform madlib.cross_validation_general(
        'madlib.elastic_net_train',   -- 训练函数
    '{%data%, %model%, (price>100000), "array[tax, bath, size, lot]", binomial, 1, lambda, true, null, fista, "{eta = 2, max_stepsize = 2, use_active_set = t}", null, 2000, 1e-6}'::varchar[],  -- 训练函数参数
    '{varchar, varchar, varchar, varchar, varchar, double precision, double precision, boolean, varchar, varchar, varchar, varchar, integer, double precision}'::varchar[],   -- 训练函数参数数据类型
    'lambda',   -- 被考察参数
    '{0.04, 0.08, 0.12, 0.16, 0.20, 0.24, 0.28, 0.32, 0.36}'::varchar[], -- 被考察参数值
    'madlib.elastic_net_predict',   -- 预测函数
    '{%model%, %data%, %id%, %prediction%}'::varchar[],   -- 预测函数参数
    '{text, text, text, text}'::varchar[],   -- 预测函数参数数据类型
    'madlib.misclassification_avg', -- 误差度量函数
    '{%prediction%, %data%, %id%, (price>100000), %error%}'::varchar[],   -- 误差度量函数参数
    '{varchar, varchar, varchar, varchar, varchar}'::varchar[],   -- 误差度量函数参数数据类型
    'houses',   -- 数据表
    'id',   -- ID列
    false,   -- id是否随机
    'valid_rst_houses', -- 验证结果表
    '{tax,bath,size,lot, price}'::varchar[],   -- 数据列
    3  -- 折数
    );
end;
$$ language plpgsql volatile;

3. 执行函数并查询结果

select check_cv();
select * from valid_rst_houses order by lambda;

        结果:

 lambda |     error_rate_avg     |             error_rate_stddev              
--------+------------------------+--------------------------------------------
   0.04 | 0.26666666666666666667 | 0.1154700538379251529018297561003914911294
   0.08 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
   0.12 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
   0.16 | 0.53333333333333333333 | 0.2309401076758503058036595122007829822590
    0.2 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
   0.24 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
   0.28 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
   0.32 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
   0.36 | 0.73333333333333333333 | 0.1154700538379251529018297561003914911294
(9 rows)

        上面的查询结果表示,随着正则化参数不断加大,平均误差也会增加,而且当正则化参数较小时标准差也较小。因此得出结论,用0.04作为正则化参数,将得到较好的预测模型。

参考文献:

  1. Cross Validation:Madlib官方文档对交叉验证的说明。
  2. 用交叉验证改善模型的预测表现-着重k重交叉验证:对k折交叉验证简明扼要的阐述

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏企鹅号快讯

Python 机器学习算法实践:树回归

前言 最近由于开始要把精力集中在课题的应用上面了,这篇总结之后算法原理的学习先告一段落。本文主要介绍决策树用于回归问题的相关算法实现,其中包括回归树(regre...

1929
来自专栏人工智能LeadAI

线性回归与最小二乘法 | 机器学习笔记

这篇笔记会将几本的线性回归概念和最小二乘法。 在机器学习中,一个重要而且常见的问题就是学习和预测特征变量(自变量)与响应的响应变量(应变量)之间的函数关系 ...

2947
来自专栏Python中文社区

机器学习算法实践:树回归

專 欄 ❈PytLab,Python 中文社区专栏作者。主要从事科学计算与高性能计算领域的应用,主要语言为Python,C,C++。熟悉数值算法(最优化方法,...

2549
来自专栏SIGAI学习与实践平台

理解凸优化

凸优化(convex optimization)是最优化问题中非常重要的一类,也是被研究的很透彻的一类。对于机器学习来说,如果要优化的问题被证明是凸优化问题,则...

732
来自专栏编程

我的R语言小白之梯度上升和逐步回归的结合使用

我的R语言小白之梯度上升和逐步回归的结合使用 今天是圣诞节,祝你圣诞节快乐啦,虽然我没有过圣诞节的习惯,昨天平安夜,也是看朋友圈才知道,原来是平安夜了,但是我昨...

2046
来自专栏深度学习思考者

卷积神经网络源码——最终输出部分的理解

  针对matlab版本的卷积神经网络的最终分类器(输出部分)的理解:   部分代码: '''cnnff''' net.fv = []; %...

1766
来自专栏智能算法

分类回归树算法---CART

一、算法介绍 分类回归树算法:CART(Classification And Regression Tree)算法也属于一种决策树,和之前介绍了C4.5算法相...

4818
来自专栏大数据挖掘DT机器学习

利用GBDT模型构造新特征具体方法

实际问题中,可直接用于机器学习模型的特征往往并不多。能否从“混乱”的原始log中挖掘到有用的特征,将会决定机器学习模型效果的好坏。引用下面一句流行的话: 特征决...

3607
来自专栏大数据挖掘DT机器学习

机器学习模型的特性

机器学习模型中有许多种不同方法可以用来解决分类和回归问题。对同一个问题来说,这些不同模型都可以被当成解决问题的黑箱来看待。然而,每种模型都源自于不同的...

33311
来自专栏H2Cloud

多元线性回归公式推导及R语言实现

实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示。

641

扫码关注云+社区