首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >RLLib模型中传递自定义模型参数的正确方法?

RLLib模型中传递自定义模型参数的正确方法?
EN

Stack Overflow用户
提问于 2020-07-13 16:23:00
回答 2查看 2.5K关注 0票数 1

我有一个基本的自定义模型,它本质上只是默认的RLLib完全连接模型(https://github.com/ray-project/ray/blob/master/rllib/models/tf/fcnet.py)的复制粘贴,并且我通过一个配置文件传递自定义模型参数,其中包含一个"custom_model_config": {}字典。此配置文件如下所示:

代码语言:javascript
运行
复制
# Custom RLLib model
custom_model: test_model

# Custom options
custom_model_config:
  ## Default fully connected network settings
  # Nonlinearity for fully connected net (tanh, relu)
  "fcnet_activation": "tanh"
  # Number of hidden layers for fully connected net
  "fcnet_hiddens": [256, 256]
  # For DiagGaussian action distributions, make the second half of the model
  # outputs floating bias variables instead of state-dependent. This only
  # has an effect is using the default fully connected net.
  "free_log_std": False
  # Whether to skip the final linear layer used to resize the hidden layer
  # outputs to size `num_outputs`. If True, then the last hidden layer
  # should already match num_outputs.
  "no_final_linear": False
  # Whether layers should be shared for the value function.
  "vf_share_layers": True

  ## Additional settings
  # L2 regularization value for fully connected layers
  "l2_reg_value": 0.1

当我使用此设置启动培训过程时,RLLib向我发出以下警告:

自定义ModelV2应该将所有自定义选项接受为**kwargs,而不是在config‘Custom _config’中期待它们!

我理解**kwargs的功能,但我不知道如何使用自定义的RLLib模型来修复这个警告。有什么想法吗?

EN

回答 2

Stack Overflow用户

发布于 2021-01-19 16:00:01

TL;DR:将**customized_model_kwargs添加到您的网络__init__中,然后从该配置中获得自定义配置。

我会解释你该怎么做才能避免这个警告。

当您使用自定义网络时,您肯定是在使用以下内容:

代码语言:javascript
运行
复制
policy.target_q_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=1,
        model_config=config["model"],
        framework="torch",
        name=Q_TARGET_SCOPE)

这个模型是由Ray实例化的(参见ModelCatalog 模块/ray/rllib/model/Catal.html):

代码语言:javascript
运行
复制
instance = model_cls(obs_space, action_space, num_outputs,
                                         model_config, name,
                                         **customized_model_kwargs)

因此,您应该像这样声明您的网络:

代码语言:javascript
运行
复制
  def __init__(self, obs_space: gym.spaces.Space,
               action_space: gym.spaces.Space, num_outputs: int,
               model_config: ModelConfigDict, name: str, **customized_model_kwargs):
    TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                          model_config, name)
    nn.Module.__init__(self)

注意customized_model_kwargs arg。

然后,您可以使用customized_model_kwargs["your_key"]访问您的自定义信任。

注: TF的情况类似。

票数 1
EN

Stack Overflow用户

发布于 2021-11-15 21:12:02

您可以通过设置"custom_model_config"传递自定义模型参数,这是模型配置的一部分。默认情况下它是空的。

来自文档

代码语言:javascript
运行
复制
# Name of a custom model to use
"custom_model": None,
# Extra options to pass to the custom classes. These will be available to
# the Model's constructor in the model_config field. Also, they will be
# attempted to be passed as **kwargs to ModelV2 models. For an example,
# see rllib/models/[tf|torch]/attention_net.py.
"custom_model_config": {},

自定义模型的构造函数中有一个model_config参数。您可以通过model_config["custom_model_config"]访问模型参数。

示例:

代码语言:javascript
运行
复制
# setting custom params
config = ppo.DEFAULT_CONFIG.copy()
config["model"] = {
  "custom_model": MyModel,
  "custom_model_config": {
    "my_param": 42
  }
}
...
trainer = ppo.PPOTrainer(config=config, env=MyEnv)

内部MyModel

代码语言:javascript
运行
复制
class MyModel(TFModelV2):
  def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
    self.my_param = model_config["custom_model_config"]["my_param"]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62880095

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档