前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >理解Pytorch中LSTM的输入输出参数含义

理解Pytorch中LSTM的输入输出参数含义

作者头像
marsggbo
发布2020-06-12 10:01:32
5.2K0
发布2020-06-12 10:01:32
举报
  • 这个输出tensor包含了LSTM模型最后一层每个time step的输出特征,比如说LSTM有两层,那么最后输出的是
[h^1_0,h^1_1,...,h^1_l]

,表示第二层LSTM每个time step对应的输出。

代码语言:txt
复制
- 另外如果前面你对输入数据使用了`torch.nn.utils.rnn.PackedSequence`,那么输出也会做同样的操作编程packed sequence。
- 对于unpacked情况,我们可以对输出做如下处理来对方向作分离`output.view(seq_len, batch, num_directions, hidden_size)`, 其中前向和后向分别用0和1表示Similarly, the directions can be separated in the packed case.

3、 代码示例

代码语言:javascript
复制
rnn = nn.LSTM(10, 20, 2) # 一个单词向量长度为10,隐藏层节点数为20,LSTM有2层
input = torch.randn(5, 3, 10) # 输入数据由3个句子组成,每个句子由5个单词组成,单词向量长度为10
h0 = torch.randn(2, 3, 20) # 2:LSTM层数*方向 3:batch 20: 隐藏层节点数
c0 = torch.randn(2, 3, 20) # 同上
output, (hn, cn) = rnn(input, (h0, c0))

print(output.shape, hn.shape, cn.shape)

>>> torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

参考:

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~

邮箱:marsggbo@foxmail.com

2019-12-31 10:41:09

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-12-31 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、举个栗子
  • 2、Pytorch源代码参数理解
    • 2.1 LSTM模型参数含义
      • 2.2 输入数据
        • 2.3 输出数据
        • 3、 代码示例
        相关产品与服务
        批量计算
        批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档