首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如果你有多个神经网络,PyTorch如何知道训练损失将被传播回哪一个神经网络?

如果你有多个神经网络,PyTorch如何知道训练损失将被传播回哪一个神经网络?
EN

Stack Overflow用户
提问于 2022-01-25 14:53:05
回答 1查看 200关注 0票数 0

我想在另外两个神经网络的帮助下训练一个神经网络,这些神经网络已经经过了训练和测试。我要训练的网络的输入同时输入到第一个静态网络。我想训练的网络的输出输入到第二个静态网络。损失应计算在静态网络的输出上,并传播回列车网络。

代码语言:javascript
运行
复制
# Initialization
var_model_statemapper = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 8)])

var_model_panda = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 27)])
var_model_panda.load_state_dict(torch.load("panda.pth"))

var_model_ur5 = NeuralNetwork(8, [('linear', 8), ('relu', None), ('dropout', 0.2), ('linear', 24)])
var_model_ur5.load_state_dict(torch.load("ur5.pth"))

var_loss_function = torch.nn.MSELoss()
var_optimizer = torch.optim.Adam(var_model_statemapper.parameters(), lr=0.001)

# Forward Propagation
var_panda_output = var_model_panda(var_statemapper_input)
var_ur5_output = var_model_ur5(var_statemapper_output)
var_train_loss = var_loss_function(var_panda_output, var_ur5_output)

# Backward Propagation
var_optimizer.zero_grad()
var_train_loss.backward()
var_optimizer.step()

你可以看到,"var_model_statemapper“是需要训练的网络。网络"var_model_panda“和"var_model_ur5”被初始化,它们的state_dicts被从相应的".pth“文件中读取,因此这些网络需要是静态的。我的主要问题是,哪一个网络是在反向传播中更新的?仅仅是"var_model_statemapper“还是所有的网络?如果"var_model_statemapper“没有更新,我该如何实现?PyTorch知道仅仅从优化器的初始化中更新哪个网络吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-25 15:25:28

将管道正规化以便对设置有一个好的了解:

代码语言:javascript
运行
复制
x --- | state_mapper | --> y --- | ur5 | --> ur5_out
 \                                              |
  \                                             ↓
   \--- | panda | --> panda_out ----------- | loss_fn | --> loss

下面是您提供的行所发生的情况:

代码语言:javascript
运行
复制
var_optimizer.zero_grad()  # 0.
var_train_loss.backward()  # 1.
var_optimizer.step()       # 2.

  1. 在优化器上调用zero_grad将清除该优化器中包含的所有参数梯度的缓存。在您的示例中,您已经使用来自var_optimizer (您想要优化的模型)的参数注册了

当您通过调用推断损失并在其上反向传播时,梯度将通过所有三个模型的参数传播。

然后,

  1. 调用优化器上的step将更新在您调用它的优化器中注册的参数。在您的示例中,这意味着var_train_loss).

将单独使用步骤1中计算的梯度更新模型var_model_statemapper的所有参数(即在上使用backward调用)。

总之,当前的方法只更新var_model_statemapper的参数。理想情况下,可以通过将模型var_model_pandavar_model_ur5的参数的requires_grad标志设置为False来冻结它们。这将节省推理和训练的速度,因为它们的梯度在反向传播过程中不会被计算和存储。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70850782

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档