Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback to global_config to allow tracking of one's own SDK usage #1469

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions cognite/client/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import socket
import time
from http import cookiejar
from typing import Any, Callable, Literal, MutableMapping
from typing import Any, Callable, ClassVar, Literal, MutableMapping

import requests
import requests.adapters
import urllib3

from cognite.client.config import global_config
from cognite.client.exceptions import CogniteConnectionError, CogniteConnectionRefused, CogniteReadTimeout
from cognite.client.utils._api_usage import RequestDetails


class BlockAll(cookiejar.CookiePolicy):
Expand Down Expand Up @@ -95,6 +96,18 @@ def should_retry(self, status_code: int | None) -> bool:


class HTTPClient:
_TIMEOUT_EXCEPTIONS: ClassVar[tuple[type[Exception], ...]] = (
socket.timeout,
urllib3.exceptions.ReadTimeoutError,
requests.exceptions.ReadTimeout,
)
_CONNECTION_EXCEPTIONS: ClassVar[tuple[type[Exception], ...]] = (
ConnectionError,
urllib3.exceptions.ConnectionError,
urllib3.exceptions.ConnectTimeoutError,
requests.exceptions.ConnectionError,
)

def __init__(
self,
config: HTTPClientConfig,
Expand All @@ -113,60 +126,56 @@ def request(self, method: str, url: str, **kwargs: Any) -> requests.Response:
last_status = None
while True:
try:
res = self._do_request(method=method, url=url, **kwargs)
res = self._make_request(method=method, url=url, **kwargs)
if global_config.usage_tracking_callback:
# We use the RequestDetails as an indirection to avoid the user mutating the request:
global_config.usage_tracking_callback(RequestDetails.from_response(res))

last_status = res.status_code
retry_tracker.status += 1
if not retry_tracker.should_retry(status_code=last_status):
# Cache .json() return value in order to avoid redecoding JSON if called multiple times
res.json = functools.lru_cache(maxsize=1)(res.json) # type: ignore[assignment]
return res
except CogniteReadTimeout as e:

except CogniteReadTimeout:
retry_tracker.read += 1
if not retry_tracker.should_retry(status_code=last_status):
raise e
except CogniteConnectionError as e:
raise

except CogniteConnectionError:
retry_tracker.connect += 1
if not retry_tracker.should_retry(status_code=last_status):
raise e
raise

# During a backoff loop, our credentials might expire, so we check and maybe refresh:
time.sleep(retry_tracker.get_backoff_time())
if headers is not None:
# TODO: Refactoring needed to make this "prettier"
self.refresh_auth_header(headers)

def _do_request(self, method: str, url: str, **kwargs: Any) -> requests.Response:
def _make_request(self, method: str, url: str, **kwargs: Any) -> requests.Response:
"""requests/urllib3 adds 2 or 3 layers of exceptions on top of built-in networking exceptions.

Sometimes the appropriate built-in networking exception is not in the context, sometimes the requests
exception is not in the context, so we need to check for the appropriate built-in exceptions,
urllib3 exceptions, and requests exceptions.
"""
try:
res = self.session.request(method=method, url=url, **kwargs)
return res
return self.session.request(method=method, url=url, **kwargs)
except Exception as e:
if self._any_exception_in_context_isinstance(
e, (socket.timeout, urllib3.exceptions.ReadTimeoutError, requests.exceptions.ReadTimeout)
):
if self._any_exception_in_context_isinstance(e, self._TIMEOUT_EXCEPTIONS):
raise CogniteReadTimeout from e
if self._any_exception_in_context_isinstance(
e,
(
ConnectionError,
urllib3.exceptions.ConnectionError,
urllib3.exceptions.ConnectTimeoutError,
requests.exceptions.ConnectionError,
),
):

if self._any_exception_in_context_isinstance(e, self._CONNECTION_EXCEPTIONS):
if self._any_exception_in_context_isinstance(e, ConnectionRefusedError):
raise CogniteConnectionRefused from e
raise CogniteConnectionError from e
raise e
raise

@classmethod
def _any_exception_in_context_isinstance(
cls, exc: BaseException, exc_types: tuple[type[BaseException], ...] | type[BaseException]
cls, exc: BaseException, exc_types: tuple[type[Exception], ...] | type[Exception]
) -> bool:
"""requests does not use the "raise ... from ..." syntax, so we need to access the underlying exceptions using
the __context__ attribute.
Expand Down
5 changes: 5 additions & 0 deletions cognite/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import getpass
import pprint
from contextlib import suppress
from typing import TYPE_CHECKING, Callable

from cognite.client._version import __api_subversion__
from cognite.client.credentials import CredentialProvider

if TYPE_CHECKING:
from cognite.client.utils._api_usage import RequestDetails


class GlobalConfig:
"""Global configuration object
Expand Down Expand Up @@ -39,6 +43,7 @@ def __init__(self) -> None:
self.max_connection_pool_size: int = 50
self.disable_ssl: bool = False
self.proxies: dict[str, str] | None = {}
self.usage_tracking_callback: Callable[[RequestDetails], None] | None = None


global_config = GlobalConfig()
Expand Down
70 changes: 70 additions & 0 deletions cognite/client/utils/_api_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import timedelta
from typing import TYPE_CHECKING

from typing_extensions import Self

if TYPE_CHECKING:
from requests import Response


@dataclass
class RequestDetails:
"""
SDK users wanting to track their own API usage (with the SDK) - for metrics or surveilance, may set
a callback on the global_config object that will then receive instances of this class, one per
actual request.

Note that due to concurrency, the sum of time_elapsed is (much) greater than the actual wall clock
waiting time.

Args:
url (str): The API endpoint that was called.
status_code (int): The status code of the API response.
content_length (int | None): The size of the response if available.
time_elapsed (timedelta): The amount of time elapsed between sending the request and the arrival of the response.

Example:

Store info on the last 1000 requests made:

>>> from cognite.client.config import global_config
>>> from collections import deque
>>> usage_info = deque(maxlen=1000)
>>> global_config.usage_tracking_callback = usage_info.append

Store the time elapsed per request, grouped per API endpoint, for all requests:

>>> from collections import defaultdict
>>> usage_info = defaultdict(list)
>>> def callback(details):
... usage_info[details.url].append(details.time_elapsed)
>>> global_config.usage_tracking_callback = callback

Tip:
Ensure the provided callback is fast to execute, or it might negatively impact the overall performance.

Warning:
Your provided callback function will be called from several different threads and thus any operation
executed must be thread-safe (or while holding a thread lock, not recommended). Best practise is to dump
the required details to a container like in the examples above, then inspect those separately in your code.
"""

url: str
status_code: int
content_length: int | None
time_elapsed: timedelta

@classmethod
def from_response(cls, resp: Response) -> Self:
# If header not set, we don't report the size. We could do len(resp.content), but
# for streaming requests this would fetch everything into memory...
content_length = int(resp.headers.get("Content-length", 0)) or None
return cls(
url=resp.url,
status_code=resp.status_code,
content_length=content_length,
time_elapsed=resp.elapsed,
)
Loading