首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用Numba加速以下代码,主要是循环部分很慢?

使用Numba加速以下代码,主要是循环部分很慢?

提问于 2022-10-18 14:26:50
回答 0关注 0查看 35
代码语言:javascript
复制
import laspy
from scipy import spatial
import time
from osgeo import ogr
from osgeo import osr
from osgeo import gdal
from numba import jit
import numpy as np
import open3d as o3d
import math
from osgeo import osr
import pandas as pd

las_file = laspy.file.File('H:\whu\IEEE_GRSS_Titan\plot3and5/plot3_space_feature - Cloud.las', mode='r')
x, y, z = las_file.x, las_file.y, las_file.z



img=gdal.Open('H:\whu\IEEE_GRSS_Titan\IEEE_Contest\plot3_HSI\plot3_warp1.dat')
icols=img.RasterXSize
irows=img.RasterYSize

geotransform = img.GetGeoTransform()
ix_min=geotransform[0]
iy_max=geotransform[3]
xlen=geotransform[1]
ylen=geotransform[5]
ix_max=ix_min+xlen*icols
iy_min=iy_max+ylen*irows
#读取48个波段
#读取成三维数组
bandNumber = img.RasterCount-2
v = [[[0] * bandNumber] * icols] * irows
v=np.zeros((bandNumber,irows,icols))#(维数,行数,列数)
#
for bandIndex in range(bandNumber):
    band = img.GetRasterBand(bandIndex+1)  # 取波段
    datas = band.ReadAsArray(0, 0, icols, irows)
    for row in range(irows):
        for col in range(icols):
            v[bandIndex][row][col] = datas[row][col]
v = v.astype(int)
### 空间分辨率设置
## get extent
x_max, y_max = las_file.header.max[0:2]
x_min, y_min = las_file.header.min[0:2]

pixelWidth, pixelHeight = 1, 1               # 分辨率 1m*1m
### calculate
xOrigin = x_min
yOrigin = y_min

cols = int(math.ceil((x_max - x_min) / pixelWidth))
rows = int(math.ceil((y_max - y_min) / abs(pixelHeight)))

##给每个点云赋予行列号
xOffset = (x - xOrigin) / pixelWidth
xOffset = xOffset.astype(int)
xOffset[np.where(xOffset == 595)]=594
yOffset = (y - yOrigin) / pixelHeight
yOffset = yOffset.astype(int)
yOffset[np.where(yOffset == 600)]=599



#对每个点云进行循环,通过行列索引取光谱值
s=time.time()
print("开始循环时间:",s)
a1 = np.zeros([bandNumber], dtype=np.int32)
#@jit
def traversal_point(a,x,v):

    for i in range(len(x)):
        b=v[:,yOffset[i],xOffset[i]]
        #b.reshape(1,bandNumber)
        a=np.vstack((a,b))

    a=a[1:,:]
    return a

aa=traversal_point(a1,x,v)
ss=time.time()
print("循环用时:",ss-s)
#
#合并xyz
x=x.reshape(-1,1)
y=y.reshape(-1,1)
z=z.reshape(-1,1)
xyz=np.hstack((x,y,z))

final=np.hstack((xyz,aa))
aaa=pd.DataFrame(data=final,columns=['x','y','z','374.4nm','388.7nm','403.1','417.4','431.7','446.1','460.4','474.7','489.0',
                                    '503.4','517.7','532.0','546.3','560.6','574.9','589.2','603.6','617.9','632.2','646.5',
                                    '660.8','675.1','689.4','703.7','718.0','732.3','746.6','760.9','775.2','789.5','803.8',
                                    '818.1','832.4','846.7','861.1','875.4','889.7','904.0','918.3','932.6','946.9','961.2',
                                    '975.5','989.8','1004.2','1018.5','1032.8','1047.1'])
aaa.to_csv("高光谱点云plot3.csv",index=False)

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档