Skip to content

Commit

Permalink
misc doc and type hint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Aug 23, 2023
1 parent 41c41cf commit c3ab8e4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/usage/tutorials/pred_and_eval_ss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
},
"source": [
"Load a :class:`Learner` with a trained model from bundle -- :meth:`.Learner.from_model_bundle`\n",
"---------------------------------------------------------------------------------------------"
"----------------------------------------------------------------------------------------------"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import (TYPE_CHECKING, Any, Iterable, List, Optional, Sequence,
Union)
from typing import (TYPE_CHECKING, Any, Iterable, List, Optional, Sequence)
from abc import abstractmethod

import numpy as np
Expand All @@ -23,8 +22,8 @@ def __init__(self, extent: Box, num_classes: int, dtype: np.dtype):
"""Constructor.
Args:
extent (Box): The extent of the region to which
the labels belong, in global coordinates.
extent (Box): The extent of the region to which the labels belong,
in global coordinates.
num_classes (int): Number of classes.
"""
self.extent = extent
Expand Down Expand Up @@ -143,9 +142,8 @@ def transform_shape(x, y, z=None):
del self[window]

@classmethod
def make_empty(cls, extent: Box, num_classes: int, smooth: bool = False
) -> Union['SemanticSegmentationDiscreteLabels',
'SemanticSegmentationSmoothLabels']:
def make_empty(cls, extent: Box, num_classes: int,
smooth: bool = False) -> 'SemanticSegmentationLabels':
"""Instantiate an empty instance.
Args:
Expand All @@ -157,8 +155,7 @@ def make_empty(cls, extent: Box, num_classes: int, smooth: bool = False
SemanticSegmentationDiscreteLabels object. Defaults to False.
Returns:
Union[SemanticSegmentationDiscreteLabels,
SemanticSegmentationSmoothLabels]: If smooth=True, returns a
SemanticSegmentationLabels: If smooth=True, returns a
SemanticSegmentationSmoothLabels. Otherwise, a
SemanticSegmentationDiscreteLabels.
Expand All @@ -174,15 +171,14 @@ def make_empty(cls, extent: Box, num_classes: int, smooth: bool = False
extent=extent, num_classes=num_classes)

@classmethod
def from_predictions(cls,
windows: Iterable['Box'],
predictions: Iterable[Any],
extent: Box,
num_classes: int,
smooth: bool = False,
crop_sz: Optional[int] = None
) -> Union['SemanticSegmentationDiscreteLabels',
'SemanticSegmentationSmoothLabels']:
def from_predictions(
cls,
windows: Iterable['Box'],
predictions: Iterable[Any],
extent: Box,
num_classes: int,
smooth: bool = False,
crop_sz: Optional[int] = None) -> 'SemanticSegmentationLabels':
"""Instantiate from windows and their corresponding predictions.
Args:
Expand All @@ -202,8 +198,7 @@ def from_predictions(cls,
windows. Defaults to None.
Returns:
Union[SemanticSegmentationDiscreteLabels,
SemanticSegmentationSmoothLabels]: If smooth=True, returns a
SemanticSegmentationLabels: If smooth=True, returns a
SemanticSegmentationSmoothLabels. Otherwise, a
SemanticSegmentationDiscreteLabels.
"""
Expand Down Expand Up @@ -349,8 +344,7 @@ def from_predictions(cls,
extent: Box,
num_classes: int,
crop_sz: Optional[int] = None
) -> Union['SemanticSegmentationDiscreteLabels',
'SemanticSegmentationSmoothLabels']:
) -> 'SemanticSegmentationDiscreteLabels':
labels = cls.make_empty(extent, num_classes)
labels.add_predictions(windows, predictions, crop_sz=crop_sz)
return labels
Expand Down Expand Up @@ -522,8 +516,7 @@ def from_predictions(cls,
extent: Box,
num_classes: int,
crop_sz: Optional[int] = None
) -> Union['SemanticSegmentationDiscreteLabels',
'SemanticSegmentationSmoothLabels']:
) -> 'SemanticSegmentationSmoothLabels':
labels = cls.make_empty(extent, num_classes)
labels.add_predictions(windows, predictions, crop_sz=crop_sz)
return labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1015,12 +1015,12 @@ def get_data_dirs(self, uri: Union[str, List[str]],
Args:
uri (Union[str, List[str]]): a URI or a list of URIs of one of the
following:
(1) a URI of a directory containing "train", "valid", and
(optinally) "test" subdirectories
(2) a URI of a zip file containing (1)
(3) a list of (2)
(4) a URI of a directory containing zip files
containing (1)
(1) a URI of a directory containing "train", "valid", and
(optinally) "test" subdirectories
(2) a URI of a zip file containing (1)
(3) a list of (2)
(4) a URI of a directory containing zip files containing (1)
Returns:
paths to directories that each contain contents of one zip file
Expand Down

0 comments on commit c3ab8e4

Please sign in to comment.