前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Alink漫谈(二十一) :回归评估之源码分析

Alink漫谈(二十一) :回归评估之源码分析

作者头像
罗西的思考
发布2020-09-28 11:41:59
4000
发布2020-09-28 11:41:59
举报
文章被收录于专栏:罗西的思考罗西的思考

Alink漫谈(二十一) :回归评估之源码分析

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和将带领大家来分析Alink中 回归评估 的实现。

这是剖析Alink以来,最轻松的一次了。因为这里的概念和实现逻辑都非常清晰。

0x01 背景概念

1.1 功能介绍

回归评估是对回归算法的预测结果进行效果评估,支持下列评估指标。这些指标基本都是统计领域概念。

1.2 具体指标

Alink 提供如下指标:

count 行数

SST 总平方和(Sum of Squared for Total),度量了Y在样本中的分散程度。

\[SST=\sum_{i=1}^{N}(y_i-\bar{y})^2 \]

SSE 误差平方和(Sum of Squares for Error),度量了总样本变异。

\[SSE=\sum_{i=1}^{N}(y_i-f_i)^2" \]

SSR 回归平方和(Sum of Squares for Regression),度量了残差的样本变异。

\[SSR=\sum_{i=1}^{N}(f_i-\bar{y})^2 \]

R^2 判定系数(Coefficient of Determination),用于估计回归方程是否很好的拟合了样本的数据,判定系数为估计的回归方程提供了一个拟合优度的度量。

\[R^2=1-\dfrac{SSE}{SST} \]

R 多重相关系数(Multiple Correlation Coeffient),指一个随机变量与某一组随机变量间线性相依性的度量。

\[R=\sqrt{R^2} \]

MSE 均方误差(Mean Squared Error),均方差(标准差)、方差都是用来描述数据集的离散程度。

均方误差是衡量“平均误差”的一种较方便的方法,可以评价数据的变化程度。从类别来看属于预测评价与预测组合;从字面上看来,“均”指的是平均,即求其平均值,“方差”即是在概率论中用来衡量随机变量和其估计值(其平均值)之间的偏离程度的度量值,“误”可以理解为测定值与真实值之间的误差。

\[MSE=\dfrac{1}{N}\sum_{i=1}^{N}(f_i-y_i)^2 \]

RMSE 均方根误差(Root Mean Squared Error)

\[RMSE=\sqrt{MSE} \]

SAE/SAD 绝对误差(Sum of Absolute Error/Difference)

\[SAE=\sum_{i=1}^{N}|f_i-y_i| \]

MAE/MAD 平均绝对误差(Mean Absolute Error/Difference)

\[MAE=\dfrac{1}{N}\sum_{i=1}^{N}|f_i-y_i| \]

MAPE 平均绝对百分误差(Mean Absolute Percentage Error)

\[MAPE=\dfrac{100}{N}\sum_{i=1}^{N}|\dfrac{f_i-y_i}{y_i}| \]

explained variance 解释方差

\[explained Variance=\dfrac{SSR}{N} \]

0x02 示例代码

直接拿出来Alink的示例代码。

代码语言:javascript
复制
public class EvalRegressionBatchOpExp {
    
    public static void main(String[] args) throws Exception {
        Row[] data =
                new Row[] {
                        Row.of(0.4, 0.5),
                        Row.of(0.3, 0.5),
                        Row.of(0.2, 0.6),
                        Row.of(0.6, 0.7),
                        Row.of(0.1, 0.5)
                };

        MemSourceBatchOp input = new MemSourceBatchOp(data, new String[] {"label", "pred"});

        RegressionMetrics metrics = new EvalRegressionBatchOp()
                .setLabelCol("label")
                .setPredictionCol("pred")
                .linkFrom(input)
                .collectMetrics();

        System.out.println(metrics.getRmse());
        System.out.println(metrics.getR2());
        System.out.println(metrics.getSse());
        System.out.println(metrics.getMape());
        System.out.println(metrics.getMae());
        System.out.println(metrics.getSsr());
        System.out.println(metrics.getSst());
    }
}

输出为:

代码语言:javascript
复制
0.27568097504180444
-1.5675675675675653
0.38
141.66666666666669
0.24
0.31999999999999973
0.14800000000000013

0x03 总体逻辑

总体逻辑是:

  • 调用 CalcLocal 进行分区计算各种统计数值;
  • reduce 调用 ReduceBaseMetrics 进行归并各种统计数值;
  • 调用 SaveDataAsParams 存储;

getLabelCol 就是 y,getPredictionCol 就是 y_hat。

代码语言:javascript
复制
public EvalRegressionBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator in = checkAndGetFirst(inputs);

    // 这里就是找到y, y_hat
    TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getLabelCol());
    TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getPredictionCol());
	
  	// 利用y, y_hat来构建Metrics
    TableUtil.assertNumericalCols(in.getSchema(), this.getLabelCol(), this.getPredictionCol());
    DataSet<Row> out = in.select(new String[] {this.getLabelCol(), this.getPredictionCol()})
        .getDataSet()
        .rebalance()
        .mapPartition(new CalcLocal())
        .reduce(new EvaluationUtil.ReduceBaseMetrics())
        .flatMap(new EvaluationUtil.SaveDataAsParams());

    this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(),
        out, new TableSchema(new String[] {"regression_eval_result"}, new TypeInformation[] {Types.STRING})
    ));
    return this;
}

