首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用PyTorch初始化权重和偏差-如何校正尺寸?

使用PyTorch初始化权重和偏差-如何校正尺寸?
EN

Stack Overflow用户
提问于 2018-07-24 02:02:04
回答 1查看 1.4K关注 0票数 0

使用此模型,我尝试使用预定义的权重和偏差来初始化网络:

代码语言:javascript
复制
dimensions_input = 10
hidden_layer_nodes = 5
output_dimension = 10

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(dimensions_input,hidden_layer_nodes)
        self.linear2 = torch.nn.Linear(hidden_layer_nodes,output_dimension)

        self.linear.weight = torch.nn.Parameter(torch.zeros(dimensions_input,hidden_layer_nodes))
        self.linear.bias = torch.nn.Parameter(torch.ones(hidden_layer_nodes))

        self.linear2.weight = torch.nn.Parameter(torch.zeros(dimensions_input,hidden_layer_nodes))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hidden_layer_nodes))

    def forward(self, x):
        l_out1 = self.linear(x)
        y_pred = self.linear2(l_out1)
        return y_pred

model = Model()

criterion = torch.nn.MSELoss(size_average = False)
optim = torch.optim.SGD(model.parameters(), lr = 0.00001)

def train_model():
    y_data = x_data.clone()
    for i in range(10000):
        y_pred = model(x_data)
        loss = criterion(y_pred, y_data)

        if i % 5000 == 0:
            print(loss)
        optim.zero_grad()

        loss.backward()
        optim.step()

RuntimeError:

张量的扩展大小(10)必须与非单一维度1的现有大小(5)匹配

我的尺寸看起来是正确的,因为它们与相应的线性层匹配?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-07-24 03:32:04

由于没有定义x_data,所以提供的代码不能运行,所以我不能确定这是不是问题所在,但是有一件事让我印象深刻,那就是您应该替换

代码语言:javascript
复制
self.linear2.weight = torch.nn.Parameter(torch.zeros(dimensions_input,hidden_layer_nodes))
self.linear2.bias = torch.nn.Parameter(torch.ones(hidden_layer_nodes))

使用

代码语言:javascript
复制
self.linear2.weight = torch.nn.Parameter(torch.zeros(hidden_layer_nodes, output_dimension))
self.linear2.bias = torch.nn.Parameter(torch.ones(output_dimension))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51484793

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档