前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >25 | 使用PyTorch完成医疗图像识别大项目:分割模型实现

25 | 使用PyTorch完成医疗图像识别大项目:分割模型实现

作者头像
机器学习之禅
发布2022-07-11 15:53:22
6860
发布2022-07-11 15:53:22
举报
文章被收录于专栏:机器学习之禅机器学习之禅

前面已经把分割模型的数据处理的差不多了,最后再加一点点关于数据增强的事情,我们就可以开始训练模型了。

常见的瓶颈

在搞机器学习项目的时候,总会有各种各样的瓶颈问题,比如IO问题,内存问题,GPU问题等等。因为我们的设备总会有一个短板的地方。

  • 1.数据加载环节,数据的大量IO(读写)可能会比较慢。
  • 2.使用CPU进行数据预处理环节可能出现瓶颈,通常来说是进行正则化和数据增强的时候。
  • 3.在模型训练的时候GPU可能是最大的瓶颈,如果说一定存在瓶颈那么我们希望是在GPU这块,因为GPU是最贵的。
  • 4.在CPU和GPU直接传输数据的带宽会影响GPU的运算。 这里我们要处理的就是在数据增强环节使用GPU。这里所做的数据增强方式跟之前一模一样,只不过这次通过类似于模型的方式实现,我们把这些步骤放在前向传播的方法里面,把它变成一个看起来像模型训练的过程,
代码语言:javascript
复制
class SegmentationAugmentation(nn.Module):
    def __init__(
            self, flip=None, offset=None, scale=None, rotate=None, noise=None
    ):
        super().__init__()

        self.flip = flip        self.offset = offset        self.scale = scale        self.rotate = rotate        self.noise = noise    def forward(self, input_g, label_g):#这里是获取变换方法
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)#因为GPU适合处理浮点数,这里传入GPU的同时转换成浮点数
        transform_t = transform_t.to(input_g.device, torch.float32)#affine_grid和grid_sample就是实现变换和重新采样(生成新图像)的方法
        affine_t = F.affine_grid(transform_t[:,:2],
                input_g.size(), align_corners=False)

        augmented_input_g = F.grid_sample(input_g,
                affine_t, padding_mode='border',
                align_corners=False)#这里同时在掩码操作
        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
                affine_t, padding_mode='border',
                align_corners=False)#最后是增加噪声
        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g += noise_t        return augmented_input_g, augmented_label_g > 0.5

    def _build2dTransformMatrix(self):
        transform_t = torch.eye(3)

        for i in range(2):
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i,i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = (random.random() * 2 - 1)
                transform_t[2,i] = offset_float * random_float            if self.scale:
                scale_float = self.scale
                random_float = (random.random() * 2 - 1)
                transform_t[i,i] *= 1.0 + scale_float * random_float        if self.rotate:
            angle_rad = random.random() * math.pi * 2
            s = math.sin(angle_rad)
            c = math.cos(angle_rad)

            rotation_t = torch.tensor([
                [c, -s, 0],
                [s, c, 0],
                [0, 0, 1]])

            transform_t @= rotation_t        return transform_t

接下来就是实现training环节。我们先把内部的一些方法写好。第一个是给模型进行初始化。

代码语言:javascript
复制
    def initModel(self):#使用我们封装的UNet模型
        segmentation_model = UNetWrapper(
            in_channels=7,
            n_classes=1,
            depth=3,
            wf=4,
            padding=True,
            batch_norm=True,
            up_mode='upconv',
        )#数据增强模型,实际上并不是一个真的模型
        augmentation_model = SegmentationAugmentation(**self.augmentation_dict)#设置使用GPU,甚至是GPU并行运算
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                segmentation_model = nn.DataParallel(segmentation_model)
                augmentation_model = nn.DataParallel(augmentation_model)#把模型传入GPU
            segmentation_model = segmentation_model.to(self.device)
            augmentation_model = augmentation_model.to(self.device)#返回模型实例
        return segmentation_model, augmentation_model

