首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >运行GRU模型时问题;缺少forward()的参数

运行GRU模型时问题;缺少forward()的参数
EN

Stack Overflow用户
提问于 2020-10-29 03:17:47
回答 1查看 106关注 0票数 0

我正在编写一个GRU,当我试图进行预测时,我会得到一个错误,表明我需要为forward()定义h。在搜索和搜索堆栈溢出数小时后,我尝试了几次,并失去了耐心。

这是一堂课:

代码语言:javascript
运行
复制
class GRUNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob = 0.2):
        super(GRUNet, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True, dropout=drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x, h):
        out, h = self.gru(x,h)
        out = self.fc(self.relu(out[:,-1]))
        return out, h
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device)
        return hidden

然后,我在这里加载模型,并尝试进行预测。这两个都在同一个脚本中。

代码语言:javascript
运行
复制
inputs = np.load('.//Pred//input_list.npy')  
print(inputs.ndim, inputs.shape)
Gmodel = GRUNet(24,256,1,2)
Gmodel = torch.load('.//GRU//GRU_1028_48.pkl')
Gmodel.eval()
pred = Gmodel(inputs)

在不使用任何其他参数的情况下,我得到了以下内容:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File ".\grunet.py", line 136, in <module>
    pred = Gmodel(inputs)
  File "C:\Users\ryang\Anaconda-3\envs\tf-gpu\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'h'
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-29 08:58:07

您还需要提供隐藏状态,这通常是最初的所有零或简单的None

也就是说,您要么需要显式地提供这样的一个:

代码语言:javascript
运行
复制
hidden_state = torch.zeros(size=(num_layers*direction, batch_size, hidden_dim)).to(device)
pred = Gmodel(inputs, hidden_state)

或者简单地做:

代码语言:javascript
运行
复制
hidden_state = None 
pred = Gmodel(inputs, hidden_state)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64584493

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档