diff --git a/mmtrack/core/motion/flow.py b/mmtrack/core/motion/flow.py index 1fd25d735..817a4da9c 100644 --- a/mmtrack/core/motion/flow.py +++ b/mmtrack/core/motion/flow.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from ..utils.misc import torch_meshgrid_ij + def flow_warp_feats(x, flow): """Use flow to warp feature map. @@ -22,7 +24,7 @@ def flow_warp_feats(x, flow): # 2. compute the flow_field (grid in the code) used to warp features. H, W = x.shape[-2:] - h_grid, w_grid = torch.meshgrid(torch.arange(H), torch.arange(W)) + h_grid, w_grid = torch_meshgrid_ij(torch.arange(H), torch.arange(W)) # [1, 1, H, W] h_grid = h_grid.to(flow)[None, None, ...] # [1, 1, H, W] diff --git a/mmtrack/core/utils/misc.py b/mmtrack/core/utils/misc.py index 8cf7b325d..49f479dd8 100644 --- a/mmtrack/core/utils/misc.py +++ b/mmtrack/core/utils/misc.py @@ -5,6 +5,8 @@ import warnings import cv2 +import torch +from packaging import version def setup_multi_processes(cfg): @@ -37,3 +39,14 @@ def setup_multi_processes(cfg): f'overloaded, please further tune the variable for optimal ' f'performance in your application as needed.') os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +_torch_version_meshgrid_indexing = version.parse( + torch.__version__) >= version.parse('1.10.0a0') + + +def torch_meshgrid_ij(*tensors): + if _torch_version_meshgrid_indexing: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) # Uses indexing='ij' by default