Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions unstructured_inference/models/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
download_if_needed_and_get_local_path,
)

# Small cache for grids keyed by (hsize, wsize) to avoid rebuilding meshgrids repeatedly.
# Caching preserves the original default integer dtypes produced by np.arange()/np.meshgrid.
_GRID_CACHE: dict = {}

YOLOX_LABEL_MAP = {
0: ElementType.CAPTION,
1: ElementType.FOOTNOTE,
Expand Down Expand Up @@ -177,17 +181,36 @@ def demo_postprocess(outputs, img_size, p6=False):
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]


# Instead of building full concatenated grids and expanded_strides arrays,
# operate on corresponding slices of `outputs` per stride. This avoids
# large intermediate allocations while preserving the original computation
# order and dtypes.
start = 0
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
num = hsize * wsize
if num == 0:
# If either dimension is zero, skip this stride (keeps behavior safe).
continue

# Retrieve cached integer grid of shape (1, num, 2)
grid = _get_grid(hsize, wsize)

end = start + num

grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
# Replicate original expression and ordering:
# outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
# Using per-slice computation to avoid creating the big concatenated arrays.
slice_xy = outputs[..., start:end, :2]
# Compute (slice_xy + grid) * stride with the same expression order as before.
outputs[..., start:end, :2] = (slice_xy + grid) * stride

# Replicate original expression:
# outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
slice_wh = outputs[..., start:end, 2:4]
outputs[..., start:end, 2:4] = np.exp(slice_wh) * stride

start = end

return outputs

Expand Down Expand Up @@ -247,3 +270,17 @@ def nms(boxes, scores, nms_thr):
order = order[inds + 1]

return keep


def _get_grid(hsize: int, wsize: int) -> np.ndarray:
"""
Return a cached grid of shape (1, hsize*wsize, 2) with integer dtype,
matching the original meshgrid + stack + reshape behavior.
"""
key = (hsize, wsize)
grid = _GRID_CACHE.get(key)
if grid is None:
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
_GRID_CACHE[key] = grid
return grid