首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Hydra中使用OmegaConf自定义插值

如何在Hydra中使用OmegaConf自定义插值
EN

Stack Overflow用户
提问于 2020-05-10 23:53:10
回答 1查看 1.1K关注 0票数 2

如何在Hydra中使用OmegaConf custom interpolation

一些背景知识:可以为平方根定义自定义插值:

代码语言:javascript
复制
from omegaconf import OmegaConf
import math
OmegaConf.register_resolver("sqrt", lambda x: math.sqrt(float(x)))

并将其与以下config.yaml一起使用:

代码语言:javascript
复制
foo: ${sqrt:9}

加载和打印foo:

代码语言:javascript
复制
cfg = OmegaConf.load('config.yaml')
print(cfg.foo)

输出3.0

在使用Hydra尝试此操作时:

代码语言:javascript
复制
import hydra

@hydra.main(config_path="config.yaml")
def main(cfg):
  print(cfg.foo)

if __name__ == "__main__":
  main()

我得到以下错误:

代码语言:javascript
复制
Unsupported interpolation type sqrt
    full_key: foo
    reference_type=Optional[Dict[Any, Any]]
    object_type=dict

如何在使用Hydra时注册我的解析器?

EN

回答 1

Stack Overflow用户

发布于 2020-05-12 02:27:33

您可以提前注册您的自定义解析器:

config.yaml:

代码语言:javascript
复制
foo: ${sqrt:9}

main.py:

代码语言:javascript
复制
from omegaconf import OmegaConf
import math
import hydra

OmegaConf.register_new_resolver("sqrt", lambda x: math.sqrt(float(x)))

@hydra.main(config_path=".", config_name="config")
def main(cfg):
  print(cfg.foo)

if __name__ == "__main__":
  main()

这将打印3.0。

这种方法在Compose API上也能很好地工作。自定义解析器的评估发生在您访问节点时(懒惰地)。你只需要在访问解析器之前注册它。

票数 10
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61714759

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档