-
Notifications
You must be signed in to change notification settings - Fork 386
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add some raster-related util functions
- Loading branch information
Showing
4 changed files
with
164 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import rasterio as rio | ||
import rasterio.windows as rio_windows | ||
from rasterio.transform import from_origin | ||
|
||
from rastervision.pipeline.file_system.utils import ( | ||
file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy) | ||
from rastervision.core.box import Box | ||
|
||
|
||
def write_to_rio_dataset(dataset: rio.DatasetReader, | ||
arr: np.ndarray, | ||
window: Optional[Box] = None) -> None: | ||
"""Write (H, W[, C]) array out to a rasterio dataset.""" | ||
if window is not None: | ||
window = window.rasterio_format() | ||
if len(arr.shape) == 2: | ||
dataset.write_band(1, arr, window=window) | ||
else: | ||
arr_chw = arr.transpose(2, 0, 1) | ||
for i, band in enumerate(arr_chw, start=1): | ||
dataset.write_band(i, band, window=window) | ||
|
||
|
||
def write_to_geotiff(path: str, arr: np.ndarray, bbox: Box, crs_wkt: str, | ||
**kwargs): | ||
if len(arr.shape) == 2: | ||
h_arr, w_arr = arr.shape | ||
num_channels = 1 | ||
else: | ||
h_arr, w_arr, num_channels = arr.shape | ||
h_bbox, w_bbox = bbox.size | ||
resolution = h_bbox / h_arr, w_bbox / w_arr | ||
transform = from_origin(bbox.xmin, bbox.ymax, *resolution) | ||
out_profile = dict( | ||
driver='GTiff', | ||
height=h_arr, | ||
width=w_arr, | ||
crs=crs_wkt, | ||
count=num_channels, | ||
dtype=arr.dtype, | ||
transform=transform, | ||
) | ||
out_profile.update(kwargs) | ||
with rio.open(path, 'w', **out_profile) as ds: | ||
write_to_rio_dataset(ds, arr) | ||
|
||
|
||
def write_geotiff_like_geojson(path: str, | ||
arr: np.ndarray, | ||
geojson_path: str, | ||
crs: Optional[str] = None, | ||
**kwargs): | ||
from rastervision.core.data.utils.geojson import geojson_to_geoms | ||
import pyproj | ||
from shapely.ops import unary_union | ||
|
||
geojson = file_to_json(geojson_path) | ||
if crs is None: | ||
try: | ||
crs = geojson['crs']['properties']['name'] | ||
except KeyError: | ||
crs = 'epsg:4326' | ||
crs_wkt = pyproj.CRS(crs).to_wkt() | ||
geoms = unary_union(list(geojson_to_geoms(geojson))) | ||
bbox = Box.from_shapely(geoms).normalize() | ||
return write_to_geotiff(path, arr, bbox=bbox, crs_wkt=crs_wkt, **kwargs) | ||
|
||
|
||
def crop_geotiff(image_uri: str, window: Box, crop_uri: str): | ||
rio_window = window.rasterio_format() | ||
|
||
with rio.open(image_uri) as src_ds, get_tmp_dir() as tmp_dir: | ||
crop_path = get_local_path(crop_uri, tmp_dir) | ||
make_dir(crop_path, use_dirname=True) | ||
|
||
meta = src_ds.meta | ||
colorinterp = src_ds.colorinterp | ||
img_cropped = src_ds.read(window=rio_window) | ||
|
||
meta['height'], meta['width'] = window.size | ||
meta['transform'] = rio_windows.transform(rio_window, src_ds.transform) | ||
|
||
with rio.open(crop_path, 'w', **meta) as dst_ds: | ||
dst_ds.colorinterp = colorinterp | ||
dst_ds.write(img_cropped) | ||
|
||
upload_or_copy(crop_path, crop_uri) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
from os.path import join | ||
|
||
import numpy as np | ||
import pyproj | ||
|
||
from rastervision.pipeline.file_system.utils import get_tmp_dir | ||
from rastervision.core.box import Box | ||
from rastervision.core.data.utils.raster import ( | ||
crop_geotiff, write_geotiff_like_geojson, write_to_geotiff) | ||
from rastervision.core.data import RasterioSource, GeoJSONVectorSource | ||
from tests import data_file_path | ||
|
||
|
||
class TestRasterUtils(unittest.TestCase): | ||
def test_write_to_geotiff(self): | ||
bbox = Box(ymin=48.815, xmin=2.224, ymax=48.902, xmax=2.469) | ||
crs_wkt = pyproj.CRS('epsg:4326').to_wkt() | ||
r = bbox.width / bbox.height | ||
arr1 = np.zeros((100, int(100 * r))) | ||
arr2 = np.zeros((100, int(100 * r), 4)) | ||
with get_tmp_dir() as tmp_dir: | ||
geotiff_path = join(tmp_dir, 'test.geotiff') | ||
write_to_geotiff(geotiff_path, arr1, bbox=bbox, crs_wkt=crs_wkt) | ||
rs = RasterioSource(geotiff_path) | ||
geotiff_bbox = rs.crs_transformer.pixel_to_map( | ||
rs.extent).normalize() | ||
np.testing.assert_array_almost_equal( | ||
np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) | ||
self.assertEqual(rs.shape, (*arr1.shape, 1)) | ||
|
||
write_to_geotiff(geotiff_path, arr2, bbox=bbox, crs_wkt=crs_wkt) | ||
rs = RasterioSource(geotiff_path) | ||
geotiff_bbox = rs.crs_transformer.pixel_to_map( | ||
rs.extent).normalize() | ||
np.testing.assert_array_almost_equal( | ||
np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) | ||
self.assertEqual(rs.shape, arr2.shape) | ||
|
||
def test_crop_geotiff(self): | ||
src_path = data_file_path('multi_raster_source/const_100_600x600.tiff') | ||
window = Box(0, 0, 10, 10) | ||
with get_tmp_dir() as tmp_dir: | ||
crop_path = join(tmp_dir, 'test.tiff') | ||
crop_geotiff(src_path, window, crop_path) | ||
rs = RasterioSource(crop_path) | ||
self.assertEqual(rs.extent, window) | ||
|
||
def test_write_geotiff_like_geojson(self): | ||
geojson_path = data_file_path('0-aoi.geojson') | ||
arr = np.zeros((10, 10)) | ||
with get_tmp_dir() as tmp_dir: | ||
geotiff_path = join(tmp_dir, 'test.tiff') | ||
write_geotiff_like_geojson( | ||
geotiff_path, arr, geojson_path, crs=None) | ||
rs = RasterioSource(geotiff_path) | ||
geotiff_bbox = rs.crs_transformer.pixel_to_map(rs.extent) | ||
vs = GeoJSONVectorSource( | ||
geojson_path, rs.crs_transformer, ignore_crs_field=True) | ||
geojson_bbox = rs.crs_transformer.pixel_to_map(vs.extent) | ||
np.testing.assert_array_almost_equal( | ||
np.array(list(geotiff_bbox)), | ||
np.array(list(geojson_bbox)), | ||
decimal=3) | ||
self.assertEqual(rs.shape, (10, 10, 1)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |