Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
erlendvollset committed Sep 30, 2024
1 parent ec243b8 commit bb4aca6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
18 changes: 8 additions & 10 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
split_into_chunks,
unpack_items_in_payload,
)
from cognite.client.utils._concurrency import execute_tasks
from cognite.client.utils._concurrency import TaskExecutor, execute_tasks
from cognite.client.utils._identifier import (
Identifier,
IdentifierCore,
Expand All @@ -69,8 +69,6 @@
from cognite.client.utils.useful_types import SequenceNotStr

if TYPE_CHECKING:
from concurrent.futures import ThreadPoolExecutor

from cognite.client import CogniteClient
from cognite.client.config import ClientConfig

Expand Down Expand Up @@ -327,7 +325,7 @@ def _retrieve_multiple(
headers: dict[str, Any] | None = None,
other_params: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResource | None: ...
Expand All @@ -343,7 +341,7 @@ def _retrieve_multiple(
headers: dict[str, Any] | None = None,
other_params: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResourceList: ...
Expand All @@ -358,7 +356,7 @@ def _retrieve_multiple(
headers: dict[str, Any] | None = None,
other_params: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResourceList | T_CogniteResource | None:
Expand Down Expand Up @@ -865,7 +863,7 @@ def _create_multiple(
extra_body_fields: dict[str, Any] | None = None,
limit: int | None = None,
input_resource_cls: type[CogniteResource] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
) -> T_CogniteResourceList: ...

Expand All @@ -881,7 +879,7 @@ def _create_multiple(
extra_body_fields: dict[str, Any] | None = None,
limit: int | None = None,
input_resource_cls: type[CogniteResource] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
) -> T_WritableCogniteResource: ...

Expand All @@ -899,7 +897,7 @@ def _create_multiple(
extra_body_fields: dict[str, Any] | None = None,
limit: int | None = None,
input_resource_cls: type[CogniteResource] | None = None,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
api_subversion: str | None = None,
) -> T_CogniteResourceList | T_WritableCogniteResource:
resource_path = resource_path or self._RESOURCE_PATH
Expand Down Expand Up @@ -964,7 +962,7 @@ def _delete_multiple(
headers: dict[str, Any] | None = None,
extra_body_fields: dict[str, Any] | None = None,
returns_items: bool = False,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
) -> list | None:
resource_path = resource_path or self._RESOURCE_PATH
tasks = [
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/data_classes/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def _count_subtree(xid: str, count: int = 0) -> int:
counts.sort(key=lambda args: -args[-1])
# The count for the fictitious "root of roots" is just len(assets), so we remove it:
(count_dct := dict(counts)).pop(None, None)
return count_dct
return cast(dict[str, int], count_dct)

def _on_error(self, on_error: Literal["ignore", "warn", "raise"], message: str) -> None:
if on_error == "warn":
Expand Down
16 changes: 10 additions & 6 deletions cognite/client/utils/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _raise_duplicated_error(self, unwrap_fn: Callable, **task_lists: list) -> No


class TaskExecutor(Protocol):
def submit(self, fn: Callable[..., T_Result], *args: Any, **kwargs: Any) -> TaskFuture[T_Result]: ...
def submit(self, fn: Callable[..., T_Result], /, *args: Any, **kwargs: Any) -> TaskFuture[T_Result]: ...


class TaskFuture(Protocol[T_Result]):
Expand Down Expand Up @@ -160,7 +160,7 @@ def empty(self) -> Literal[True]:

self._work_queue = AlwaysEmpty()

def submit(self, fn: Callable[..., T_Result], *args: Any, **kwargs: Any) -> SyncFuture:
def submit(self, fn: Callable[..., T_Result], /, *args: Any, **kwargs: Any) -> SyncFuture:
return SyncFuture(fn, *args, **kwargs)


Expand Down Expand Up @@ -219,16 +219,16 @@ def get_thread_pool_executor_or_raise(cls, max_workers: int) -> ThreadPoolExecut
)

@classmethod
def get_data_modeling_executor(cls) -> ThreadPoolExecutor:
def get_data_modeling_executor(cls) -> TaskExecutor:
"""
The data modeling backend has different concurrency limits compared with the rest of CDF.
Thus, we use a dedicated executor for these endpoints to match the backend requirements.
Returns:
ThreadPoolExecutor: The data modeling executor.
TaskExecutor: The data modeling executor.
"""
if cls.uses_mainthread():
return cls.get_mainthread_executor() # type: ignore [return-value]
return cls.get_mainthread_executor()

global _DATA_MODELING_THREAD_POOL_EXECUTOR_SINGLETON
try:
Expand Down Expand Up @@ -276,7 +276,7 @@ def execute_tasks(
tasks: Sequence[tuple | dict],
max_workers: int,
fail_fast: bool = False,
executor: ThreadPoolExecutor | None = None,
executor: TaskExecutor | None = None,
) -> TasksSummary:
"""
Will use a default executor if one is not passed explicitly. The default executor type uses a thread pool but can
Expand All @@ -286,6 +286,10 @@ def execute_tasks(
"""
if ConcurrencySettings.uses_mainthread() or isinstance(executor, MainThreadExecutor):
return execute_tasks_serially(func, tasks, fail_fast)
elif isinstance(executor, ThreadPoolExecutor):
pass
else:
raise TypeError("executor must be a ThreadPoolExecutor or MainThreadExecutor")

executor = executor or ConcurrencySettings.get_thread_pool_executor(max_workers)
task_order = [id(task) for task in tasks]
Expand Down
8 changes: 4 additions & 4 deletions cognite/client/utils/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

@functools.lru_cache(1)
def get_zoneinfo_utc() -> ZoneInfo:
return ZoneInfo("UTC") # type: ignore [abstract]
return ZoneInfo("UTC")


def parse_str_timezone_offset(tz: str) -> timezone:
Expand All @@ -96,7 +96,7 @@ def parse_str_timezone_offset(tz: str) -> timezone:

def parse_str_timezone(tz: str) -> timezone | ZoneInfo:
try:
return ZoneInfo(tz) # type: ignore [abstract]
return ZoneInfo(tz)
except ZoneInfoNotFoundError:
try:
return parse_str_timezone_offset(tz)
Expand Down Expand Up @@ -697,7 +697,7 @@ def _timezones_are_equal(start_tz: tzinfo, end_tz: tzinfo) -> bool:
return True
with suppress(ValueError, ZoneInfoNotFoundError):
# ValueError is raised for non-conforming keys (ZoneInfoNotFoundError is self-explanatory)
if ZoneInfo(str(start_tz)) is ZoneInfo(str(end_tz)): # type: ignore [abstract]
if ZoneInfo(str(start_tz)) is ZoneInfo(str(end_tz)):
return True
return False

Expand All @@ -717,7 +717,7 @@ def validate_timezone(start: datetime, end: datetime) -> ZoneInfo:

pd = local_import("pandas")
if isinstance(start, pd.Timestamp):
return ZoneInfo(str(start_tz)) # type: ignore [abstract]
return ZoneInfo(str(start_tz))

raise ValueError("Only tz-aware pandas.Timestamp and datetime (must be using ZoneInfo) are supported.")

Expand Down

0 comments on commit bb4aca6

Please sign in to comment.