nn.LSTM
PyTorch LSTM API文档
?...torch.randn(10, 3, 100) # 一个句子10个单词,送进去3条句子,每个单词用一个100维的vector表示
out, (h, c) = lstm(x)
print(out.shape...h.shape, c.shape)
# torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])
nn.LSTMCell
PyTorch...和RNNCell类似,输入input_size的shape是[batch, input_size],输出$h_t$和$c_t$的shape是[batch, hidden_size]
看个一层的LSTM的例子...x:
h, c = cell(xt, [h, c])
print(h.shape, c.shape) # torch.Size([3, 20]) torch.Size([3, 20])
两层的LSTM