前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何构建PyTorch项目

如何构建PyTorch项目

作者头像
代码医生工作室
发布2019-10-28 18:31:40
1.8K0
发布2019-10-28 18:31:40
举报
文章被收录于专栏:相约机器人
作者 | Branislav Hollander 来源 | Medium

编辑 | 代码医生团队

自从开始训练深度神经网络以来,一直在想所有Python代码的结构是什么。理想情况下,良好的结构应支持对该模型进行广泛的试验,允许在一个紧凑的框架中实现各种不同的模型,并且每个阅读代码的人都容易理解。必须能够通过编码和重用各种数据加载器来使用来自不同数据源的数据。此外,如果模型支持在一个模型中组合多个网络(例如GAN或原始R-CNN的情况),那就太好了。该框架还应该具有足够的灵活性以允许进行复杂的可视化(这是在数据科学中的核心信念之一,即可视化使一切变得更加容易,尤其是在计算机视觉任务的情况下)。

深度学习框架的详细实现当然取决于正在使用的基础库,无论是TensorFlow,PyTorch还是CNTK。在这篇文章中,将介绍基于PyTorch的方法。但是,认为一般结构同样适用于使用的任何库。可以在链接找到整个存储库。

https://gitlab.com/branislav.hollander/pytorchprojectframework

总体结构

深度学习框架的项目结构

在上图(取自Python编辑器VS代码)上,可以看到为框架创建的常规文件夹结构。该框架由一些启动脚本(train.py,validate.py,hyperopt.py)以及隐藏在文件夹中的库组成。该数据集文件夹中包含加载各种类型的数据的类和方法的训练。损失的文件夹可以包含附加的功能损失或验证指标。如果项目不需要任何自定义损失函数,则可能不需要此文件夹。该机型的文件夹是最重要的:它包含实际的模型。该优化的文件夹包括自定义优化程序的代码。与losss文件夹一样,如果没有任何自定义优化器,则可以省略此文件夹。最后,utils文件夹包含整个框架使用的各种实用程序,最著名的是visualizer。还将注意到项目根文件夹中的config_segmentation.json文件。该文件包含训练所需的所有配置选项。

可能已经猜到了,可以通过调用train.py脚本来启动训练。使用适当的配置文件作为命令行参数调用此脚本。它负责所有高级训练内容,例如加载训练和验证数据集和模型,设置可视化,运行训练循环以及最后导出训练后的模型。

通过调用适当的脚本并将配置文件作为参数来传递,从而类似地使用验证。

数据集

以2D分割数据集为例,数据集文件夹中的文件。

在上图中,可以看到数据集文件夹的结构。它包含__init__.py模块,该模块包含一些用于查找和创建正确数据集的必要功能,以及一个自定义数据加载器,该数据加载器会将数据转发到训练管道(有关此的更多信息,请查看PyTorch API文档)。顾名思义,base_dataset.py在框架中定义的每个数据集定义了抽象基类。

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

对于定义的每个自定义数据集,都必须实现__getitem__和__len__方法,以便PyTorch可以对其进行迭代。不再需要处理DataLoader,因为它已经在datasets / __ init__.py中定义。还可以为每个时期之前和之后要调用的数据集定义自定义回调。如果要使用某种预热方法,该方法可以在前几个时期将不同的数据馈送到模型,然后再切换到更复杂的数据集,则这可能会很有用。

要实例化数据集,train.py脚本调用以下代码:

代码语言:javascript
复制
print(‘Initializing dataset…’)
train_dataset =
   create_dataset(configuration[‘train_dataset_params’])
train_dataset_size = len(train_dataset)
print(‘The number of training samples = {0}’.format(train_dataset_size))

这将调用create_dataset函数,该函数查看配置文件并根据其名称选择正确的数据集。命名数据集时,务必遵循惯例<datasetname> _dataset.py,因为这是脚本能够根据配置文件中的字符串查找数据集的方式。最后,以上脚本在数据集上调用len()函数,以告知您其大小。

