From 9c72517483405e6442ad1f255156dab95cd2d17e Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer Date: Mon, 30 May 2022 11:36:50 +0200 Subject: [PATCH 1/2] Fix warning with torch.meshgrid --- mmtrack/core/motion/flow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmtrack/core/motion/flow.py b/mmtrack/core/motion/flow.py index 1fd25d735..ac4371896 100644 --- a/mmtrack/core/motion/flow.py +++ b/mmtrack/core/motion/flow.py @@ -22,7 +22,8 @@ def flow_warp_feats(x, flow): # 2. compute the flow_field (grid in the code) used to warp features. H, W = x.shape[-2:] - h_grid, w_grid = torch.meshgrid(torch.arange(H), torch.arange(W)) + h_grid, w_grid = torch.meshgrid( + torch.arange(H), torch.arange(W), indexing='ij') # [1, 1, H, W] h_grid = h_grid.to(flow)[None, None, ...] # [1, 1, H, W] From 9c907bab84e2a0e25005c7317b18fff16fd0e27d Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer Date: Tue, 7 Jun 2022 11:51:25 +0200 Subject: [PATCH 2/2] Add torch_meshgrid_ij wrapper --- mmtrack/core/motion/flow.py | 5 +++-- mmtrack/core/utils/misc.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mmtrack/core/motion/flow.py b/mmtrack/core/motion/flow.py index ac4371896..817a4da9c 100644 --- a/mmtrack/core/motion/flow.py +++ b/mmtrack/core/motion/flow.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from ..utils.misc import torch_meshgrid_ij + def flow_warp_feats(x, flow): """Use flow to warp feature map. @@ -22,8 +24,7 @@ def flow_warp_feats(x, flow): # 2. compute the flow_field (grid in the code) used to warp features. H, W = x.shape[-2:] - h_grid, w_grid = torch.meshgrid( - torch.arange(H), torch.arange(W), indexing='ij') + h_grid, w_grid = torch_meshgrid_ij(torch.arange(H), torch.arange(W)) # [1, 1, H, W] h_grid = h_grid.to(flow)[None, None, ...] # [1, 1, H, W] diff --git a/mmtrack/core/utils/misc.py b/mmtrack/core/utils/misc.py index 8cf7b325d..49f479dd8 100644 --- a/mmtrack/core/utils/misc.py +++ b/mmtrack/core/utils/misc.py @@ -5,6 +5,8 @@ import warnings import cv2 +import torch +from packaging import version def setup_multi_processes(cfg): @@ -37,3 +39,14 @@ def setup_multi_processes(cfg): f'overloaded, please further tune the variable for optimal ' f'performance in your application as needed.') os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +_torch_version_meshgrid_indexing = version.parse( + torch.__version__) >= version.parse('1.10.0a0') + + +def torch_meshgrid_ij(*tensors): + if _torch_version_meshgrid_indexing: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) # Uses indexing='ij' by default