我正在用python编写代码,使用numpy。我想优化一个这样的公式,我用了一张图片,以达到可再生的目的。

在这个例子中,时间t来自不同的列表,用上标表示。这里对应的向量是T_t,它是一个列表。这是我的原始代码:
def first_version(m, n, k, T_t, BETA):
if k == 1:
return 0
ans = 0
for i in range(len(T_t[n])):
if T_t[n][i] < T_t[m][k - 1]:
ans += (T_t[m][k - 1] - T_t[n][i]) * np.exp(-BETA[m, n] * (T_t[m][k - 1] - T_t[n][i]))
else:
break
return ans最后的休息让我有了一些时间。我有一个很好的主意,用numpy库来提高性能:
def second_version(m, n, k, T_t, BETA):
if k == 1:
return 0
the_times = np.maximum( T_t[m][k - 1] - np.array(T_t[n]) , 0 )
ans = sum(the_times * np.exp( -BETA[m, n] * the_times ))
return ans为了便于比较,第二种算法运行速度快了100倍。有可能做得更好吗?特别是,我感到遗憾的是,numpy计算整个向量的最大值时,可能有一半是0。
你知道如何改进这些代码吗?
我在nr2代码中忘记了一个和,这减慢了代码的速度,使它只快了20倍。
发布于 2020-04-29 00:44:48
我有两个主要建议:
使用second_version
sum()的速度是原来的三倍,numba.jit的速度再次提高了大约8倍。(实际上,您可以jit编译任何版本,并以相同的速度结束)。完整代码示例:
import numpy as np
import numba
import timeit
def first_version(m, n, k, T_t, BETA):
if k == 1:
return 0
ans = 0
for i in range(len(T_t[n])):
if T_t[n][i] < T_t[m][k - 1]:
ans += (T_t[m][k - 1] - T_t[n][i]) * np.exp(-BETA[m, n] * (T_t[m][k - 1] - T_t[n][i]))
else:
break
return ans
def second_version(m, n, k, T_t, BETA):
if k == 1:
return 0
the_times = np.maximum( T_t[m][k - 1] - np.array(T_t[n]) , 0 )
ans = np.sum(the_times * np.exp( -BETA[m, n] * the_times ))
return ans
def jit_version(m, n, k, T_t, BETA):
# wrapper makes it to that numba doesn't have to deal with
# the list-of-arrays data type
return jit_version_core(k, T_t[m], T_t[n], BETA[m, n])
@numba.jit(nopython=True)
def jit_version_core(k, t1, t2, b):
if k == 1:
return 0
ans = 0
for i in range(len(t2)):
if t2[i] < t1[k - 1]:
ans += (t1[k - 1] - t2[i]) * np.exp(-b * (t1[k - 1] - t2[i]))
else:
break
return ans
N = 10000
t1 = np.cumsum(np.random.random(size=N))
t2 = np.cumsum(np.random.random(size=N))
beta = np.random.random(size=(2, 2))
for fn in ['first_version', 'second_version', 'jit_version']:
print("------", fn)
v = globals()[fn](0, 1, len(t1), [t1, t2], beta)
t = timeit.timeit('%s(0, 1, len(t1), [t1, t2], beta)' % fn, number=100, globals=globals())
print("output:", v, "time:", t)以及产出:
------ first_version
output: 3.302938986817431 time: 2.900316455983557
------ second_version
output: 3.3029389868174306 time: 0.12064526398899034
------ jit_version
output: 3.302938986817431 time: 0.013476221996825188https://stackoverflow.com/questions/61491515
复制相似问题