首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用numba加速代码?

使用numba加速代码?

提问于 2022-10-18 14:35:26
回答 0关注 0查看 313
代码语言:javascript
复制
import laspy
from scipy import spatial
import time
from osgeo import ogr
from osgeo import osr
from osgeo import gdal
from numba import jit,uint
import numpy as np
import open3d as o3d
import math
from osgeo import osr
import pandas as pd

las_file = laspy.file.File('H:\whu\IEEE_GRSS_Titan\plot3and5/plot3_space_feature - Cloud.las', mode='r')
x, y, z = las_file.x, las_file.y, las_file.z



img=gdal.Open('H:\whu\IEEE_GRSS_Titan\IEEE_Contest\plot3_HSI\plot3_warp1.dat')
icols=img.RasterXSize
irows=img.RasterYSize

geotransform = img.GetGeoTransform()
ix_min=geotransform[0]
iy_max=geotransform[3]
xlen=geotransform[1]
ylen=geotransform[5]
ix_max=ix_min+xlen*icols
iy_min=iy_max+ylen*irows
#读取48个波段
#读取成三维数组
bandNumber = img.RasterCount-2
v = [[[0] * bandNumber] * icols] * irows
v=np.zeros((bandNumber,irows,icols))#(维数,行数,列数)
#
for bandIndex in range(bandNumber):
    band = img.GetRasterBand(bandIndex+1)  # 取波段
    datas = band.ReadAsArray(0, 0, icols, irows)
    for row in range(irows):
        for col in range(icols):
            v[bandIndex][row][col] = datas[row][col]
v = v.astype(int)
### 空间分辨率设置
## get extent
x_max, y_max = las_file.header.max[0:2]
x_min, y_min = las_file.header.min[0:2]

pixelWidth, pixelHeight = 1, 1               # 分辨率 1m*1m
### calculate
xOrigin = x_min
yOrigin = y_min

cols = int(math.ceil((x_max - x_min) / pixelWidth))
rows = int(math.ceil((y_max - y_min) / abs(pixelHeight)))

##给每个点云赋予行列号
xOffset = (x - xOrigin) / pixelWidth
xOffset = xOffset.astype(int)
xOffset[np.where(xOffset == 595)]=594
yOffset = (y - yOrigin) / pixelHeight
yOffset = yOffset.astype(int)
yOffset[np.where(yOffset == 600)]=599



#对每个点云进行循环,通过行列索引取光谱值
s=time.time()
print("开始循环时间:",s)
a1 = np.zeros([bandNumber], dtype=np.int32)
@jit(uint[:,:](uint[:,:],uint[:,:],uint[:,:,:]))
def traversal_point(a,x,v):

    for i in range(len(x)):
        b=v[:,yOffset[i],xOffset[i]]
        #b.reshape(1,bandNumber)
        a=np.vstack((a,b))

    a=a[1:,:]
    return a

aa=traversal_point(a1,x,v)
ss=time.time()
print("循环用时:",ss-s)
#
#合并xyz
x=x.reshape(-1,1)
y=y.reshape(-1,1)
z=z.reshape(-1,1)
xyz=np.hstack((x,y,z))

final=np.hstack((xyz,aa))
aaa=pd.DataFrame(data=final,columns=['x','y','z','374.4nm','388.7nm','403.1','417.4','431.7','446.1','460.4','474.7','489.0',
                                    '503.4','517.7','532.0','546.3','560.6','574.9','589.2','603.6','617.9','632.2','646.5',
                                    '660.8','675.1','689.4','703.7','718.0','732.3','746.6','760.9','775.2','789.5','803.8',
                                    '818.1','832.4','846.7','861.1','875.4','889.7','904.0','918.3','932.6','946.9','961.2',
                                    '975.5','989.8','1004.2','1018.5','1032.8','1047.1'])
aaa.to_csv("高光谱点云plot3.csv",index=False)

Compilation is falling back to object mode WITH looplifting enabled because Function "traversal_point" failed type inference due to: No implementation of function Function(<function vstack at 0x0000019FF74D28C8>) found for signature:

>>> vstack(Tuple(array(uint32, 2d, A), array(uint32, 1d, A)))

There are 2 candidate implementations:

- Of which 2 did not match due to:

Overload in function 'vstack': File: numba\core\typing\npydecl.py: Line 836.

With argument(s): '(Tuple(array(uint32, 2d, A), array(uint32, 1d, A)))':

Rejected as the implementation raised a specific error:

TypeError: np.vstack(): all the input arrays must have same number of dimensions

raised from C:\Users\Wa Oh\AppData\Roaming\Python\Python36\site-packages\numba\core\typing\npydecl.py:760

During: resolving callee type: Function(<function vstack at 0x0000019FF74D28C8>)

During: typing of call at H:/whu/python_code/python_test/fusion_plot3.py (76)

File "fusion_plot3.py", line 76:

def traversal_point(a,x,v):

<source elided>

#b.reshape(1,bandNumber)

a=np.vstack((a,b))

^

@jit(uint[:,:](uint[:,:],uint[:,:],uint[:,:,:]))

H:/whu/python_code/python_test/fusion_plot3.py:70: NumbaWarning:

Compilation is falling back to object mode WITHOUT looplifting enabled because Function "traversal_point" failed type inference due to: Cannot determine Numba type of <class 'numba.core.dispatcher.LiftedLoop'>

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

@jit(uint[:,:](uint[:,:],uint[:,:],uint[:,:,:]))

C:\Users\Wa Oh\AppData\Roaming\Python\Python36\site-packages\numba\core\object_mode_passes.py:152: NumbaWarning: Function "traversal_point" was compiled in object mode without forceobj=True, but has lifted loops.

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

state.func_ir.loc))

C:\Users\Wa Oh\AppData\Roaming\Python\Python36\site-packages\numba\core\object_mode_passes.py:162: NumbaDeprecationWarning:

Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

state.func_ir.loc))

H:/whu/python_code/python_test/fusion_plot3.py:70: NumbaWarning:

Compilation is falling back to object mode WITHOUT looplifting enabled because Function "traversal_point" failed type inference due to: Cannot unify array(int32, 1d, C) and array(int32, 2d, C) for 'a', defined at H:/whu/python_code/python_test/fusion_plot3.py (73)

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

During: typing of assignment at H:/whu/python_code/python_test/fusion_plot3.py (76)

File "fusion_plot3.py", line 76:

def traversal_point(a,x,v):

<source elided>

#b.reshape(1,bandNumber)

a=np.vstack((a,b))

^

@jit(uint[:,:](uint[:,:],uint[:,:],uint[:,:,:]))

C:\Users\Wa Oh\AppData\Roaming\Python\Python36\site-packages\numba\core\object_mode_passes.py:152: NumbaWarning: Function "traversal_point" was compiled in object mode without forceobj=True.

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

state.func_ir.loc))

C:\Users\Wa Oh\AppData\Roaming\Python\Python36\site-packages\numba\core\object_mode_passes.py:162: NumbaDeprecationWarning:

Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "fusion_plot3.py", line 73:

def traversal_point(a,x,v):

<source elided>

for i in range(len(x)):

^

state.func_ir.loc))

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档