第二个要定义的是优化器。在这里使用Adam优化器。Adam有很多的优点,比如说不太需要我们去调整参数,它会为每个参数维护一个单独的学习率,并且可以根据训练的进行自动更新学习率。这个只需要一行调用就可以实现,如果你想了解Adam的细节,可以点进去研究一下它的源代码。

代码语言:javascript
复制
    def initOptimizer(self):
        return Adam(self.segmentation_model.parameters())

第三个是定义损失函数。这块我们又要换一个新的损失计算方法了。前面我们已经学过L1损失,L2损失,交叉熵损失,现在新加一个骰子损失(Dice Loss)。它的计算逻辑也不难理解,是按照实际的图像面积和预测出来的图像面积进行比较的,这是在图像分割领域常用的损失计算方法。看下面这张图,考虑实际的图像是圆圈内的图像,预测的图像是方框区域的图像,其中阴影部分就是预测命中的部分,而dice系数的计算就是阴影面积的二倍除方框加圆圈的面积。可以想象,当预测完全准确的时候这个系数计算出来是1.0,所以我们使用1-dice系数作为损失,因为我们期望预测越准确的时候损失越小。

image.png

代码语言:javascript
复制
    def diceLoss(self, prediction_g, label_g, epsilon=1):
        diceLabel_g = label_g.sum(dim=[1,2,3])
        dicePrediction_g = prediction_g.sum(dim=[1,2,3])
        diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])

        diceRatio_g = (2 * diceCorrect_g + epsilon) \            / (dicePrediction_g + diceLabel_g + epsilon)

        return 1 - diceRatio_g

计算批量损失。

代码语言:javascript
复制
    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
                         classificationThreshold=0.5):
        input_t, label_t, series_list, _slice_ndx_list = batch_tup#数据传入GPU
        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)#判断是否需要增强数据,训练时候需要,验证时候不需要
        if self.segmentation_model.training and self.augmentation_dict:
            input_g, label_g = self.augmentation_model(input_g, label_g)#运行分割模型
        prediction_g = self.segmentation_model(input_g)#计算损失
        diceLoss_g = self.diceLoss(prediction_g, label_g)#这个fnLoss使用的是prediction_g * label_g输入,也就是只保留了预测正确的那一部分,用于后面我们对损失进行加权
        fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)#结果指标存储,批数据的起始位置和终止位置
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + input_t.size(0)

        with torch.no_grad():
            predictionBool_g = (prediction_g[:, 0:1]
                                > classificationThreshold).to(torch.float32)#计算真阳性,假阴性,假阳性数目
            tp = (     predictionBool_g *  label_g).sum(dim=[1,2,3])
            fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
            fp = (     predictionBool_g * (~label_g)).sum(dim=[1,2,3])

            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
            metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
            metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
            metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp#这个地方进行了损失加权,这里×了8,表示正向的像素重要性比负向像素高8倍,用来增强我们把图像分割出结节的情况,因为我们希望能更多的找到结节,所以哪怕召回多一些也没关系,总比丢掉了一部分要好。
        return diceLoss_g.mean() + fnLoss_g.mean() * 8

再往下,我们研究把图像导入TensorBoard,以便我们能够在TensorBoard上显性地观察模型效果。做图像任务的好处就是比较容易观察中间结果,其实我自己做NLP比较多,中间结果输出出来也看不出什么效果。

图像记录方法。

