Skip to content

Commit

Permalink
add some raster-related util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Aug 18, 2023
1 parent 143f8c4 commit adc4e08
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from rastervision.core.data.label_source import SemanticSegmentationLabelSource
from rastervision.core.data.raster_transformer import RGBClassTransformer
from rastervision.core.data.raster_source import RasterioSource
from rastervision.core.data.utils import write_to_rio_dataset

if TYPE_CHECKING:
from rastervision.core.data import (VectorOutputConfig,
Expand Down Expand Up @@ -266,7 +267,8 @@ def write_smooth_raster_output(
score_arr = self._scores_to_uint8(score_arr)
else:
score_arr = score_arr.astype(dtype)
self._write_array(ds, window, score_arr)
score_arr = score_arr.transpose(1, 2, 0)
write_to_rio_dataset(ds, score_arr, window)

# save pixel hits too
np.save(hits_path, labels.pixel_hits)
Expand All @@ -291,8 +293,7 @@ def write_discrete_raster_output(
if self.class_transformer is not None:
label_arr = self.class_transformer.class_to_rgb(
label_arr)
label_arr = label_arr.transpose(2, 0, 1)
self._write_array(ds, window, label_arr)
write_to_rio_dataset(ds, label_arr, window)

def write_vector_outputs(self, labels: SemanticSegmentationLabels,
vector_output_dir: str) -> None:
Expand All @@ -314,18 +315,6 @@ def write_vector_outputs(self, labels: SemanticSegmentationLabels,
out_uri = vo.get_uri(vector_output_dir, self.class_config)
json_to_file(geojson, out_uri)

def _write_array(self, dataset: rio.DatasetReader, window: Box,
arr: np.ndarray) -> None:
"""Write array out to a rasterio dataset. Array must be of shape
(C, H, W).
"""
rio_window = window.rasterio_format()
if len(arr.shape) == 2:
dataset.write_band(1, arr, window=rio_window)
else:
for i, band in enumerate(arr, start=1):
dataset.write_band(i, band, window=rio_window)

def _clip_to_extent(self,
extent: Box,
window: Box,
Expand Down
1 change: 1 addition & 0 deletions rastervision_core/rastervision/core/data/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from rastervision.core.data.utils.misc import *
from rastervision.core.data.utils.geojson import *
from rastervision.core.data.utils.factory import *
from rastervision.core.data.utils.raster import *
from rastervision.core.data.utils.vectorization import *
90 changes: 90 additions & 0 deletions rastervision_core/rastervision/core/data/utils/raster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional

import numpy as np
import rasterio as rio
import rasterio.windows as rio_windows
from rasterio.transform import from_origin

from rastervision.pipeline.file_system.utils import (
file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy)
from rastervision.core.box import Box


def write_to_rio_dataset(dataset: rio.DatasetReader,
arr: np.ndarray,
window: Optional[Box] = None) -> None:
"""Write (H, W[, C]) array out to a rasterio dataset."""
if window is not None:
window = window.rasterio_format()
if len(arr.shape) == 2:
dataset.write_band(1, arr, window=window)
else:
arr_chw = arr.transpose(2, 0, 1)
for i, band in enumerate(arr_chw, start=1):
dataset.write_band(i, band, window=window)


def write_to_geotiff(path: str, arr: np.ndarray, bbox: Box, crs_wkt: str,
**kwargs):
if len(arr.shape) == 2:
h_arr, w_arr = arr.shape
num_channels = 1
else:
h_arr, w_arr, num_channels = arr.shape
h_bbox, w_bbox = bbox.size
resolution = h_bbox / h_arr, w_bbox / w_arr
transform = from_origin(bbox.xmin, bbox.ymax, *resolution)
out_profile = dict(
driver='GTiff',
height=h_arr,
width=w_arr,
crs=crs_wkt,
count=num_channels,
dtype=arr.dtype,
transform=transform,
)
out_profile.update(kwargs)
with rio.open(path, 'w', **out_profile) as ds:
write_to_rio_dataset(ds, arr)


def write_geotiff_like_geojson(path: str,
arr: np.ndarray,
geojson_path: str,
crs: Optional[str] = None,
**kwargs):
from rastervision.core.data.utils.geojson import geojson_to_geoms
import pyproj
from shapely.ops import unary_union

geojson = file_to_json(geojson_path)
if crs is None:
try:
crs = geojson['crs']['properties']['name']
except KeyError:
crs = 'epsg:4326'
crs_wkt = pyproj.CRS(crs).to_wkt()
geoms = unary_union(list(geojson_to_geoms(geojson)))
bbox = Box.from_shapely(geoms).normalize()
return write_to_geotiff(path, arr, bbox=bbox, crs_wkt=crs_wkt, **kwargs)


def crop_geotiff(image_uri: str, window: Box, crop_uri: str):
rio_window = window.rasterio_format()

with rio.open(image_uri) as src_ds, get_tmp_dir() as tmp_dir:
crop_path = get_local_path(crop_uri, tmp_dir)
make_dir(crop_path, use_dirname=True)

meta = src_ds.meta
colorinterp = src_ds.colorinterp
img_cropped = src_ds.read(window=rio_window)

meta['height'], meta['width'] = window.size
meta['transform'] = rio_windows.transform(rio_window, src_ds.transform)

with rio.open(crop_path, 'w', **meta) as dst_ds:
dst_ds.colorinterp = colorinterp
dst_ds.write(img_cropped)

upload_or_copy(crop_path, crop_uri)
69 changes: 69 additions & 0 deletions tests/core/data/utils/test_raster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
from os.path import join

import numpy as np
import pyproj

from rastervision.pipeline.file_system.utils import get_tmp_dir
from rastervision.core.box import Box
from rastervision.core.data.utils.raster import (
crop_geotiff, write_geotiff_like_geojson, write_to_geotiff)
from rastervision.core.data import RasterioSource, GeoJSONVectorSource
from tests import data_file_path


class TestRasterUtils(unittest.TestCase):
def test_write_to_geotiff(self):
bbox = Box(ymin=48.815, xmin=2.224, ymax=48.902, xmax=2.469)
crs_wkt = pyproj.CRS('epsg:4326').to_wkt()
r = bbox.width / bbox.height
arr1 = np.zeros((100, int(100 * r)))
arr2 = np.zeros((100, int(100 * r), 4))
with get_tmp_dir() as tmp_dir:
geotiff_path = join(tmp_dir, 'test.geotiff')
write_to_geotiff(geotiff_path, arr1, bbox=bbox, crs_wkt=crs_wkt)
rs = RasterioSource(geotiff_path)
geotiff_bbox = rs.crs_transformer.pixel_to_map(
rs.extent).normalize()
np.testing.assert_array_almost_equal(
np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3)
self.assertEqual(rs.shape, (*arr1.shape, 1))

write_to_geotiff(geotiff_path, arr2, bbox=bbox, crs_wkt=crs_wkt)
rs = RasterioSource(geotiff_path)
geotiff_bbox = rs.crs_transformer.pixel_to_map(
rs.extent).normalize()
np.testing.assert_array_almost_equal(
np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3)
self.assertEqual(rs.shape, arr2.shape)

def test_crop_geotiff(self):
src_path = data_file_path('multi_raster_source/const_100_600x600.tiff')
window = Box(0, 0, 10, 10)
with get_tmp_dir() as tmp_dir:
crop_path = join(tmp_dir, 'test.tiff')
crop_geotiff(src_path, window, crop_path)
rs = RasterioSource(crop_path)
self.assertEqual(rs.extent, window)

def test_write_geotiff_like_geojson(self):
geojson_path = data_file_path('0-aoi.geojson')
arr = np.zeros((10, 10))
with get_tmp_dir() as tmp_dir:
geotiff_path = join(tmp_dir, 'test.tiff')
write_geotiff_like_geojson(
geotiff_path, arr, geojson_path, crs=None)
rs = RasterioSource(geotiff_path)
geotiff_bbox = rs.crs_transformer.pixel_to_map(rs.extent)
vs = GeoJSONVectorSource(
geojson_path, rs.crs_transformer, ignore_crs_field=True)
geojson_bbox = rs.crs_transformer.pixel_to_map(vs.extent)
np.testing.assert_array_almost_equal(
np.array(list(geotiff_bbox)),
np.array(list(geojson_bbox)),
decimal=3)
self.assertEqual(rs.shape, (10, 10, 1))


if __name__ == '__main__':
unittest.main()

0 comments on commit adc4e08

Please sign in to comment.