Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >25 | 使用PyTorch完成医疗图像识别大项目:分割模型实现

25 | 使用PyTorch完成医疗图像识别大项目:分割模型实现

作者头像
机器学习之禅
发布于 2022-07-11 07:53:22
发布于 2022-07-11 07:53:22
78900
代码可运行
举报
文章被收录于专栏:机器学习之禅机器学习之禅
运行总次数:0
代码可运行

前面已经把分割模型的数据处理的差不多了,最后再加一点点关于数据增强的事情,我们就可以开始训练模型了。

常见的瓶颈

在搞机器学习项目的时候,总会有各种各样的瓶颈问题,比如IO问题,内存问题,GPU问题等等。因为我们的设备总会有一个短板的地方。

  • 1.数据加载环节,数据的大量IO(读写)可能会比较慢。
  • 2.使用CPU进行数据预处理环节可能出现瓶颈,通常来说是进行正则化和数据增强的时候。
  • 3.在模型训练的时候GPU可能是最大的瓶颈,如果说一定存在瓶颈那么我们希望是在GPU这块,因为GPU是最贵的。
  • 4.在CPU和GPU直接传输数据的带宽会影响GPU的运算。 这里我们要处理的就是在数据增强环节使用GPU。这里所做的数据增强方式跟之前一模一样,只不过这次通过类似于模型的方式实现,我们把这些步骤放在前向传播的方法里面,把它变成一个看起来像模型训练的过程,
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class SegmentationAugmentation(nn.Module):
    def __init__(
            self, flip=None, offset=None, scale=None, rotate=None, noise=None
    ):
        super().__init__()

        self.flip = flip        self.offset = offset        self.scale = scale        self.rotate = rotate        self.noise = noise    def forward(self, input_g, label_g):#这里是获取变换方法
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)#因为GPU适合处理浮点数,这里传入GPU的同时转换成浮点数
        transform_t = transform_t.to(input_g.device, torch.float32)#affine_grid和grid_sample就是实现变换和重新采样(生成新图像)的方法
        affine_t = F.affine_grid(transform_t[:,:2],
                input_g.size(), align_corners=False)

        augmented_input_g = F.grid_sample(input_g,
                affine_t, padding_mode='border',
                align_corners=False)#这里同时在掩码操作
        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
                affine_t, padding_mode='border',
                align_corners=False)#最后是增加噪声
        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g += noise_t        return augmented_input_g, augmented_label_g > 0.5

    def _build2dTransformMatrix(self):
        transform_t = torch.eye(3)

        for i in range(2):
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i,i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = (random.random() * 2 - 1)
                transform_t[2,i] = offset_float * random_float            if self.scale:
                scale_float = self.scale
                random_float = (random.random() * 2 - 1)
                transform_t[i,i] *= 1.0 + scale_float * random_float        if self.rotate:
            angle_rad = random.random() * math.pi * 2
            s = math.sin(angle_rad)
            c = math.cos(angle_rad)

            rotation_t = torch.tensor([
                [c, -s, 0],
                [s, c, 0],
                [0, 0, 1]])

            transform_t @= rotation_t        return transform_t

接下来就是实现training环节。我们先把内部的一些方法写好。第一个是给模型进行初始化。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def initModel(self):#使用我们封装的UNet模型
        segmentation_model = UNetWrapper(
            in_channels=7,
            n_classes=1,
            depth=3,
            wf=4,
            padding=True,
            batch_norm=True,
            up_mode='upconv',
        )#数据增强模型,实际上并不是一个真的模型
        augmentation_model = SegmentationAugmentation(**self.augmentation_dict)#设置使用GPU,甚至是GPU并行运算
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                segmentation_model = nn.DataParallel(segmentation_model)
                augmentation_model = nn.DataParallel(augmentation_model)#把模型传入GPU
            segmentation_model = segmentation_model.to(self.device)
            augmentation_model = augmentation_model.to(self.device)#返回模型实例
        return segmentation_model, augmentation_model

第二个要定义的是优化器。在这里使用Adam优化器。Adam有很多的优点,比如说不太需要我们去调整参数,它会为每个参数维护一个单独的学习率,并且可以根据训练的进行自动更新学习率。这个只需要一行调用就可以实现,如果你想了解Adam的细节,可以点进去研究一下它的源代码。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def initOptimizer(self):
        return Adam(self.segmentation_model.parameters())

