常用的两类图像分割损失函数有二值交叉熵,dice系数,tversky,FocalLoss等。今天我将在TensorFlow下复现上述损失函数,并进行结果对比。
1、Cross Entropy
交叉熵损失函数是逐像素将类预测值与目标值进行比较,然后再对所有像素求平均值。公式如下所示,其中p是真实类别值,p’是预测属于类别1的概率值。
该函数对每个类别都有相同权重,所以极易受到类别不平衡的影响。
复现代码如下:
def binary_crossentropy(Y_pred, Y_gt):
epsilon = 1.e-5
Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
logits = tf.log(Y_pred / (1 - Y_pred))
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)
loss = tf.reduce_mean(loss)
return loss
2、Dice loss
Dice loss 是在V-net模型中使用的,一般感兴趣的解剖结构区域占据相对较小的区域,因此加大前景区域的权重,可减少类别不平衡的影响。公式如下所示,其中TP,FP,FN分别是真阳性、假阳性、假阴性的个数。
有的论文中Dice loss计算公式也可以是这样的,如下所示,其中p是真实类别值(0或1),p’是预测类别的概率值(0~1)。
复现代码如下:
def binary_dice(Y_pred, Y_gt):
smooth = 1.e-5
smooth_tf = tf.constant(smooth, tf.float32)
pred_flat = tf.cast(Y_pred, tf.float32)
true_flat = tf.cast(Y_gt, tf.float32)
Z, H, W, C = Y_gt.get_shape().as_list()[1:]
pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf
loss = -tf.reduce_mean(intersection / denominator)
return loss
3、Tversky loss
Tversky loss是Dice loss的一般表达式,Tversky loss在假阳性、假阴性区域增加了权重因子。公式如下所示,其中p是真实类别值(0或1),p’是预测类别的概率值(0~1)。可以发现,当beta值为0.5时,Tversky loss就是Dice loss了。
复现代码如下:
def binary_tversky(Y_pred, Y_gt, beta):
smooth = 1.e-5
smooth_tf = tf.constant(smooth, tf.float32)
pred_flat = tf.cast(Y_pred, tf.float32)
true_flat = tf.cast(Y_gt, tf.float32)
Z, H, W, C = Y_gt.get_shape().as_list()[1:]
pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1)
denominator = intersection + tf.reduce_sum(beta * pred_flat * (1 - true_flat), axis=-1) + tf.reduce_sum(
(1 - beta) * true_flat * (1 - pred_flat), axis=-1)
loss = -tf.reduce_mean((intersection + smooth_tf) / (denominator + smooth_tf))
return loss
4、Focal loss
Focal loss是对Cross Entropy函数的改进,该函数减少了简单样本损失权重,从而使网络更专注于困难样本损失。公式如下所示,其中p是真实类别值,p’是预测属于类别1的概率值。
复现代码如下:
def binary_focalloss(Y_pred, Y_gt, alpha=0.25, gamma=2.):
epsilon = 1.e-5
pt_1 = tf.where(tf.equal(Y_gt, 1), Y_pred, tf.ones_like(Y_pred))
pt_0 = tf.where(tf.equal(Y_gt, 0), Y_pred, tf.zeros_like(Y_pred))
# clip to prevent NaN's and Inf's
pt_1 = tf.clip_by_value(pt_1, epsilon, 1. - epsilon)
pt_0 = tf.clip_by_value(pt_0, epsilon, 1. - epsilon)
loss_1 = alpha * tf.pow(1. - pt_1, gamma) * tf.log(pt_1)
loss_0 = (1 - alpha) * tf.pow(pt_0, gamma) * tf.log(1. - pt_0)
loss = -tf.reduce_sum(loss_1 + loss_0)
loss = tf.reduce_mean(loss)
return loss
5、Cross Entropy+Dice loss
有些文章里结合不同的损失函数来训练网络,腾讯医疗AI实验室发表的论文《AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy》中提出了Dice loss + Focal loss来处理小器官的分割问题。这里复现了Cross Entropy+Dice loss的函数实现,复现代码如下:
def binary_dicePcrossentroy(Y_pred, Y_gt):
# step 1,calculate binary crossentroy
epsilon = 1.e-5
Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
logits = tf.log(Y_pred / (1 - Y_pred))
loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)
loss1 = tf.reduce_mean(loss1)
# step 2,calculate binary dice
smooth_tf = tf.constant(epsilon, tf.float32)
pred_flat = tf.cast(Y_pred, tf.float32)
true_flat = tf.cast(Y_gt, tf.float32)
Z, H, W, C = Y_gt.get_shape().as_list()[1:]
pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf
loss2 = tf.reduce_mean(1 - intersection / denominator)
# step 3,calculate all loss
loss = loss1 + tf.log1p(loss2)
return loss
对上面五个损失函数进行训练,并在10例测试数据上预测并计算dice值,结果如下。
为了大家更好的学习,我把整个项目代码分享到github上:
https://github.com/junqiangchen/Image-Segmentation-Loss-Functions
如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。如果碰到任何问题,随时留言,我会尽量去回答的。