Skip to content

Commit

Permalink
clarify in func docstrings where **kwargs are going, where appropriate
Browse files Browse the repository at this point in the history
also
- general docstring fixes
- fix type-hints
- link to onnx docs
  • Loading branch information
AdeelH committed Jul 21, 2023
1 parent bff97f3 commit 474dbe6
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 42 deletions.
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def setup(app: 'Sphinx') -> None:
'https://pytorch.org/docs/stable/',
'https://pytorch.org/docs/stable/objects.inv',
),
'onnx': (
'https://onnx.ai/onnx/',
'https://onnx.ai/onnx/objects.inv',
),
}

#########################
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Optional, Union
from pyproj import Transformer

import numpy as np
Expand Down Expand Up @@ -99,10 +99,16 @@ def _pixel_to_map(self, pixel_point):
return map_point

@classmethod
def from_dataset(cls,
dataset,
map_crs: Optional[str] = 'epsg:4326',
**kwargs) -> 'RasterioCRSTransformer':
def from_dataset(
cls, dataset: Any, map_crs: Optional[str] = 'epsg:4326', **kwargs
) -> Union[IdentityCRSTransformer, 'RasterioCRSTransformer']:
"""Build from rasterio dataset.
Args:
dataset (Any): Rasterio dataset.
map_crs (Optional[str]): Target map CRS. Defaults to 'epsg:4326'.
**kwargs: Extra args for :meth:`.__init__`.
"""
transform = dataset.transform
image_crs = None if dataset.crs is None else dataset.crs.wkt
map_crs = image_crs if map_crs is None else map_crs
Expand All @@ -118,7 +124,14 @@ def from_dataset(cls,
return cls(transform, image_crs, map_crs, **kwargs)

@classmethod
def from_uri(cls, uri: str, map_crs: Optional[str] = 'epsg:4326',
**kwargs) -> 'RasterioCRSTransformer':
def from_uri(cls, uri: str, map_crs: Optional[str] = 'epsg:4326', **kwargs
) -> Union[IdentityCRSTransformer, 'RasterioCRSTransformer']:
"""Build from raster URI.
Args:
uri (Any): Raster URI.
map_crs (Optional[str]): Target map CRS. Defaults to 'epsg:4326'.
**kwargs: Extra args for :meth:`.__init__`.
"""
with rio.open(uri) as ds:
return cls.from_dataset(ds, map_crs=map_crs, **kwargs)
14 changes: 8 additions & 6 deletions rastervision_core/rastervision/core/data/label/labels.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Defines the abstract Labels class."""

from typing import TYPE_CHECKING, Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable, List
from abc import (ABC, abstractclassmethod, abstractmethod)

if TYPE_CHECKING:
from shapely.geometry import Polygon
from rastervision.core.box import Box


Expand All @@ -15,24 +16,25 @@ class Labels(ABC):
"""

@abstractmethod
def __add__(self, other):
def __add__(self, other: 'Labels'):
"""Add labels to these labels.
Returns a concatenation of this and the other labels.
"""
pass

@abstractmethod
def filter_by_aoi(self, aoi_polygons):
"""Returns a copy of these labels filtered by a given set of AOI polygons
def filter_by_aoi(self, aoi_polygons: List['Polygon']) -> 'Labels':
"""Return a copy of these labels filtered by given AOI polygons.
Args:
aoi_polygons - A list of AOI polygons to filter by, in pixel coordinates.
aoi_polygons: List of AOI polygons to filter by, in pixel
coordinates.
"""
pass

@abstractmethod
def __eq__(self, other):
def __eq__(self, other: 'Labels'):
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rastervision.core.data.label.utils import discard_prediction_edges

if TYPE_CHECKING:
from shapely.geometry import Polygon
from rastervision.core.data import (ClassConfig, CRSTransformer,
VectorOutputConfig)

Expand Down Expand Up @@ -66,21 +67,34 @@ def get_label_arr(self, window: Box,
pass

def get_windows(self, **kwargs) -> List[Box]:
"""Generate sliding windows over the local extent. The keyword args
are passed to Box.get_windows() and can therefore be used to control
the specifications of the windows.
"""Generate sliding windows over the local extent.
The keyword args are passed to :meth:`.Box.get_windows` and can
therefore be used to control the specifications of the windows.
If the keyword args do not contain size, a list of length 1,
containing the full extent is returned.
Args:
**kwargs: Extra args for :meth:`.Box.get_windows`.
"""
size: Optional[int] = kwargs.pop('size', None)
if size is None:
return [self.extent]
return self.extent.get_windows(size, size, **kwargs)

def filter_by_aoi(self, aoi_polygons: list, null_class_id: int,
def filter_by_aoi(self, aoi_polygons: List['Polygon'], null_class_id: int,
**kwargs) -> 'SemanticSegmentationLabels':
"""Keep only the values that lie inside the AOI."""
"""Keep only the values that lie inside the AOI.
Args:
aoi_polygons (List[Polygon]): AOI polygons to filter by, in pixel
coordinates.
null_class_id (int): Class ID to assign to pixels falling outside
the AOI polygons.
**kwargs: Extra args for
:meth:`.SemanticSegmentationLabels.get_windows`.
"""
if not aoi_polygons:
return self
for window in self.get_windows(**kwargs):
Expand All @@ -95,7 +109,7 @@ def mask_fill(self, window: Box, mask: np.ndarray,
"""
pass

def _filter_window_by_aoi(self, window: Box, aoi_polygons: list,
def _filter_window_by_aoi(self, window: Box, aoi_polygons: List['Polygon'],
null_class_id: int) -> None:
window_geom = window.to_shapely()
label_arr = self[window]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def from_raster_sources(cls,
raster_sources: List['RasterSource'],
sample_prob: Optional[float] = 0.1,
max_stds: float = 3.) -> 'StatsTransformer':
"""Create a StatsTransformer with stats from the given raster sources.
"""Build with stats from the given raster sources.
Args:
raster_sources (List['RasterSource']): List of raster
sources to compute stats from.
raster_sources (List[RasterSource]): List of raster sources to
compute stats from.
sample_prob (float, optional): Fraction of each raster to sample
for computing stats. For details see docs for
RasterStats.compute(). Defaults to 0.1.
Expand All @@ -120,18 +120,44 @@ def from_raster_sources(cls,

@classmethod
def from_stats_json(cls, uri: str, **kwargs) -> 'StatsTransformer':
"""Build with stats from a JSON file.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
uri (str): URI of the JSON file.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
StatsTransformer: A StatsTransformer.
"""
stats = RasterStats.load(uri)
stats_transformer = StatsTransformer.from_raster_stats(stats, **kwargs)
return stats_transformer

@classmethod
def from_raster_stats(cls, stats: RasterStats,
**kwargs) -> 'StatsTransformer':
"""Build with stats from a :class:`.RasterStats` instance.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
stats (RasterStats): A :class:`.RasterStats` instance with
non-None stats.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
StatsTransformer: A StatsTransformer.
"""
stats_transformer = StatsTransformer(stats.means, stats.stds, **kwargs)
return stats_transformer

@property
def stats(self):
def stats(self) -> RasterStats:
"""Current statistics as a :class:`.RasterStats` instance."""
return RasterStats(self.means, self.stds)

def __repr__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ def __call__(self,
geojson: dict,
crs_transformer: Optional['CRSTransformer'] = None,
**kwargs) -> dict:
"""Shortcut for :meth:`.transform`.
Args:
geojson (dict): A GeoJSON-like mapping of a FeatureCollection.
crs_transformer (Optional[CRSTransformer]): CRSTransformer.
Defaults to None.
**kwargs: Extra args for :meth:`.transform`.
Returns:
dict: Transformed GeoJSON.
"""
return self.transform(
geojson, crs_transformer=crs_transformer, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def from_multiclass_conf_mat(cls, conf_mat: np.ndarray, class_id: int,
conf_mat (np.ndarray): A multi-class confusion matrix.
class_id (int): The ID of the target class.
class_name (str): The name of the target class.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
ClassEvaluationItem: ClassEvaluationItem for target class.
Expand Down
1 change: 1 addition & 0 deletions rastervision_core/rastervision/core/raster_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def save(self, stats_uri: str) -> None:

@property
def vars(self) -> Optional[np.ndarray]:
"""Channel variances, if self.stds is set."""
if self.stds is None:
return None
return self.stds**2
Expand Down
10 changes: 7 additions & 3 deletions rastervision_core/rastervision/core/utils/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,17 @@ def parse_stac(stac_uri: str, item_limit: Optional[int] = None) -> List[dict]:

def read_stac(uri: str, extract_dir: Optional[str] = None,
**kwargs) -> List[dict]:
"""Parse the contents of a STAC catalog (downloading it first, if
remote). If the uri is a zip file, unzip it, find catalog.json inside it
and parse that.
"""Parse the contents of a STAC catalog.
The file is downloaded if needed. If it is a zip file, it is unzipped and
the catalog.json inside is parsed.
Args:
uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip
file containing a STAC catalog JSON file.
extract_dir (Optional[str]): Dir to extract to, if URI is a zip file.
If None, a temporary dir will be used. Defaults to None.
**kwargs: Extra args for :func:`.parse_stac`.
Raises:
FileNotFoundError: If catalog.json is not found inside the zip file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ class TransformType(Enum):
semantic_segmentation = 'semantic_segmentation'


def apply_transform(transform: A.BasicTransform,
**kwargs) -> Callable[..., dict]:
def apply_transform(transform: A.BasicTransform, **kwargs) -> dict:
"""Apply Albumentations transform to possibly batched images.
In case of batched images, the same transform is applied to all of them.
This is useful for when the images represent a time-series.
Args:
transform (A.BasicTransform): An albumentations transform.
**kwargs: Extra args for ``transform``.
Returns:
dict: Output of ``transform``. If ndim == 4, the transformed image in
the dict is also 4-dimensional.
"""
img = kwargs['image']
if img.ndim == 3:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ def get_batch(self, dataset: 'Dataset', batch_sz: int = 4,
This is a convenience method for generating a batch of data to plot.
Returns (x, y) tuple where x is images and y is labels
Args:
dataset (Dataset): A Pytorch Datset.
batch_sz (int): Batch size. Defaults to 4.
**kwargs: Extra args for :class:`~torch.utils.data.DataLoader`.
Returns:
Tuple[Tensor, Any]: (x, y) tuple where x is images and y is labels.
"""
collate_fn = self.get_collate_fn()
dl = DataLoader(dataset, batch_sz, collate_fn=collate_fn, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def from_model_bundle(cls: Type,
inference rather than the PyTorch weights. Defaults to the
boolean environment variable RASTERVISION_USE_ONNX if set,
False otherwise.
**kwargs: See :meth:`.Learner.__init__`.
**kwargs: Extra args for :meth:`.__init__`.
Raises:
FileNotFoundError: If using custom Albumentations transforms and
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def export_to_onnx(self,
sample_input: Optional[Tensor] = None,
validate_export: bool = True,
**kwargs) -> None:
"""Export model to ONNX format via torch.onnx.export.
"""Export model to ONNX format via :func:`torch.onnx.export`.
Args:
path (str): File path to save to.
Expand All @@ -1110,11 +1110,12 @@ def export_to_onnx(self,
sample_input (Optional[Tensor]): Sample input to the model. If
None, a single batch from any available DataLoader in this
Learner will be used. Defaults to None.
validate_export (bool): If True, use onnx.checker.check_model to
validate exported model. An exception is raised if the check
fails. Defaults to True.
**kwargs (dict): Keyword args to pass to torch.onnx.export. These
override the default values used in the function definition.
validate_export (bool): If True, use
:func:`onnx.checker.check_model` to validate exported model.
An exception is raised if the check fails. Defaults to True.
**kwargs (dict): Keyword args to pass to :func:`torch.onnx.export`.
These override the default values used in the function
definition.
Raises:
ValueError: If sample_input is None and the Learner has no valid
Expand Down Expand Up @@ -1290,7 +1291,12 @@ def load_init_weights(self,
self.load_weights(uri=uri, **args)

def load_weights(self, uri: str, **kwargs) -> None:
"""Load model weights from a file."""
"""Load model weights from a file.
Args:
uri (str): URI.
**kwargs: Extra args for :meth:`nn.Module.load_state_dict`.
"""
weights_path = download_if_needed(uri)
self.model.load_state_dict(
torch.load(weights_path, map_location=self.device), **kwargs)
Expand Down
Loading

0 comments on commit 474dbe6

Please sign in to comment.