首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >numba和闭包中定义的变量

numba和闭包中定义的变量
EN

Stack Overflow用户
提问于 2022-10-21 17:31:41
回答 1查看 43关注 0票数 0

我需要创建一些被字典参数化的numba函数。此字典位于工厂函数的命名空间中,我希望在实际函数中使用它。问题是我得到了一个NotImplemented错误,这个问题是否有一个解决方案,甚至只是一个解决办法?

我已经将我的代码简化为这个示例:

目标裁剪函数采取:

一个选择器,它决定字典中的哪个范围,它应该使用(series)

  • a值来比较字典中的范围(在实际的应用程序中,大约有十几个这样的范围)

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from numba.core import types
from numba.typed import Dict

dict_ranges = Dict.empty(
    key_type=types.int64,
    value_type=types.Tuple((types.float64, types.float64))
    )

dict_ranges[3] = (1, 3)

def MB_cut_factory(dict_ranges):
    def cut(series, value):
        return dict_ranges[series][0] < value < dict_ranges[series][1]
    return cut

MB_cut_factory(dict_ranges)(3,2)
True

在纯Python中,它工作得很好。用numba

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
njit(MB_cut_factory(dict_ranges))(3,2)
---------------------------------------------------------------------------
NumbaNotImplementedError                  Traceback (most recent call last)
Cell In [107], line 1
----> 1 njit(MB_cut_factory(dict_ranges))(3,2)

File ~/micromamba/envs/root/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/micromamba/envs/root/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f8c054fefd0>, (DictType[int64,UniTuple(float64 x 2)]<iv=None>,)
During: lowering "$2load_deref.0 = freevar(dict_ranges: {3: (1.0, 3.0)})" at /tmp/ipykernel_2259/3022317309.py (3)

在参数是简单类型的简单情况下,这样做很好:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def MB_cut_factory(limit):
    def cut(value):
        return value < limit
    return cut

MB_cut_factory(4)(3)

njit(MB_cut_factory(4))(3)
EN

回答 1

Stack Overflow用户

发布于 2022-10-24 19:52:55

我已经找到了适用于我的情况的解决方案,但使用了exec

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def MB_cut_factory(dict_ranges):
    exec("def cut(series, value):\n    dict_ranges=" +\
         dict_ranges.__str__() +\
        "\n    return dict_ranges[series][0] < value < dict_ranges[series][1]", globals())
    return cut

MB_cut_factory(dict_ranges)(3,2)
True
njit(MB_cut_factory(dict_ranges))(3,2)
True

如果有人有一个不那么尴尬的解决这个问题,请!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74160505

复制
相关文章

相似问题

添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文