前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch学习-GRU使用

Pytorch学习-GRU使用

作者头像
致Great
发布2021-12-24 19:15:52
6750
发布2021-12-24 19:15:52
举报
文章被收录于专栏:程序生活
代码语言:javascript
复制
import torch.nn as nn
import torch

# gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)
# embed = nn.Embedding(3, 50)
# x = torch.LongTensor([[0, 1, 2]])
# x_embed = embed(x)
# out, hidden = gru(x_embed)


gru = nn.GRU(input_size=5, hidden_size=6,
             num_layers=2,  # gru层数
             batch_first=False,  # 默认参数 True:(batch, seq, feature) False:True:( seq,batch, feature),
             bidirectional=False,  # 默认参数
             )

# N=batch size
# L=sequence length
# D=2 if bidirectional=True else 1
# Hin=input size
# Hout=outout size


input_ = torch.randn(1, 3, 5)  # (L,N,hin)(序列长度,batch size大小,输入维度大小)
h0 = torch.randn(2 * 1, 3, 6)  # (D∗num_layers,N,Hout)(是否双向乘以层数,batch size大小,输出维度大小)

output, hn = gru(input_, h0)
# output:[1, 3, 6] (L,N,D*Hout)=(1,3,1*6)
# hn:[2, 3, 6] (D*num_layers,N,Hout)(1*2,3,6)

print(output.shape, hn.shape)
# torch.Size([1, 3, 6]) torch.Size([2, 3, 6])
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/12/22下,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档