模型

以细分模型为例,models文件夹中的文件。

框架中的模型与数据集的工作方式相同:__init__.py模块包含用于根据其模块名称和配置文件中定义的字符串查找和创建正确模型的函数。模型类本身继承自抽象BaseModel类,并且必须实现两个方法:

  • 向前(自己)运行前向预测。
  • 训练通过后,optimize_parameters(self)可以修改网络的权重。

所有其他方法都可能被覆盖,或者可以使用默认的BaseClass实现。可能要覆盖的功能包括pre_epoch_callback和post_epoch_callback(在每个时期之前和之后调用)或测试(在验证期间调用)。

为了正确使用框架,了解如何使用网络,优化器和模型中的损失非常重要。由于模型中可能有多个使用不同优化器的网络以及多个不同的损失(例如,可能希望显示语义本地化模型的边界框分类和回归损失),因此界面要涉及更多一点。具体来说,需要提供损失名称和网络名称以及BaseModel类的优化程序,以了解如何训练模型。在提供的代码中,包括2D细分模型的示例以及示例数据集,以供了解应如何使用框架。

看一下提供的2D分割模型的__init __()函数:

代码语言:javascript
复制
class Segmentation2DModel(BaseModel):
   def __init__(self, configuration):
      super().__init__(configuration)
      self.loss_names = [‘segmentation’]
      self.network_names = [‘unet’]
      self.netunet = UNet(1, 2)
      self.netunet = self.netunet.to(self.device)
      if self.is_train: # only defined during training time
         self.criterion_loss = torch.nn.CrossEntropyLoss()
         self.optimizer = torch.optim.Adam(self.netunet.parameters(), lr=configuration[‘lr’])
         self.optimizers = [self.optimizer]

这就是这里发生的事情:首先,阅读模型配置。然后,定义“分段”损失并将其放入self.loss_names列表中。损失的名称很重要,因为将变量self.loss_segmentation用于损失。通过知道名称,BaseModel可以查找丢失并将其打印在控制台中或将其可视化。同样,定义网络的名称。这可以确保BaseModel知道如何训练模型而无需明确定义它。接下来,初始化网络(在本例中为U-Net)并将其移至GPU。如果处于训练模式,还将定义损失标准并实例化优化器(在本例中为Adam)。最后,将优化器放入self.optimizers列表中。此列表再次在BaseModel类中使用,以更新学习率或从给定的检查点恢复训练。

看一下forward()和optimize_parameters()函数:

代码语言:javascript
复制
def forward(self):
   self.output = self.netunet(self.input)
def backward(self):
   self.loss_segmentation = self.criterion_loss(self.output, self.label)
def optimize_parameters(self):
   self.loss_segmentation.backward() # calculate gradients
   self.optimizer.step()
   self.optimizer.zero_grad()

如您所见,这是标准的PyTorch代码:它唯一的职责是在网络本身上调用forward(),在计算出梯度后将优化器步进,并将其再次置零。为自己的模型实施此操作应该很容易。

可视化

utils文件夹中的文件

可视化可以在Visualizer类中找到。此类负责将损失信息打印到终端,并使用visdom库可视化各种结果。它在训练脚本的开头进行初始化(将加载visdom服务器)。训练脚本还调用其plot_current_losses()和print_current_losses()函数以可视化并写出训练损失。它还包含诸如plot_current_validation_metrics(),plot_roc_curve()和show_validation_images()之类的函数,这些函数不会自动调用,但可以从post_epoch_callback()模型中调用在验证时进行一些有用的可视化。试图使可视化工具保持一般性。当然,可以自己扩展可视化器的功能,使其对您更有用。

https://github.com/facebookresearch/visdom

结论

提出了一种编写通用的深度学习框架的方法,该框架可用于深度学习的所有领域。通过使用此结构,将获得清晰灵活的代码库,以进行进一步的开发。当然,有许多替代方法可以解决该问题。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档