前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >手写数字识别基本思路

手写数字识别基本思路

作者头像
算法与编程之美
发布2023-08-22 13:11:09
2690
发布2023-08-22 13:11:09
举报
文章被收录于专栏:算法与编程之美

问题

什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验?

方法

mnist数据集

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。

使用Pytorch实现手写数字识别

1.进行数据预处理对于MNIST数据集,可以通过torchvision中的datasets进行下载。

root (string):表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录。

train (bool, optional):如果为True,则从training.pt创建数据集,否则从test.pt创建数据集。

download (bool, optional):如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载。

transform (callable, optional):接收PIL图片并返回转换后版本图片的转换函数。

bat = 128transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(0.1307, 0.3081) # (均值,方差)]) # Compoes 两个操作合为一个train_ds = datasets.MNIST(root='data', download=False, train=True, transform=transform)train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])test_ds = datasets.MNIST(root='data', download=True, train=False, transform=transform)train_loader = DataLoader(dataset=train_ds, batch_size=bat, shuffle=True)val_loader = DataLoader(dataset=val_ds, batch_size=bat)test_loader = DataLoader(dataset=test_ds, batch_size=bat)

2.构建模型

class MyNet(nn.Module): def __init__(self) -> None: super().__init__() self.flatten = nn.Flatten() # 将28*28的图像拉伸为784维向量 # 第一个全连接层Full Connection(FC) self.fc1 = nn.Linear(in_features=784, out_features=256) self.fc2 = nn.Linear(in_features=256, out_features=128) self.fc3 = nn.Linear(in_features=128, out_features=10) def forward(self, x): x = self.flatten(x) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) out = torch.relu(self.fc3(x)) return out

构建一个三层的神经网络MNIST数据集中的图片都是28×28大小的,而且是灰度图。而全连接神经网络的输入要是一个行向量,所以我们要把28×28的矩阵转换成28×28=764的行向量,作为神经网络的输入

3.优化器的选择,参数设置

使用优化器和损失函数。优化器选择SGD,SGD随机梯度下降,lr学习率取值0.2最优,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生。

optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#lr学习率,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生#损失函数#衡量y与y_hat之间的差异loss_fn=nn.CrossEntropyLoss()

4.对模型进行训练测试,网络的输入,输入尺寸B*C*H*W B是batch,一个batch一个batch交给网络处理,x=torch.rand(size=(128,1,28,28)),基于loss信息利用优化器从后向前更新网络全部参数。

def train(dataloader, net, loss_fn, optimizer, epoch): size = len(dataloader.dataset) corrent = 0 epoch_loss = 0.0 batch_num = len(dataloader) net.train() # 一个batch一个batch的训练网络 for batch_idx, (X, y) in enumerate(dataloader): pred = net(X) # 衡量y与y_hat之间的loss # y:128, pred:128x10 CrossEntropyloss loss = loss_fn(pred, y) # 基于loss信息利用优化器从后向前更新网络全部参数 <--- optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() corrent += (pred.argmax(1) == y).type(torch.float).sum().item() if batch_idx % 100 == 0: # f-string print(f'[{batch_idx + 1:>5d}/{batch_num + 1:>5d}],loss:{loss.item()}') avg_loss = epoch_loss / batch_num avg_accuracy = corrent / size # loss_list.append(avg_loss) return avg_accuracy, avg_lossdef test(dataloader, net, loss_fn): size = len(dataloader.dataset) batch_num = len(dataloader) corrent = 0 losses = 0 net.eval() with torch.no_grad(): for X, y in test_loader: pred = net(X) correct = (pred.argmax(1) == y).type(torch.int).sum().item() # print(y.size(0)) # print(correct) corrent += correct accuracy = corrent / size avg_loss = losses / batch_num return accuracy, avg_loss

5.保存最优的模型

net.load_state_dict(torch.load('model_best.pth')) test(test_loader,net,loss_fn)

6.读入自己的写入数字,进行识别

model = MyNet()model.load_state_dict(torch.load('model_best.pth'))img = Image.open("7.png").convert("L") # 转为灰度图像img = transform(img)# img = np.array(img)# print(img)result = model(img)_, predict = torch.max(result.data, dim=1)print(result)print("the result is:",predict.item())

结语

minist是一个28*28的图像,所以输入就是28*28=784的维度,输出为10,0-9十个数字。手写数字识别首先需要初始化全局变量,构建数据集。然后构建模型,构建迭代器与损失函数,进行训练测试。最后可以将训练的模型进行保存,通过读取自己写的数字进行识别验证,完成一个简单的深度学习。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-05-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法与编程之美 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档