首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow中的平衡精度分数

Tensorflow中的平衡精度分数
EN

Stack Overflow用户
提问于 2019-12-15 05:59:50
回答 4查看 3.5K关注 0票数 6

我正在为一个高度不平衡的分类问题实现一个CNN,我想在tensorflow中实现客户指标,以使用Select Best Model回调。具体地说,我想实现平衡的准确率分数,这是每个类的召回率的平均值(参见sklearn implementation here),有人知道如何做到这一点吗?

EN

回答 4

Stack Overflow用户

发布于 2020-01-28 14:59:43

我也面临着同样的问题,所以我实现了一个基于SparseCategoricalAccuracy的自定义类

代码语言:javascript
运行
复制
class BalancedSparseCategoricalAccuracy(keras.metrics.SparseCategoricalAccuracy):
    def __init__(self, name='balanced_sparse_categorical_accuracy', dtype=None):
        super().__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_flat = y_true
        if y_true.shape.ndims == y_pred.shape.ndims:
            y_flat = tf.squeeze(y_flat, axis=[-1])
        y_true_int = tf.cast(y_flat, tf.int32)

        cls_counts = tf.math.bincount(y_true_int)
        cls_counts = tf.math.reciprocal_no_nan(tf.cast(cls_counts, self.dtype))
        weight = tf.gather(cls_counts, y_true_int)
        return super().update_state(y_true, y_pred, sample_weight=weight)

其思想是将每个类的权重设置为与其大小成反比。

这段代码从Autograph中产生了一些警告,但我认为这些都是Autograph的错误,而且指标似乎工作得很好。

票数 5
EN

Stack Overflow用户

发布于 2019-12-18 13:58:56

我可以想到3种方法来解决这种情况:

1)随机欠采样-在此方法中,您可以从多数类中随机删除样本。

2)随机过采样-在此方法中,您可以通过复制样本来增加样本。

3)加权交叉熵-您也可以使用加权交叉熵,以便可以补偿少数类的损失值。See here

我亲自尝试过method2,它确实显著提高了我的准确率,但不同的数据集可能会有所不同

票数 3
EN

Stack Overflow用户

发布于 2020-03-17 04:44:16

我还没有测试过这段代码,,但是看看tensorflow==2.1.0source code,这可能适用于二进制分类的情况:

代码语言:javascript
运行
复制
from tensorflow.keras.metrics import Recall
from tensorflow.python.ops import math_ops


class BalancedBinaryAccuracy(Recall):
    def result(self):
        result = (math_ops.div_no_nan(self.true_positives, self.true_positives + self.false_negatives) +
                  math_ops.div_no_nan(self.true_negatives, self.true_negatives + self.false_positives)) / 2
        return result[0] if len(self.thresholds) == 1 else result
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59339531

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档