首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >包含求和和指数的命令的优化

包含求和和指数的命令的优化
EN

Stack Overflow用户
提问于 2020-04-28 23:11:08
回答 1查看 57关注 0票数 1

我正在用python编写代码,使用numpy。我想优化一个这样的公式,我用了一张图片,以达到可再生的目的。

在这个例子中,时间t来自不同的列表,用上标表示。这里对应的向量是T_t,它是一个列表。这是我的原始代码:

代码语言:javascript
运行
复制
def first_version(m, n, k, T_t, BETA):
    if k == 1:
        return 0
    ans = 0
    for i in range(len(T_t[n])):
        if T_t[n][i] < T_t[m][k - 1]:
            ans += (T_t[m][k - 1] - T_t[n][i]) * np.exp(-BETA[m, n] * (T_t[m][k - 1] - T_t[n][i]))
        else:
            break
    return ans

最后的休息让我有了一些时间。我有一个很好的主意,用numpy库来提高性能:

代码语言:javascript
运行
复制
def second_version(m, n, k, T_t, BETA):
    if k == 1:
        return 0
    the_times = np.maximum( T_t[m][k - 1] - np.array(T_t[n]) , 0  )
    ans = sum(the_times * np.exp( -BETA[m, n] * the_times  ))
    return ans

为了便于比较,第二种算法运行速度快了100倍。有可能做得更好吗?特别是,我感到遗憾的是,numpy计算整个向量的最大值时,可能有一半是0。

你知道如何改进这些代码吗?

我在nr2代码中忘记了一个和,这减慢了代码的速度,使它只快了20倍。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-04-29 00:44:48

我有两个主要建议:

使用second_version

  • Using
  • 而不是sum()的速度是原来的三倍,numba.jit的速度再次提高了大约8倍。(实际上,您可以jit编译任何版本,并以相同的速度结束)。

完整代码示例:

代码语言:javascript
运行
复制
import numpy as np
import numba
import timeit


def first_version(m, n, k, T_t, BETA):
    if k == 1:
        return 0
    ans = 0
    for i in range(len(T_t[n])):
        if T_t[n][i] < T_t[m][k - 1]:
            ans += (T_t[m][k - 1] - T_t[n][i]) * np.exp(-BETA[m, n] * (T_t[m][k - 1] - T_t[n][i]))
        else:
            break
    return ans


def second_version(m, n, k, T_t, BETA):
    if k == 1:
        return 0
    the_times = np.maximum( T_t[m][k - 1] - np.array(T_t[n]) , 0  )
    ans = np.sum(the_times * np.exp( -BETA[m, n] * the_times  ))
    return ans


def jit_version(m, n, k, T_t, BETA):
    # wrapper makes it to that numba doesn't have to deal with 
    # the list-of-arrays data type
    return jit_version_core(k, T_t[m], T_t[n], BETA[m, n])


@numba.jit(nopython=True)
def jit_version_core(k, t1, t2, b):
    if k == 1:
        return 0
    ans = 0
    for i in range(len(t2)):
        if t2[i] < t1[k - 1]:
            ans += (t1[k - 1] - t2[i]) * np.exp(-b * (t1[k - 1] - t2[i]))
        else:
            break
    return ans


N = 10000
t1 = np.cumsum(np.random.random(size=N))
t2 = np.cumsum(np.random.random(size=N))
beta = np.random.random(size=(2, 2))

for fn in ['first_version', 'second_version', 'jit_version']:
    print("------", fn)
    v = globals()[fn](0, 1, len(t1), [t1, t2], beta)
    t = timeit.timeit('%s(0, 1, len(t1), [t1, t2], beta)' % fn, number=100, globals=globals())
    print("output:", v, "time:", t)

以及产出:

代码语言:javascript
运行
复制
------ first_version
output: 3.302938986817431 time: 2.900316455983557
------ second_version
output: 3.3029389868174306 time: 0.12064526398899034
------ jit_version
output: 3.302938986817431 time: 0.013476221996825188
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61491515

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档