专栏首页相约机器人如何构建PyTorch项目

如何构建PyTorch项目

作者 | Branislav Hollander 来源 | Medium

编辑 | 代码医生团队

自从开始训练深度神经网络以来,一直在想所有Python代码的结构是什么。理想情况下,良好的结构应支持对该模型进行广泛的试验,允许在一个紧凑的框架中实现各种不同的模型,并且每个阅读代码的人都容易理解。必须能够通过编码和重用各种数据加载器来使用来自不同数据源的数据。此外,如果模型支持在一个模型中组合多个网络(例如GAN或原始R-CNN的情况),那就太好了。该框架还应该具有足够的灵活性以允许进行复杂的可视化(这是在数据科学中的核心信念之一,即可视化使一切变得更加容易,尤其是在计算机视觉任务的情况下)。

深度学习框架的详细实现当然取决于正在使用的基础库,无论是TensorFlow,PyTorch还是CNTK。在这篇文章中,将介绍基于PyTorch的方法。但是,认为一般结构同样适用于使用的任何库。可以在链接找到整个存储库。

https://gitlab.com/branislav.hollander/pytorchprojectframework

总体结构

深度学习框架的项目结构

在上图(取自Python编辑器VS代码)上,可以看到为框架创建的常规文件夹结构。该框架由一些启动脚本(train.py,validate.py,hyperopt.py)以及隐藏在文件夹中的库组成。该数据集文件夹中包含加载各种类型的数据的类和方法的训练。损失的文件夹可以包含附加的功能损失或验证指标。如果项目不需要任何自定义损失函数,则可能不需要此文件夹。该机型的文件夹是最重要的:它包含实际的模型。该优化的文件夹包括自定义优化程序的代码。与losss文件夹一样,如果没有任何自定义优化器,则可以省略此文件夹。最后,utils文件夹包含整个框架使用的各种实用程序,最著名的是visualizer。还将注意到项目根文件夹中的config_segmentation.json文件。该文件包含训练所需的所有配置选项。

可能已经猜到了,可以通过调用train.py脚本来启动训练。使用适当的配置文件作为命令行参数调用此脚本。它负责所有高级训练内容,例如加载训练和验证数据集和模型,设置可视化,运行训练循环以及最后导出训练后的模型。

通过调用适当的脚本并将配置文件作为参数来传递,从而类似地使用验证。

数据集

以2D分割数据集为例,数据集文件夹中的文件。

在上图中,可以看到数据集文件夹的结构。它包含__init__.py模块,该模块包含一些用于查找和创建正确数据集的必要功能,以及一个自定义数据加载器,该数据加载器会将数据转发到训练管道(有关此的更多信息,请查看PyTorch API文档)。顾名思义,base_dataset.py在框架中定义的每个数据集定义了抽象基类。

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

对于定义的每个自定义数据集,都必须实现__getitem__和__len__方法,以便PyTorch可以对其进行迭代。不再需要处理DataLoader,因为它已经在datasets / __ init__.py中定义。还可以为每个时期之前和之后要调用的数据集定义自定义回调。如果要使用某种预热方法,该方法可以在前几个时期将不同的数据馈送到模型,然后再切换到更复杂的数据集,则这可能会很有用。

要实例化数据集,train.py脚本调用以下代码:

print(‘Initializing dataset…’)
train_dataset =
   create_dataset(configuration[‘train_dataset_params’])
train_dataset_size = len(train_dataset)
print(‘The number of training samples = {0}’.format(train_dataset_size))

这将调用create_dataset函数,该函数查看配置文件并根据其名称选择正确的数据集。命名数据集时,务必遵循惯例<datasetname> _dataset.py,因为这是脚本能够根据配置文件中的字符串查找数据集的方式。最后,以上脚本在数据集上调用len()函数,以告知您其大小。

模型

以细分模型为例,models文件夹中的文件。

框架中的模型与数据集的工作方式相同:__init__.py模块包含用于根据其模块名称和配置文件中定义的字符串查找和创建正确模型的函数。模型类本身继承自抽象BaseModel类,并且必须实现两个方法:

  • 向前(自己)运行前向预测。
  • 训练通过后,optimize_parameters(self)可以修改网络的权重。

所有其他方法都可能被覆盖,或者可以使用默认的BaseClass实现。可能要覆盖的功能包括pre_epoch_callback和post_epoch_callback(在每个时期之前和之后调用)或测试(在验证期间调用)。

为了正确使用框架,了解如何使用网络,优化器和模型中的损失非常重要。由于模型中可能有多个使用不同优化器的网络以及多个不同的损失(例如,可能希望显示语义本地化模型的边界框分类和回归损失),因此界面要涉及更多一点。具体来说,需要提供损失名称和网络名称以及BaseModel类的优化程序,以了解如何训练模型。在提供的代码中,包括2D细分模型的示例以及示例数据集,以供了解应如何使用框架。

看一下提供的2D分割模型的__init __()函数:

