Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc. refactoring and fixes #1838

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from shapely.geometry import Polygon
from rasterio.windows import Window as RioWindow

from rastervision.pipeline.utils import repr_with_args

NonNegInt = conint(ge=0)

if TYPE_CHECKING:
Expand Down Expand Up @@ -118,11 +120,7 @@ def __getitem__(self, i):
return self.tuple_format()[i]

def __repr__(self) -> str:
arg_keys = ['ymin', 'xmin', 'ymax', 'xmax']
arg_vals = [getattr(self, k) for k in arg_keys]
arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)]
arg_str = ', '.join(arg_strs)
return f'{type(self).__name__}({arg_str})'
return repr_with_args(self, **self.to_dict())

def __hash__(self) -> int:
return hash(self.tuple_format())
Expand Down Expand Up @@ -444,11 +442,12 @@ def get_windows(self,
return windows

def to_dict(self) -> Dict[str, int]:
"""Convert to a dict with keys: ymin, xmin, ymax, xmax."""
return {
'xmin': self.xmin,
'ymin': self.ymin,
'xmin': self.xmin,
'ymax': self.ymax,
'xmax': self.xmax,
'ymax': self.ymax
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from rastervision.core.data.raster_transformer.raster_transformer \
import RasterTransformer
from rastervision.pipeline.utils import repr_with_args

import numpy as np

Expand All @@ -18,7 +19,7 @@ def __init__(self, to_dtype: str):
self.to_dtype = np.dtype(to_dtype)

def __repr__(self):
return f'CastTransformer(to_dtype="{self.to_dtype}")'
return repr_with_args(self, to_dtype=str(self.to_dtype))

def transform(self, chip: np.ndarray,
channel_order: Optional[list] = None) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np

from rastervision.core.data.raster_transformer import RasterTransformer
from rastervision.core.raster_stats import RasterStats
from rastervision.pipeline.utils import repr_with_args

if TYPE_CHECKING:
from rastervision.core.data import RasterSource
Expand Down Expand Up @@ -94,7 +96,7 @@ def transform(self,
@classmethod
def from_raster_sources(cls,
raster_sources: List['RasterSource'],
sample_prob: float = 0.1,
sample_prob: Optional[float] = 0.1,
max_stds: float = 3.) -> 'StatsTransformer':
"""Create a StatsTransformer with stats from the given raster sources.

Expand All @@ -110,9 +112,28 @@ def from_raster_sources(cls,
Returns:
StatsTransformer: A StatsTransformer.
"""
from rastervision.core.raster_stats import RasterStats
stats = RasterStats()
stats.compute(raster_sources=raster_sources, sample_prob=sample_prob)
stats_transformer = StatsTransformer(
means=stats.means, stds=stats.stds, max_stds=max_stds)
stats_transformer = StatsTransformer.from_raster_stats(
stats, max_stds=max_stds)
return stats_transformer

@classmethod
def from_stats_json(cls, uri: str, **kwargs) -> 'StatsTransformer':
stats = RasterStats.load(uri)
stats_transformer = StatsTransformer.from_raster_stats(stats, **kwargs)
return stats_transformer

@classmethod
def from_raster_stats(cls, stats: RasterStats,
**kwargs) -> 'StatsTransformer':
stats_transformer = StatsTransformer(stats.means, stats.stds, **kwargs)
return stats_transformer

@property
def stats(self):
return RasterStats(self.means, self.stds)

def __repr__(self) -> str:
return repr_with_args(
self, means=self.means, std=self.stds, max_stds=self.max_stds)
21 changes: 21 additions & 0 deletions rastervision_core/rastervision/core/data/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,24 @@ def parse_array_slices_Nd(key: Union[tuple, slice],
dim_slices[w_dim] = w_slice

return window, dim_slices


def ensure_json_serializable(obj: Any) -> dict:
"""Convert numpy types to JSON serializable equivalents."""
if obj is None or isinstance(obj, (str, int, bool)):
return obj
if isinstance(obj, dict):
return {k: ensure_json_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [ensure_json_serializable(o) for o in obj]
if isinstance(obj, np.ndarray):
return ensure_json_serializable(obj.tolist())
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, (float, np.floating)):
if np.isnan(obj):
return None
return float(obj)
if isinstance(obj, Box):
return obj.to_dict()
return obj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from rastervision.pipeline.file_system import str_to_file
from rastervision.core.data.utils import ensure_json_serializable

if TYPE_CHECKING:
from rastervision.core.evaluation import ClassEvaluationItem
Expand Down Expand Up @@ -151,24 +152,3 @@ def compute(self, ground_truth_labels, prediction_labels):
prediction_labels: The predicted labels to evaluate.
"""
pass


def ensure_json_serializable(obj: Any) -> dict:
"""Convert numpy types to JSON serializable equivalents."""
if obj is None or isinstance(obj, (str, int, bool)):
return obj
if isinstance(obj, dict):
return {k: ensure_json_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [ensure_json_serializable(o) for o in obj]
if isinstance(obj, np.ndarray):
return ensure_json_serializable(obj.tolist())
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, float):
if np.isnan(obj):
return None
return float(obj)
if isinstance(obj, np.floating):
return float(obj)
return obj
Loading
Loading