第三个是定义损失函数。这块我们又要换一个新的损失计算方法了。前面我们已经学过L1损失,L2损失,交叉熵损失,现在新加一个骰子损失(Dice Loss)。它的计算逻辑也不难理解,是按照实际的图像面积和预测出来的图像面积进行比较的,这是在图像分割领域常用的损失计算方法。看下面这张图,考虑实际的图像是圆圈内的图像,预测的图像是方框区域的图像,其中阴影部分就是预测命中的部分,而dice系数的计算就是阴影面积的二倍除方框加圆圈的面积。可以想象,当预测完全准确的时候这个系数计算出来是1.0,所以我们使用1-dice系数作为损失,因为我们期望预测越准确的时候损失越小。

image.png

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def diceLoss(self, prediction_g, label_g, epsilon=1):
        diceLabel_g = label_g.sum(dim=[1,2,3])
        dicePrediction_g = prediction_g.sum(dim=[1,2,3])
        diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])

        diceRatio_g = (2 * diceCorrect_g + epsilon) \            / (dicePrediction_g + diceLabel_g + epsilon)

        return 1 - diceRatio_g

计算批量损失。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
                         classificationThreshold=0.5):
        input_t, label_t, series_list, _slice_ndx_list = batch_tup#数据传入GPU
        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)#判断是否需要增强数据,训练时候需要,验证时候不需要
        if self.segmentation_model.training and self.augmentation_dict:
            input_g, label_g = self.augmentation_model(input_g, label_g)#运行分割模型
        prediction_g = self.segmentation_model(input_g)#计算损失
        diceLoss_g = self.diceLoss(prediction_g, label_g)#这个fnLoss使用的是prediction_g * label_g输入,也就是只保留了预测正确的那一部分,用于后面我们对损失进行加权
        fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)#结果指标存储,批数据的起始位置和终止位置
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + input_t.size(0)

        with torch.no_grad():
            predictionBool_g = (prediction_g[:, 0:1]
                                > classificationThreshold).to(torch.float32)#计算真阳性,假阴性,假阳性数目
            tp = (     predictionBool_g *  label_g).sum(dim=[1,2,3])
            fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
            fp = (     predictionBool_g * (~label_g)).sum(dim=[1,2,3])

            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
            metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
            metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
            metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp#这个地方进行了损失加权,这里×了8,表示正向的像素重要性比负向像素高8倍,用来增强我们把图像分割出结节的情况,因为我们希望能更多的找到结节,所以哪怕召回多一些也没关系,总比丢掉了一部分要好。
        return diceLoss_g.mean() + fnLoss_g.mean() * 8

再往下,我们研究把图像导入TensorBoard,以便我们能够在TensorBoard上显性地观察模型效果。做图像任务的好处就是比较容易观察中间结果,其实我自己做NLP比较多,中间结果输出出来也看不出什么效果。

图像记录方法。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def logImages(self, epoch_ndx, mode_str, dl):#把模型设置为eval模式
        self.segmentation_model.eval()#获取12CT
        images = sorted(dl.dataset.series_list)[:12]
        for series_ndx, series_uid in enumerate(images):
            ct = getCt(series_uid)#取出6个切片
            for slice_ndx in range(6):
                ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
                sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)

                ct_t, label_t, series_uid, ct_ndx = sample_tup

                input_g = ct_t.to(self.device).unsqueeze(0)
                label_g = pos_g = label_t.to(self.device).unsqueeze(0)

                prediction_g = self.segmentation_model(input_g)[0]
                prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                label_a = label_g.cpu().numpy()[0][0] > 0.5

                ct_t[:-1,:,:] /= 2000
                ct_t[:-1,:,:] += 0.5

                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()

                image_a = np.zeros((512, 512, 3), dtype=np.float32)
                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                image_a[:,:,0] += prediction_a & (1 - label_a) #把假阳性区域标记成红色
                image_a[:,:,0] += (1 - prediction_a) & label_a #假阴性标记为橙色
                image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5 

                image_a[:,:,1] += prediction_a & label_a  #真阳性标记为绿色
                image_a *= 0.5
                image_a.clip(0, 1, image_a)

                writer = getattr(self, mode_str + '_writer')
                writer.add_image(
                    f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
                    image_a,
                    self.totalTrainingSamples_count,
                    dataformats='HWC',
                )

                if epoch_ndx == 1:
                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                    # image_a[:,:,0] += (1 - label_a) & lung_a # Red
                    image_a[:,:,1] += label_a  # Green
                    # image_a[:,:,2] += neg_a  # Blue

                    image_a *= 0.5
                    image_a[image_a < 0] = 0
                    image_a[image_a > 1] = 1
                    writer.add_image(
                        '{}/{}_label_{}'.format(
                            mode_str,
                            series_ndx,
                            slice_ndx,
                        ),
                        image_a,
                        self.totalTrainingSamples_count,
                        dataformats='HWC',
                    )
                # This flush prevents TB from getting confused about which
                # data item belongs where.
                writer.flush()

