首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >类标签不显示sklearn.ensemble.RandomForestClassifier用于class_weight

类标签不显示sklearn.ensemble.RandomForestClassifier用于class_weight
EN

Stack Overflow用户
提问于 2019-06-26 19:14:28
回答 1查看 4.1K关注 0票数 0

我正在使用来自sklearn.ensemblesklearn.ensemble。当我不使用class_weight时,它可以工作,但是当我添加class_weight时,它会产生这个错误。

代码语言:javascript
复制
lr = RandomForestClassifier(n_estimators = 22,criterion =
                           'entropy',max_depth=5,class_weight={'Sex':2.})

lr.fit(X_train.values[:,1:],Y_train)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-248-411a1c135d08> in <module>
      1 print(X_train)
----> 2 lr.fit(X_train.values[:,1:],Y_train)

/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/forest.py in fit(self, X, y, sample_weight)
    273         self.n_outputs_ = y.shape[1]
    274 
--> 275         y, expanded_class_weight = self._validate_y_class_weight(y)
    276 
    277         if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:

/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/forest.py in _validate_y_class_weight(self, y)
    519                     class_weight = self.class_weight
    520                 expanded_class_weight = compute_sample_weight(class_weight,
--> 521                                                               y_original)
    522 
    523         return y, expanded_class_weight

/opt/conda/lib/python3.6/site-packages/sklearn/utils/class_weight.py in compute_sample_weight(class_weight, y, indices)
    161             weight_k = compute_class_weight(class_weight_k,
    162                                             classes_full,
--> 163                                             y_full)
    164 
    165         weight_k = weight_k[np.searchsorted(classes_full, y_full)]

/opt/conda/lib/python3.6/site-packages/sklearn/utils/class_weight.py in compute_class_weight(class_weight, classes, y)
     63             i = np.searchsorted(classes, c)
     64             if i >= len(classes) or classes[i] != c:
---> 65                 raise ValueError("Class label {} not present.".format(c))
     66             else:
     67                 weight[i] = class_weight[c]

ValueError: Class label Sex not present. 

这是我的X_train:

PassengerId Pclass Sex ... Ticket Fare Embarked

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-26 20:15:01

您在Y_train中有多少类?

class_weight涉及Y_train,即标签。

示例:

class_weight={0:1,1:2}

意味着重量1到0级,重量2到1级。

使用class_weight={'Sex':2.}是错误的,它指的是X_train

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56779657

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档