我试图用numba.guvectorize
编译函数,但遇到了索引超出范围的异常
def compute_kinetic_energy(velocity, mass, ke):
ke = 0.0
# Increase kinetic energy
for i in prange(velocity.shape[0]):
for x in range(2):
ke += velocity[i, x] * velocity[i, x] * mass[i]
在我尝试进行guvectorization之前,上面的函数运行得很完美:
compute_kinetic_energy_gu = guvectorize(['float32[:,:], float32[:], float32'],
'(nnodes,dim),(nnodes),()->()',
target='cpu',
nopython=True)(compute_kinetic_energy)
执行代码:
import numpy as np
from numba import guvectorize, prange
nnodes = 1000
mass = np.ones(nnodes, dtype=np.float32)
velocity = np.zeros(nnodes*3, dtype=np.float32).reshape(nnodes,3)
compute_kinetic_energy(velocity, mass, ke) # Works :)
compute_kinetic_energy_gu(velocity, mass, ke) # Do not work :(
发布于 2019-11-28 16:32:19
找到错误了。我有一个错误的函数签名。正确的判断是
compute_kinetic_energy_gu = guvectorize(['float32[:,:], float32[:], float32'],
'(nnodes,dim),(nnodes)->()',
target='cpu',
nopython=True)(compute_kinetic_energy)
https://stackoverflow.com/questions/59090806
复制相似问题