前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >HAWQ + MADlib 玩转数据挖掘之(十二)——模型评估之交叉验证

HAWQ + MADlib 玩转数据挖掘之(十二)——模型评估之交叉验证

作者头像
用户1148526
发布2018-01-03 17:11:04
2.5K0
发布2018-01-03 17:11:04
举报
文章被收录于专栏:Hadoop数据仓库Hadoop数据仓库

一、交叉验证概述

        机器学习技术在应用之前使用“训练+检验”的模式,通常被称作“交叉验证”,如图1所示。

图1

1. 预测模型的稳定性

        让我们通过以下几幅图来理解这个问题:

图2

        此处我们试图找到尺寸(size)和价格(price)的关系。三个模型各自做了如下工作:

  1. 第一个模型使用了线性等式。对于训练用的数据点,此模型有很大误差。这样的模型在初期排行榜和最终排行榜都会表现不好。这是“拟合不足”(“Under fitting”)的一个例子。此模型不足以发掘数据背后的趋势。
  2. 第二个模型发现了价格和尺寸的正确关系,此模型误差低/概括程度高。
  3. 第三个模型对于训练数据几乎是零误差。这是因为此关系模型把每个数据点的偏差(包括噪声)都纳入了考虑范围,也就是说,这个模型太过敏感,甚至会捕捉到只在当前数据训练集出现的一些随机模式。这是“过度拟合”(Over fitting)的一个例子。这个关系模型可能在初榜和终榜成绩变化很大。

        在应用中,一个常见的做法是对多个模型进行迭代,从中选择表现更好的。然而,最终的分数是否会有改善依然未知,因为我们不知道这个模型是更好的发掘潜在关系了,还是过度拟合了。为了解答这个难题,我们应该使用交叉验证(cross validation)技术。它能帮我们得到更有概括性的关系模型。

        实际上,机器学习关注的是通过训练集训练过后的模型对测试样本的分类效果,我们称之为泛化能力。左右两图的泛化能力就不好。在机器学习中,对偏差和方差的权衡是机器学习理论着重解决的问题。

2. 什么是交叉验证

        交叉验证意味着需要保留一个样本数据集,不用来训练模型。在最终完成模型前,用这个数据集验证模型。交叉验证包含以下步骤:

  1. 保留一个样本数据集,即测试集。
  2. 用剩余部分(训练集)训练模型。
  3. 用保留的数据集(测试集)验证模型。

        这样做有助于了解模型的有效性。如果当前的模型在此数据集也表现良好,说明模型的泛化能力较好,可以用来预测未知数据。 

3. 交叉验证的常用方法

        交叉验证有很多方法,下面介绍其中三种。

(1)“验证集”法

        保留 50% 的数据集用作验证,剩下 50% 训练模型。之后用验证集测试模型表现。这个方法的主要缺陷是,由于只使用了 50% 数据训练模型,原数据中一些重要的信息可能被忽略,也就是说,会有较大偏误。

(2)留一法交叉验证 ( LOOCV )

        这种方法只保留一个数据点用作验证,用剩余的数据集训练模型。然后对每个数据点重复这个过程。这个方法有利有弊:

  • 由于使用了所有数据点,所以偏差较低。
  • 验证过程重复了 n 次( n 为数据点个数),导致执行时间很长。
  • 由于只使用一个数据点验证,这个方法导致模型有效性的差异更大。得到的估计结果深受此点的影响。如果这是个离群点,会引起较大偏差。

(3)K折交叉验证 (K-fold cross validation)

        从以上两个验证方法中,我们知道:

  1. 应该使用较大比例的数据集来训练模型,否则会导致失败,最终得到偏误很大的模型。
  2. 验证用的数据点,其比例应该恰到好处。如果太少,会导致验证模型有效性时,得到的结果波动较大。
  3. 训练和验证过程应该重复多次(迭代)。训练集和验证集不能一成不变,这样有助于验证模型有效性。

        是否有一种方法可以兼顾这三个方面?答案是肯定的!这种方法就是“ K折交叉验证”。该方法的简要步骤如下:

  1.  把整个数据集随机分成 K“层”。
  2.  对于每一份数据来说: 1) 以该份作为测试集,其余作为训练集,也就是说用其中 K-1 层训练模型,然后用第K层验证。2) 在训练集上得到模型。3) 在测试集上得到生成误差。
  3. 重复这个过程,直到每“层”数据都作过验证集。这样对每一份数据都有一个预测结果,记录从每个预测结果获得的误差。
  4. 记录下的 k 个误差的平均值,被称为交叉验证误差(cross-validation error)。可以被用做衡量模型表现的标准。
  5. 取误差最小的那一个模型。

        此算法的缺点是计算量较大,当 k=10 时,k 层交叉验证示意图如下:

图3

        一个常见的问题是:如何确定合适的k值?K 值越小,偏误越大,所以越不推荐。另一方面,K 值太大,所得结果会变化多端。K 值小,则会变得像“验证集法”;K 值大,则会变得像“留一法”(LOOCV)。所以通常建议的值是 k=10 。

