写在开头,这个实例有局限性,我在工作站上就无法正常运行。。。所谓的无法正常运行是指运行的时间长度和单进程是一致的。另外,进程数设为2所用的时间最短,不知道为什么。。。
单进程
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 11 15:02:37 2019
@author: Administrator
"""
from osgeo import gdal,ogr
import struct
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numba import jit
from pykrige.uk import UniversalKriging
from multiprocessing import Pool,Manager
import time
start = time.time()
@jit
def kri(a,b,c,d,e):
UK = UniversalKriging(a,b,c, variogram_model='exponential',drift_terms=['regional_linear'])
pre, ss = UK.execute('grid', d, e)
pr=pre.reshape(448,465)
pr=pr[::-1][:,::-1][:,::-1]
return(pr)
ylist=[]
xlist=[]
for i in range(448):
for j in range(465):
ylist.append(i+1)
xlist.append(j+1)
np.array(ylist.reverse())
xyarray=np.column_stack((xlist,ylist))
gridx = np.arange(0.0, 465.0, 1)
gridy = np.arange(0.0, 448.0, 1)
ds = gdal.Open(r'D:/Thesis/ML/himawari/china/yw5kmraster.tif')
bandg = ds.GetRasterBand(1)
elevationg = bandg.ReadAsArray()
[cols, rows] = elevationg.shape
format = "GTiff"#5
driver = gdal.GetDriverByName(format)#6
path='D:/Thesis/ML/himawari/dailytif'
files=os.listdir(path)
for i in files[:30]:
hiday=[]
src_ds=gdal.Open(path+'/'+i)
gt=src_ds.GetGeoTransform()
rb=src_ds.GetRasterBand(1)
shp_filename = r'D:/Thesis/ML/himawari/china/yw5kmwgs.shp'
ds1=ogr.Open(shp_filename)
lyr=ds1.GetLayer()
for feat in lyr:
geom = feat.GetGeometryRef()
mx,my=geom.GetX(), geom.GetY() #coord in map units
px = int((mx - gt[0]) / gt[1]) #x pixel
py = int((my - gt[3]) / gt[5]) #y pixel
structval=rb.ReadRaster(px,py,1,1,buf_type=gdal.GDT_UInt16) #Assumes 16 bit int aka 'short'
intval = struct.unpack('h' , structval) #use the 'short' format code (2 bytes) not int (4 bytes)
hiday.append(intval[0])
#arr365=np.tile(xyarray,(365,1))
hidatapre=np.column_stack((xyarray,hiday))
#
#for ij in range(365):
# hidatapre=hidatapre[ij*208320:(ij+1)*208320]
pretemp=[]
for ii in range(208320):
if hidatapre[ii,2]>0:
pretemp.append(hidatapre[ii])
tridata=np.array(pretemp).reshape((len(pretemp),3))
# hixy=tridata[:,:2]
# hiz=tridata[:,2]
row_rand_array = np.arange(tridata.shape[0])
np.random.shuffle(row_rand_array)
tridata1 = tridata[row_rand_array[0:200]]
pr=kri(tridata1[:,0],tridata1[:,1],tridata1[:,2],gridx,gridy)
#pr=kri(tridata1[:,0],tridata1[:,1],tridata1[:,2])
outDataRaster = driver.Create(r'D:/Thesis/ML/himawari/dailykri5km/'+str(i)+'.tif', rows, cols, 1, gdal.GDT_Int16)
outDataRaster.SetGeoTransform(ds.GetGeoTransform())
outDataRaster.SetProjection(ds.GetProjection())
outDataRaster.GetRasterBand(1).WriteArray(pr)
outDataRaster.FlushCache()
end = time.time()
print(end-start)
多进程
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 10 14:35:04 2019
@author: Administrator
"""
###krichahzi 葵花日
from osgeo import gdal,ogr
import struct
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numba import jit
from pykrige.uk import UniversalKriging
from multiprocessing import Pool,Manager
import time
start = time.time()
#@jit
def kri(a,b,c,d,e):
UK = UniversalKriging(a,b,c, variogram_model='exponential',drift_terms=['regional_linear'])
pre, ss = UK.execute('grid', d, e)
pr=pre.reshape(448,465)
pr=pr[::-1][:,::-1][:,::-1]
return(pr)
if __name__ == "__main__":
pool = Pool(5)
ylist=[]
xlist=[]
for i in range(448):
for j in range(465):
ylist.append(i+1)
xlist.append(j+1)
np.array(ylist.reverse())
xyarray=np.column_stack((xlist,ylist))
gridx = np.arange(0.0, 465.0, 1)
gridy = np.arange(0.0, 448.0, 1)
ds = gdal.Open(r'D:/Thesis/ML/himawari/china/yw5kmraster.tif')
bandg = ds.GetRasterBand(1)
elevationg = bandg.ReadAsArray()
[cols, rows] = elevationg.shape
format = "GTiff"#5
driver = gdal.GetDriverByName(format)#6
path='D:/Thesis/ML/himawari/dailytif'
files=os.listdir(path)
prlist=[]
for i in files[:30]:
src_ds=gdal.Open(path+'/'+i)
gt=src_ds.GetGeoTransform()
rb=src_ds.GetRasterBand(1)
shp_filename = r'D:/Thesis/ML/himawari/china/yw5kmwgs.shp'
ds1=ogr.Open(shp_filename)
lyr=ds1.GetLayer()
hiday=[]
for feat in lyr:
geom = feat.GetGeometryRef()
mx,my=geom.GetX(), geom.GetY() #coord in map units
px = int((mx - gt[0]) / gt[1]) #x pixel
py = int((my - gt[3]) / gt[5]) #y pixel
structval=rb.ReadRaster(px,py,1,1,buf_type=gdal.GDT_UInt16) #Assumes 16 bit int aka 'short'
intval = struct.unpack('h' , structval) #use the 'short' format code (2 bytes) not int (4 bytes)
hiday.append(intval[0])
#arr365=np.tile(xyarray,(365,1))
hidatapre=np.column_stack((xyarray,hiday))
pretemp=[]
for ii in range(208320):
if hidatapre[ii,2]>0:
pretemp.append(hidatapre[ii])
tridata=np.array(pretemp).reshape((len(pretemp),3))
# hixy=tridata[:,:2]
# hiz=tridata[:,2]
row_rand_array = np.arange(tridata.shape[0])
np.random.shuffle(row_rand_array)
tridata1 = tridata[row_rand_array[0:200]]
pr=pool.apply_async(kri,[tridata1[:,0],tridata1[:,1],tridata1[:,2],gridx,gridy])
prlist.append(pr)
pool.close()
pool.join()
for j in prlist:
a=np.array(j.get()).astype('float16')
#pr=kri(tridata1[:,0],tridata1[:,1],tridata1[:,2])
outDataRaster = driver.Create(r'D:/Thesis/ML/himawari/dailykri5km/'+str(prlist.index(j))+'.tif', rows, cols, 1, gdal.GDT_Int16)
outDataRaster.SetGeoTransform(ds.GetGeoTransform())
outDataRaster.SetProjection(ds.GetProjection())
outDataRaster.GetRasterBand(1).WriteArray(a)
outDataRaster.FlushCache()
end = time.time()
print(end-start)
#ds = gdal.Open(r'D:/Thesis/ML/himawari/china/yw5kmraster.tif')
#bandg = ds.GetRasterBand(1)
#elevationg = bandg.ReadAsArray()
#[cols, rows] = elevationg.shape
#format = "GTiff"#5
#driver = gdal.GetDriverByName(format)#6
#
#ylist=[]
#xlist=[]
#for i in range(448):
# for j in range(465):
# ylist.append(i+1)
# xlist.append(j+1)
#np.array(ylist.reverse())
#xyarray=np.column_stack((xlist,ylist))
#gridx = np.arange(0.0, 465.0, 1)
#gridy = np.arange(0.0, 448.0, 1)
#
#hiday=[]
#src_ds=gdal.Open('D:/Thesis/ML/himawari/dailytif/H08_20180101_0000_1DARP030_FLDK.02401_02401.nc.tif')
#gt=src_ds.GetGeoTransform()
#rb=src_ds.GetRasterBand(1)
#shp_filename = r'D:/Thesis/ML/himawari/china/yw5kmwgs.shp'
#ds=ogr.Open(shp_filename)
#lyr=ds.GetLayer()
#for feat in lyr:
# geom = feat.GetGeometryRef()
# mx,my=geom.GetX(), geom.GetY() #coord in map units
# px = int((mx - gt[0]) / gt[1]) #x pixel
# py = int((my - gt[3]) / gt[5]) #y pixel
# structval=rb.ReadRaster(px,py,1,1,buf_type=gdal.GDT_UInt16) #Assumes 16 bit int aka 'short'
# intval = struct.unpack('h' , structval) #use the 'short' format code (2 bytes) not int (4 bytes)
# hiday.append(intval[0])
##arr365=np.tile(xyarray,(365,1))
#hidatapre=np.column_stack((xyarray,hiday))
##
##for ij in range(365):
## hidatapre=hidatapre[ij*208320:(ij+1)*208320]
#pretemp=[]
#for ii in range(208320):
# if hidatapre[ii,2]>0:
# pretemp.append(hidatapre[ii])
#tridata=np.array(pretemp).reshape((len(pretemp),3))
## hixy=tridata[:,:2]
## hiz=tridata[:,2]
#pr=kri(tridata1[:,0],tridata1[:,1],tridata1[:,2])
#outDataRaster = driver.Create(r'D:/Thesis/ML/himawari/dailykri5km/'+str(i)+'.tif', rows, cols, 1, gdal.GDT_Int16)
#outDataRaster.SetGeoTransform(ds.GetGeoTransform())
#outDataRaster.SetProjection(ds.GetProjection())
#outDataRaster.GetRasterBand(1).WriteArray(pr)
#outDataRaster.FlushCache()
#del outDataRaster
运行时间
上为单进程,下为多进程,我暂时没懂到底是哪里存在不足,需要优化