前面的文章中已经介绍了两类图像分割损失函数,今天将分享常用的多类图像分割损失函数有多类交叉熵,加权多类交叉熵,多类dice系数,多类FocalLoss等,并给出在TensorFlow下复现上述损失函数代码。
1、Cross Entropy
交叉熵损失函数是每个像素将类预测值与目标值进行比较,然后再对所有像素求平均值。公式如下所示,其中p是真实类别值,p’是预测类别概率值。
该函数对每个类别都有相同权重,所以极易受到类别不平衡的影响。
复现代码如下:
def categorical_crossentropy(Y_pred, Y_gt):
"""
Categorical crossentropy between an output and a target
loss=-y*log(y')
:param Y_pred: A tensor resulting from a softmax
:param Y_gt: A tensor of the same shape as `output`
:return:categorical_crossentropy loss
"""
epsilon = 1.e-5
# scale preds so that the class probas of each sample sum to 1
output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
# manual computation of crossentropy
output = tf.clip_by_value(output, epsilon, 1. - epsilon)
loss = -Y_gt * tf.log(output)
loss = tf.reduce_sum(loss, axis=(1, 2, 3))
loss = tf.reduce_mean(loss, axis=0)
loss = tf.reduce_mean(loss)
return loss
2、Weighted Cross Entropy
加权交叉熵损失函数是对交叉熵损失函数的每个类别加入了权重因子,可以有效解决类别不平衡影响。公式如下所示,其中p是真实类别值,p’是预测类别概率值,W是一维向量值,大小与类别个数相等。
复现代码如下:
def weighted_categorical_crossentropy(Y_pred, Y_gt, weights):
"""
weighted_categorical_crossentropy between an output and a target
loss=-weight*y*log(y')
:param Y_pred:A tensor resulting from a softmax
:param Y_gt:A tensor of the same shape as `output`
:param weights:numpy array of shape (C,) where C is the number of classes
:return:categorical_crossentropy loss
Usage:
weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x.
"""
weights = np.array(weights)
epsilon = 1.e-5
# scale preds so that the class probas of each sample sum to 1
output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
# manual computation of crossentropy
output = tf.clip_by_value(output, epsilon, 1. - epsilon)
loss = - Y_gt * tf.log(output)
loss = tf.reduce_sum(loss, axis=(1, 2, 3))
loss = tf.reduce_mean(loss, axis=0)
loss = tf.reduce_mean(weights * loss)
return loss
3、Dice loss
Dice loss 是在V-net模型中使用的,一般感兴趣的解剖结构区域占据相对较小的区域,因此加大前景区域的权重,可减少类别不平衡的影响。公式如下所示,其中TP,FP,FN分别是真阳性、假阳性、假阴性的个数。
复现代码如下:
def categorical_dice(Y_pred, Y_gt, weight_loss):
"""
multi label dice loss with weighted
WDL=1-2*(sum(w*sum(r&p))/sum((w*sum(r+p)))),w=array of shape (C,)
:param Y_pred: [None, self.image_depth, self.image_height, self.image_width,
self.numclass],Y_pred is softmax result
:param Y_gt:[None, self.image_depth, self.image_height, self.image_width,
self.numclass],Y_gt is one hot result
:param weight_loss: numpy array of shape (C,) where C is the number of classes
:return:
"""
weight_loss = np.array(weight_loss)
smooth = 1.e-5
smooth_tf = tf.constant(smooth, tf.float32)
Y_pred = tf.cast(Y_pred, tf.float32)
Y_gt = tf.cast(Y_gt, tf.float32)
# Compute gen dice coef:
numerator = Y_gt * Y_pred
numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
denominator = Y_gt + Y_pred
denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
gen_dice_coef = tf.reduce_mean(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
loss = -tf.reduce_mean(weight_loss * gen_dice_coef)
return loss
4、Focal loss
Focal loss是对Cross Entropy函数的改进,该函数减少了简单样本损失权重,从而使网络更专注于困难样本损失。公式如下所示,其中p是真实类别值,p’是预测类别的概率值,a是类别权重值,r是样本分类难易因子。
复现代码如下:
def categorical_focal_loss(Y_pred, Y_gt, gamma, alpha):
"""
Categorical focal_loss between an output and a target
:param Y_pred: A tensor of the same shape as `y_pred`
:param Y_gt: A tensor resulting from a softmax(-1,z,h,w,numclass)
:param alpha: Sample category weight,which is shape (C,) where C is the number of classes
:param gamma: Difficult sample weight
:return:
"""
weight_loss = np.array(alpha)
epsilon = 1.e-5
# Scale predictions so that the class probas of each sample sum to 1
output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keepdims=True)
# Clip the prediction value to prevent NaN's and Inf's
output = tf.clip_by_value(output, epsilon, 1. - epsilon)
# Calculate Cross Entropy
cross_entropy = -Y_gt * tf.log(output)
# Calculate Focal Loss
loss = tf.pow(1 - output, gamma) * cross_entropy
loss = tf.reduce_sum(loss, axis=(1, 2, 3))
loss = tf.reduce_mean(loss, axis=0)
loss = tf.reduce_mean(weight_loss * loss)
return loss
5、Cross Entropy+Dice loss
有些文章里结合不同的损失函数来训练网络,这里复现了Cross Entropy+Dice loss的函数实现,复现代码如下:
def categorical_dicePcrossentroy(Y_pred, Y_gt, weight, lamda=0.5):
"""
hybrid loss function from dice loss and crossentroy
loss=Ldice+lamda*Lfocalloss
:param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
:param Y_gt: A tensor of the same shape as `y_pred`
:param gamma:Difficult sample weight
:param alpha:Sample category weight,which is shape (C,) where C is the number of classes
:param lamda:trade-off between dice loss and focal loss,can set 0.1,0.5,1
:return:diceplusfocalloss
"""
weight_loss = np.array(weight)
smooth = 1.e-5
smooth_tf = tf.constant(smooth, tf.float32)
Y_pred = tf.cast(Y_pred, tf.float32)
Y_gt = tf.cast(Y_gt, tf.float32)
# Compute gen dice coef:
numerator = Y_gt * Y_pred
numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
denominator = Y_gt + Y_pred
denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
gen_dice_coef = tf.reduce_sum(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
loss1 = tf.reduce_mean(weight_loss * gen_dice_coef)
epsilon = 1.e-5
# scale preds so that the class probas of each sample sum to 1
output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
# manual computation of crossentropy
output = tf.clip_by_value(output, epsilon, 1. - epsilon)
loss = -Y_gt * tf.log(output)
loss = tf.reduce_mean(loss, axis=(1, 2, 3))
loss = tf.reduce_mean(loss, axis=0)
loss2 = tf.reduce_mean(weight_loss * loss)
total_loss = (1 - lamda) * (1 - loss1) + lamda * loss2
return total_loss
6、Cross Entropy+Focal loss
腾讯医疗AI实验室发表的论文《AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy》中提出了Dice loss + Focal loss来处理小器官的分割问题。复现代码如下:
def categorical_dicePfocalloss(Y_pred, Y_gt, alpha, lamda=0.5, gamma=2.):
"""
hybrid loss function from dice loss and focalloss
loss=Ldice+lamda*Lfocalloss
:param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
:param Y_gt: A tensor of the same shape as `y_pred`
:param gamma:Difficult sample weight
:param alpha:Sample category weight,which is shape (C,) where C is the number of classes
:param lamda:trade-off between dice loss and focal loss,can set 0.1,0.5,1
:return:dicePfocalloss
"""
weight_loss = np.array(alpha)
smooth = 1.e-5
smooth_tf = tf.constant(smooth, tf.float32)
Y_pred = tf.cast(Y_pred, tf.float32)
Y_gt = tf.cast(Y_gt, tf.float32)
# Compute gen dice coef:
numerator = Y_gt * Y_pred
numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
denominator = Y_gt + Y_pred
denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
gen_dice_coef = tf.reduce_sum(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
loss1 = tf.reduce_mean(weight_loss * gen_dice_coef)
epsilon = 1.e-5
# Scale predictions so that the class probas of each sample sum to 1
output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keepdims=True)
# Clip the prediction value to prevent NaN's and Inf's
output = tf.clip_by_value(output, epsilon, 1. - epsilon)
# Calculate Cross Entropy
cross_entropy = -Y_gt * tf.log(output)
# Calculate Focal Loss
loss = tf.pow(1 - output, gamma) * cross_entropy
loss = tf.reduce_mean(loss, axis=(1, 2, 3))
loss = tf.reduce_mean(loss, axis=0)
loss2 = tf.reduce_mean(weight_loss * loss)
total_loss = (1 - lamda) * (1 - loss1) + lamda * loss2
return total_loss
为了大家更好的学习,我把整个项目代码分享到github上:
https://github.com/junqiangchen/Image-Segmentation-Loss-Functions
如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。如果碰到任何问题,随时留言,我会尽量去回答的。