然后在训练的main方法里面调用它。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def main(self):……self.validation_cadence = 5
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
        ……
            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)#记录第一个epoch或者每隔几个周期的时候记录图像信息
            if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
                # if validation is wanted
                valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                best_score = max(score, best_score)

                self.saveModel('seg', epoch_ndx, score == best_score)

                self.logImages(epoch_ndx, 'trn', train_dl)
                self.logImages(epoch_ndx, 'val', val_dl)

下图是书上给出的样例图,每个图上还有滚动条,通过滚动可以查看在不同迭代周期的图像。

image.png

除了记录图像,我们还得把迭代的指标信息也记录下来,这部分跟前面基本一样,就不再过多解释了。 最后,如果我们的模型效果还不错,我们要把它存下来,实际上我们存储的是模型训练好的参数信息。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def saveModel(self, type_str, epoch_ndx, isBest=False):#存储文件路径信息
        file_path = os.path.join(
            'data-unversioned',
            'part2',
            'models',
            self.cli_args.tb_prefix,
            '{}_{}_{}.{}.state'.format(
                type_str,
                self.time_str,
                self.cli_args.comment,
                self.totalTrainingSamples_count,
            )
        )#创建目录
        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)#获取模型
        model = self.segmentation_model        if isinstance(model, torch.nn.DataParallel):
            model = model.module#需要存储的状态信息
        state = {
            'sys_argv': sys.argv,  #系统参数
            'time': str(datetime.datetime.now()), #时间信息
            'model_state': model.state_dict(), #模型状态
            'model_name': type(model).__name__, #模型名称
            'optimizer_state' : self.optimizer.state_dict(), #优化器状态
            'optimizer_name': type(self.optimizer).__name__, #优化器名称
            'epoch': epoch_ndx, #迭代周期
            'totalTrainingSamples_count': self.totalTrainingSamples_count, #训练样本数量
        }#存储,通过存储模型,我们可以在下次接着训练
        torch.save(state, file_path)

        log.info("Saved model params to {}".format(file_path))#这里做一个备份,如果这是效果最好的一版模型,就再存一次,记得多做这种操作,并且文件命名一定要好,具体为什么你自己考虑,说多了都是泪。
        if isBest:
            best_path = os.path.join(
                'data-unversioned', 'part2', 'models',
                self.cli_args.tb_prefix,
                f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
            shutil.copyfile(file_path, best_path)

            log.info("Saved model params to {}".format(best_path))#最后这个hash是用于校验文件的
        with open(file_path, 'rb') as f:
            log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())

