前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >LSTM使用MNIST手写数字识别实战的代码和心得

LSTM使用MNIST手写数字识别实战的代码和心得

原创
作者头像
flykiss
修改2021-09-14 10:29:32
1.2K0
修改2021-09-14 10:29:32
举报
文章被收录于专栏:学习新东西学习新东西

RNN的架构除了RNN类中的模型不同,其他的构架与CNN类似,如果还没有阅读过CNN文章的可以点击下方链接进入: CNN使用MNIST手写数字识别实战的代码和心得

LSTM(Long Short-Term Memory长短时记忆网络)虽然在MNIST手写数字识别方面不擅长,但是也可以进行使用,效果比CNN略显逊色

对LSTM使用MNIST手写数字识别的思路图

undefined

LSTM是在RNN的主线基础上增加了支线,增加了三个门,输入门,输出门和忘记门。

避免了可能因为加权问题,使程序忘记之前的内容,梯度弥散或者梯度爆炸。

batch_size在这里选取的是100,选择了一个隐藏层和128的神经元,对LSTM结构进行部署,

MNIST长宽为28,选取一行28作为一份数据传入input_size,RNN是按照时间序列进行传值,batch_size为100,也就是在每次传入的数据为(128,28)

进入隐藏层后,out结果张量的shape为(100, 28, 128)

在out:, -1, :时间序列中取得最后一次的输出,得到(100, 128)

再进入全连接层后将hidden_size的128变为所需要的输出的10种图片的维度(100, 10)

对超参数的定义

代码语言:txt
复制
#定义超参数
代码语言:txt
复制
input_size = 28
代码语言:txt
复制
time_step = 28# 时间序列
代码语言:txt
复制
Layers = 1# 隐藏单元的个数
代码语言:txt
复制
hidden_size = 128# 每个隐藏单元中神经元个数
代码语言:txt
复制
classes = 10
代码语言:txt
复制
batch_size = 100
代码语言:txt
复制
EPOCHS = 10
代码语言:txt
复制
learning_rate = 0.01 #学习率

RNN对于数据的读取有别于CNN,按照时间来读取,在这里可以将input_size看作是图片的长,而time_step看作宽的长度。

代码语言:txt
复制
    #Long Short-Term Memory(长短时记忆网络)
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, Layers, classes):
            super(RNN, self).__init__()
            self.Layers = Layers
            self.hidden_size = hidden_size
            self.lstm = nn.LSTM(input_size, hidden_size, Layers, batch_first=True)
            self.fc = nn.Linear(hidden_size, classes)
        def forward(self, x):
            # 设置初始隐藏状态和单元格状态
            h0 = torch.zeros(self.Layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.Layers, x.size(0), self.hidden_size).to(device)
            # out张量的shape(batch_size, time_step, hidden_size)
            out, _ = self.lstm(x, (h0, c0))#torch.Size([100, 28, 128])
            #out[:, -1, :].shape torch.Size([100, 128])
            # 只得到时间顺序点的最后一步
            out = self.fc(out[:, -1, :])#torch.Size([100, 10])
            return out
            ```
运行结果:
```python
代码语言:txt
复制
RNN(
代码语言:txt
复制
  (lstm): LSTM(28, 128, batch_first=True)
代码语言:txt
复制
  (fc): Linear(in_features=128, out_features=10, bias=True)
代码语言:txt
复制
)
代码语言:txt
复制
Epoch [10/10],  Loss: 0.0115
代码语言:txt
复制
Test Accuracy to test: 98.07 %
代码语言:txt
复制

undefined

undefined

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
NLP 服务
NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档