前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >26 | 使用PyTorch完成医疗图像识别大项目:分割模型实训

26 | 使用PyTorch完成医疗图像识别大项目:分割模型实训

作者头像
机器学习之禅
发布2022-07-11 15:54:21
7950
发布2022-07-11 15:54:21
举报
文章被收录于专栏:机器学习之禅机器学习之禅

开始训练模型之前,我们需要先把之前的标注文件清理好。如下是原作给出的代码示例。

代码语言:javascript
复制
#在这个引入部分,有一个新的pylidc包需要安装,使用pip安装即可。这个包是专门用来处理LIDC数据集的,现在用的LUNA数据集就是在这个基础上加工的,关于这个包的说明很简单:A library for working with the LIDC dataset.import torchimport SimpleITK as sitkimport pandasimport glob, osimport numpyimport tqdmimport pylidc

安装完之后,首先读取原来的标注文件。这个文件里记录了1000多个结节的坐标和直径信息。

代码语言:javascript
复制
annotations = pandas.read_csv('D:/lunadata/annotations.csv')

然后对我们的数据进行扫描,记录恶性数据,是否有缺失数据等等

代码语言:javascript
复制
malignancy_data = []missing = []spacing_dict = {}scans = {s.series_instance_uid:s for s in pylidc.query(pylidc.Scan).all()}suids = annotations.seriesuid.unique()for suid in tqdm.tqdm(suids):
    fn = glob.glob('D:/lunadata/subset*/{}.mhd'.format(suid))
    if len(fn) == 0 or '*' in fn[0]:
        missing.append(suid)
        continue
    fn = fn[0]
    x = sitk.ReadImage(fn)
    spacing_dict[suid] = x.GetSpacing()
    s = scans[suid]
    for ann_cluster in s.cluster_annotations():
        is_malignant = len([a.malignancy for a in ann_cluster if a.malignancy >= 4])>=2
        centroid = numpy.mean([a.centroid for a in ann_cluster], 0)
        bbox = numpy.mean([a.bbox_matrix() for a in ann_cluster], 0).T
        coord = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in centroid[[1, 0, 2]]])
        bbox_low = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in bbox[0, [1, 0, 2]]])
        bbox_high = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in bbox[1, [1, 0, 2]]])
        malignancy_data.append((suid, coord[0], coord[1], coord[2], bbox_low[0], bbox_low[1], bbox_low[2], bbox_high[0], bbox_high[1], bbox_high[2], is_malignant, [a.malignancy for a in ann_cluster]))

这里能看到处理的进度条。

其中的miss用来记录是否有源文件(mhd文件损坏或者缺失),好在我这里是0缺失的。用原始数据信息去匹配我们读取的数据,

代码语言:javascript
复制
df_mal = pandas.DataFrame(malignancy_data, columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'bboxLowX', 'bboxLowY', 'bboxLowZ', 'bboxHighX', 'bboxHighY', 'bboxHighZ', 'mal_bool', 'mal_details'])processed_annot = []annotations['mal_bool'] = float('nan')annotations['mal_details'] = [[] for _ in annotations.iterrows()]bbox_keys = ['bboxLowX', 'bboxLowY', 'bboxLowZ', 'bboxHighX', 'bboxHighY', 'bboxHighZ']for k in bbox_keys:
    annotations[k] = float('nan')for series_id in tqdm.tqdm(annotations.seriesuid.unique()):
    # series_id = '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860'
    c = candidates[candidates.seriesuid == series_id]
    a = annotations[annotations.seriesuid == series_id]
    m = df_mal[df_mal.seriesuid == series_id]
    if len(m) > 0:
        m_ctrs = m[['coordX', 'coordY', 'coordZ']].values
        a_ctrs = a[['coordX', 'coordY', 'coordZ']].values
        #print(m_ctrs.shape, a_ctrs.shape)
        matches = (numpy.linalg.norm(a_ctrs[:, None] - m_ctrs[None], ord=2, axis=-1) / a.diameter_mm.values[:, None] < 0.5)
        has_match = matches.max(-1)
        match_idx = matches.argmax(-1)[has_match]
        a_matched = a[has_match].copy()
        # c_matched['diameter_mm'] = a.diameter_mm.values[match_idx]
        a_matched['mal_bool'] = m.mal_bool.values[match_idx]
        a_matched['mal_details'] = m.mal_details.values[match_idx]
        for k in bbox_keys:
            a_matched[k] = m[k].values[match_idx]
        processed_annot.append(a_matched)
        processed_annot.append(a[~has_match])
    else:
        processed_annot.append(c)processed_annot = pandas.concat(processed_annot)processed_annot.sort_values('mal_bool', ascending=False, inplace=True)processed_annot['len_mal_details'] = processed_annot.mal_details.apply(len)

我这块的输出显示没有需要丢掉的数据,那看起来LUNA数据集里提供的数据已经更新了。

最后把这个文件保存下来。

代码语言:javascript
复制
df_nona = processed_annot.dropna()df_nona.to_csv('./data/part2/luna/annotations_with_malignancy.csv', index=False)

已经生成新的标注数据。

下面开始执行训练代码。首先还是创建缓存,结果这里遇到一个问题,代码接收的参数有问题, 在13章dset.py的49行,isMal_bool = {'False': False, 'True': True}[row[5]] 但实际上我们的文件里这一列存的是0.0和1.0,导致读取异常,把这里改成如下就能正常运行了。 isMal_bool = {'0.0': False, '1.0': True}[row[5]],接着启动缓存建设。

代码语言:javascript
复制
run('test13ch.prepcache.LunaPrepCacheApp')

在shell里面运行训练环节。

代码语言:javascript
复制
> python -m test13ch.training --epoch 20 --augmented final_seg

结果训练了一个epoch就内存溢出了。无奈,把batch size调小一点,从16改成了8个,这次就没问题了,我的设备还是不太行,真想买一台双3090Ti卡的机器。

image.png

这次看看效果。这里列出了第1,5,10,15,20个epoch的结果,可看到第1个epoch不管在训练集还是验证集的精确度很低,召回率还可以,在验证集上的fp(假阳性)达到了2442.7%,这主要是因为训练集使用的是裁剪后的小图片,而验证集使用的是完整的CT切片数据,所以假阳性很高也正常,多给出一些结果再让医生去看总比漏掉要好的多。

到了第5个epoch,精确度有所提升,训练集的f1达到0.71了

到了第10个epoch,又提升了一点点,但是验证集上给出的tp有些下降。

到15个epoch,在训练集上的效果持续提升,但是在验证集上的效果下降明显,tp值以及到了79%,说明这个时候已经出现了过拟合现象。

下面去TensorBoard上去看看效果。蓝色是训练集,红色是验证集,首先是损失情况,在训练集上前期损失下降比较快,后面就比较平缓,在验证集上的损失变化不大。

然后是fn,fp,tp指

最后是f1 score,精确度,召回率,可以看到在经过了几个epoch之后验证集的召回率开始下降,出现了过拟合现象。

最后看一看导入TensorBoard的图像效果。带有label_x的表示这是一个标注图像,上面没有颜色的表名这个图像上都是无标注的,在对应的预测结果上,有一些橙色结果是假阳性预测,对于下面带绿色就是阳性标注及阳性预测结果。

看起来效果还不错,我们的这个模型就先训到这里,基本上可以满足我们的需求了。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习之禅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档