我有一个多输出(200)二进制分类模型,它是我用keras编写的。
在这个模型中,我想添加额外的指标,比如ROC和AUC,但在我的知识内核中没有内置的ROC和AUC指标函数。
我试着从scikit-learn导入ROC,AUC函数
from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense
.
.
.
model.add(Dense(200, activation='relu'))
model.add(Dense(300, activation='relu'))
model.add(Dense(400, activation='relu'))
model.add(Dense(300, activation='relu'))
model.add(Dense(200,init='normal', activation='softmax')) #outputlayer
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy','roc_curve','auc'])
但是它给出了这个错误:
Exception: Invalid metric: roc_curve
如何将ROC,AUC添加到keras中?
发布于 2018-08-08 04:33:59
和你一样,我更喜欢使用scikit-learn的内置方法来评估AUROC。我发现在keras中实现这一点的最好、最简单的方法是创建一个自定义指标。如果tensorflow是您的后端,则只需几行代码即可实现:
import tensorflow as tf
from sklearn.metrics import roc_auc_score
def auroc(y_true, y_pred):
return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])
创建其他答案中提到的自定义回调不会对您的情况起作用,因为您的模型有多个输出,但这会起作用。此外,这种方法允许在训练和验证数据上评估指标,而keras回调不能访问训练数据,因此只能用于评估训练数据的性能。
发布于 2018-07-20 14:55:50
下面的解决方案对我很有效:
import tensorflow as tf
from keras import backend as K
def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
K.get_session().run(tf.local_variables_initializer())
return auc
model.compile(loss="binary_crossentropy", optimizer='adam', metrics=[auc])
发布于 2020-01-09 03:42:03
您可以通过以下方式提供指标来监控训练期间的auc:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
model = keras.Sequential([
keras.layers.Dense(16, activation='relu', input_shape=(train_features.shape[-1],)),
keras.layers.Dense(1, activation='sigmoid'),
])
model.compile(
optimizer=keras.optimizers.Adam(lr=1e-3)
loss=keras.losses.BinaryCrossentropy(),
metrics=METRICS)
有关更详细的教程,请参阅:
https://www.tensorflow.org/tutorials/structured_data/imbalanced_data
https://stackoverflow.com/questions/41032551
复制相似问题