Skip to content

Commit

Permalink
refactor example utils
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Aug 21, 2023
1 parent 48790ee commit 6fdf384
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def make_scene(scene_info) -> SceneConfig:
raster_uri,
crop_uri,
label_uri=label_uri,
vector_labels=True,
label_crop_uri=label_crop_uri,
size=600,
min_features=20,
class_config=class_config)
class_config=class_config,
default_class_id=1)
raster_uri = crop_uri
label_uri = label_crop_uri

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def make_scene(id: str) -> SceneConfig:
raster_uri,
crop_uri,
label_uri=label_uri,
vector_labels=True,
size=2000,
min_features=5)
min_features=5,
default_class_id=0)
raster_uri = crop_uri

raster_source = RasterioSourceConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,69 +1,34 @@
from typing import TYPE_CHECKING, Optional
import os
import csv
from io import StringIO
import os

import rasterio
from shapely.strtree import STRtree
from shapely.geometry import shape, mapping
from shapely.ops import transform

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core import Box
from rastervision.core.data import (RasterioCRSTransformer,
GeoJSONVectorSourceConfig,
ClassInferenceTransformerConfig)
from rastervision.pipeline.file_system import (file_to_str, file_exists,
get_local_path, upload_or_copy,
make_dir, json_to_file)
from rastervision.core.data import (RasterioSource, GeoJSONVectorSource,
ClassInferenceTransformer)
from rastervision.core.data.utils import geoms_to_geojson, crop_geotiff
from rastervision.pipeline.file_system import (file_to_str, json_to_file)
from rastervision.aws_s3 import S3FileSystem

if TYPE_CHECKING:
from rastervision.core.data import ClassConfig

def str_to_bool(x):
if type(x) == str:
if x.lower() == 'true':
return True
elif x.lower() == 'false':
return False
else:
raise ValueError('{} is expected to be true or false'.format(x))
return x


def get_scene_info(csv_uri):
def get_scene_info(csv_uri: str) -> list: # pragma: no cover
csv_str = file_to_str(csv_uri)
reader = csv.reader(StringIO(csv_str), delimiter=',')
reader = csv.reader(StringIO(csv_str))
return list(reader)


def crop_image(image_uri: str, window: Box, crop_uri: str):
im_dataset = rasterio.open(image_uri)
rasterio_window = window.rasterio_format()
im = im_dataset.read(window=rasterio_window)

with get_tmp_dir() as tmp_dir:
crop_path = get_local_path(crop_uri, tmp_dir)
make_dir(crop_path, use_dirname=True)

meta = im_dataset.meta
meta['height'], meta['width'] = window.size
meta['transform'] = rasterio.windows.transform(rasterio_window,
im_dataset.transform)

with rasterio.open(crop_path, 'w', **meta) as dst:
dst.colorinterp = im_dataset.colorinterp
dst.write(im)

upload_or_copy(crop_path, crop_uri)


def save_image_crop(image_uri,
image_crop_uri,
label_uri=None,
label_crop_uri=None,
size=600,
min_features=10,
vector_labels=True,
class_config=None):
def save_image_crop(
image_uri: str,
image_crop_uri: str,
label_uri: Optional[str] = None,
label_crop_uri: Optional[str] = None,
size: int = 600,
min_features: int = 10,
vector_labels: bool = True,
default_class_id: int = 0,
class_config: Optional['ClassConfig'] = None): # pragma: no cover
"""Save a crop of an image to use for testing.
If label_uri is set, the crop needs to cover >= min_features.
Expand All @@ -74,73 +39,54 @@ def save_image_crop(image_uri,
label_uri: optional URI of label file
label_crop_uri: optional URI of cropped labels to save
size: height and width of crop
min_features: min number of label polygons that the crop should have
vector_labels: whether the labels are vector labels
default_class_id: default class ID to use to infer labels
class_config: ClassConfig to use to infer labels
Raises:
ValueError if cannot find a crop satisfying min_features constraint.
"""
if not file_exists(image_crop_uri):
print('Saving test crop to {}...'.format(image_crop_uri))
old_environ = os.environ.copy()
try:
request_payer = S3FileSystem.get_request_payer()
if request_payer == 'requester':
os.environ['AWS_REQUEST_PAYER'] = request_payer
im_dataset = rasterio.open(image_uri)
h, w = im_dataset.height, im_dataset.width

extent = Box(0, 0, h, w)
windows = extent.get_windows(size, size)
print(f'Saving test crop to {image_crop_uri}...')
old_environ = os.environ.copy()
try:
request_payer = S3FileSystem.get_request_payer()
if request_payer == 'requester':
os.environ['AWS_REQUEST_PAYER'] = request_payer
rs = RasterioSource(image_uri, allow_streaming=True)
if label_uri and vector_labels:
crs_tf = rs.crs_transformer
vs = GeoJSONVectorSource(
uris=label_uri,
crs_transformer=crs_tf,
ignore_crs_field=True,
vector_transformers=[
ClassInferenceTransformer(
default_class_id=default_class_id,
class_config=class_config)
])
labels_df = vs.get_dataframe()

windows = rs.extent.get_windows(size, size)
for w in windows:
use_window = True
if label_uri and vector_labels:
crs_transformer = RasterioCRSTransformer.from_dataset(
im_dataset)
geojson_vs_config = GeoJSONVectorSourceConfig(
uris=label_uri,
ignore_crs_field=True,
transformers=[
ClassInferenceTransformerConfig(default_class_id=0)
])
vs = geojson_vs_config.build(class_config, crs_transformer)
geojson = vs.get_geojson()
geoms = []
for f in geojson['features']:
g = shape(f['geometry'])
geoms.append(g)
tree = STRtree(geoms)

def p2m(x, y, z=None):
return crs_transformer.pixel_to_map((x, y))

for w in windows:
use_window = True
if label_uri and vector_labels:
w_polys = tree.query(w.to_shapely())
use_window = len(w_polys) >= min_features
if use_window and label_crop_uri is not None:
print('Saving test crop labels to {}...'.format(
label_crop_uri))

label_crop_features = [
mapping(transform(p2m, wp)) for wp in w_polys
]
label_crop_json = {
'type':
'FeatureCollection',
'features': [{
'geometry': f
} for f in label_crop_features]
}
json_to_file(label_crop_json, label_crop_uri)

if use_window:
crop_image(image_uri, w, image_crop_uri)

if not vector_labels and label_uri and label_crop_uri:
crop_image(label_uri, w, label_crop_uri)

break

if not use_window:
raise ValueError('Could not find a good crop.')
finally:
os.environ.clear()
os.environ.update(old_environ)
w_geom = w.to_shapely()
df_int = labels_df[labels_df.intersects(w_geom)]
w_polys = df_int.geometry
use_window = len(w_polys) >= min_features
if use_window and label_crop_uri is not None:
print(f'Saving test crop labels to {label_crop_uri}...')
w_polys_map = [crs_tf.pixel_to_map(wp) for wp in w_polys]
label_crop_json = geoms_to_geojson(w_polys_map)
json_to_file(label_crop_json, label_crop_uri)
if use_window:
crop_geotiff(image_uri, w, image_crop_uri)
if not vector_labels and label_uri and label_crop_uri:
crop_geotiff(label_uri, w, label_crop_uri)
break
if not use_window:
raise ValueError('Could not find a good crop.')
finally:
os.environ.clear()
os.environ.update(old_environ)

0 comments on commit 6fdf384

Please sign in to comment.