前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow入门教程(三十四)——常用两类图像分割损失函数

Tensorflow入门教程(三十四)——常用两类图像分割损失函数

作者头像
医学处理分析专家
发布2020-06-29 15:38:48
1.9K0
发布2020-06-29 15:38:48
举报

常用的两类图像分割损失函数有二值交叉熵,dice系数,tversky,FocalLoss等。今天我将在TensorFlow下复现上述损失函数,并进行结果对比。

1、Cross Entropy

交叉熵损失函数是逐像素将类预测值与目标值进行比较,然后再对所有像素求平均值。公式如下所示,其中p是真实类别值,p’是预测属于类别1的概率值。

该函数对每个类别都有相同权重,所以极易受到类别不平衡的影响。

复现代码如下:

代码语言:javascript
复制
def binary_crossentropy(Y_pred, Y_gt):
    epsilon = 1.e-5
    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
    logits = tf.log(Y_pred / (1 - Y_pred))
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)
    loss = tf.reduce_mean(loss)
    return loss

2、Dice loss

Dice loss 是在V-net模型中使用的,一般感兴趣的解剖结构区域占据相对较小的区域,因此加大前景区域的权重,可减少类别不平衡的影响。公式如下所示,其中TP,FP,FN分别是真阳性、假阳性、假阴性的个数。

有的论文中Dice loss计算公式也可以是这样的,如下所示,其中p是真实类别值(0或1),p’是预测类别的概率值(0~1)。

复现代码如下:

代码语言:javascript
复制
def binary_dice(Y_pred, Y_gt):
    smooth = 1.e-5
    smooth_tf = tf.constant(smooth, tf.float32)
    pred_flat = tf.cast(Y_pred, tf.float32)
    true_flat = tf.cast(Y_gt, tf.float32)
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
    denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf
    loss = -tf.reduce_mean(intersection / denominator)
    return loss

3、Tversky loss

Tversky loss是Dice loss的一般表达式,Tversky loss在假阳性、假阴性区域增加了权重因子。公式如下所示,其中p是真实类别值(0或1),p’是预测类别的概率值(0~1)。可以发现,当beta值为0.5时,Tversky loss就是Dice loss了。

复现代码如下:

代码语言:javascript
复制
def binary_tversky(Y_pred, Y_gt, beta):
    smooth = 1.e-5
    smooth_tf = tf.constant(smooth, tf.float32)
    pred_flat = tf.cast(Y_pred, tf.float32)
    true_flat = tf.cast(Y_gt, tf.float32)
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
    intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1)
    denominator = intersection + tf.reduce_sum(beta * pred_flat * (1 - true_flat), axis=-1) + tf.reduce_sum(
        (1 - beta) * true_flat * (1 - pred_flat), axis=-1)
    loss = -tf.reduce_mean((intersection + smooth_tf) / (denominator + smooth_tf))
    return loss

4、Focal loss

Focal loss是对Cross Entropy函数的改进,该函数减少了简单样本损失权重,从而使网络更专注于困难样本损失。公式如下所示,其中p是真实类别值,p’是预测属于类别1的概率值。

复现代码如下:

代码语言:javascript
复制
def binary_focalloss(Y_pred, Y_gt, alpha=0.25, gamma=2.):
    epsilon = 1.e-5
    pt_1 = tf.where(tf.equal(Y_gt, 1), Y_pred, tf.ones_like(Y_pred))
    pt_0 = tf.where(tf.equal(Y_gt, 0), Y_pred, tf.zeros_like(Y_pred))
    # clip to prevent NaN's and Inf's
    pt_1 = tf.clip_by_value(pt_1, epsilon, 1. - epsilon)
    pt_0 = tf.clip_by_value(pt_0, epsilon, 1. - epsilon)
    loss_1 = alpha * tf.pow(1. - pt_1, gamma) * tf.log(pt_1)
    loss_0 = (1 - alpha) * tf.pow(pt_0, gamma) * tf.log(1. - pt_0)
    loss = -tf.reduce_sum(loss_1 + loss_0)
    loss = tf.reduce_mean(loss)
    return loss

5、Cross Entropy+Dice loss

有些文章里结合不同的损失函数来训练网络,腾讯医疗AI实验室发表的论文《AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy》中提出了Dice loss + Focal loss来处理小器官的分割问题。这里复现了Cross Entropy+Dice loss的函数实现,复现代码如下:

代码语言:javascript
复制
def binary_dicePcrossentroy(Y_pred, Y_gt):
    # step 1,calculate binary crossentroy
    epsilon = 1.e-5
    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
    logits = tf.log(Y_pred / (1 - Y_pred))
    loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)
    loss1 = tf.reduce_mean(loss1)
    # step 2,calculate binary dice
    smooth_tf = tf.constant(epsilon, tf.float32)
    pred_flat = tf.cast(Y_pred, tf.float32)
    true_flat = tf.cast(Y_gt, tf.float32)
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])
    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])
    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
    denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf
    loss2 = tf.reduce_mean(1 - intersection / denominator)
    # step 3,calculate all loss
    loss = loss1 + tf.log1p(loss2)
    return loss

对上面五个损失函数进行训练,并在10例测试数据上预测并计算dice值,结果如下。

为了大家更好的学习,我把整个项目代码分享到github上:

https://github.com/junqiangchen/Image-Segmentation-Loss-Functions

如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。如果碰到任何问题,随时留言,我会尽量去回答的。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 最新医学影像技术 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档