文档中心 智能钛机器学习平台 最佳实践 用 PyTorch 实现图像识别

用 PyTorch 实现图像识别

最近更新时间:2019-11-22 15:30:10

案例背景

PyTorch 是一个开源的深度学习框架,对深度学习算法进行训练和优化,它被广泛应用在人工智能领域。更多详细介绍可参考 PyTorch 官网

本案例使用的场景是计算机视觉领域基本任务之一:手写数字识别,输入一张手写数字的图像,然后识别图像中手写的是哪个数字(0 - 9)。

本文通过智能钛机器学习平台提供的 PyTorch 框架搭建一个简单的神经网络模型实现 MNIST 手写数字识别。通过本文的学习,您可以掌握如下操作:

  • 如何在智能钛机器学习平台使用 PyTorch 框架。
  • 如何上传数据集至 COS 并通过代码访问。
  • 如何上传自定义代码。
  • 如何在工作流页面向自定义代码传参。
  • 如何查看代码日志/报错信息。

数据集介绍

本案例使用的 MNIST 数据集可参考 MNIST官网。该数据集由来自 250 个不同人手写的数字构成,共包含 60,000 个训练数据,10,000 个测试数据,每个数据都是一张 28 像素 * 28 像素大小的灰度图像。

部分手写数字图像示例如下:

整体流程

利用智能钛机器学习平台完成手写数字识别任务,我们需要完成以下几个步骤:

  1. 准备案例所需数据集。
  2. 用户本地准备实现手写数字识别任务的自定义代码。
  3. 利用智能钛机器学习平台提供的 PyTorch 框架运行自定义代码。
  4. 查看工作流运行状态和结果。

整体工作流示例如下:

详细流程

一、数据集和自定义代码准备

1. 数据集准备
为方便用户操作,我们将本案例所需数据集 MNIST 上传到公共访问路径下,用户在代码中可直接通过公共路径访问该数据集。

注意:

上传到公共存储桶中文件的访问路径格式为/cos_public/XXXX,上传到个人存储桶中文件的访问路径为cos_person/XXXX

2. 代码准备
本案例使用 PyTorch 框架搭建简单的卷积神经网络来完成手写数字图像识别的任务。
为方便用户直接进行后续工作流的搭建,本文提供案例源代码 mnist.py 供用户直接下载体验。
在 mnist.py 源代码中,访问公共 COS 下 MNIST 数据集的示例代码片段如下,访问路径为:/cos_public/mnist

二、利用 PyTorch 框架运行自定义代码

  1. 在智能钛控制台的左侧导航栏中,选择【框架】>【深度学习】>【PyTorch】,并拖入画布中。
  2. 右键【PyTorch】,选择【重命名】,输入新名称:手写数字识别,单击【确定】。
  3. 单击【手写数字识别】,在右侧弹出的配置栏中配置框架参数。
    • 单击【程序脚本】,在自动弹出的【资源列表】中上传本地的用户自定义代码(此处您可直接下载本文提供的源代码 mnist.py 到本地,然后上传到工作流进行案例学习)。
    • 依赖包文件:若使用本文提供的自定义代码,则此处无需填写。此处用于上传在“程序脚本”代码中需依赖包的压缩文件,注意此处上传的压缩包为一级目录下的压缩,否则可能导致无法读取依赖包的报错。
    • 程序依赖:若使用本文提供的自定义代码,则此处无需填写。
    • 程序参数:若使用本文提供的自定义代码,则可直接拷贝以下参数信息。此处提供给用户来指定自定义代码中所需参数的取值,格式为:--参数名[空格]取值。
      --batch-size 64
      --test-batch-size 1000
      --epochs 10
      --lr 0.01
      --momentum 0.5

以上列表中的参数对应源代码 mnist.py 中以下部分:(用户可参考此处格式配置自定义代码中的参数)

注意:

若自定义代码中,用户给未给参数命名,则可在代码中可通过默认参数 args[0] 读取用户填写的第一个取值,args[1] 读取第二个取值,以此类推。

  1. 配置资源参数,用户可直接选择平台提供的默认值,也可根据自身代码调整资源分配。
  2. 运行工作流
    单击画布左上角的【运行】,即可开始运行工作流,待运行成功(运行大概需要 4 min)。

三、查看工作流运行状态和结果

  1. 右键【手写数字识别】,单击【PyTorch 控制台】可查看该工作流运行相关日志。
  2. 在弹框中, 选择单击 【stdout.log】即可在日志中查看手写数字识别任务的训练过程和测试结果。

    本案例实验结果展示了:手写数字识别任务一共训练了10个 Epoch,并详细输出了每个 Epoch 过程中各 batch_size 数据下的损失值 Loss 变化过程。在第10个 Epoch 训练后,模型在测试数据集上取得最佳准确率 98%。