我正在写一个程序,其中在某些点上使用了Scipy CubicSpline例程,因为使用了Scipy例程,所以我不能在我的整个程序上使用Numba @jit。
我最近遇到了@overload功能,我想知道它是否可以这样使用,
from numba.extending import overload
from numba import jit
from scipy.interpolate import CubicSpline
import numpy as np
x = np.arange(10)
y = np.sin(x)
xs = np.arange(-0.5, 9.6, 0.1)
def Spline_interp(xs,x,y):
cs = CubicSpline(x, y)
ds = cs(xs)
return ds
@overload(Spline_interp)
def jit_Spline_interp(xs,x,y):
ds = Spline_interp(xs,x,y)
def jit_Spline_interp_impl(xs,x, y):
return ds
return jit_Spline_interp_impl
@jit(nopython=True)
def main():
# other codes compatible with @njit
ds = Spline_interp(xs,x,y)
# other codes compatible with @njit
return ds
print(main())
如果我对@overload特性的理解是错误的,请纠正我,以及在Numba中使用这样的Scipy库的可能解决方案是什么。
发布于 2019-12-30 11:34:53
您可能需要回退到object-mode (本地,就像建议的@max9111 ),或者自己在Numba中实现CubicSpline
函数。
据我所知,重载修饰器"only“使编译器知道,如果它遇到重载函数,它可以使用与Numba兼容的实现。它不会神奇地将函数转换为Numba兼容。
有一个包向Numba公开了一些Scipy功能,但这似乎还为时过早,到目前为止只包含一些scipy.special函数。
https://stackoverflow.com/questions/59476751
复制相似问题