Skip to content

Commit

Permalink
move fn to a method on client cfg. add a few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt committed Sep 25, 2024
1 parent be922c9 commit 19200c3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 20 deletions.
4 changes: 2 additions & 2 deletions cognite/client/_api/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from cognite.client.data_classes.filters import _BASIC_FILTERS, Filter, _validate_filter
from cognite.client.exceptions import CogniteAPIError
from cognite.client.utils._auxiliary import get_cdf_cluster, split_into_chunks, split_into_n_parts
from cognite.client.utils._auxiliary import split_into_chunks, split_into_n_parts
from cognite.client.utils._concurrency import ConcurrencySettings, classify_error, execute_tasks
from cognite.client.utils._identifier import IdentifierSequence
from cognite.client.utils._importing import import_as_completed
Expand Down Expand Up @@ -1483,7 +1483,7 @@ def _raise_latest_exception(self, exceptions: list[Exception], successful: list[
unknown=AssetList(self.unknown),
failed=AssetList(self.failed),
unwrap_fn=op.attrgetter("external_id"),
cluster=get_cdf_cluster(self.assets_api._config),
cluster=self.assets_api._config.cdf_cluster,
)
err_message = "One or more errors happened during asset creation. Latest error:"
if isinstance(latest_exception, CogniteAPIError):
Expand Down
5 changes: 2 additions & 3 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from cognite.client.exceptions import CogniteAPIError, CogniteNotFoundError
from cognite.client.utils import _json
from cognite.client.utils._auxiliary import (
get_cdf_cluster,
get_current_sdk_version,
get_user_agent,
interpolate_and_url_encode,
Expand Down Expand Up @@ -1153,7 +1152,7 @@ def _upsert_multiple(
successful=successful,
failed=failed,
unknown=unknown,
cluster=get_cdf_cluster(self._config),
cluster=self._config.cdf_cluster,
)
# Need to retrieve the successful updated items from the first call.
successful_resources: T_CogniteResourceList | None = None
Expand Down Expand Up @@ -1325,7 +1324,7 @@ def _raise_api_error(self, res: Response, payload: dict) -> NoReturn:
missing=missing,
duplicated=duplicated,
extra=extra,
cluster=get_cdf_cluster(self._config),
cluster=self._config.cdf_cluster,
)

def _log_request(self, res: Response, **kwargs: Any) -> None:
Expand Down
14 changes: 12 additions & 2 deletions cognite/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import getpass
import pprint
import re
import warnings
from contextlib import suppress
from typing import Any
Expand Down Expand Up @@ -184,6 +185,8 @@ def _validate_config(self) -> None:
raise ValueError(f"Invalid value for ClientConfig.project: <{self.project}>")
if not self.base_url:
raise ValueError(f"Invalid value for ClientConfig.base_url: <{self.base_url}>")
elif self.cdf_cluster is None:
warnings.warn(f"Given base URL may be invalid, please double-check: {self.base_url!r}", UserWarning)

def __str__(self) -> str:
return pprint.pformat({"max_workers": self.max_workers, **self.__dict__}, indent=4)
Expand Down Expand Up @@ -211,7 +214,7 @@ def default(
client_name=client_name or getpass.getuser(),
project=project,
credentials=credentials,
base_url=f"https://{cdf_cluster}.cognitedata.com/",
base_url=f"https://{cdf_cluster}.cognitedata.com",
)

@classmethod
Expand All @@ -233,7 +236,7 @@ def load(cls, config: dict[str, Any] | str) -> ClientConfig:
>>> config = {
... "client_name": "abcd",
... "project": "cdf-project",
... "base_url": "https://api.cognitedata.com/",
... "base_url": "https://api.cognitedata.com",
... "credentials": {
... "client_credentials": {
... "client_id": "abcd",
Expand Down Expand Up @@ -264,3 +267,10 @@ def load(cls, config: dict[str, Any] | str) -> ClientConfig:
file_transfer_timeout=loaded.get("file_transfer_timeout"),
debug=loaded.get("debug", False),
)

@property
def cdf_cluster(self) -> str | None:
# A best effort attempt to extract the cluster from the base url
if match := re.match(r"https?://([^/\.\s]+)\.cognitedata\.com(?::\d+)?(?:/|$)", self.base_url):
return match.group(1)
return None
9 changes: 0 additions & 9 deletions cognite/client/utils/_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
import platform
import warnings
from contextlib import suppress
from threading import Thread
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -32,7 +31,6 @@

if TYPE_CHECKING:
from cognite.client import CogniteClient
from cognite.client.config import ClientConfig
from cognite.client.data_classes._base import T_CogniteObject, T_CogniteResource

T = TypeVar("T")
Expand Down Expand Up @@ -145,13 +143,6 @@ def get_current_sdk_version() -> str:
return __version__


def get_cdf_cluster(config: ClientConfig) -> str | None:
# A best effort attempt to extract the cluster from the base url
with suppress(Exception):
return config.base_url.split("//")[1].split(".cognitedata")[0]
return None


@functools.lru_cache(maxsize=1)
def get_user_agent() -> str:
sdk_version = f"CognitePythonSDK/{get_current_sdk_version()}"
Expand Down
27 changes: 23 additions & 4 deletions tests/tests_unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ def test_load_non_existent_attr(self):
assert global_config.max_workers != 0


class TestClientConfig:
def test_default(self):
config = {
@pytest.fixture
def client_config():
return ClientConfig.default(
**{
"project": "test-project",
"cdf_cluster": "test-cluster",
"credentials": Token("abc"),
"client_name": "test-client",
}
client_config = ClientConfig.default(**config)
)


class TestClientConfig:
def test_default(self, client_config):
assert client_config.project == "test-project"
assert client_config.base_url == "https://test-cluster.cognitedata.com"
assert isinstance(client_config.credentials, Token)
Expand All @@ -100,3 +105,17 @@ def test_load(self, credentials):
assert isinstance(client_config.credentials, Token)
assert "Authorization", "Bearer abc" == client_config.credentials.authorization_header()
assert client_config.client_name == "test-client"

@pytest.mark.parametrize("protocol", ("http", "https"))
@pytest.mark.parametrize("end", ("", "/", ":8080", "/api/v1/", ":8080/api/v1/"))
def test_extract_cdf_cluster(self, client_config, protocol, end):
for valid in ("3D", "my_clus-ter", "jazz-testing-asia-northeast1-1", "trial-00ed82e12d9cbadfe28e4"):
client_config.base_url = f"{protocol}://{valid}.cognitedata.com{end}"
assert client_config.cdf_cluster == valid

for invalid in ("", ".", "..", "huh.my_cluster."):
client_config.base_url = f"{protocol}://{valid}cognitedata.com{end}"
assert client_config.cdf_cluster is None

client_config.base_url = "invalid"
assert client_config.cdf_cluster is None

0 comments on commit 19200c3

Please sign in to comment.