4. 如何衡量模型的偏误/变化程度

        K 层交叉检验之后,我们得到 K 个不同的模型误差估算值(e1, e2 …..ek)。理想的情况是,这些误差值相加得 0 。要计算模型的偏误,我们把所有这些误差值相加再取平均值,平均值越低,模型越好。

        模型表现变化程度的计算与之类似。取所有误差值的标准差,标准差越小说明模型随训练数据的变化越小。

        应该试图在偏误和变化程度间找到一种平衡。降低变化程度、控制偏误可以达到这个目的。这样会得到更好的预测模型。进行这个取舍,通常会得出复杂程度较低的预测模型。

二、Madlib的交叉验证

        在决策树的例子中,我们已经用到了交叉验证,只不过那是内嵌在决策树训练函数中的交叉验证。Madlib还提供了独立的交叉验证函数,可用于大部分Madlib的预测模型。

        如前所述,交叉验证可以估计一个预测模型在实践中的执行的精度,还可用于设置预测目标。Madlib提供的交叉验证函数非常灵活,不但可以选择交已经支持的叉验证算法,用户还能编写自己的验证算法。从交叉验证函数输入需要验证的训练、预测和误差估计函数规范。这些规范包括三部分:函数名称、传递给函数的参数数组、参数对应的数据类型数组。

  • 训练函数使用给定的自变量和因变量数据集产生模型,模型存储于输出表中。
  • 预测函数使用训练函数生成的模型,并接收不同于训练数据的自变量数据集,产生基于模型的对因变量的预测,并将预测结果存储在输出表中。预测函数的输入中应该包含一个表示唯一ID的列名,便于预测结果与验证值作比较。注意,有些Madlib的预测函数不将预测结果存储在输出表中,这种函数不适用于交叉验证。
  • 误差度量函数比较数据集中已知的因变量和预测结果,用特定的算法计算误差度量,并将结果存入一个表中。

其它输入包括输出表名,k折交叉验证的k值等。

三、交叉验证函数

1. 语法

代码语言:javascript
复制
cross_validation_general( modelling_func,
                          modelling_params,
                          modelling_params_type,
                          param_explored,
                          explore_values,
                          predict_func,
                          predict_params,
                          predict_params_type,
                          metric_func,
                          metric_params,
                          metric_params_type,
                          data_tbl,
                          data_id,
                          id_is_random,
                          validation_result,
                          data_cols,
                          fold_num
                        )

2. 参数

modelling_func:VARCHAR类型,模型训练函数名称。

modelling_params:VARCHAR[]类型,训练函数参数数组。

modelling_params_type:VARCHAR[]类型,训练函数参数对应的数据类型名称数组。

param_explored:VARCHAR类型,被寻找最佳值的参数名称,必须是modelling_params数组中的元素。

explore_values:VARCHAR类型,候选的参数值。

predict_func:VARCHAR类型,预测函数名称。

predict_params:VARCHAR[]类型,提供给预测函数的参数数组。

predict_params_type:VARCHAR[]类型,预测函数参数对应的数据类型名称数组。

metric_func:VARCHAR类型,误差度量函数名称。

metric_params:VARCHAR[]类型,提供给误差度量函数的参数数组。

metric_params_type:VARCHAR[]类型,误差度量函数参数对应的数据类型名称数组。

data_tbl:VARCHAR类型,包含原始输入数据表名,这些数据将被分成训练集和测试集。

data_id:VARCHAR类型,表示每一行唯一ID的列名,但可以为空。理想情况下,数据集中的每行数据都包含一个唯一ID,这样便于将数据集分成训练部分与验证部分。id_is_random参数值告诉交叉验证函数ID值是否是随机赋值。如果不是随机赋的ID值,验证函数为每行生成一个随机ID。

id_is_random:BOOLEAN类型,为TRUE时表示提供的ID是随机分配的。

validation_result:VARCHAR类型,存储交叉验证函数输出结果的表名,具有以下列:

                param_explored被寻找最佳值的参数名称。与cross_validation_general()函数的param_explored入参相同。

                average error误差度量函数计算出的平均误差。

                standard deviation of error标准差。

data_cols:逗号分隔的用于计算的数据列名。为NULL时,函数自动计算数据表中的所有列。只有当data_id参数为NULL时才会用到此参数,否则忽略。如果数据集没有唯一ID,交叉验证函数为每行生成一个随机ID,并将带有随机ID的数据集复制到一个临时表。设置此参数为自变量和因变量列表,通过只复制计算需要的数据,最小化复制工作量。计算完成后临时表被自动删除。

fold_num:INTEGER类型,k值,缺省值为10,指定验证轮数,每轮验证使用1/fold_num数据做验证。

        训练、预测和误差度量函数的参数数组中可以包含以下特殊关键字:

  • %data% – 代表训练/验证数据。
  • %model% – 代表训练函数的输出,即预测函数的输入。
  • %id% – 代表唯一ID列(用户提供的或函数生成的)。
  • %prediction% – 代表预测函数的输出,即误差度量函数的输入。
  • %error% – 代表误差度量函数的输出。

        注意,如果explore_values参数值为NULL,那么只运行一轮交叉验证。

