一步一步带你实现PyTorch手写字符识别

sefd

手写字符识别是图像处理的一个经典入门案例,今天我们将用pytorch来编程实现,这里我们用的是一种非常常用的统计方法Logistic回归,它能帮助我们从一组自变量中预测二进制输出。 现在,我们将了解如何在PyTorch中实现这一点,PyTorch是一个非常流行的深度学习库,由Facebook开发。

现在,我们将看到如何使用PyTorch中的Logistic回归对MNIST数据集中的手写数字进行分类。 首先,您需要将PyTorch安装到Python环境中。 最简单的方法是使用pip或conda工具。 访问pytorch.org并安装您想要使用的Python解释器版本和包管理器。

实现步骤

一、安装PyTorch后,导入重新安装的库函数和对象。

这里,torch.nn模块包含模型所需的代码,torchvision.datasets包含MNIST数据集。 它包含我们将在这里使用的手写数字的数据集。 torchvision.transforms模块包含将对象转换为其他对象的各种方法。 在这里,我们将使用它从图像转换为PyTorch张量。 此外,torch.autograd模块包含Variable类以及其他类,我们将在定义我们的张量时使用它。

二、下载并导入数据集

三、参数定义

在我们的数据集中,图像大小为28 * 28。 因此,我们的输入大小是784.此外,这里有10位数字,因此,我们可以有10个不同的输出。 因此,我们将num_classes设置为10.此外,我们将在整个数据集上训练五次。 最后,我们将分别训练小批量的100张图像,以防止因内存溢出而导致程序崩溃。

四、定义我们的模型如下

我们将我们的模型初始化为torch.nn.Module的子类,然后定义前向传递。 在我们编写的代码中,softmax在每次正向传递期间内部计算,因此我们不需要在forward()函数内指定它。

五、初始化

六、设置了损失函数和优化器

我们将使用交叉熵损失,对于优化器,我们将使用随机梯度下降算法,其学习率为0.001,如上面的超参数中所定义。

七、现在,我们将开始训练。

将所有梯度重置为0

向前传播计算输出值

计算损失

执行反向传播

更新所有重量

八、最后我们将使用以下代码测试模型

假设您正确执行了所有步骤,您将获得82%的准确度,这与当今最先进的模型相差甚远,后者使用了一种特殊类型的神经网络架构。

完整代码:

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181220A06PH400?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券