import numpy as np
import pykrige.kriging_tools as kt
from pykrige.ok import OrdinaryKriging
import matplotlib.pyplot as plt
import geopandas as gpd
from osgeo import gdal, ogr, osr
def write_img(filename, description, im_data):
(upper_left_x, upper_left_y, pixel_width, pixel_height) = description
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform([upper_left_x, pixel_width, 0, upper_left_y, 0, pixel_height])
dataset.SetProjection("GCS_WGS_1984")
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description='''This is a code for training model.''')
parser.add_argument('--input_shp_path', type=str, default='./data/test.shp', help='path to the root of data')
parser.add_argument('--output_raster_path', type=str, default='mask.tif', help='path to the root of data')
parser.add_argument('--pixel_size', type=float, default=5e-5, help='input channels')
args = parser.parse_args()
gdf = gpd.read_file(args.input_shp_path) #LINESTRING
box = gdf.total_bounds
proj = gdf.crs
h = int((box[2] - box[0]) / args.pixel_size)
w = int((box[3] - box[1]) / args.pixel_size)
print(f"{h}, {w}")
grid_lon = np.linspace(box[0], box[2], h)
grid_lat = np.linspace(box[1], box[3], w)
val = gdf._values
data = val[:, 1]
loc = val[:, 2]
lons = [x.x for x in loc]
lats = [y.y for y in loc]
OK = OrdinaryKriging(
lons,
lats,
data,
variogram_model="linear",
verbose=False,
enable_plotting=False,
)
z, ss = OK.execute("grid", grid_lon, grid_lat)
description = (box[0], box[1], args.pixel_size, args.pixel_size)
write_img(args.output_raster_path, description, z)