class Segmentation2DModel(BaseModel):
   def __init__(self, configuration):
      super().__init__(configuration)
      self.loss_names = [‘segmentation’]
      self.network_names = [‘unet’]
      self.netunet = UNet(1, 2)
      self.netunet = self.netunet.to(self.device)
      if self.is_train: # only defined during training time
         self.criterion_loss = torch.nn.CrossEntropyLoss()
         self.optimizer = torch.optim.Adam(self.netunet.parameters(), lr=configuration[‘lr’])
         self.optimizers = [self.optimizer]

这就是这里发生的事情:首先,阅读模型配置。然后,定义“分段”损失并将其放入self.loss_names列表中。损失的名称很重要,因为将变量self.loss_segmentation用于损失。通过知道名称,BaseModel可以查找丢失并将其打印在控制台中或将其可视化。同样,定义网络的名称。这可以确保BaseModel知道如何训练模型而无需明确定义它。接下来,初始化网络(在本例中为U-Net)并将其移至GPU。如果处于训练模式,还将定义损失标准并实例化优化器(在本例中为Adam)。最后,将优化器放入self.optimizers列表中。此列表再次在BaseModel类中使用,以更新学习率或从给定的检查点恢复训练。

看一下forward()和optimize_parameters()函数:

def forward(self):
   self.output = self.netunet(self.input)
def backward(self):
   self.loss_segmentation = self.criterion_loss(self.output, self.label)
def optimize_parameters(self):
   self.loss_segmentation.backward() # calculate gradients
   self.optimizer.step()
   self.optimizer.zero_grad()

如您所见,这是标准的PyTorch代码:它唯一的职责是在网络本身上调用forward(),在计算出梯度后将优化器步进,并将其再次置零。为自己的模型实施此操作应该很容易。

可视化

utils文件夹中的文件

可视化可以在Visualizer类中找到。此类负责将损失信息打印到终端,并使用visdom库可视化各种结果。它在训练脚本的开头进行初始化(将加载visdom服务器)。训练脚本还调用其plot_current_losses()和print_current_losses()函数以可视化并写出训练损失。它还包含诸如plot_current_validation_metrics(),plot_roc_curve()和show_validation_images()之类的函数,这些函数不会自动调用,但可以从post_epoch_callback()模型中调用在验证时进行一些有用的可视化。试图使可视化工具保持一般性。当然,可以自己扩展可视化器的功能,使其对您更有用。

https://github.com/facebookresearch/visdom

结论

提出了一种编写通用的深度学习框架的方法,该框架可用于深度学习的所有领域。通过使用此结构,将获得清晰灵活的代码库,以进行进一步的开发。当然,有许多替代方法可以解决该问题。

本文分享自微信公众号 - 相约机器人(xiangyuejiqiren),作者:代码医生

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-10-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 使用PyTorch进行表格数据的深度学习

    使用表格数据进行深度学习的最简单方法是通过fast-ai库,它可以提供非常好的结果,但是对于试图了解幕后实际情况的人来说,它可能有点抽象。因此在本文中,介绍了如...

    代码医生工作室
  • 在点云上进行深度学习:在Google Colab中实现PointNet

    3D数据对于自动驾驶汽车,自动驾驶机器人,虚拟现实和增强现实至关重要。与以像素阵列表示的2D图像不同,它可以表示为多边形网格,体积像素网格,点云等。

    代码医生工作室
  • 结合知识图谱实现基于电影的推荐系统

    知识图谱(Knowledge Graph,KG)可以理解成一个知识库,用来存储实体与实体之间的关系。知识图谱可以为机器学习算法提供更多的信息,帮助模型更好地完成...

    代码医生工作室
  • 论文式编程

    文学编程(Literate programming)的一些概念,上个世纪 70 年代就有人提出来了。

    py3study
  • 小甲鱼《零基础学习Python》课后笔记(三十七):类和对象——面向对象编程

    1.当程序员不想把同一段代码写几次,他们发明了函数解决了这种情况。当程序员已经有了一个类,而又想建立一个非常接近的新类,他们会怎么做呢? 定义一个新类继承已有...

    小火柴棒
  • Simplex 单纯形算法的python

    算法可以在给定一个包含线性规划问题的标准形式的描述下,求解该线性规划问题。 例如某一个 pro.txt 文件内容如下:

    py3study
  • Leetcode 684. 冗余连接(dsu,氵)

    输入一个图,该图由一个有着N个节点 (节点值不重复1, 2, ..., N) 的树及一条附加的边构成。附加的边的两个顶点包含在1到N中间,这条附加的边不属于树中...

    glm233
  • Python面向对象编程

    而解决这一问题的比较有效的方法之一就是数据隐藏,即编码过程中尽可能的隐藏内部的实现细节。

    讲编程的高老师
  • Python开发植物大战僵尸游戏

    ------------------- End -------------------

    python学习教程
  • 【Code】关于 GCN,我有三种写法

    本篇文章主要基于 DGL 框架用三种不同的方式来实现图卷积神经网络。手机看可能不太方便,可以点击阅读原文,移步到知乎上看(但是我忘了加 = =)。

    阿泽 Crz

扫码关注云+社区

领取腾讯云代金券