From 8bac07d6ee30fe39eeba2a170f82d2f1e776ceb9 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Fri, 21 Jul 2023 12:27:05 -0400 Subject: [PATCH] disambiguate the two BoxList classes --- .../data/label/object_detection_labels.py | 10 +-- .../core/data/label/tfod_utils/np_box_list.py | 2 +- .../data/label/tfod_utils/np_box_list_ops.py | 71 ++++++++++--------- .../label/test_object_detection_labels.py | 4 +- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/rastervision_core/rastervision/core/data/label/object_detection_labels.py b/rastervision_core/rastervision/core/data/label/object_detection_labels.py index 0093d6703..d3289a164 100644 --- a/rastervision_core/rastervision/core/data/label/object_detection_labels.py +++ b/rastervision_core/rastervision/core/data/label/object_detection_labels.py @@ -4,7 +4,7 @@ from rastervision.core.box import Box from rastervision.core.data.label.labels import Labels -from rastervision.core.data.label.tfod_utils.np_box_list import BoxList +from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList from rastervision.core.data.label.tfod_utils.np_box_list_ops import ( prune_non_overlapping_boxes, clip_to_window, concatenate, non_max_suppression) @@ -33,7 +33,7 @@ def __init__(self, class_ids: int numpy array of size n with class ids scores: float numpy array of size n """ - self.boxlist = BoxList(npboxes) + self.boxlist = NpBoxList(npboxes) # This field name actually needs to be 'classes' to be able to use # certain utility functions in the TF Object Detection API. self.boxlist.add_field('classes', class_ids) @@ -103,7 +103,7 @@ def make_empty(cls) -> 'ObjectDetectionLabels': return cls(npboxes, class_ids, scores) @staticmethod - def from_boxlist(boxlist: BoxList): + def from_boxlist(boxlist: NpBoxList): """Make ObjectDetectionLabels from BoxList object.""" scores = (boxlist.get_field('scores') if boxlist.has_field('scores') else None) @@ -166,7 +166,7 @@ def __len__(self) -> int: def __str__(self) -> str: return str(self.boxlist.get()) - def to_boxlist(self) -> BoxList: + def to_boxlist(self) -> NpBoxList: return self.boxlist def to_dict(self, round_boxes: bool = True) -> dict: @@ -245,7 +245,7 @@ def get_overlapping(labels: 'ObjectDetectionLabels', clip: If True, clip label boxes to the window. """ window_npbox = window.npbox_format() - window_boxlist = BoxList(np.expand_dims(window_npbox, axis=0)) + window_boxlist = NpBoxList(np.expand_dims(window_npbox, axis=0)) boxlist = prune_non_overlapping_boxes( labels.boxlist, window_boxlist, minoverlap=ioa_thresh) if clip: diff --git a/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list.py b/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list.py index cf89338cf..7e9a572db 100644 --- a/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list.py +++ b/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list.py @@ -19,7 +19,7 @@ import numpy as np -class BoxList(object): +class NpBoxList(object): """A list of bounding boxes as a [y_min, x_min, y_max, x_max] numpy array. It is assumed that all bounding boxes within a given list correspond to a diff --git a/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list_ops.py b/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list_ops.py index 5a39de9eb..a84e9fe64 100644 --- a/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list_ops.py +++ b/rastervision_core/rastervision/core/data/label/tfod_utils/np_box_list_ops.py @@ -18,7 +18,7 @@ from tqdm.auto import trange import numpy as np -from rastervision.core.data.label.tfod_utils.np_box_list import BoxList +from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList from rastervision.core.data.label.tfod_utils import np_box_ops @@ -33,7 +33,7 @@ class SortOrder(object): DESCEND = 2 -def area(boxlist: BoxList) -> np.ndarray: +def area(boxlist: NpBoxList) -> np.ndarray: """Computes area of boxes. Args: @@ -46,7 +46,7 @@ def area(boxlist: BoxList) -> np.ndarray: return (y_max - y_min) * (x_max - x_min) -def intersection(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: +def intersection(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray: """Compute pairwise intersection areas between boxes. Args: @@ -60,7 +60,7 @@ def intersection(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: return np_box_ops.intersection(boxlist1.get(), boxlist2.get()) -def iou(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: +def iou(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray: """Computes pairwise intersection-over-union between box collections. Args: @@ -73,7 +73,7 @@ def iou(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: return np_box_ops.iou(boxlist1.get(), boxlist2.get()) -def ioa(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: +def ioa(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray: """Computes pairwise intersection-over-area between box collections. Intersection-over-area (ioa) between two boxes box1 and box2 is defined as @@ -91,9 +91,9 @@ def ioa(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray: return np_box_ops.ioa(boxlist1.get(), boxlist2.get()) -def gather(boxlist: BoxList, +def gather(boxlist: NpBoxList, indices: np.ndarray, - fields: Optional[List[str]] = None) -> BoxList: + fields: Optional[List[str]] = None) -> NpBoxList: """Gather boxes from BoxList according to indices and return new BoxList. By default, gather returns boxes corresponding to the input index list, as @@ -119,7 +119,7 @@ def gather(boxlist: BoxList, if indices.size: if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0: raise ValueError('indices are out of valid range.') - subboxlist = BoxList(boxlist.get()[indices, :]) + subboxlist = NpBoxList(boxlist.get()[indices, :]) if fields is None: fields = boxlist.get_extra_fields() for field in fields: @@ -128,7 +128,7 @@ def gather(boxlist: BoxList, return subboxlist -def sort_by_field(boxlist: BoxList, +def sort_by_field(boxlist: NpBoxList, field: str, order: SortOrder = SortOrder.DESCEND): """Sort boxes and associated fields according to a scalar field. @@ -161,10 +161,10 @@ def sort_by_field(boxlist: BoxList, return gather(boxlist, sorted_indices) -def non_max_suppression(boxlist: BoxList, +def non_max_suppression(boxlist: NpBoxList, max_output_size: int = 10_000, iou_threshold: float = 1.0, - score_threshold: float = -10.0) -> BoxList: + score_threshold: float = -10.0) -> NpBoxList: """Non maximum suppression. This op greedily selects a subset of detection bounding boxes, pruning away boxes that have high IOU (intersection over union) overlap (> thresh) @@ -239,9 +239,9 @@ def non_max_suppression(boxlist: BoxList, return gather(boxlist, np.array(selected_indices)) -def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float, +def multi_class_non_max_suppression(boxlist: NpBoxList, score_thresh: float, iou_thresh: float, - max_output_size: int) -> BoxList: + max_output_size: int) -> NpBoxList: """Multi-class version of non maximum suppression. This op greedily selects a subset of detection bounding boxes, pruning @@ -277,7 +277,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float, """ if not 0 <= iou_thresh <= 1.0: raise ValueError('thresh must be between 0 and 1') - if not isinstance(boxlist, BoxList): + if not isinstance(boxlist, NpBoxList): raise ValueError('boxlist must be a BoxList') if not boxlist.has_field('scores'): raise ValueError('input boxlist must have \'scores\' field') @@ -300,7 +300,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float, selected_boxes_list = [] for class_idx in range(num_classes): - boxlist_and_class_scores = BoxList(boxlist.get()) + boxlist_and_class_scores = NpBoxList(boxlist.get()) class_scores = np.reshape(scores[0:num_scores, class_idx], [-1]) boxlist_and_class_scores.add_field('scores', class_scores) boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores, @@ -319,7 +319,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float, return sorted_boxes -def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList: +def scale(boxlist: NpBoxList, y_scale: float, x_scale: float) -> NpBoxList: """Scale box coordinates in x and y dimensions. Args: @@ -335,7 +335,7 @@ def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList: y_max = y_scale * y_max x_min = x_scale * x_min x_max = x_scale * x_max - scaled_boxlist = BoxList(np.hstack([y_min, x_min, y_max, x_max])) + scaled_boxlist = NpBoxList(np.hstack([y_min, x_min, y_max, x_max])) fields = boxlist.get_extra_fields() for field in fields: @@ -345,7 +345,7 @@ def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList: return scaled_boxlist -def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList: +def clip_to_window(boxlist: NpBoxList, window: np.ndarray) -> NpBoxList: """Clip bounding boxes to a window. This op clips input bounding boxes (represented by bounding box @@ -370,7 +370,7 @@ def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList: y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min) x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min) x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min) - clipped = BoxList( + clipped = NpBoxList( np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped])) clipped = _copy_extra_fields(clipped, boxlist) @@ -380,9 +380,9 @@ def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList: return gather(clipped, nonzero_area_indices) -def prune_non_overlapping_boxes(boxlist1: BoxList, - boxlist2: BoxList, - minoverlap: float = 0.0) -> BoxList: +def prune_non_overlapping_boxes(boxlist1: NpBoxList, + boxlist2: NpBoxList, + minoverlap: float = 0.0) -> NpBoxList: """Prunes boxes with insufficient overlap b/w boxlists. Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2. @@ -407,8 +407,8 @@ def prune_non_overlapping_boxes(boxlist1: BoxList, return new_boxlist1 -def prune_outside_window(boxlist: BoxList, - window: np.ndarray) -> Tuple[BoxList, np.ndarray]: +def prune_outside_window(boxlist: NpBoxList, + window: np.ndarray) -> Tuple[NpBoxList, np.ndarray]: """Prunes bounding boxes that fall outside a given window. This function prunes bounding boxes that even partially fall outside the @@ -444,8 +444,8 @@ def prune_outside_window(boxlist: BoxList, return pruned_boxlist, valid_indices -def concatenate(boxlists: List[BoxList], - fields: Optional[List[str]] = None) -> BoxList: +def concatenate(boxlists: List[NpBoxList], + fields: Optional[List[str]] = None) -> NpBoxList: """Concatenate list of BoxLists. This op concatenates a list of input BoxLists into a larger BoxList. It also @@ -472,10 +472,11 @@ def concatenate(boxlists: List[BoxList], if not boxlists: raise ValueError('boxlists should have nonzero length') for boxlist in boxlists: - if not isinstance(boxlist, BoxList): + if not isinstance(boxlist, NpBoxList): raise ValueError( 'all elements of boxlists should be BoxList objects') - concatenated = BoxList(np.vstack([boxlist.get() for boxlist in boxlists])) + concatenated = NpBoxList( + np.vstack([boxlist.get() for boxlist in boxlists])) if fields is None: fields = boxlists[0].get_extra_fields() for field in fields: @@ -496,7 +497,7 @@ def concatenate(boxlists: List[BoxList], return concatenated -def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList: +def filter_scores_greater_than(boxlist: NpBoxList, thresh: float) -> NpBoxList: """Filter to keep only boxes with score exceeding a given threshold. This op keeps the collection of boxes whose corresponding scores are @@ -514,7 +515,7 @@ def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList: ValueError: If boxlist not a BoxList object or if it does not have a scores field. """ - if not isinstance(boxlist, BoxList): + if not isinstance(boxlist, NpBoxList): raise ValueError('boxlist must be a BoxList') if not boxlist.has_field('scores'): raise ValueError('input boxlist must have \'scores\' field') @@ -529,7 +530,8 @@ def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList: return gather(boxlist, high_score_indices) -def change_coordinate_frame(boxlist: BoxList, window: np.ndarray) -> BoxList: +def change_coordinate_frame(boxlist: NpBoxList, + window: np.ndarray) -> NpBoxList: """Change coordinate frame of the boxlist to be relative to window's frame. Given a window of the form [ymin, xmin, ymax, xmax], @@ -550,15 +552,16 @@ def change_coordinate_frame(boxlist: BoxList, window: np.ndarray) -> BoxList: win_height = window[2] - window[0] win_width = window[3] - window[1] boxlist_new = scale( - BoxList(boxlist.get() - [window[0], window[1], window[0], window[1]]), + NpBoxList(boxlist.get() - + [window[0], window[1], window[0], window[1]]), 1.0 / win_height, 1.0 / win_width) _copy_extra_fields(boxlist_new, boxlist) return boxlist_new -def _copy_extra_fields(boxlist_to_copy_to: BoxList, - boxlist_to_copy_from: BoxList) -> BoxList: +def _copy_extra_fields(boxlist_to_copy_to: NpBoxList, + boxlist_to_copy_from: NpBoxList) -> NpBoxList: """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to. Args: diff --git a/tests/core/data/label/test_object_detection_labels.py b/tests/core/data/label/test_object_detection_labels.py index d9044a8e0..c1da2bde3 100644 --- a/tests/core/data/label/test_object_detection_labels.py +++ b/tests/core/data/label/test_object_detection_labels.py @@ -6,7 +6,7 @@ from rastervision.core.data.class_config import ClassConfig from rastervision.core.data.label.object_detection_labels import ( ObjectDetectionLabels) -from rastervision.core.data.label.tfod_utils.np_box_list import BoxList +from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList class TestObjectDetectionLabels(unittest.TestCase): @@ -22,7 +22,7 @@ def setUp(self): self.npboxes, self.class_ids, scores=self.scores) def test_from_boxlist(self): - boxlist = BoxList(self.npboxes) + boxlist = NpBoxList(self.npboxes) boxlist.add_field('classes', self.class_ids) boxlist.add_field('scores', self.scores) labels = ObjectDetectionLabels.from_boxlist(boxlist)