Skip to content

Chore: Requests handling enhancements #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions automated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@
# Fake modules to avoid import errors

requests = type(sys)("requests")
requests_adapters = type(sys)("requests.adapters")
requests.__dict__["Response"] = type(
"Response", (), {"__module__": "requests"}
)
requests.__dict__["adapters"] = requests_adapters
requests_adapters.__dict__["HTTPAdapter"] = type(
"HTTPAdapter", (), {"__module__": "requests.adapters"}
)
requests_adapters.__dict__["Retry"] = type(
"Retry", (), {"__module__": "requests.adapters"}
)

sys.modules["requests"] = requests
sys.modules["requests.adapters"] = requests_adapters
sys.modules["unidecode"] = type(sys)("unidecode")

import ayon_api # noqa: E402
Expand Down
238 changes: 133 additions & 105 deletions ayon_api/server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
HTTPStatus = None

import requests
from requests.adapters import HTTPAdapter, Retry
try:
# This should be used if 'requests' have it available
from requests.exceptions import JSONDecodeError as RequestsJSONDecodeError
Expand Down Expand Up @@ -476,6 +477,9 @@ def __init__(
if not base_url:
raise ValueError(f"Invalid server URL {str(base_url)}")

self._session = None
self._session_handlers = {}

base_url = base_url.rstrip("/")
self._base_url: str = base_url
self._rest_url: str = f"{base_url}/api"
Expand Down Expand Up @@ -522,17 +526,6 @@ def __init__(

self._graphql_allows_data_in_query = None

self._session = None

self._base_functions_mapping = {
RequestTypes.get: requests.get,
RequestTypes.post: requests.post,
RequestTypes.put: requests.put,
RequestTypes.patch: requests.patch,
RequestTypes.delete: requests.delete
}
self._session_functions_mapping = {}

# Attributes cache
self._attributes_schema = None
self._entity_type_attributes_cache = {}
Expand Down Expand Up @@ -674,7 +667,14 @@ def set_max_retries(self, max_retries: Optional[int]):
"""
if max_retries is None:
max_retries = self.get_default_max_retries()
self._max_retries = int(max_retries)
max_retries = int(max_retries)
if max_retries < 0:
max_retries = 0
if max_retries == self._max_retries:
return
self._max_retries = max_retries
for handler in self._session_handlers.values():
handler.max_retries = Retry.from_int(max_retries)

timeout = property(get_timeout, set_timeout)
max_retries = property(get_max_retries, set_max_retries)
Expand Down Expand Up @@ -996,27 +996,18 @@ def create_session(
# Validate token before session creation
self.validate_token()

session = requests.Session()
session.cert = self._cert
session.verify = self._ssl_verify
session.headers.update(self.get_headers())

self._session_functions_mapping = {
RequestTypes.get: session.get,
RequestTypes.post: session.post,
RequestTypes.put: session.put,
RequestTypes.patch: session.patch,
RequestTypes.delete: session.delete
}
session, handlers = self._create_new_session()

self._session = session
self._session_handlers = handlers

def close_session(self):
if self._session is None:
return

session = self._session
self._session = None
self._session_functions_mapping = {}
self._session_handlers = {}
session.close()

def _update_session_headers(self):
Expand Down Expand Up @@ -1340,12 +1331,35 @@ def logout(self, soft: bool = False):
def _logout(self):
logout_from_server(self._base_url, self._access_token)

def _do_rest_request(self, function, url, **kwargs):
def _do_rest_request(self, request_type, url, **kwargs):
"""

Args:
request_type (RequestType): Request type.
url (str): Request url.
max_retries (int): Does affect only connection issues or
when session is not created.
**kwargs:

Returns:
RestApiResponse: Response.

Raises:
ConnectionRefusedError: When connection is refused.
requests.exceptions.Timeout: When connection timed out.
requests.exceptions.ConnectionError: When connection error
happens.

"""
kwargs.setdefault("timeout", self.timeout)
max_retries = kwargs.get("max_retries", self.max_retries)
if max_retries < 1:
max_retries = 1
if self._session is None:

close_session = False
session = self._session
max_retries = kwargs.get("max_retries")
if max_retries is None:
max_retries = self.max_retries

if session is None:
# Validate token if was not yet validated
# - ignore validation if we're in middle of
# validation
Expand All @@ -1355,65 +1369,49 @@ def _do_rest_request(self, function, url, **kwargs):
):
self.validate_token()

if "headers" not in kwargs:
kwargs["headers"] = self.get_headers()

if isinstance(function, RequestType):
function = self._base_functions_mapping[function]

elif isinstance(function, RequestType):
function = self._session_functions_mapping[function]
headers = kwargs.get("headers")
close_session = True
session, _ = self._create_new_session(
max_retries=max_retries, headers=headers
)

response = None
new_response = None
for retry_idx in reversed(range(max_retries)):
try:
response = function(url, **kwargs)
break

except ConnectionRefusedError:
if retry_idx == 0:
self.log.warning(
"Connection error happened.", exc_info=True
)

# Server may be restarting
new_response = RestApiResponse(
None,
{
"detail": (
"Unable to connect the server. Connection refused"
)
}
)

except requests.exceptions.Timeout:
# Connection timed out
new_response = RestApiResponse(
None,
{"detail": "Connection timed out."}
)
if max_retries < 1:
max_retries = 1

except requests.exceptions.ConnectionError:
# Log warning only on last attempt
if retry_idx == 0:
self.log.warning(
"Connection error happened.", exc_info=True
try:
for retry_idx in reversed(range(max_retries)):
try:
response = session.request(
request_type.name, url, **kwargs
)

new_response = RestApiResponse(
None,
{
"detail": (
"Unable to connect the server. Connection error"
break

except (
# These are 'ConnectionError' but it doesn't make sense
# to retry
requests.exceptions.ProxyError,
requests.exceptions.SSLError,
):
raise

except (
ConnectionRefusedError,
requests.exceptions.ConnectionError
):
# Log warning only on last attempt
if retry_idx == 0:
self.log.warning(
"Connection error happened.", exc_info=True
)
}
)
raise

time.sleep(0.1)
time.sleep(0.1)

if new_response is not None:
return new_response
finally:
if close_session:
session.close()

content_type = response.headers.get("Content-Type")
if content_type == "application/json":
Expand All @@ -1434,6 +1432,28 @@ def _do_rest_request(self, function, url, **kwargs):
self.log.debug(f"Response {str(new_response)}")
return new_response

def _create_new_session(self, max_retries=None, headers=None):
if max_retries is None:
max_retries = self.max_retries
if max_retries < 0:
max_retries = 0

if headers is None:
headers = self.get_headers()

session = requests.Session()
session.cert = self._cert
session.verify = self._ssl_verify
session.headers.update(headers)
handlers = {
"http://": HTTPAdapter(max_retries=max_retries),
"https://": HTTPAdapter(max_retries=max_retries),
}
for prefix, adapter in handlers.items():
session.mount(prefix, adapter)

return session, handlers

def raw_post(self, entrypoint: str, **kwargs):
url = self._endpoint_to_url(entrypoint)
self.log.debug(f"Executing [POST] {url}")
Expand Down Expand Up @@ -2142,18 +2162,22 @@ def _download_file_to_stream(
self, url: str, stream, chunk_size, progress
):
kwargs = {"stream": True}
if self._session is None:
kwargs["headers"] = self.get_headers()
get_func = self._base_functions_mapping[RequestTypes.get]
else:
get_func = self._session_functions_mapping[RequestTypes.get]
session = self._session
close_session = False
if session is None:
close_session = True
session, _ = self._create_new_session()

with get_func(url, **kwargs) as response:
response.raise_for_status()
progress.set_content_size(response.headers["Content-length"])
for chunk in response.iter_content(chunk_size=chunk_size):
stream.write(chunk)
progress.add_transferred_chunk(len(chunk))
try:
with session.request("GET", url, **kwargs) as response:
response.raise_for_status()
progress.set_content_size(response.headers["Content-length"])
for chunk in response.iter_content(chunk_size=chunk_size):
stream.write(chunk)
progress.add_transferred_chunk(len(chunk))
finally:
if close_session:
session.close()

def download_file_to_stream(
self,
Expand Down Expand Up @@ -2317,25 +2341,29 @@ def _upload_file(

"""
if request_type is None:
request_type = RequestTypes.put
request_type = "PUT"
elif isinstance(request_type, RequestType):
request_type = request_type.name

session = self._session
close_session = False
if self._session is None:
headers = kwargs.setdefault("headers", {})
for key, value in self.get_headers().items():
if key not in headers:
headers[key] = value
post_func = self._base_functions_mapping[request_type]
else:
post_func = self._session_functions_mapping[request_type]
close_session = True
session, _ = self._create_new_session()

if not chunk_size:
chunk_size = self.default_upload_chunk_size

response = post_func(
url,
data=self._upload_chunks_iter(stream, progress, chunk_size),
**kwargs
)
try:
response = session.request(
request_type,
url,
data=self._upload_chunks_iter(stream, progress, chunk_size),
**kwargs
)
finally:
if close_session:
session.close()

response.raise_for_status()
return response
Expand Down