Skip to content

Commit

Permalink
disambiguate the two BoxList classes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jul 21, 2023
1 parent 61ec67a commit 9b4a149
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rastervision.core.box import Box
from rastervision.core.data.label.labels import Labels
from rastervision.core.data.label.tfod_utils.np_box_list import BoxList
from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList
from rastervision.core.data.label.tfod_utils.np_box_list_ops import (
prune_non_overlapping_boxes, clip_to_window, concatenate,
non_max_suppression)
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self,
class_ids: int numpy array of size n with class ids
scores: float numpy array of size n
"""
self.boxlist = BoxList(npboxes)
self.boxlist = NpBoxList(npboxes)
# This field name actually needs to be 'classes' to be able to use
# certain utility functions in the TF Object Detection API.
self.boxlist.add_field('classes', class_ids)
Expand Down Expand Up @@ -103,7 +103,7 @@ def make_empty(cls) -> 'ObjectDetectionLabels':
return cls(npboxes, class_ids, scores)

@staticmethod
def from_boxlist(boxlist: BoxList):
def from_boxlist(boxlist: NpBoxList):
"""Make ObjectDetectionLabels from BoxList object."""
scores = (boxlist.get_field('scores')
if boxlist.has_field('scores') else None)
Expand Down Expand Up @@ -166,7 +166,7 @@ def __len__(self) -> int:
def __str__(self) -> str:
return str(self.boxlist.get())

def to_boxlist(self) -> BoxList:
def to_boxlist(self) -> NpBoxList:
return self.boxlist

def to_dict(self, round_boxes: bool = True) -> dict:
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_overlapping(labels: 'ObjectDetectionLabels',
clip: If True, clip label boxes to the window.
"""
window_npbox = window.npbox_format()
window_boxlist = BoxList(np.expand_dims(window_npbox, axis=0))
window_boxlist = NpBoxList(np.expand_dims(window_npbox, axis=0))
boxlist = prune_non_overlapping_boxes(
labels.boxlist, window_boxlist, minoverlap=ioa_thresh)
if clip:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


class BoxList(object):
class NpBoxList(object):
"""A list of bounding boxes as a [y_min, x_min, y_max, x_max] numpy array.
It is assumed that all bounding boxes within a given list correspond to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tqdm.auto import trange
import numpy as np

from rastervision.core.data.label.tfod_utils.np_box_list import BoxList
from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList
from rastervision.core.data.label.tfod_utils import np_box_ops


Expand All @@ -33,7 +33,7 @@ class SortOrder(object):
DESCEND = 2


def area(boxlist: BoxList) -> np.ndarray:
def area(boxlist: NpBoxList) -> np.ndarray:
"""Computes area of boxes.
Args:
Expand All @@ -46,7 +46,7 @@ def area(boxlist: BoxList) -> np.ndarray:
return (y_max - y_min) * (x_max - x_min)


