首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >快速累积和幂算子

快速累积和幂算子
EN

Stack Overflow用户
提问于 2019-04-18 11:14:23
回答 2查看 1.3K关注 0票数 0

我有一种预测算法,它使用以下代码处理时间序列在给定范围内的趋势:

代码语言:javascript
运行
复制
import numpy as np
horizon = 91
phi = 0.2
trend = -0.004
trend_up_to_horizon = np.cumsum(phi ** np.arange(horizon) + 1) * self.trend

在本例中,前两个trend_up_horizon值是:

代码语言:javascript
运行
复制
array([-0.008 , -0.0128])

是否有一种计算速度更快的方法来实现这一点?目前,这需要很长时间,因为我猜使用np.cumsum方法和**运算符是很昂贵的。

谢谢你的帮助

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-04-18 14:05:34

你可以用Cython来使它更快一点,但它并不多。

在基本的%timeit上运行np.cumsum(phi ** np.arange(horizon) + 1) * trend表示,在我的笔记本电脑上需要17.5秒,这并不多

与此等效的Cython版本如下:

代码语言:javascript
运行
复制
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
def do_cumsum(size_t horizon, double phi, double trend):
    cdef np.ndarray[double, ndim=1] out = np.empty(horizon, dtype=np.float)
    cdef double csum = 0
    cdef int i

    for i in range(horizon):
        csum += phi ** i + 1
        out[i] = csum * trend

    return out

这将do_cumsum(horizon, phi, trend)的时间减少到6.9秒,而如果切换到单精度/32位浮点数,则下降到4.5秒。

尽管如此,微秒并不多,你最好把精力集中在其他地方

票数 2
EN

Stack Overflow用户

发布于 2019-04-19 16:40:45

您可以更快地完成这个操作。正如您已经假定的,(不必要的)电源运算符是这里的主要问题。

此外,Numpy没有power的特殊实现(float64,int64),其中指数是一个小的正整数。相反,Numpy总是计算功率(float64,float64),这是一个更复杂的任务。

Numba为简单的case power(float64,int64)提供了一个特殊的实现,因此让我们在第一步中尝试这一点。

优先逼近

代码语言:javascript
运行
复制
import numpy as np
import numba as nb

horizon = 91
phi = 0.2
trend = -0.004

@nb.njit()
def pow_cumsum(horizon,phi,trend):
    out=np.empty(horizon)
    csum=0.
    for i in range(horizon):
        csum+=phi**i+1
        out[i]=csum*trend
    return out

如前所述,在直接计算功率之前,可以重写该算法,以完全避免这种情况。

第二次逼近

代码语言:javascript
运行
复制
@nb.njit()
def pow_cumsum_2(horizon,phi,trend):
    out=np.empty(horizon)

    out[0]=2.*trend
    TMP=2.
    val=phi
    for i in range(horizon-1):
        TMP=(val+1+TMP)
        out[i+1]=TMP*trend
        val*=phi
    return out

时间

代码语言:javascript
运行
复制
%timeit np.cumsum(phi ** np.arange(horizon) + 1) * trend
7.44 µs ± 89.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit pow_cumsum(horizon,phi,trend)
882 ns ± 4.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit pow_cumsum_2(horizon,phi,trend)
559 ns ± 3.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55744822

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档