前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >cross_validate和KFold做Cross-validation的区别

cross_validate和KFold做Cross-validation的区别

作者头像
木羊
发布2022-04-11 18:03:07
7920
发布2022-04-11 18:03:07
举报
文章被收录于专栏:睡前机器学习

睡前絮语:

一年又快要过去了,这篇本来是要新年发的文章,还好赶在农历新年前发了。想想今年也写了不少的字,也看到了不少机器学习公号从干货到恰饭的转变,有过一些困惑,甚至到了质疑意义本身。

不过想想,我写不过只是我想写,如果居然有人愿意看,居然还感觉还有些帮助,那真是意外之喜。公号写着写着,不知不觉变成了经营。这篇也许是革新,也许只是回归初心。

新年,祝大家万事胜意,不忘初心!

以下正文

机器学习的模型选择一般通过cross-validation(交叉验证)来完成,很多人也简称为做CV。做CV的主要方法就几种,最常用的叫K折交叉验证,简单来说就是把数据集切成K份,然后做K次CV,每次分别取其中的K-1份作为训练集。这些随便找本讲机器学习的书都有,不展开了。

理解完原理就可以用sklearn(scikit-learn)来实际做做,但是一查文档傻眼了:sklearn有两个常用的API,一个叫cross_validate,直译过来就是“交叉验证”;另一个叫 KFold ,直译过来就是“K折”。

这就十分挠头了,这俩API各叫一半,那我们要做K折交叉验证该怎么选呢,岂不是要逼死强迫症?

别急,没什么是读一遍文档不能解决的,如果有,再看一眼源码。

先看文档。

对于cross_validate,文档如是说:

Evaluate metric(s) by cross-validation and also record fit/score times,翻译过来就是这个api用于计算交叉验证的值,同时还能还记录训练时间。简单来说,就是CV指标的计算工具。

对于KFold,文档的介绍要长一点:

K-Folds cross-validator.Provides train/test indices to split data in train/test sets. Split dataset into k consecutive folds (without shuffling by default).Each fold is then used once as a validation while the k - 1 remaining folds form the training set.

这段说明很有意思,反复说KFold是用来切(Split)数据的,粗看和书上对K折交叉验证的说明很像,让人容易混淆。但是另一份文档给这个api归了个类,归为Cross validation iterators,是“交叉验证迭代器”。

这一下就清晰了:cross_validate是直接算出CV的指标值,而KFold只负责将数据按K折要求切分数据,然后通过迭代器对外提供,至于你怎么用,是用来计算指标还是直接输出数据,KFold都甩手不管了。

再简单一点,你只要计算CV值,用cross_validate就行了,你想自己对K折数据进行一些处理,那就用KFold。可以说cross_validate输出的是成品,而KFold输出的只是半成品。

别看成品半成品,就觉得只要用cross_validate就行。从比赛来看,选手们用得更多的是KFold,原因有机会聊。现在我特别好奇另一个问题:如果是我,我会选择用KFold来实现cross_validate。那cross_validate有没有用到KFold呢?

扒了cross_validate代码的核心部分,如下:

代码语言:javascript
复制
    X, y, groups = indexable(X, y, groups)

    cv = check_cv(cv, y, classifier=is_classifier(estimator))

    if callable(scoring):
        scorers = scoring
    elif scoring is None or isinstance(scoring, str):
        scorers = check_scoring(estimator, scoring)
    else:
        scorers = _check_multimetric_scoring(estimator, scoring)

    # We clone the estimator to make sure that all the folds are
    # independent, and that it is pickle-able.
    parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)
    results = parallel(
        delayed(_fit_and_score)(
            clone(estimator),
            X,
            y,
            scorers,
            train,
            test,
            verbose,
            None,
            fit_params,
            return_train_score=return_train_score,
            return_times=True,
            return_estimator=return_estimator,
            error_score=error_score,
        )
        for train, test in cv.split(X, y, groups)
    )

就这么看好像没用到KFold。难道sklearn还要重复造轮子?别急,先找到“切数据”的部分:

代码语言:javascript
复制
for train, test in cv.split(X, y, groups)

用的是一个叫“cv”的对象的split方法。这个cv是通过check_cv函数得到的,核心代码如下:

代码语言:javascript
复制
   cv = 5 if cv is None else cv
    if isinstance(cv, numbers.Integral):
        if (
            classifier
            and (y is not None)
            and (type_of_target(y) in ("binary", "multiclass"))
        ):
            return StratifiedKFold(cv)
        else:
            return KFold(cv)

    if not hasattr(cv, "split") or isinstance(cv, str):
        if not isinstance(cv, Iterable) or isinstance(cv, str):
            raise ValueError(
                "Expected cv as an integer, cross-validation "
                "object (from sklearn.model_selection) "
                "or an iterable. Got %s." % cv
            )
        return _CVIterableWrapper(cv)

    return cv  # New style cv objects are passed without any modification

这里有两个KFold,一个叫StratifiedKFold,另一个就是我们要找的KFold,二者都是按K折且数据,为什么要分两个我们找机会另聊,不过,至此我们找到了KFold,也更清楚了KFold和cross_validate的关系和区别。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-01-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 睡前机器学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档