From adc4e08d78fb9198eed12c70e67ddf2417869f38 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Fri, 18 Aug 2023 18:08:23 -0400 Subject: [PATCH] add some raster-related util functions --- .../semantic_segmentation_label_store.py | 19 +--- .../rastervision/core/data/utils/__init__.py | 1 + .../rastervision/core/data/utils/raster.py | 90 +++++++++++++++++++ tests/core/data/utils/test_raster.py | 69 ++++++++++++++ 4 files changed, 164 insertions(+), 15 deletions(-) create mode 100644 rastervision_core/rastervision/core/data/utils/raster.py create mode 100644 tests/core/data/utils/test_raster.py diff --git a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py index ff95411011..4e9343a718 100644 --- a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py +++ b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py @@ -17,6 +17,7 @@ from rastervision.core.data.label_source import SemanticSegmentationLabelSource from rastervision.core.data.raster_transformer import RGBClassTransformer from rastervision.core.data.raster_source import RasterioSource +from rastervision.core.data.utils import write_to_rio_dataset if TYPE_CHECKING: from rastervision.core.data import (VectorOutputConfig, @@ -266,7 +267,8 @@ def write_smooth_raster_output( score_arr = self._scores_to_uint8(score_arr) else: score_arr = score_arr.astype(dtype) - self._write_array(ds, window, score_arr) + score_arr = score_arr.transpose(1, 2, 0) + write_to_rio_dataset(ds, score_arr, window) # save pixel hits too np.save(hits_path, labels.pixel_hits) @@ -291,8 +293,7 @@ def write_discrete_raster_output( if self.class_transformer is not None: label_arr = self.class_transformer.class_to_rgb( label_arr) - label_arr = label_arr.transpose(2, 0, 1) - self._write_array(ds, window, label_arr) + write_to_rio_dataset(ds, label_arr, window) def write_vector_outputs(self, labels: SemanticSegmentationLabels, vector_output_dir: str) -> None: @@ -314,18 +315,6 @@ def write_vector_outputs(self, labels: SemanticSegmentationLabels, out_uri = vo.get_uri(vector_output_dir, self.class_config) json_to_file(geojson, out_uri) - def _write_array(self, dataset: rio.DatasetReader, window: Box, - arr: np.ndarray) -> None: - """Write array out to a rasterio dataset. Array must be of shape - (C, H, W). - """ - rio_window = window.rasterio_format() - if len(arr.shape) == 2: - dataset.write_band(1, arr, window=rio_window) - else: - for i, band in enumerate(arr, start=1): - dataset.write_band(i, band, window=rio_window) - def _clip_to_extent(self, extent: Box, window: Box, diff --git a/rastervision_core/rastervision/core/data/utils/__init__.py b/rastervision_core/rastervision/core/data/utils/__init__.py index a8e64fa17a..c3951b12da 100644 --- a/rastervision_core/rastervision/core/data/utils/__init__.py +++ b/rastervision_core/rastervision/core/data/utils/__init__.py @@ -3,4 +3,5 @@ from rastervision.core.data.utils.misc import * from rastervision.core.data.utils.geojson import * from rastervision.core.data.utils.factory import * +from rastervision.core.data.utils.raster import * from rastervision.core.data.utils.vectorization import * diff --git a/rastervision_core/rastervision/core/data/utils/raster.py b/rastervision_core/rastervision/core/data/utils/raster.py new file mode 100644 index 0000000000..397a2f95c7 --- /dev/null +++ b/rastervision_core/rastervision/core/data/utils/raster.py @@ -0,0 +1,90 @@ +from typing import Optional + +import numpy as np +import rasterio as rio +import rasterio.windows as rio_windows +from rasterio.transform import from_origin + +from rastervision.pipeline.file_system.utils import ( + file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy) +from rastervision.core.box import Box + + +def write_to_rio_dataset(dataset: rio.DatasetReader, + arr: np.ndarray, + window: Optional[Box] = None) -> None: + """Write (H, W[, C]) array out to a rasterio dataset.""" + if window is not None: + window = window.rasterio_format() + if len(arr.shape) == 2: + dataset.write_band(1, arr, window=window) + else: + arr_chw = arr.transpose(2, 0, 1) + for i, band in enumerate(arr_chw, start=1): + dataset.write_band(i, band, window=window) + + +def write_to_geotiff(path: str, arr: np.ndarray, bbox: Box, crs_wkt: str, + **kwargs): + if len(arr.shape) == 2: + h_arr, w_arr = arr.shape + num_channels = 1 + else: + h_arr, w_arr, num_channels = arr.shape + h_bbox, w_bbox = bbox.size + resolution = h_bbox / h_arr, w_bbox / w_arr + transform = from_origin(bbox.xmin, bbox.ymax, *resolution) + out_profile = dict( + driver='GTiff', + height=h_arr, + width=w_arr, + crs=crs_wkt, + count=num_channels, + dtype=arr.dtype, + transform=transform, + ) + out_profile.update(kwargs) + with rio.open(path, 'w', **out_profile) as ds: + write_to_rio_dataset(ds, arr) + + +def write_geotiff_like_geojson(path: str, + arr: np.ndarray, + geojson_path: str, + crs: Optional[str] = None, + **kwargs): + from rastervision.core.data.utils.geojson import geojson_to_geoms + import pyproj + from shapely.ops import unary_union + + geojson = file_to_json(geojson_path) + if crs is None: + try: + crs = geojson['crs']['properties']['name'] + except KeyError: + crs = 'epsg:4326' + crs_wkt = pyproj.CRS(crs).to_wkt() + geoms = unary_union(list(geojson_to_geoms(geojson))) + bbox = Box.from_shapely(geoms).normalize() + return write_to_geotiff(path, arr, bbox=bbox, crs_wkt=crs_wkt, **kwargs) + + +def crop_geotiff(image_uri: str, window: Box, crop_uri: str): + rio_window = window.rasterio_format() + + with rio.open(image_uri) as src_ds, get_tmp_dir() as tmp_dir: + crop_path = get_local_path(crop_uri, tmp_dir) + make_dir(crop_path, use_dirname=True) + + meta = src_ds.meta + colorinterp = src_ds.colorinterp + img_cropped = src_ds.read(window=rio_window) + + meta['height'], meta['width'] = window.size + meta['transform'] = rio_windows.transform(rio_window, src_ds.transform) + + with rio.open(crop_path, 'w', **meta) as dst_ds: + dst_ds.colorinterp = colorinterp + dst_ds.write(img_cropped) + + upload_or_copy(crop_path, crop_uri) diff --git a/tests/core/data/utils/test_raster.py b/tests/core/data/utils/test_raster.py new file mode 100644 index 0000000000..613b445f38 --- /dev/null +++ b/tests/core/data/utils/test_raster.py @@ -0,0 +1,69 @@ +import unittest +from os.path import join + +import numpy as np +import pyproj + +from rastervision.pipeline.file_system.utils import get_tmp_dir +from rastervision.core.box import Box +from rastervision.core.data.utils.raster import ( + crop_geotiff, write_geotiff_like_geojson, write_to_geotiff) +from rastervision.core.data import RasterioSource, GeoJSONVectorSource +from tests import data_file_path + + +class TestRasterUtils(unittest.TestCase): + def test_write_to_geotiff(self): + bbox = Box(ymin=48.815, xmin=2.224, ymax=48.902, xmax=2.469) + crs_wkt = pyproj.CRS('epsg:4326').to_wkt() + r = bbox.width / bbox.height + arr1 = np.zeros((100, int(100 * r))) + arr2 = np.zeros((100, int(100 * r), 4)) + with get_tmp_dir() as tmp_dir: + geotiff_path = join(tmp_dir, 'test.geotiff') + write_to_geotiff(geotiff_path, arr1, bbox=bbox, crs_wkt=crs_wkt) + rs = RasterioSource(geotiff_path) + geotiff_bbox = rs.crs_transformer.pixel_to_map( + rs.extent).normalize() + np.testing.assert_array_almost_equal( + np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) + self.assertEqual(rs.shape, (*arr1.shape, 1)) + + write_to_geotiff(geotiff_path, arr2, bbox=bbox, crs_wkt=crs_wkt) + rs = RasterioSource(geotiff_path) + geotiff_bbox = rs.crs_transformer.pixel_to_map( + rs.extent).normalize() + np.testing.assert_array_almost_equal( + np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) + self.assertEqual(rs.shape, arr2.shape) + + def test_crop_geotiff(self): + src_path = data_file_path('multi_raster_source/const_100_600x600.tiff') + window = Box(0, 0, 10, 10) + with get_tmp_dir() as tmp_dir: + crop_path = join(tmp_dir, 'test.tiff') + crop_geotiff(src_path, window, crop_path) + rs = RasterioSource(crop_path) + self.assertEqual(rs.extent, window) + + def test_write_geotiff_like_geojson(self): + geojson_path = data_file_path('0-aoi.geojson') + arr = np.zeros((10, 10)) + with get_tmp_dir() as tmp_dir: + geotiff_path = join(tmp_dir, 'test.tiff') + write_geotiff_like_geojson( + geotiff_path, arr, geojson_path, crs=None) + rs = RasterioSource(geotiff_path) + geotiff_bbox = rs.crs_transformer.pixel_to_map(rs.extent) + vs = GeoJSONVectorSource( + geojson_path, rs.crs_transformer, ignore_crs_field=True) + geojson_bbox = rs.crs_transformer.pixel_to_map(vs.extent) + np.testing.assert_array_almost_equal( + np.array(list(geotiff_bbox)), + np.array(list(geojson_bbox)), + decimal=3) + self.assertEqual(rs.shape, (10, 10, 1)) + + +if __name__ == '__main__': + unittest.main()