0x04 分区计算统计数值

调用 CalcLocal 进行分区计算各种统计数值,间接调用getRegressionStatistics。

代码语言:javascript
复制
/**
 * Get the label sum, predResult sum, SSE, MAE, MAPE of one partition.
 */
public static class CalcLocal implements MapPartitionFunction<Row, BaseMetricsSummary> {
    @Override
    public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector)
        throws Exception {
        collector.collect(getRegressionStatistics(rows));
    }
}

getRegressionStatistics作用是遍历输入数据,在本Partition内部计算各种累积数值,为后续做准备。

代码语言:javascript
复制
/**
 * Calculate the RegressionMetrics from local data.
 *
 * @param rows Input rows, the first field is label value, the second field is prediction value.
 * @return RegressionMetricsSummary.
 */
public static RegressionMetricsSummary getRegressionStatistics(Iterable<Row> rows) {
    RegressionMetricsSummary regressionSummary = new RegressionMetricsSummary();
    for (Row row : rows) {
        if (checkRowFieldNotNull(row)) {
            double yVal = ((Number)row.getField(0)).doubleValue();
            double predictVal = ((Number)row.getField(1)).doubleValue();
            double diff = Math.abs(yVal - predictVal);
            regressionSummary.ySumLocal += yVal;
            regressionSummary.ySum2Local += yVal * yVal;
            regressionSummary.predSumLocal += predictVal;
            regressionSummary.predSum2Local += predictVal * predictVal;
            regressionSummary.maeLocal += diff;
            regressionSummary.sseLocal += diff * diff;
            regressionSummary.mapeLocal += Math.abs(diff / yVal);
            regressionSummary.total++;
        }
    }
    return regressionSummary.total == 0 ? null : regressionSummary;
}

0x05 归并统计数值

reduce 调用 ReduceBaseMetrics 进行归并各种统计数值:

代码语言:javascript
复制
/**
 * Merge the BaseMetrics calculated locally.
 */
public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
    @Override
    public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
        return null == t1 ? t2 : t1.merge(t2);
    }
}

0x06 存储模型

这里调用SaveDataAsParams来存储模型。

代码语言:javascript
复制
/**
 * After merging all the BaseMetrics, we get the total BaseMetrics. Calculate the indexes and save them into params.
 */
public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
    @Override
    public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
        collector.collect(t.toMetrics().serialize());
    }
}

0x07 toMetrics

最后呈现出统计指标。

代码语言:javascript
复制
public RegressionMetrics toMetrics() {
    Params params = new Params();
    params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.SSE, sseLocal);
    params.set(RegressionMetrics.SSR,
        predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST));
    params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2)));
    params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total);
    params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE)));
    params.set(RegressionMetrics.SAE, maeLocal);
    params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total);
    params.set(RegressionMetrics.COUNT, (double)total);
    params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total);
    params.set(RegressionMetrics.Y_MEAN, ySumLocal / total);
    params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total);
    params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total);

    return new RegressionMetrics(params);
}

最后得到结果

代码语言:javascript
复制
params = {Params@9098} "Params {R2=-1.5675675675675693, predictionMean=0.5599999999999999, SSE=0.38, count=5.0, MAPE=141.66666666666666, RMSE=0.27568097504180444, MAE=0.24, R=NaN, SSR=0.3200000000000002, yMean=0.32, SST=0.1479999999999999, SAE=1.2, Explained Variance=0.06400000000000003, MSE=0.076}"
 params = {HashMap@9101}  size = 14
  "R2" -> "-1.5675675675675693"
  "predictionMean" -> "0.5599999999999999"
  "SSE" -> "0.38"
  "count" -> "5.0"
  "MAPE" -> "141.66666666666666"
  "RMSE" -> "0.27568097504180444"
  "MAE" -> "0.24"
  "R" -> "NaN"
  "SSR" -> "0.3200000000000002"
  "yMean" -> "0.32"
  "SST" -> "0.1479999999999999"
  "SAE" -> "1.2"
  "Explained Variance" -> "0.06400000000000003"
  "MSE" -> "0.076"

0xFF 参考

均方误差

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-09-25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Alink漫谈(二十一) :回归评估之源码分析
    • 0x00 摘要
      • 0x01 背景概念
        • 1.1 功能介绍
        • 1.2 具体指标
      • 0x02 示例代码
        • 0x03 总体逻辑
          • 0x04 分区计算统计数值
            • 0x05 归并统计数值
              • 0x06 存储模型
                • 0x07 toMetrics
                  • 0xFF 参考
                  相关产品与服务
                  流计算 Oceanus
                  流计算 Oceanus 是大数据产品生态体系的实时化分析利器,是基于 Apache Flink 构建的企业级实时大数据分析平台,具备一站开发、无缝连接、亚秒延时、低廉成本、安全稳定等特点。流计算 Oceanus 以实现企业数据价值最大化为目标,加速企业实时化数字化的建设进程。
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档