四、示例

        我们将调用交叉验证函数,量化弹性网络正则化回归模型的准确性,并找出最佳的正则化参数。关于弹性网络正则化的说明可以参考Elastic net regularization

1. 准备输入数据

代码语言:javascript
复制
drop table if exists houses;
-- 房屋价格表
create table houses (
    id serial not null,   -- 自增序列
    tax integer,          -- 税金
    bedroom real,         -- 卧室数
    bath real,            -- 卫生间数
    price integer,        -- 价格
    size integer,         -- 使用面积
    lot integer           -- 占地面积
);

insert into houses(tax, bedroom, bath, price, size, lot) values
( 590, 2, 1,    50000,  770, 22100),
(1050, 3, 2,    85000, 1410, 12000),
(  20, 3, 1,    22500, 1060, 3500 ),
( 870, 2, 2,    90000, 1300, 17500),
(1320, 3, 2,   133000, 1500, 30000),
(1350, 2, 1,    90500,  820, 25700),
(2790, 3, 2.5, 260000, 2130, 25000),
( 680, 2, 1,   142500, 1170, 22000),
(1840, 3, 2,   160000, 1500, 19000),
(3680, 4, 2,   240000, 2790, 20000),
(1660, 3, 1,    87000, 1030, 17500),
(1620, 3, 2,   118600, 1250, 20000),
(3100, 3, 2,   140000, 1760, 38000),
(2070, 2, 3,   148000, 1550, 14000),
( 650, 3, 1.5,  65000, 1450, 12000);

2. 创建函数执行交叉验证

代码语言:javascript
复制
create or replace function check_cv()
returns void as $$
begin
    execute 'drop table if exists valid_rst_houses';
    perform madlib.cross_validation_general(
        'madlib.elastic_net_train',   -- 训练函数
    '{%data%, %model%, (price>100000), "array[tax, bath, size, lot]", binomial, 1, lambda, true, null, fista, "{eta = 2, max_stepsize = 2, use_active_set = t}", null, 2000, 1e-6}'::varchar[],  -- 训练函数参数
    '{varchar, varchar, varchar, varchar, varchar, double precision, double precision, boolean, varchar, varchar, varchar, varchar, integer, double precision}'::varchar[],   -- 训练函数参数数据类型
    'lambda',   -- 被考察参数
    '{0.04, 0.08, 0.12, 0.16, 0.20, 0.24, 0.28, 0.32, 0.36}'::varchar[], -- 被考察参数值
    'madlib.elastic_net_predict',   -- 预测函数
    '{%model%, %data%, %id%, %prediction%}'::varchar[],   -- 预测函数参数
    '{text, text, text, text}'::varchar[],   -- 预测函数参数数据类型
    'madlib.misclassification_avg', -- 误差度量函数
    '{%prediction%, %data%, %id%, (price>100000), %error%}'::varchar[],   -- 误差度量函数参数
    '{varchar, varchar, varchar, varchar, varchar}'::varchar[],   -- 误差度量函数参数数据类型
    'houses',   -- 数据表
    'id',   -- ID列
    false,   -- id是否随机
    'valid_rst_houses', -- 验证结果表
    '{tax,bath,size,lot, price}'::varchar[],   -- 数据列
    3  -- 折数
    );
end;
$$ language plpgsql volatile;

3. 执行函数并查询结果

代码语言:javascript
复制
select check_cv();
select * from valid_rst_houses order by lambda;

        结果:

代码语言:javascript
复制
 lambda |     error_rate_avg     |             error_rate_stddev              
--------+------------------------+--------------------------------------------
   0.04 | 0.26666666666666666667 | 0.1154700538379251529018297561003914911294
   0.08 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
   0.12 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
   0.16 | 0.53333333333333333333 | 0.2309401076758503058036595122007829822590
    0.2 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
   0.24 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
   0.28 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
   0.32 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
   0.36 | 0.73333333333333333333 | 0.1154700538379251529018297561003914911294
(9 rows)

        上面的查询结果表示,随着正则化参数不断加大,平均误差也会增加,而且当正则化参数较小时标准差也较小。因此得出结论,用0.04作为正则化参数,将得到较好的预测模型。

参考文献:

  1. Cross Validation:Madlib官方文档对交叉验证的说明。
  2. 用交叉验证改善模型的预测表现-着重k重交叉验证:对k折交叉验证简明扼要的阐述
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年08月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、交叉验证概述
    • 1. 预测模型的稳定性
      • 2. 什么是交叉验证
        • 3. 交叉验证的常用方法
          • 4. 如何衡量模型的偏误/变化程度
          • 二、Madlib的交叉验证
          • 三、交叉验证函数
            • 1. 语法
              • 2. 参数
              • 四、示例
                • 1. 准备输入数据
                  • 2. 创建函数执行交叉验证
                    • 3. 执行函数并查询结果
                    • 参考文献:
                    领券
                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档