首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

PyTorch深度学习开发医学影像端到端判别项目(内附资料)

PyTorch深度学习开发医学影像端到端判别项目

download:https://www.sisuoit.com/3888.html

模型输入是 128×128 的图像,练习集大概有 122k 张图片,校验集大概有 22k 张图片。

经历1:对 Loss 的处理

通常,在练习过程中,我们都是将 loss 的增加到一个 list 里保存。记住在保存前,先 detach,然后仅运用其数值。否则,你增加的就不仅仅是 loss,而是整个核算图。

正确用法:

123loss = F.mse_loss(prd, true)epoch_loss += loss.detach().item()training_log.append(epoch_loss)

过错用法:

123loss = F.mse_loss(prd, true)epoch_loss += losstraining_log.append(epoch_loss)

经历2:将模型、输入、输出加载到 CUDA

避免机器内存的暴升,记得把模型和从 dataloader 读取的输入数据放到 CUDA 里再运用。

正确用法:

123456model = MyModel()model = model.to(device)for batch_idx, (x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) prd = model(x)

过错用法:

123model = MyModel()for batch_idx, (x,y) in enumerate(train_loader): prd = model(x)

经历3:运用废物收回

Python 在内存废物收回方面做的可能不太好,不必的变量往往不会被当即收回。要做到当即收回,最好在每个练习循环里加入下面的代码:

12import gcgc.collect()

这个带来的效果可能微乎其微,但能够保证高效的废物收回。

经历4:DataLoader 的 worker 数量不是越多越好

如果你运用了多个 worker 读取数据,记住这个数并不是越多越好。很多的 worker 可能会由于进程协作的问题或者 IO 的问题而拖慢速度。

正确用法:

1train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_worker = [一个合理的数字])

过错用法:

1train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_worker = [一个特别大的数,比如 50,100 等])

关于多少 worker 才是最合适的,能够参考官方论坛的评论帖:地址。

经历5:把数据保存在 Numpy Array 里,而不是 List 里

这个问题的官方评论在这里:链接,关于解决方案引证如下:

Python lists store only references to the objects. The objects are kept separately in memory. Every object has a refcount, therefore every item in the list has a refcount.

Numpy arrays (of standard np types) are stored as continuous blocks in memory and are only ONE object with one refcount.

This changes if you make the NumPy array explicitly of type object, which makes it start behaving like a regular Python list (only storing references to (string) objects). The same “problems” with memory consumption now appear.”

所以,如果的在 DataLoader 中的数据是保存在 list 里的,记得用 np.array(x) 转换成 Numpy Array。

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OGdyuwm0ulvuSNnJ34rYLTow0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券