专栏首页cnblogs-技术博客tensorflow2.0 评估函数

tensorflow2.0 评估函数

一,常用的内置评估指标

  • MeanSquaredError(平方差误差,用于回归,可以简写为MSE,函数形式为mse)
  • MeanAbsoluteError (绝对值误差,用于回归,可以简写为MAE,函数形式为mae)
  • MeanAbsolutePercentageError (平均百分比误差,用于回归,可以简写为MAPE,函数形式为mape)
  • RootMeanSquaredError (均方根误差,用于回归)
  • Accuracy (准确率,用于分类,可以用字符串"Accuracy"表示,Accuracy=(TP+TN)/(TP+TN+FP+FN),要求y_true和y_pred都为类别序号编码)
  • Precision (精确率,用于二分类,Precision = TP/(TP+FP))
  • Recall (召回率,用于二分类,Recall = TP/(TP+FN))
  • TruePositives (真正例,用于二分类)
  • TrueNegatives (真负例,用于二分类)
  • FalsePositives (假正例,用于二分类)
  • FalseNegatives (假负例,用于二分类)
  • AUC(ROC曲线(TPR vs FPR)下的面积,用于二分类,直观解释为随机抽取一个正样本和一个负样本,正样本的预测值大于负样本的概率)
  • CategoricalAccuracy(分类准确率,与Accuracy含义相同,要求y_true(label)为onehot编码形式)
  • SparseCategoricalAccuracy (稀疏分类准确率,与Accuracy含义相同,要求y_true(label)为序号编码形式)
  • MeanIoU (Intersection-Over-Union,常用于图像分割)
  • TopKCategoricalAccuracy (多分类TopK准确率,要求y_true(label)为onehot编码形式)
  • SparseTopKCategoricalAccuracy (稀疏多分类TopK准确率,要求y_true(label)为序号编码形式)
  • Mean (平均值)
  • Sum (求和)
  • https://tensorflow.google.cn/api_docs/python/tf/keras/metrics

二,自定义品函数及使用

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics
 
# 函数形式的自定义评估指标
@tf.function
def ks(y_true,y_pred):
    y_true = tf.reshape(y_true,(-1,))
    y_pred = tf.reshape(y_pred,(-1,))
    length = tf.shape(y_true)[0]
    t = tf.math.top_k(y_pred,k = length,sorted = False)
    y_pred_sorted = tf.gather(y_pred,t.indices)
    y_true_sorted = tf.gather(y_true,t.indices)
    cum_positive_ratio = tf.truediv(
        tf.cumsum(y_true_sorted),tf.reduce_sum(y_true_sorted))
    cum_negative_ratio = tf.truediv(
        tf.cumsum(1 - y_true_sorted),tf.reduce_sum(1 - y_true_sorted))
    ks_value = tf.reduce_max(tf.abs(cum_positive_ratio - cum_negative_ratio)) 
    return ks_value
y_true = tf.constant([[1],[1],[1],[0],[1],[1],[1],[0],[0],[0],[1],[0],[1],[0]])
y_pred = tf.constant([[0.6],[0.1],[0.4],[0.5],[0.7],[0.7],[0.7],
                      [0.4],[0.4],[0.5],[0.8],[0.3],[0.5],[0.3]])
tf.print(ks(y_true,y_pred))
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.Adam(lr=0.001),
    metrics=[keras.metrics.MeanIoU(num_classes=2),ks]
    
)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • jdbc

    jdbc.driver=com.mysql.jdbc.Driver jdbc.url=jdbc:mysql://localhost:3306/dangdang ...

    Dean0731
  • 深度学习_1_神经网络_3_验证码识别

    ​ N ------>[0.01,0.02,0.03.......] 概率 N------->[0,0,0,0,1.......] one-hot编码

    Dean0731
  • 深度学习_1_神经网络_2_深度神经网络

    ​ sigmod f(x)=1/(1+e^(-x)) 计算量大,反向传播时 容易出现梯度爆炸

    Dean0731
  • Github 项目 - YOLOV3 的 TensorFlow 复现

    原文:Github 项目 - YOLOV3 的 TensorFlow 复现 - AIUAI

    AIHGF
  • supervisor监控业务程序(2)

    不能用shutdown.sh 和startup.sh来进行启动控制,需要使用catalina.sh run这种方式来进行启动,配置完成后重启即可

    dogfei
  • 微信启动画面的是怎么拍出来的?

    微信的启动画面:一个站在巨大星球下的孤独小人的背影,深深传递着与人沟通的渴望。画面地球原图为阿波罗 17 号太空船船员所拍摄的著名地球照片《蓝色弹珠》,不是在月...

    智能算法
  • 听说这款 IonVR 头盔和 iPhone 更配哦!

    镁客网
  • 神经网络训练细节part1(上)

    听城
  • 王鹏威:投资人愿意为什么内容埋单

    2016年1月18日,第19期互联网前沿沙龙上,国润投资基金董事长兼主管合伙人王鹏威介绍了国润投资基金投资媒体的策略。 ?   王鹏威提出,新媒体投资主要...

    腾讯研究院
  • Log4j官方文档翻译(五、日志输出的方法)

    日志类提供了很多方法用于处理日志活动,它不允许我们自己实例化一个logger,但是提供给我们两种静态方法获得logger对象: public static Lo...

    用户1154259

扫码关注云+社区

领取腾讯云代金券