首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如何将火炬损耗与模型参数连接起来?

如何将火炬损耗与模型参数连接起来?
EN

Stack Overflow用户
提问于 2022-08-20 02:26:43
回答 1查看 108关注 0票数 0

我知道在PyTorch中,优化器是通过

代码语言:javascript
代码运行次数:0
运行
复制
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练循环中,我们必须向后执行,并通过执行这两行来更新梯度。

代码语言:javascript
代码运行次数:0
运行
复制
loss.backward()
optimizer.step()

但是,损失是如何与模型参数相关联的呢?因为我们只定义优化器和模型之间的连接,而从不定义损失和模型之间的关联。

当我们执行loss.backward()时,PyTorch如何知道我们将为我们的model做反向传播

我把完整的代码放在这里作为上下文

代码语言:javascript
代码运行次数:0
运行
复制
import torch
import torch.nn as nn

X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)
X_test = torch.tensor([[5]], dtype=torch.float32)

n_sample, n_feature = X.shape
input_size = n_feature
output_size = n_feature

model = nn.Linear(input_size, output_size)

# Training
learning_rate = 0.01
n_iters = 100

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# print(model(X_test))
print(f"Prediction before training f(5) = {model(X_test).item():.3f}")

for epoch in range(n_iters):
  y_pred = model(X)

  # compute loss
  l = loss(Y, y_pred)

  # gradient
  l.backward()

  # update gradient
  optimizer.step()

  # zero gradient
  optimizer.zero_grad()

  if epoch % 10 == 0:
    w, b = model.parameters()
    # print(model.parameters())
    print(f"Epoch {epoch + 1}, w = {w[0][0].item():.3f}, loss = {l:.5f}")

print(f"Prediction after training f(5) = {model(X_test).item():.3f}")
EN

回答 1

Stack Overflow用户

发布于 2022-08-20 06:01:40

问:当我们执行loss.backward()时,PyTorch如何知道我们将为我们的模型做反向传播?

l = loss(Y, y_pred)行中,预测用于计算损失。这就有效地将模型参数与损失连接起来,使得loss.backward()可以对网络进行反向传播,计算参数梯度。注意,model()中的张量具有requires_grad=True,而不需要渐变的标签则不是这样。通过l.backward(),每个进入损失计算并需要一个梯度(在我们的例子中是模型参数)的张量值都被分配一个梯度。有关文档属性,请参阅grad

问:但是损失是如何与模型参数相关联的呢?

语句optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)连接优化器和模型参数。由于通过loss.backward()计算的梯度成为模型参数的属性,所以优化器可以访问它们。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73423703

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档