首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Cross validation with ShuffleSplit使用ShuffleSplit做交叉验证

Cross validation with ShuffleSplit使用ShuffleSplit做交叉验证

作者头像
到不了的都叫做远方
修改2020-05-06 11:46:34
9210
修改2020-05-06 11:46:34
举报

ShuffleSplit is one of the simplest cross validation techniques. This cross validation technique will simply take a sample of the data for the number of iterations specified.

ShuffleSplit是交叉验证最简单的技术之一,这种交叉验证技术将从数据集中简单的抽取一个样本来具体说明大量的迭代。

Getting ready准备工作

ShuffleSplit is another cross validation technique that is very simple. We'll specify the total elements in the dataset, and it will take care of the rest. We'll walk through an example of estimating the mean of a univariate dataset. This is somewhat similar to resampling, but it'll illustrate one reason why we want to use cross validation while showing cross validation.

ShuffleSplit是另一种非常简单交叉验证技术,我们将具体说明数据集中的总量,然后关注剩余部分。我们将学习一个单变量数据集的均值估计的例子。这是重采样的某种相似的形式,但是这将说明当出现交叉验证的时候,为什么我们使用交叉验证的一个原因。

How to do it...怎么做

First, we need to create the dataset. We'll use NumPy to create a dataset, where we know the underlying mean. We'll sample half of the dataset to estimate the mean and see how close it is to the underlying mean:

首先,我们需要生成数据集,我们将使用Numpy来生成已知均值的数据集。我们将抽取一半的数据集来估计均值然后看看看它与潜在均值的近似程度。

import numpy as np
true_loc = 1000
true_scale = 10
N = 1000
dataset = np.random.normal(true_loc, true_scale, N)
import matplotlib.pyplot as plt
f, ax = plt.subplots(figsize=(7, 5))
ax.hist(dataset, color='k', alpha=.65, histtype='stepfilled');
ax.set_title("Histogram of dataset");
f.savefig("978-1-78398-948-5_06_06.png")

NumPy will give the following output: Numpy将给出以下输出:

Now, let's take the first half of the data and guess the mean:现在,我们选用第一半数据来预测它的均值:

from sklearn.model_selection import cross_val_score
holdout_set = dataset[:500]
fitting_set = dataset[500:]
estimate = fitting_set[:int(N/2)].mean()
import matplotlib.pyplot as plt
f, ax = plt.subplots(figsize=(7, 5))
ax.set_title("True Mean vs Regular Estimate")
ax.vlines(true_loc, 0, 1, color='r', linestyles='-', lw=5,alpha=.65, label='true mean')
ax.vlines(estimate, 0, 1, color='g', linestyles='-', lw=5,alpha=.65, label='regular estimate')
ax.set_xlim(999, 1001)
ax.legend()
f.show()
f.savefig("978-1-78398-948-5_06_07.png")

We'll get the following output:我们将得到以下输出结果:

Now, we can use ShuffleSplit to fit the estimator on several smaller datasets:现在,我们使用ShuffleSplit用几个更小的数据集来拟合估计值:

from sklearn.model_selection import ShuffleSplit
shuffle_split = ShuffleSplit(len(fitting_set))
mean_p = []
for train, _ in shuffle_split.split(fitting_set):
    mean_p.append(fitting_set[train].mean())
    shuf_estimate = np.mean(mean_p)
import matplotlib.pyplot as plt
f, ax = plt.subplots(figsize=(7, 5))
ax.vlines(true_loc, 0, 1, color='r', linestyles='-', lw=5,alpha=.65, label='true mean')
ax.vlines(estimate, 0, 1, color='g', linestyles='-', lw=5,alpha=.65, label='regular estimate')
ax.vlines(shuf_estimate, 0, 1, color='b', linestyles='-', lw=5,alpha=.65, label='shufflesplit estimate')
ax.set_title("All Estimates")
ax.set_xlim(999, 1001)
ax.legend(loc=3)

The output will be as follows:输出如下图所示:

As we can see, we got an estimate that was similar to what we expected, but we were able to take many samples to get that estimate.

如我们所见,我们能得到与我们期待的很接近的估计值,但是其实我们可以采取更多的例子来得到估计值。

本文系外文翻译,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文系外文翻译前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档