1
1
from typing import Any , AsyncContextManager , AsyncIterator , Dict , List , Optional , Set , Tuple , Type , Union
2
2
from types import TracebackType
3
3
4
+ import aiohttp
4
5
import abc
5
6
import re
6
7
import os
@@ -372,7 +373,13 @@ async def wrapped(self: 'AzureAsyncFS', url, *args, **kwargs):
372
373
class AzureAsyncFS (AsyncFS ):
373
374
PATH_REGEX = re .compile ('/(?P<container>[^/]+)(?P<name>.*)' )
374
375
375
- def __init__ (self , * , credential_file : Optional [str ] = None , credentials : Optional [AzureCredentials ] = None ):
376
+ def __init__ (
377
+ self ,
378
+ * ,
379
+ credential_file : Optional [str ] = None ,
380
+ credentials : Optional [AzureCredentials ] = None ,
381
+ timeout : Optional [Union [int , float , aiohttp .ClientTimeout ]] = None ,
382
+ ):
376
383
if credentials is None :
377
384
scopes = ['https://storage.azure.com/.default' ]
378
385
if credential_file is not None :
@@ -382,6 +389,16 @@ def __init__(self, *, credential_file: Optional[str] = None, credentials: Option
382
389
elif credential_file is not None :
383
390
raise ValueError ('credential and credential_file cannot both be defined' )
384
391
392
+ if isinstance (timeout , aiohttp .ClientTimeout ):
393
+ self .read_timeout = timeout .sock_read or timeout .total or 5
394
+ self .connection_timeout = timeout .sock_connect or timeout .connect or timeout .total or 5
395
+ elif isinstance (timeout , (int , float )):
396
+ self .read_timeout = timeout
397
+ self .connection_timeout = timeout
398
+ else :
399
+ self .read_timeout = 5
400
+ self .connection_timeout = 5
401
+
385
402
self ._credential = credentials .credential
386
403
self ._blob_service_clients : Dict [Tuple [str , str , Union [AzureCredentials , str , None ]], BlobServiceClient ] = {}
387
404
@@ -482,8 +499,8 @@ def get_blob_service_client(self, account: str, container: str, token: Optional[
482
499
self ._blob_service_clients [k ] = BlobServiceClient (
483
500
f'https://{ account } .blob.core.windows.net' ,
484
501
credential = credential , # type: ignore
485
- connection_timeout = 5 ,
486
- read_timeout = 5 ,
502
+ connection_timeout = self . connection_timeout ,
503
+ read_timeout = self . read_timeout ,
487
504
)
488
505
return self ._blob_service_clients [k ]
489
506
0 commit comments