def intersection(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
def intersection(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray:
"""Compute pairwise intersection areas between boxes.
Args:
Expand All @@ -60,7 +60,7 @@ def intersection(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
return np_box_ops.intersection(boxlist1.get(), boxlist2.get())


def iou(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
def iou(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray:
"""Computes pairwise intersection-over-union between box collections.
Args:
Expand All @@ -73,7 +73,7 @@ def iou(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
return np_box_ops.iou(boxlist1.get(), boxlist2.get())


def ioa(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
def ioa(boxlist1: NpBoxList, boxlist2: NpBoxList) -> np.ndarray:
"""Computes pairwise intersection-over-area between box collections.
Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
Expand All @@ -91,9 +91,9 @@ def ioa(boxlist1: BoxList, boxlist2: BoxList) -> np.ndarray:
return np_box_ops.ioa(boxlist1.get(), boxlist2.get())


def gather(boxlist: BoxList,
def gather(boxlist: NpBoxList,
indices: np.ndarray,
fields: Optional[List[str]] = None) -> BoxList:
fields: Optional[List[str]] = None) -> NpBoxList:
"""Gather boxes from BoxList according to indices and return new BoxList.
By default, gather returns boxes corresponding to the input index list, as
Expand All @@ -119,7 +119,7 @@ def gather(boxlist: BoxList,
if indices.size:
if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0:
raise ValueError('indices are out of valid range.')
subboxlist = BoxList(boxlist.get()[indices, :])
subboxlist = NpBoxList(boxlist.get()[indices, :])
if fields is None:
fields = boxlist.get_extra_fields()
for field in fields:
Expand All @@ -128,7 +128,7 @@ def gather(boxlist: BoxList,
return subboxlist


def sort_by_field(boxlist: BoxList,
def sort_by_field(boxlist: NpBoxList,
field: str,
order: SortOrder = SortOrder.DESCEND):
"""Sort boxes and associated fields according to a scalar field.
Expand Down Expand Up @@ -161,10 +161,10 @@ def sort_by_field(boxlist: BoxList,
return gather(boxlist, sorted_indices)


def non_max_suppression(boxlist: BoxList,
def non_max_suppression(boxlist: NpBoxList,
max_output_size: int = 10_000,
iou_threshold: float = 1.0,
score_threshold: float = -10.0) -> BoxList:
score_threshold: float = -10.0) -> NpBoxList:
"""Non maximum suppression.
This op greedily selects a subset of detection bounding boxes, pruning
away boxes that have high IOU (intersection over union) overlap (> thresh)
Expand Down Expand Up @@ -239,9 +239,9 @@ def non_max_suppression(boxlist: BoxList,
return gather(boxlist, np.array(selected_indices))


def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float,
def multi_class_non_max_suppression(boxlist: NpBoxList, score_thresh: float,
iou_thresh: float,
max_output_size: int) -> BoxList:
max_output_size: int) -> NpBoxList:
"""Multi-class version of non maximum suppression.
This op greedily selects a subset of detection bounding boxes, pruning
Expand Down Expand Up @@ -277,7 +277,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float,
"""
if not 0 <= iou_thresh <= 1.0:
raise ValueError('thresh must be between 0 and 1')
if not isinstance(boxlist, BoxList):
if not isinstance(boxlist, NpBoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field('scores'):
raise ValueError('input boxlist must have \'scores\' field')
Expand All @@ -300,7 +300,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float,

selected_boxes_list = []
for class_idx in range(num_classes):
boxlist_and_class_scores = BoxList(boxlist.get())
boxlist_and_class_scores = NpBoxList(boxlist.get())
class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
boxlist_and_class_scores.add_field('scores', class_scores)
boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores,
Expand All @@ -319,7 +319,7 @@ def multi_class_non_max_suppression(boxlist: BoxList, score_thresh: float,
return sorted_boxes


def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList:
def scale(boxlist: NpBoxList, y_scale: float, x_scale: float) -> NpBoxList:
"""Scale box coordinates in x and y dimensions.
Args:
Expand All @@ -335,7 +335,7 @@ def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList:
y_max = y_scale * y_max
x_min = x_scale * x_min
x_max = x_scale * x_max
scaled_boxlist = BoxList(np.hstack([y_min, x_min, y_max, x_max]))
scaled_boxlist = NpBoxList(np.hstack([y_min, x_min, y_max, x_max]))

fields = boxlist.get_extra_fields()
for field in fields:
Expand All @@ -345,7 +345,7 @@ def scale(boxlist: BoxList, y_scale: float, x_scale: float) -> BoxList:
return scaled_boxlist


def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList:
def clip_to_window(boxlist: NpBoxList, window: np.ndarray) -> NpBoxList:
"""Clip bounding boxes to a window.
This op clips input bounding boxes (represented by bounding box
Expand All @@ -370,7 +370,7 @@ def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList:
y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min)
x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min)
x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min)
clipped = BoxList(
clipped = NpBoxList(
np.hstack([y_min_clipped, x_min_clipped, y_max_clipped,
x_max_clipped]))
clipped = _copy_extra_fields(clipped, boxlist)
Expand All @@ -380,9 +380,9 @@ def clip_to_window(boxlist: BoxList, window: np.ndarray) -> BoxList:
return gather(clipped, nonzero_area_indices)


def prune_non_overlapping_boxes(boxlist1: BoxList,
boxlist2: BoxList,
minoverlap: float = 0.0) -> BoxList:
def prune_non_overlapping_boxes(boxlist1: NpBoxList,
boxlist2: NpBoxList,
minoverlap: float = 0.0) -> NpBoxList:
"""Prunes boxes with insufficient overlap b/w boxlists.
Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
Expand All @@ -407,8 +407,8 @@ def prune_non_overlapping_boxes(boxlist1: BoxList,
return new_boxlist1


def prune_outside_window(boxlist: BoxList,
window: np.ndarray) -> Tuple[BoxList, np.ndarray]:
def prune_outside_window(boxlist: NpBoxList,
window: np.ndarray) -> Tuple[NpBoxList, np.ndarray]:
"""Prunes bounding boxes that fall outside a given window.
This function prunes bounding boxes that even partially fall outside the
Expand Down Expand Up @@ -444,8 +444,8 @@ def prune_outside_window(boxlist: BoxList,
return pruned_boxlist, valid_indices


def concatenate(boxlists: List[BoxList],
fields: Optional[List[str]] = None) -> BoxList:
def concatenate(boxlists: List[NpBoxList],
fields: Optional[List[str]] = None) -> NpBoxList:
"""Concatenate list of BoxLists.
This op concatenates a list of input BoxLists into a larger BoxList. It also
Expand All @@ -472,10 +472,11 @@ def concatenate(boxlists: List[BoxList],
if not boxlists:
raise ValueError('boxlists should have nonzero length')
for boxlist in boxlists:
if not isinstance(boxlist, BoxList):
if not isinstance(boxlist, NpBoxList):
raise ValueError(
'all elements of boxlists should be BoxList objects')
concatenated = BoxList(np.vstack([boxlist.get() for boxlist in boxlists]))
concatenated = NpBoxList(
np.vstack([boxlist.get() for boxlist in boxlists]))
if fields is None:
fields = boxlists[0].get_extra_fields()
for field in fields:
Expand All @@ -496,7 +497,7 @@ def concatenate(boxlists: List[BoxList],
return concatenated


def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList:
def filter_scores_greater_than(boxlist: NpBoxList, thresh: float) -> NpBoxList:
"""Filter to keep only boxes with score exceeding a given threshold.
This op keeps the collection of boxes whose corresponding scores are
Expand All @@ -514,7 +515,7 @@ def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList:
ValueError: If boxlist not a BoxList object or if it does not have a
scores field.
"""
if not isinstance(boxlist, BoxList):
if not isinstance(boxlist, NpBoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field('scores'):
raise ValueError('input boxlist must have \'scores\' field')
Expand All @@ -529,7 +530,8 @@ def filter_scores_greater_than(boxlist: BoxList, thresh: float) -> BoxList:
return gather(boxlist, high_score_indices)


def change_coordinate_frame(boxlist: BoxList, window: np.ndarray) -> BoxList:
def change_coordinate_frame(boxlist: NpBoxList,
window: np.ndarray) -> NpBoxList:
"""Change coordinate frame of the boxlist to be relative to window's frame.
Given a window of the form [ymin, xmin, ymax, xmax],
Expand All @@ -550,15 +552,16 @@ def change_coordinate_frame(boxlist: BoxList, window: np.ndarray) -> BoxList:
win_height = window[2] - window[0]
win_width = window[3] - window[1]
boxlist_new = scale(
BoxList(boxlist.get() - [window[0], window[1], window[0], window[1]]),
NpBoxList(boxlist.get() -
[window[0], window[1], window[0], window[1]]),
1.0 / win_height, 1.0 / win_width)
_copy_extra_fields(boxlist_new, boxlist)

return boxlist_new


def _copy_extra_fields(boxlist_to_copy_to: BoxList,
boxlist_to_copy_from: BoxList) -> BoxList:
def _copy_extra_fields(boxlist_to_copy_to: NpBoxList,
boxlist_to_copy_from: NpBoxList) -> NpBoxList:
"""Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
Args:
Expand Down
4 changes: 2 additions & 2 deletions tests/core/data/label/test_object_detection_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rastervision.core.data.class_config import ClassConfig
from rastervision.core.data.label.object_detection_labels import (
ObjectDetectionLabels)
from rastervision.core.data.label.tfod_utils.np_box_list import BoxList
from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList


class TestObjectDetectionLabels(unittest.TestCase):
Expand All @@ -22,7 +22,7 @@ def setUp(self):
self.npboxes, self.class_ids, scores=self.scores)

def test_from_boxlist(self):
boxlist = BoxList(self.npboxes)
boxlist = NpBoxList(self.npboxes)
boxlist.add_field('classes', self.class_ids)
boxlist.add_field('scores', self.scores)
labels = ObjectDetectionLabels.from_boxlist(boxlist)
Expand Down

0 comments on commit 9b4a149

Please sign in to comment.