接下来就是训练模型然后看结果了。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习之禅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
05-PyTorch自定义数据集Datasets、Loader和tranform
对于机器学习中的许多不同问题,我们采取的步骤都是相似的。PyTorch 有许多内置数据集,用于大量机器学习基准测试。除此之外也可以自定义数据集,本问将使用我们自己的披萨、牛排和寿司图像数据集,而不是使用内置的 PyTorch 数据集。具体来说,我们将使用 torchvision.datasets 以及我们自己的自定义 Dataset 类来加载食物图像,然后我们将构建一个 PyTorch 计算机视觉模型,希望对三种物体进行分类。
renhai
2023/11/24
1.1K0
05-PyTorch自定义数据集Datasets、Loader和tranform
20 | 使用PyTorch完成医疗图像识别大项目:编写训练模型代码
在之前的环节,我们已经能够读取数据,并且构建了我们的Dataset类,处理了数据中各种异常情况,并把数据转换成PyTorch可以处理的样子。一般来说,到了这一步就开始训练模型了。先不要考虑模型的效果,也不用做什么优化,先把模型训练跑通,看一下我们的效果,这样这个结果就可以作为baseline,然后再考虑优化的事情,每进行一步优化,就可以看到它对比基线有没有效果上的提升。话不多说,我们这就来搞一个模型。
机器学习之禅
2022/07/11
1.1K0
20 | 使用PyTorch完成医疗图像识别大项目:编写训练模型代码
27 | 使用PyTorch完成医疗图像识别大项目:实现端到端模型方案
接下来需要再做一些工作,并把我们前面搞好的模型串起来,形成一个端到端的解决方案。这个方案如下,首先是从原始的CT数据出发进行图像分割,识别可能是结节的体素,并对这些体素区域进行分组,然后用这些分割出的候选结节信息进行分类,首先是区分这是否是一个结节,针对是结节的,再区分这是否是一个恶性结节,这样就完成了整个模型框架。
机器学习之禅
2022/07/11
1.6K2
27 | 使用PyTorch完成医疗图像识别大项目:实现端到端模型方案
23 | 使用PyTorch完成医疗图像识别大项目:优化数据
上一小节修改了我们的评估指标,然而效果并没有什么变化,甚至连指标都不能正常的输出出来。我们期望的是下面这种样子,安全事件都聚集在左边,危险事件都聚集在右边,中间只有少量的难以判断的事件,这样我们的模型很容易分出来,错误率也会比较低。
机器学习之禅
2022/07/11
8310
23 | 使用PyTorch完成医疗图像识别大项目:优化数据
21 | 使用PyTorch完成医疗图像识别大项目:训练模型
昨天我们已经完成了训练和验证模型的主体代码,在进行训练之前,我们还需要处理一下输出信息。前面我们已经记录了一部分信息到trnMetrics_g和valMetrics_g中,每迭代一个周期,就会输出一次结果方便我们查看。如果发现模型的结果很差,比如说出现了无法收敛的情况,我们就可以中止模型训练,不用再浪费更多时间,因为一个深度模型训练需要花费很长的时间。
机器学习之禅
2022/07/11
7211
21 | 使用PyTorch完成医疗图像识别大项目:训练模型
PyTorch 深度学习(GPT 重译)(五)
上一章的结束让我们陷入了困境。虽然我们能够将深度学习项目的机制放置好,但实际上没有任何结果是有用的;网络只是将一切都分类为非结节!更糟糕的是,结果表面看起来很好,因为我们正在查看训练和验证集中被正确分类的整体百分比。由于我们的数据严重倾向于负样本,盲目地将一切都视为负面是我们的模型快速得分的一种简单而快速的方法。太糟糕了,这样做基本上使模型无用!
ApacheCN_飞龙
2024/03/21
1560
PyTorch 深度学习(GPT 重译)(五)
24 | 使用PyTorch完成医疗图像识别大项目:图像分割数据准备
本周有点丧,前面几天不是忙于面试就是忙于塞尔达炸鱼,一直没更新,好在这周把这本书读完了,今天再更一篇,终于快要结束了。
机器学习之禅
2022/07/11
1.6K0
24 | 使用PyTorch完成医疗图像识别大项目:图像分割数据准备
26 | 使用PyTorch完成医疗图像识别大项目:分割模型实训
安装完之后,首先读取原来的标注文件。这个文件里记录了1000多个结节的坐标和直径信息。
机器学习之禅
2022/07/11
9250
26 | 使用PyTorch完成医疗图像识别大项目:分割模型实训
22 | 使用PyTorch完成医疗图像识别大项目:模型指标
今天又是相对轻松的一节。今天我们来研究一下评估模型的指标问题。前两节我们已经把模型训练完了,并且能够在TensorBoard上面查看我们的迭代效果。但是模型的效果实在是不如人意,哪怕我已经把全部的数据都加进去了,但是模型也只能学会把类别都归为非节点。
机器学习之禅
2022/07/11
8890
22 | 使用PyTorch完成医疗图像识别大项目:模型指标
手把手教你训练自己的Mask R-CNN图像实例分割模型(PyTorch官方教程)
关于Mask R-CNN的详细理论说明,可以参见原作论文https://arxiv.org/abs/1703.06870,网上也有大量解读的文章。本篇博客主要是参考了PyTorch官方给出的训练教程,将如何在自己的数据集上训练Mask R-CNN模型的过程记录下来,希望能为感兴趣的读者提供一些帮助。
全栈程序员站长
2022/09/23
4K1
手把手教你训练自己的Mask R-CNN图像实例分割模型(PyTorch官方教程)
收藏 | PyTorch Cookbook:常用代码段集锦
链接 | https://zhuanlan.zhihu.com/p/59205847
AI算法修炼营
2020/06/03
7300
收藏 | PyTorch Cookbook:常用代码段集锦
语义分割:最简单的代码实现!
分割对于图像解释任务至关重要,那就不要落后于流行趋势,让我们来实施它,我们很快就会成为专业人士!
小白学视觉
2022/02/14
1.2K0
语义分割:最简单的代码实现!
在PyTorch中使用DeepLabv3进行语义分割的迁移学习
当我在使用深度学习进行图像语义分割并想使用PyTorch在DeepLabv3[1]上运行一些实验时,我找不到任何在线教程。并且torchvision不仅没有提供分割数据集,而且也没有关于DeepLabv3类内部结构的详细解释。然而,我是通过自己的研究进行了现有模型的迁移学习,我想分享这个过程,这样可能会对你们有帮助。
deephub
2020/12/24
1.5K0
Transformers 4.37 中文文档(四)
www.youtube-nocookie.com/embed/KWwzcmG98Ds
ApacheCN_飞龙
2024/06/26
4040
Transformers 4.37 中文文档(四)
深度学习黑客竞赛神器:基于PyTorch图像特征工程的深度学习图像增强
在深度学习黑客竞赛中表现出色的技巧(或者坦率地说,是任何数据科学黑客竞赛) 通常归结为特征工程。当您获得的数据不足以建立一个成功的深度学习模型时,你能发挥多少创造力?
磐创AI
2020/06/05
9780
18 | 使用PyTorch完成医疗图像识别大项目:理解数据
上一节我们理解了业务,也就是我们这个项目到底要做什么事情,并定好了一个方案。这一节我们就开始动手了,动手第一步就是把数据搞清楚,把原始数据搞成我们可以用PyTorch处理的样子。这个数据不同于我们之前用的图片数据,像之前那种RGB图像拿过来做一些简单的预处理就可以放进tensor中,这里的医学影像数据预处理部分就要复杂的多。比如说怎么去把影像数据导入进来,怎么转换成我们能处理的形式;数据可能存在错误,给定的结节位置和实际的坐标位置有偏差;数据量太大我们不能一次性加载怎么处理等等。今天理解数据这部分处理的就是之前整个项目框架图的第一步,关于数据加载的问题。
机器学习之禅
2022/07/11
1.8K1
18 | 使用PyTorch完成医疗图像识别大项目:理解数据
Unet网络实现叶子病虫害图像分割
智能化农业作为人工智能应用的重要领域,对较高的图像处理能力要求较高,其中图像分割作为图像处理方法在其中起着重要作用。图像分割是图像分析的关键步骤, 在复杂的自然背景下进行图像分割, 难度较大。
AI科技大本营
2021/09/03
2K0
PyTorch 人工智能研讨会:6~7
本章扩展了循环神经网络的概念。 您将了解循环神经网络(RNN)的学习过程以及它们如何存储内存。 本章将介绍长短期记忆(LSTM)网络架构,该架构使用短期和长期存储器来解决数据序列中的数据问题。 在本章的最后,您将牢固地掌握 RNN 以及如何解决自然语言处理(NLP)数据问题。
ApacheCN_飞龙
2023/04/27
1.6K0
轻松学Pytorch –Mask-RCNN图像实例分割
前面介绍了torchvison框架下Faster-RCNN对象检测模型使用与自定义对象检测的数据集制作与训练。在计算机视觉所要面对的任务中,最常见的就是对象检测、图像语义分割跟实例分割,torchvision支持Mask-RCNN模型的调用与自定义数据训练,可以同时实现对象检测与实例分割任务。本文主要跟大家分享一下如何使用mask-rcnn网络实现对象检测与实例分割,下一篇将会介绍如何制作数据集训练Mask-RCNN网络。
OpenCV学堂
2020/08/17
2.5K0
PyTorch 深度学习(GPT 重译)(四)
第 2 部分的结构与第 1 部分不同;它几乎是一本书中的一本书。我们将以几章的篇幅深入探讨一个单一用例,从第 1 部分学到的基本构建模块开始,构建一个比我们迄今为止看到的更完整的项目。我们的第一次尝试将是不完整和不准确的,我们将探讨如何诊断这些问题,然后修复它们。我们还将确定我们解决方案的各种其他改进措施,实施它们,并衡量它们的影响。为了训练第 2 部分中将开发的模型,您将需要访问至少 8 GB RAM 的 GPU,以及数百 GB 的可用磁盘空间来存储训练数据。
ApacheCN_飞龙
2024/03/21
3390
PyTorch 深度学习(GPT 重译)(四)
推荐阅读
相关推荐
05-PyTorch自定义数据集Datasets、Loader和tranform
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验