diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 932242ec..a17fd218 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -20,6 +20,10 @@ download_if_needed_and_get_local_path, ) +# Small cache for grids keyed by (hsize, wsize) to avoid rebuilding meshgrids repeatedly. +# Caching preserves the original default integer dtypes produced by np.arange()/np.meshgrid. +_GRID_CACHE: dict = {} + YOLOX_LABEL_MAP = { 0: ElementType.CAPTION, 1: ElementType.FOOTNOTE, @@ -177,17 +181,36 @@ def demo_postprocess(outputs, img_size, p6=False): hsizes = [img_size[0] // stride for stride in strides] wsizes = [img_size[1] // stride for stride in strides] + + # Instead of building full concatenated grids and expanded_strides arrays, + # operate on corresponding slices of `outputs` per stride. This avoids + # large intermediate allocations while preserving the original computation + # order and dtypes. + start = 0 for hsize, wsize, stride in zip(hsizes, wsizes, strides): - xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) - grid = np.stack((xv, yv), 2).reshape(1, -1, 2) - grids.append(grid) - shape = grid.shape[:2] - expanded_strides.append(np.full((*shape, 1), stride)) + num = hsize * wsize + if num == 0: + # If either dimension is zero, skip this stride (keeps behavior safe). + continue + + # Retrieve cached integer grid of shape (1, num, 2) + grid = _get_grid(hsize, wsize) + + end = start + num - grids = np.concatenate(grids, 1) - expanded_strides = np.concatenate(expanded_strides, 1) - outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides - outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + # Replicate original expression and ordering: + # outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + # Using per-slice computation to avoid creating the big concatenated arrays. + slice_xy = outputs[..., start:end, :2] + # Compute (slice_xy + grid) * stride with the same expression order as before. + outputs[..., start:end, :2] = (slice_xy + grid) * stride + + # Replicate original expression: + # outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + slice_wh = outputs[..., start:end, 2:4] + outputs[..., start:end, 2:4] = np.exp(slice_wh) * stride + + start = end return outputs @@ -247,3 +270,17 @@ def nms(boxes, scores, nms_thr): order = order[inds + 1] return keep + + +def _get_grid(hsize: int, wsize: int) -> np.ndarray: + """ + Return a cached grid of shape (1, hsize*wsize, 2) with integer dtype, + matching the original meshgrid + stack + reshape behavior. + """ + key = (hsize, wsize) + grid = _GRID_CACHE.get(key) + if grid is None: + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + _GRID_CACHE[key] = grid + return grid