首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >多标签计算类权重不可散列类型

多标签计算类权重不可散列类型
EN

Stack Overflow用户
提问于 2019-02-26 00:36:21
回答 2查看 3.2K关注 0票数 5

在我的神经网络中处理多标签分类问题,有13个可能的输出,使用Keras,sklearn等……

每个输出可以是像0,0,1,0,1,1,0,0,0,0,0,1,0这样的数组。

我有一个不平衡的数据集,并且我尝试应用compute_class_weight方法,例如:

class_weight = compute_class_weight('balanced', np.unique(Y_train), Y_train)

当我尝试运行我的代码时,我得到了不可散列类型:'numpy.ndarray':

Traceback (most recent call last):
  File "main.py", line 115, in <module>
    train(dataset, labels)
  File "main.py", line 66, in train
    class_weight = compute_class_weight('balanced', np.unique(Y_train), Y_train)
  File "/home/python-env/env/lib/python3.6/site-packages/sklearn/utils/class_weight.py", line 41, in compute_class_weight
if set(y) - set(classes):
  TypeError: unhashable type: 'numpy.ndarray'

我知道这是因为我使用数组,已经尝试添加了一些字典,

即:

class_weight_dict = dict(enumerate(np.unique(y_train), class_weight))

嗯,我不知道该怎么做,试过别人的策略,但没有成功...有什么想法吗?

提前感谢!

EN

回答 2

Stack Overflow用户

发布于 2020-12-23 16:32:55

我最近遇到了类似的问题,我正在分享我的思考过程。

如果你的“类不平衡”意味着一些标签组合比其他组合出现得更频繁,例如,有10个0,1,0,0,1,但只有1个0,1,0,0,0,你可以使用compute_sample_weight("balanced", Y_train)而不是compute_class_weight()。如果我是对的,这个函数会给训练数据集中的每个数据一个权重。返回的元组的长度是训练数据集的长度(即输入数据的数量)。此样本权重可以与X_train和y_train一起添加到您的训练集中,作为第三个参数。

如果你的“类不平衡”指的是预测标签中的负比正(大于0比1 ) --这种情况将在训练过程中给出不切实际的高准确度分数,我认为上面@Prateek给出的答案可以是一个解决方案,函数给出的权重是0和1。

有人在这里巧妙地构造了一段Multi-label classification with class weights in Keras代码来回答这个问题。

如果你的“类不平衡”是指某些类比其他类具有更多的外观,例如10个样本中有9个样本包含标签2,但只有1个样本包含标签3,我不知道如何使用class_weight或sample_weight来解决这个问题。也许你可以尝试硬编码并计算每个类的出现次数,然后通过以下公式计算每个类的权重:

# weight_of_class_1 = n_samples/n_class*n_freq_class_1

# n_sample: total number of data
# n_class: number of class
# n_freq_class_1: number of appearance of class 1 in all your labels.

这个公式在compute_class_weight中使用,但我不确定计算出的权重是否适合您的情况。

票数 3
EN

Stack Overflow用户

发布于 2019-05-14 03:28:50

这主要是因为你的Ytrain是一个2D数组而不是1D数组。尝试:

class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(np.ravel(y_train,order='C')),
                                                 np.ravel(y_train,order='C'))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54870799

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档