Skip to content

Commit af1042c

Browse files
committed
async_ssh: Use async.Semaphore instead of manual locking
1 parent d8567fc commit af1042c

File tree

1 file changed

+47
-52
lines changed

1 file changed

+47
-52
lines changed

src/aiida/transports/plugins/ssh_async.py

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs):
133133
# a computer with core.ssh_async transport plugin should be configured before any instantiation.
134134
self.machine = kwargs.pop('host', kwargs.pop('machine'))
135135
self._max_io_allowed = kwargs.pop('max_io_allowed', self._DEFAULT_max_io_allowed)
136+
self._semaphore = asyncio.Semaphore(self._max_io_allowed)
136137
self.script_before = kwargs.pop('script_before', 'None')
137138

138139
if kwargs.get('backend') == 'openssh':
@@ -145,20 +146,10 @@ def __init__(self, *args, **kwargs):
145146

146147
self.async_backend = _AsyncSSH(self.machine, self.logger, self._bash_command_str) # type: ignore[assignment]
147148

148-
self._concurrent_io = 0
149-
150149
@property
151150
def max_io_allowed(self):
152151
return self._max_io_allowed
153152

154-
async def _lock(self, sleep_time=0.5):
155-
while self._concurrent_io >= self.max_io_allowed:
156-
await asyncio.sleep(sleep_time)
157-
self._concurrent_io += 1
158-
159-
async def _unlock(self):
160-
self._concurrent_io -= 1
161-
162153
async def open_async(self):
163154
"""Open the transport.
164155
This plugin supports running scripts before and during the connection.
@@ -316,14 +307,17 @@ async def getfile_async(
316307
if os.path.isfile(localpath) and not overwrite:
317308
raise OSError('Destination already exists: not overwriting it')
318309

319-
try:
320-
await self._lock()
321-
await self.async_backend.get(
322-
remotepath=remotepath, localpath=localpath, dereference=dereference, preserve=preserve, recursive=False
323-
)
324-
await self._unlock()
325-
except OSError as exc:
326-
raise OSError(f'Error while downloading file {remotepath}: {exc}')
310+
async with self._semaphore:
311+
try:
312+
await self.async_backend.get(
313+
remotepath=remotepath,
314+
localpath=localpath,
315+
dereference=dereference,
316+
preserve=preserve,
317+
recursive=False,
318+
)
319+
except OSError as exc:
320+
raise OSError(f'Error while downloading file {remotepath}: {exc}')
327321

328322
async def gettree_async(
329323
self,
@@ -383,19 +377,18 @@ async def gettree_async(
383377

384378
content_list = await self.listdir_async(remotepath)
385379
for content_ in content_list:
386-
try:
387-
await self._lock()
388-
parentpath = str(PurePath(remotepath) / content_)
389-
await self.async_backend.get(
390-
remotepath=parentpath,
391-
localpath=localpath,
392-
dereference=dereference,
393-
preserve=preserve,
394-
recursive=True,
395-
)
396-
await self._unlock()
397-
except OSError as exc:
398-
raise OSError(f'Error while downloading file {parentpath}: {exc}')
380+
parentpath = str(PurePath(remotepath) / content_)
381+
async with self._semaphore:
382+
try:
383+
await self.async_backend.get(
384+
remotepath=parentpath,
385+
localpath=localpath,
386+
dereference=dereference,
387+
preserve=preserve,
388+
recursive=True,
389+
)
390+
except OSError as exc:
391+
raise OSError(f'Error while downloading file {parentpath}: {exc}')
399392

400393
async def put_async(
401394
self,
@@ -528,14 +521,17 @@ async def putfile_async(
528521
if await self.isfile_async(remotepath) and not overwrite:
529522
raise OSError('Destination already exists: not overwriting it')
530523

531-
try:
532-
await self._lock()
533-
await self.async_backend.put(
534-
localpath=localpath, remotepath=remotepath, dereference=dereference, preserve=preserve, recursive=False
535-
)
536-
await self._unlock()
537-
except OSError as exc:
538-
raise OSError(f'Error while uploading file {localpath}: {exc}')
524+
async with self._semaphore:
525+
try:
526+
await self.async_backend.put(
527+
localpath=localpath,
528+
remotepath=remotepath,
529+
dereference=dereference,
530+
preserve=preserve,
531+
recursive=False,
532+
)
533+
except OSError as exc:
534+
raise OSError(f'Error while uploading file {localpath}: {exc}')
539535

540536
async def puttree_async(
541537
self,
@@ -598,19 +594,18 @@ async def puttree_async(
598594
# Or to put and rename the parent folder at the same time
599595
content_list = os.listdir(localpath)
600596
for content_ in content_list:
601-
try:
602-
await self._lock()
603-
parentpath = str(PurePath(localpath) / content_)
604-
await self.async_backend.put(
605-
localpath=parentpath,
606-
remotepath=remotepath,
607-
dereference=dereference,
608-
preserve=preserve,
609-
recursive=True,
610-
)
611-
await self._unlock()
612-
except OSError as exc:
613-
raise OSError(f'Error while uploading file {parentpath}: {exc}')
597+
parentpath = str(PurePath(localpath) / content_)
598+
async with self._semaphore:
599+
try:
600+
await self.async_backend.put(
601+
localpath=parentpath,
602+
remotepath=remotepath,
603+
dereference=dereference,
604+
preserve=preserve,
605+
recursive=True,
606+
)
607+
except OSError as exc:
608+
raise OSError(f'Error while uploading file {parentpath}: {exc}')
614609

615610
async def copy_async(
616611
self,

0 commit comments

Comments
 (0)