机器学习技术在应用之前使用“训练+检验”的模式,通常被称作“交叉验证”,如图1所示。
图1
让我们通过以下几幅图来理解这个问题:
图2
此处我们试图找到尺寸(size)和价格(price)的关系。三个模型各自做了如下工作:
在应用中,一个常见的做法是对多个模型进行迭代,从中选择表现更好的。然而,最终的分数是否会有改善依然未知,因为我们不知道这个模型是更好的发掘潜在关系了,还是过度拟合了。为了解答这个难题,我们应该使用交叉验证(cross validation)技术。它能帮我们得到更有概括性的关系模型。
实际上,机器学习关注的是通过训练集训练过后的模型对测试样本的分类效果,我们称之为泛化能力。左右两图的泛化能力就不好。在机器学习中,对偏差和方差的权衡是机器学习理论着重解决的问题。
交叉验证意味着需要保留一个样本数据集,不用来训练模型。在最终完成模型前,用这个数据集验证模型。交叉验证包含以下步骤:
这样做有助于了解模型的有效性。如果当前的模型在此数据集也表现良好,说明模型的泛化能力较好,可以用来预测未知数据。
交叉验证有很多方法,下面介绍其中三种。
(1)“验证集”法
保留 50% 的数据集用作验证,剩下 50% 训练模型。之后用验证集测试模型表现。这个方法的主要缺陷是,由于只使用了 50% 数据训练模型,原数据中一些重要的信息可能被忽略,也就是说,会有较大偏误。
(2)留一法交叉验证 ( LOOCV )
这种方法只保留一个数据点用作验证,用剩余的数据集训练模型。然后对每个数据点重复这个过程。这个方法有利有弊:
(3)K折交叉验证 (K-fold cross validation)
从以上两个验证方法中,我们知道:
是否有一种方法可以兼顾这三个方面?答案是肯定的!这种方法就是“ K折交叉验证”。该方法的简要步骤如下:
此算法的缺点是计算量较大,当 k=10 时,k 层交叉验证示意图如下:
图3
一个常见的问题是:如何确定合适的k值?K 值越小,偏误越大,所以越不推荐。另一方面,K 值太大,所得结果会变化多端。K 值小,则会变得像“验证集法”;K 值大,则会变得像“留一法”(LOOCV)。所以通常建议的值是 k=10 。
K 层交叉检验之后,我们得到 K 个不同的模型误差估算值(e1, e2 …..ek)。理想的情况是,这些误差值相加得 0 。要计算模型的偏误,我们把所有这些误差值相加再取平均值,平均值越低,模型越好。
模型表现变化程度的计算与之类似。取所有误差值的标准差,标准差越小说明模型随训练数据的变化越小。
应该试图在偏误和变化程度间找到一种平衡。降低变化程度、控制偏误可以达到这个目的。这样会得到更好的预测模型。进行这个取舍,通常会得出复杂程度较低的预测模型。
在决策树的例子中,我们已经用到了交叉验证,只不过那是内嵌在决策树训练函数中的交叉验证。Madlib还提供了独立的交叉验证函数,可用于大部分Madlib的预测模型。
如前所述,交叉验证可以估计一个预测模型在实践中的执行的精度,还可用于设置预测目标。Madlib提供的交叉验证函数非常灵活,不但可以选择交已经支持的叉验证算法,用户还能编写自己的验证算法。从交叉验证函数输入需要验证的训练、预测和误差估计函数规范。这些规范包括三部分:函数名称、传递给函数的参数数组、参数对应的数据类型数组。
其它输入包括输出表名,k折交叉验证的k值等。
cross_validation_general( modelling_func,
modelling_params,
modelling_params_type,
param_explored,
explore_values,
predict_func,
predict_params,
predict_params_type,
metric_func,
metric_params,
metric_params_type,
data_tbl,
data_id,
id_is_random,
validation_result,
data_cols,
fold_num
)
modelling_func:VARCHAR类型,模型训练函数名称。
modelling_params:VARCHAR[]类型,训练函数参数数组。
modelling_params_type:VARCHAR[]类型,训练函数参数对应的数据类型名称数组。
param_explored:VARCHAR类型,被寻找最佳值的参数名称,必须是modelling_params数组中的元素。
explore_values:VARCHAR类型,候选的参数值。
predict_func:VARCHAR类型,预测函数名称。
predict_params:VARCHAR[]类型,提供给预测函数的参数数组。
predict_params_type:VARCHAR[]类型,预测函数参数对应的数据类型名称数组。
metric_func:VARCHAR类型,误差度量函数名称。
metric_params:VARCHAR[]类型,提供给误差度量函数的参数数组。
metric_params_type:VARCHAR[]类型,误差度量函数参数对应的数据类型名称数组。
data_tbl:VARCHAR类型,包含原始输入数据表名,这些数据将被分成训练集和测试集。
data_id:VARCHAR类型,表示每一行唯一ID的列名,但可以为空。理想情况下,数据集中的每行数据都包含一个唯一ID,这样便于将数据集分成训练部分与验证部分。id_is_random参数值告诉交叉验证函数ID值是否是随机赋值。如果不是随机赋的ID值,验证函数为每行生成一个随机ID。
id_is_random:BOOLEAN类型,为TRUE时表示提供的ID是随机分配的。
validation_result:VARCHAR类型,存储交叉验证函数输出结果的表名,具有以下列:
param_explored被寻找最佳值的参数名称。与cross_validation_general()函数的param_explored入参相同。
average error误差度量函数计算出的平均误差。
standard deviation of error标准差。
data_cols:逗号分隔的用于计算的数据列名。为NULL时,函数自动计算数据表中的所有列。只有当data_id参数为NULL时才会用到此参数,否则忽略。如果数据集没有唯一ID,交叉验证函数为每行生成一个随机ID,并将带有随机ID的数据集复制到一个临时表。设置此参数为自变量和因变量列表,通过只复制计算需要的数据,最小化复制工作量。计算完成后临时表被自动删除。
fold_num:INTEGER类型,k值,缺省值为10,指定验证轮数,每轮验证使用1/fold_num数据做验证。
训练、预测和误差度量函数的参数数组中可以包含以下特殊关键字:
注意,如果explore_values参数值为NULL,那么只运行一轮交叉验证。
我们将调用交叉验证函数,量化弹性网络正则化回归模型的准确性,并找出最佳的正则化参数。关于弹性网络正则化的说明可以参考Elastic net regularization。
drop table if exists houses;
-- 房屋价格表
create table houses (
id serial not null, -- 自增序列
tax integer, -- 税金
bedroom real, -- 卧室数
bath real, -- 卫生间数
price integer, -- 价格
size integer, -- 使用面积
lot integer -- 占地面积
);
insert into houses(tax, bedroom, bath, price, size, lot) values
( 590, 2, 1, 50000, 770, 22100),
(1050, 3, 2, 85000, 1410, 12000),
( 20, 3, 1, 22500, 1060, 3500 ),
( 870, 2, 2, 90000, 1300, 17500),
(1320, 3, 2, 133000, 1500, 30000),
(1350, 2, 1, 90500, 820, 25700),
(2790, 3, 2.5, 260000, 2130, 25000),
( 680, 2, 1, 142500, 1170, 22000),
(1840, 3, 2, 160000, 1500, 19000),
(3680, 4, 2, 240000, 2790, 20000),
(1660, 3, 1, 87000, 1030, 17500),
(1620, 3, 2, 118600, 1250, 20000),
(3100, 3, 2, 140000, 1760, 38000),
(2070, 2, 3, 148000, 1550, 14000),
( 650, 3, 1.5, 65000, 1450, 12000);
create or replace function check_cv()
returns void as $$
begin
execute 'drop table if exists valid_rst_houses';
perform madlib.cross_validation_general(
'madlib.elastic_net_train', -- 训练函数
'{%data%, %model%, (price>100000), "array[tax, bath, size, lot]", binomial, 1, lambda, true, null, fista, "{eta = 2, max_stepsize = 2, use_active_set = t}", null, 2000, 1e-6}'::varchar[], -- 训练函数参数
'{varchar, varchar, varchar, varchar, varchar, double precision, double precision, boolean, varchar, varchar, varchar, varchar, integer, double precision}'::varchar[], -- 训练函数参数数据类型
'lambda', -- 被考察参数
'{0.04, 0.08, 0.12, 0.16, 0.20, 0.24, 0.28, 0.32, 0.36}'::varchar[], -- 被考察参数值
'madlib.elastic_net_predict', -- 预测函数
'{%model%, %data%, %id%, %prediction%}'::varchar[], -- 预测函数参数
'{text, text, text, text}'::varchar[], -- 预测函数参数数据类型
'madlib.misclassification_avg', -- 误差度量函数
'{%prediction%, %data%, %id%, (price>100000), %error%}'::varchar[], -- 误差度量函数参数
'{varchar, varchar, varchar, varchar, varchar}'::varchar[], -- 误差度量函数参数数据类型
'houses', -- 数据表
'id', -- ID列
false, -- id是否随机
'valid_rst_houses', -- 验证结果表
'{tax,bath,size,lot, price}'::varchar[], -- 数据列
3 -- 折数
);
end;
$$ language plpgsql volatile;
select check_cv();
select * from valid_rst_houses order by lambda;
结果:
lambda | error_rate_avg | error_rate_stddev
--------+------------------------+--------------------------------------------
0.04 | 0.26666666666666666667 | 0.1154700538379251529018297561003914911294
0.08 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
0.12 | 0.33333333333333333333 | 0.1154700538379251529018297561003914911294
0.16 | 0.53333333333333333333 | 0.2309401076758503058036595122007829822590
0.2 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
0.24 | 0.60000000000000000000 | 0.2000000000000000000000000000000000000000
0.28 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
0.32 | 0.66666666666666666667 | 0.2309401076758503058036595122007829822590
0.36 | 0.73333333333333333333 | 0.1154700538379251529018297561003914911294
(9 rows)
上面的查询结果表示,随着正则化参数不断加大,平均误差也会增加,而且当正则化参数较小时标准差也较小。因此得出结论,用0.04作为正则化参数,将得到较好的预测模型。