代码语言:javascript
复制
    def logImages(self, epoch_ndx, mode_str, dl):#把模型设置为eval模式
        self.segmentation_model.eval()#获取12个CT
        images = sorted(dl.dataset.series_list)[:12]
        for series_ndx, series_uid in enumerate(images):
            ct = getCt(series_uid)#取出6个切片
            for slice_ndx in range(6):
                ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
                sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)

                ct_t, label_t, series_uid, ct_ndx = sample_tup

                input_g = ct_t.to(self.device).unsqueeze(0)
                label_g = pos_g = label_t.to(self.device).unsqueeze(0)

                prediction_g = self.segmentation_model(input_g)[0]
                prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                label_a = label_g.cpu().numpy()[0][0] > 0.5

                ct_t[:-1,:,:] /= 2000
                ct_t[:-1,:,:] += 0.5

                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()

                image_a = np.zeros((512, 512, 3), dtype=np.float32)
                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                image_a[:,:,0] += prediction_a & (1 - label_a) #把假阳性区域标记成红色
                image_a[:,:,0] += (1 - prediction_a) & label_a #假阴性标记为橙色
                image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5 

                image_a[:,:,1] += prediction_a & label_a  #真阳性标记为绿色
                image_a *= 0.5
                image_a.clip(0, 1, image_a)

                writer = getattr(self, mode_str + '_writer')
                writer.add_image(
                    f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
                    image_a,
                    self.totalTrainingSamples_count,
                    dataformats='HWC',
                )

                if epoch_ndx == 1:
                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                    # image_a[:,:,0] += (1 - label_a) & lung_a # Red
                    image_a[:,:,1] += label_a  # Green
                    # image_a[:,:,2] += neg_a  # Blue

                    image_a *= 0.5
                    image_a[image_a < 0] = 0
                    image_a[image_a > 1] = 1
                    writer.add_image(
                        '{}/{}_label_{}'.format(
                            mode_str,
                            series_ndx,
                            slice_ndx,
                        ),
                        image_a,
                        self.totalTrainingSamples_count,
                        dataformats='HWC',
                    )
                # This flush prevents TB from getting confused about which
                # data item belongs where.
                writer.flush()

然后在训练的main方法里面调用它。

代码语言:javascript
复制
def main(self):……self.validation_cadence = 5
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
        ……
            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)#记录第一个epoch或者每隔几个周期的时候记录图像信息
            if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
                # if validation is wanted
                valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                best_score = max(score, best_score)

                self.saveModel('seg', epoch_ndx, score == best_score)

                self.logImages(epoch_ndx, 'trn', train_dl)
                self.logImages(epoch_ndx, 'val', val_dl)

下图是书上给出的样例图,每个图上还有滚动条,通过滚动可以查看在不同迭代周期的图像。

image.png

除了记录图像,我们还得把迭代的指标信息也记录下来,这部分跟前面基本一样,就不再过多解释了。 最后,如果我们的模型效果还不错,我们要把它存下来,实际上我们存储的是模型训练好的参数信息。

代码语言:javascript
复制
    def saveModel(self, type_str, epoch_ndx, isBest=False):#存储文件路径信息
        file_path = os.path.join(
            'data-unversioned',
            'part2',
            'models',
            self.cli_args.tb_prefix,
            '{}_{}_{}.{}.state'.format(
                type_str,
                self.time_str,
                self.cli_args.comment,
                self.totalTrainingSamples_count,
            )
        )#创建目录
        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)#获取模型
        model = self.segmentation_model        if isinstance(model, torch.nn.DataParallel):
            model = model.module#需要存储的状态信息
        state = {
            'sys_argv': sys.argv,  #系统参数
            'time': str(datetime.datetime.now()), #时间信息
            'model_state': model.state_dict(), #模型状态
            'model_name': type(model).__name__, #模型名称
            'optimizer_state' : self.optimizer.state_dict(), #优化器状态
            'optimizer_name': type(self.optimizer).__name__, #优化器名称
            'epoch': epoch_ndx, #迭代周期
            'totalTrainingSamples_count': self.totalTrainingSamples_count, #训练样本数量
        }#存储,通过存储模型,我们可以在下次接着训练
        torch.save(state, file_path)

        log.info("Saved model params to {}".format(file_path))#这里做一个备份,如果这是效果最好的一版模型,就再存一次,记得多做这种操作,并且文件命名一定要好,具体为什么你自己考虑,说多了都是泪。
        if isBest:
            best_path = os.path.join(
                'data-unversioned', 'part2', 'models',
                self.cli_args.tb_prefix,
                f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
            shutil.copyfile(file_path, best_path)

            log.info("Saved model params to {}".format(best_path))#最后这个hash是用于校验文件的
        with open(file_path, 'rb') as f:
            log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())

接下来就是训练模型然后看结果了。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习之禅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 常见的瓶颈
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档