首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >sklearn train_test_split()分层不使用2D标签

sklearn train_test_split()分层不使用2D标签
EN

Stack Overflow用户
提问于 2022-03-06 09:52:25
回答 1查看 287关注 0票数 0

我训练一个带有稀疏2D标签的seq到seq模型。为每个时间步骤分别定义了类的索引。这是多类单标签任务(softmax)。

这里的分层对于新生成的数据集中的平衡标签非常有用,可以通过原始数据集中分割的标签来实现。

代码语言:javascript
运行
复制
# load dataset
f = np.load('./new_dataset.npz')
signals = f['signals']
labels = f['labels']

# downsample to 50 Hz (6 sec windows)
if (signals.shape[0] % 2) != 0:
    signals = signals[:-1]
    labels = labels[:-1]

signals = np.reshape(signals, (-1, 600, signals.shape[-1]))
labels = np.reshape(labels, (-1, 600))

signals = signals[:, ::2]
labels = labels[:, ::2]

print(f"signals: {signals.shape}")
print(f"labels: {labels.shape}")

# split to train-test
X_train, X_test, y_train, y_test = train_test_split(
    signals, labels, test_size=0.15, random_state=9, stratify=labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15, random_state=9, stratify=y_train
)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
print(X_val.shape, y_val.shape)

结果

代码语言:javascript
运行
复制
signals: (41564, 300, 6)
labels: (41564, 300)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/7v/fqqcktvs23qc8fwgftjpz_gh0000gn/T/ipykernel_15879/1199612105.py in <module>
     42 
     43 # split to train-test
---> 44 X_train, X_test, y_train, y_test = train_test_split(
     45     signals, labels, test_size=0.15, random_state=9, stratify=labels
     46 )

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays)
   2439         cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
   2440 
-> 2441         train, test = next(cv.split(X=arrays[0], y=stratify))
   2442 
   2443     return list(

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in split(self, X, y, groups)
   1598         """
   1599         X, y, groups = indexable(X, y, groups)
-> 1600         for train, test in self._iter_indices(X, y, groups):
   1601             yield train, test
   1602 

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in _iter_indices(self, X, y, groups)
   1938         class_counts = np.bincount(y_indices)
   1939         if np.min(class_counts) < 2:
-> 1940             raise ValueError(
   1941                 "The least populated class in y has only 1"
   1942                 " member, which is too few. The minimum"

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.
EN

回答 1

Stack Overflow用户

发布于 2022-03-07 05:06:36

因为错误是说你的一个类只有一个成员。最小值为2。考虑删除或扩展该类。

代码语言:javascript
运行
复制
ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

在更改数组的形状之前,可以使用此方法查看每个类的计数:

代码语言:javascript
运行
复制
import collections
print(collections.Counter(np.argmax(labels, axis=1)))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71369221

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档