11# Copyright (c) OpenMMLab. All rights reserved.
22import os
33import os .path as osp
4- import sys
54import tempfile
65from contextlib import contextmanager
76from copy import deepcopy
@@ -49,62 +48,75 @@ def build_temporary_directory():
4948 yield tmp_dir
5049
5150
51+ # Check if petrel_client is available
52+ PETREL_CLIENT_AVAILABLE = False
5253try :
53- # Other unit tests may mock these modules so we need to pop them first.
54- sys .modules .pop ('petrel_client' , None )
55- sys .modules .pop ('petrel_client.client' , None )
56-
57- # If petrel_client is imported successfully, we can test PetrelBackend
58- # without mock.
5954 import petrel_client # noqa: F401
60- except ImportError :
61- sys .modules ['petrel_client' ] = MagicMock ()
62- sys .modules ['petrel_client.client' ] = MagicMock ()
63-
64- class MockPetrelClient :
65-
66- def __init__ (self ,
67- enable_mc = True ,
68- enable_multi_cluster = False ,
69- conf_path = None ):
70- self .enable_mc = enable_mc
71- self .enable_multi_cluster = enable_multi_cluster
72- self .conf_path = conf_path
73-
74- def Get (self , filepath ):
75- with open (filepath , 'rb' ) as f :
76- content = f .read ()
77- return content
78-
79- def put (self ):
80- pass
55+ PETREL_CLIENT_AVAILABLE = True
56+ except (ImportError , ModuleNotFoundError ):
57+ PETREL_CLIENT_AVAILABLE = False
8158
82- def delete (self ):
83- pass
8459
85- def contains ( self ) :
86- pass
60+ class MockPetrelClient :
61+ """Mock PetrelClient for testing when petrel_client is not available."""
8762
88- def isdir (self ):
89- pass
63+ def __init__ (self ,
64+ enable_mc = True ,
65+ enable_multi_cluster = False ,
66+ conf_path = None ):
67+ self .enable_mc = enable_mc
68+ self .enable_multi_cluster = enable_multi_cluster
69+ self .conf_path = conf_path
70+
71+ def Get (self , filepath ):
72+ with open (filepath , 'rb' ) as f :
73+ content = f .read ()
74+ return content
75+
76+ def put (self , filepath , content ):
77+ pass
78+
79+ def delete (self , filepath ):
80+ pass
81+
82+ def contains (self , filepath ):
83+ pass
84+
85+ def isdir (self , filepath ):
86+ pass
87+
88+ def list (self , dir_path ):
89+ for entry in os .scandir (dir_path ):
90+ if entry .name .startswith ('.' ):
91+ continue
92+ if entry .is_file ():
93+ yield entry .name
94+ elif entry .is_dir ():
95+ yield entry .name + '/'
9096
91- def list (self , dir_path ):
92- for entry in os .scandir (dir_path ):
93- if not entry .name .startswith ('.' ) and entry .is_file ():
94- yield entry .name
95- elif osp .isdir (entry .path ):
96- yield entry .name + '/'
9797
98- @contextmanager
99- def delete_and_reset_method (obj , method ):
98+ @contextmanager
99+ def delete_and_reset_method (obj , method ):
100+ if hasattr (obj , '_mock_methods' ) or str (type (obj ).__name__ ) == 'MagicMock' :
101+ method_obj = deepcopy (getattr (obj , method ))
102+ try :
103+ delattr (obj , method )
104+ yield
105+ finally :
106+ setattr (obj , method , method_obj )
107+ else :
100108 method_obj = deepcopy (getattr (type (obj ), method ))
101109 try :
102110 delattr (type (obj ), method )
103111 yield
104112 finally :
105113 setattr (type (obj ), method , method_obj )
106114
107- @patch ('petrel_client.client.Client' , MockPetrelClient )
115+
116+ if not PETREL_CLIENT_AVAILABLE :
117+ # Define the test class that uses mocking when
118+ # petrel_client is not available
119+
108120 class TestPetrelBackend (TestCase ):
109121
110122 @classmethod
@@ -118,6 +130,24 @@ def setUpClass(cls):
118130 cls .expected_dir = 's3://user/data'
119131 cls .expected_path = f'{ cls .expected_dir } /test.jpg'
120132
133+ def setUp (self ):
134+ # Mock petrel_client for each test
135+ self .mock_petrel_client = MagicMock ()
136+ self .mock_client_module = MagicMock ()
137+ self .mock_client_module .Client = MockPetrelClient
138+ self .mock_petrel_client .client = self .mock_client_module
139+
140+ self .patcher_petrel = patch .dict (
141+ 'sys.modules' , {
142+ 'petrel_client' : self .mock_petrel_client ,
143+ 'petrel_client.client' : self .mock_client_module
144+ })
145+ self .patcher_petrel .start ()
146+
147+ def tearDown (self ):
148+ # Clean up the mock
149+ self .patcher_petrel .stop ()
150+
121151 def test_name (self ):
122152 backend = PetrelBackend ()
123153 self .assertEqual (backend .name , 'PetrelBackend' )
@@ -563,6 +593,7 @@ def test_generate_presigned_url(self):
563593 pass
564594
565595else :
596+ # Define the test class that uses real petrel_client when available
566597
567598 class TestPetrelBackend (TestCase ): # type: ignore
568599
0 commit comments