Skip to content

Commit 13cb318

Browse files
authored
Add typing to aiida.tools.graph module (#7036)
1 parent cfbbd68 commit 13cb318

File tree

18 files changed

+302
-253
lines changed

18 files changed

+302
-253
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ repos:
141141
src/aiida/tools/data/orbital/orbital.py|
142142
src/aiida/tools/data/orbital/realhydrogen.py|
143143
src/aiida/tools/dbimporters/plugins/.*|
144-
src/aiida/tools/graph/age_entities.py|
145-
src/aiida/tools/graph/age_rules.py|
146-
src/aiida/tools/graph/deletions.py|
147-
src/aiida/tools/graph/graph_traversers.py|
148144
src/aiida/transports/cli.py|
149145
src/aiida/transports/plugins/local.py|
150146
src/aiida/transports/plugins/ssh.py|

docs/source/nitpick-exceptions

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ py:class FC
123123
py:class P
124124
py:class N
125125
py:class T
126+
py:class _ContainerTypes
127+
py:class _NodeOrGroupCls
126128
py:class aiida.cmdline.params.types.choice.T
127129
py:class aiida.common.lang.T
128130
py:class aiida.engine.processes.functions.N
@@ -162,6 +164,7 @@ py:class requests.Response
162164
py:class concurrent.futures._base.TimeoutError
163165
py:class concurrent.futures._base.Future
164166

167+
py:class BackupManager
165168
py:class disk_objectstore.utils.LazyOpener
166169
py:class disk_objectstore.backup_utils.BackupManager
167170

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ disallow_subclassing_any = true
353353
disallow_untyped_calls = true
354354
disallow_untyped_defs = true
355355
module = [
356+
'aiida.tools.graph.*',
356357
'aiida.cmdline.params.*',
357358
'aiida.cmdline.groups.*',
358359
'aiida.tools.query.*'

src/aiida/cmdline/params/types/path.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,11 @@
1616

1717
import click
1818

19-
if t.TYPE_CHECKING:
20-
try:
21-
from typing import TypeAlias
22-
except ImportError:
23-
from typing_extensions import TypeAlias
24-
2519
__all__ = ('AbsolutePathParamType', 'FileOrUrl', 'PathOrUrl')
2620

2721
URL_TIMEOUT_SECONDS = 10
2822

29-
PathType: TypeAlias = 'str | bytes | os.PathLike[str]'
23+
PathType = t.Union[str, bytes, os.PathLike[str]]
3024

3125

3226
def check_timeout_seconds(timeout_seconds: float) -> int:

src/aiida/common/lang.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Any, Callable, Generic, TypeVar
1414

1515

16-
def isidentifier(identifier):
16+
def isidentifier(identifier: str) -> bool:
1717
"""Return whether the given string is a valid python identifier.
1818
1919
:return: boolean, True if identifier is valid, False otherwise
@@ -26,7 +26,7 @@ def isidentifier(identifier):
2626
T = TypeVar('T')
2727

2828

29-
def type_check(what: T, of_type: Any, msg: 'str | None' = None, allow_none: bool = False) -> 'T | None':
29+
def type_check(what: T, of_type: Any, msg: 'str | None' = None, allow_none: bool = False) -> T:
3030
"""Verify that object 'what' is of type 'of_type' and if not the case, raise a TypeError.
3131
3232
:param what: the object to check
@@ -37,7 +37,7 @@ def type_check(what: T, of_type: Any, msg: 'str | None' = None, allow_none: bool
3737
:return: `what` or `None`
3838
"""
3939
if allow_none and what is None:
40-
return None
40+
return what
4141

4242
if not isinstance(what, of_type):
4343
if msg is None:
@@ -50,7 +50,7 @@ def type_check(what: T, of_type: Any, msg: 'str | None' = None, allow_none: bool
5050
MethodType = TypeVar('MethodType', bound=Callable[..., Any])
5151

5252

53-
def override_decorator(check=False) -> Callable[[MethodType], MethodType]:
53+
def override_decorator(check: bool = False) -> Callable[[MethodType], MethodType]:
5454
"""Decorator to signal that a method from a base class is being overridden completely."""
5555

5656
def wrap(func: MethodType) -> MethodType:
@@ -63,18 +63,17 @@ def wrap(func: MethodType) -> MethodType:
6363
if not args:
6464
raise RuntimeError('Can only use the override decorator on member functions')
6565

66-
if check:
66+
if not check:
67+
return func
6768

68-
@functools.wraps(func)
69-
def wrapped_fn(self, *args, **kwargs):
70-
try:
71-
getattr(super(), func.__name__)
72-
except AttributeError:
73-
raise RuntimeError(f'Function {func} does not override a superclass method')
69+
@functools.wraps(func)
70+
def wrapped_fn(self, *args, **kwargs):
71+
try:
72+
getattr(super(), func.__name__)
73+
except AttributeError:
74+
raise RuntimeError(f'Function {func} does not override a superclass method')
7475

75-
return func(self, *args, **kwargs)
76-
else:
77-
wrapped_fn = func # type: ignore[assignment]
76+
return func(self, *args, **kwargs)
7877

7978
return wrapped_fn # type: ignore[return-value]
8079

@@ -84,7 +83,6 @@ def wrapped_fn(self, *args, **kwargs):
8483
override = override_decorator(check=False)
8584

8685
ReturnType = TypeVar('ReturnType')
87-
SelfType = TypeVar('SelfType')
8886

8987

9088
class classproperty(Generic[ReturnType]): # noqa: N801
@@ -95,8 +93,8 @@ class classproperty(Generic[ReturnType]): # noqa: N801
9593
instance as its first argument).
9694
"""
9795

98-
def __init__(self, getter: Callable[[SelfType], ReturnType]) -> None:
96+
def __init__(self, getter: Callable[[Any], ReturnType]) -> None:
9997
self.getter = getter
10098

101-
def __get__(self, instance: Any, owner: SelfType) -> ReturnType:
102-
return self.getter(owner) # type: ignore[arg-type]
99+
def __get__(self, instance: Any, owner: type) -> ReturnType:
100+
return self.getter(owner)

src/aiida/common/log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
import collections
13+
import collections.abc
1414
import contextlib
1515
import enum
1616
import io

src/aiida/common/typing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
# For further information on the license, see the LICENSE.txt file #
77
# For further information please visit http://www.aiida.net #
88
###########################################################################
9-
"""Module to define commonly used data structures."""
9+
"""Module to define commonly used types."""
1010

1111
from __future__ import annotations
1212

1313
import pathlib
14+
import sys
1415
from typing import Union
1516

16-
try:
17+
if sys.version_info >= (3, 11):
1718
from typing import Self
18-
except ImportError:
19+
else:
1920
from typing_extensions import Self
2021

2122
__all__ = ('FilePath', 'Self')
2223

23-
2424
FilePath = Union[str, pathlib.PurePath]

src/aiida/common/utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,13 @@
1818
import sys
1919
from collections.abc import Iterable, Iterator
2020
from datetime import datetime, timedelta
21-
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
21+
from typing import Any, Callable, TypeVar, overload
2222
from uuid import UUID
2323

2424
from aiida.common.typing import Self
2525

2626
from .lang import classproperty
2727

28-
if TYPE_CHECKING:
29-
# TypeAlias added in Python 3.10
30-
try:
31-
from typing import TypeAlias
32-
except ImportError:
33-
from typing_extensions import TypeAlias
34-
3528
T = TypeVar('T')
3629
R = TypeVar('R')
3730

@@ -421,7 +414,7 @@ def prettify(self, label: str) -> str:
421414
return self._prettifier_f(label)
422415

423416

424-
_Labels: TypeAlias = list[tuple[float, str]]
417+
_Labels = list[tuple[float, str]]
425418

426419

427420
def prettify_labels(labels: _Labels, format: str | None = None) -> _Labels:

src/aiida/orm/groups.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
import warnings
1313
from functools import cached_property
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union, cast
15+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Sequence, Tuple, Type, Union, cast
1616

1717
from aiida.common import exceptions
1818
from aiida.common.lang import classproperty, type_check
1919
from aiida.common.pydantic import MetadataField
20+
from aiida.common.typing import Self
2021
from aiida.common.warnings import warn_deprecation
2122
from aiida.manage import get_manager
2223

@@ -31,8 +32,6 @@
3132

3233
__all__ = ('AutoGroup', 'Group', 'ImportGroup', 'UpfFamily')
3334

34-
SelfType = TypeVar('SelfType', bound='Group')
35-
3635

3736
def load_group_class(type_string: str) -> Type['Group']:
3837
"""Load the sub class of `Group` that corresponds to the given `type_string`.
@@ -203,7 +202,7 @@ def __repr__(self) -> str:
203202
def __str__(self) -> str:
204203
return f'{self.__class__.__name__}<{self.label}>'
205204

206-
def store(self: SelfType) -> SelfType:
205+
def store(self) -> Self:
207206
"""Verify that the group is allowed to be stored, which is the case along as `type_string` is set."""
208207
if self._type_string is None:
209208
raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.')

src/aiida/orm/implementation/storage_backend.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
from __future__ import annotations
1212

1313
import abc
14-
from typing import TYPE_CHECKING, Any, ContextManager, List, Optional, Sequence, TypeVar, Union
14+
from collections.abc import Iterable
15+
from typing import TYPE_CHECKING, Any, ContextManager, List, Optional, TypeVar, Union
1516

1617
if TYPE_CHECKING:
18+
from disk_objectstore.backup_utils import BackupManager
19+
1720
from aiida.manage.configuration.profile import Profile
1821
from aiida.orm.autogroup import AutogroupManager
1922
from aiida.orm.entities import EntityTypes
@@ -129,7 +132,7 @@ def version(self) -> str:
129132
return version
130133

131134
@abc.abstractmethod
132-
def close(self):
135+
def close(self) -> None:
133136
"""Close the storage access."""
134137

135138
@property
@@ -253,12 +256,12 @@ def delete(self) -> None:
253256
raise NotImplementedError()
254257

255258
@abc.abstractmethod
256-
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
259+
def delete_nodes_and_connections(self, pks_to_delete: Iterable[int]) -> None:
257260
"""Delete all nodes corresponding to pks in the input and any links to/from them.
258261
259262
This method is intended to be used within a transaction context.
260263
261-
:param pks_to_delete: a sequence of node pks to delete
264+
:param pks_to_delete: an iterable of node pks to delete
262265
263266
:raises: ``AssertionError`` if a transaction is not active
264267
"""
@@ -269,7 +272,7 @@ def get_repository(self) -> 'AbstractRepositoryBackend':
269272

270273
@abc.abstractmethod
271274
def set_global_variable(
272-
self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True
275+
self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite: bool = True
273276
) -> None:
274277
"""Set a global variable in the storage.
275278
@@ -291,7 +294,7 @@ def get_global_variable(self, key: str) -> Union[None, str, int, float]:
291294
"""
292295

293296
@abc.abstractmethod
294-
def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None:
297+
def maintain(self, full: bool = False, dry_run: bool = False, **kwargs: Any) -> None:
295298
"""Perform maintenance tasks on the storage.
296299
297300
If `full == True`, then this method may attempt to block the profile associated with the
@@ -309,10 +312,10 @@ def _backup(
309312
self,
310313
dest: str,
311314
keep: Optional[int] = None,
312-
):
315+
) -> None:
313316
raise NotImplementedError
314317

315-
def _write_backup_config(self, backup_manager):
318+
def _write_backup_config(self, backup_manager: BackupManager) -> None:
316319
import pathlib
317320
import tempfile
318321

@@ -338,7 +341,7 @@ def _write_backup_config(self, backup_manager):
338341
except (exceptions.MissingConfigurationError, exceptions.ConfigurationError) as exc:
339342
raise exceptions.StorageBackupError('AiiDA config.json not found!') from exc
340343

341-
def _validate_or_init_backup_folder(self, dest, keep):
344+
def _validate_or_init_backup_folder(self, dest: str, keep: int | None) -> BackupManager:
342345
import json
343346
import tempfile
344347

@@ -396,7 +399,7 @@ def backup(
396399
self,
397400
dest: str,
398401
keep: Optional[int] = None,
399-
):
402+
) -> None:
400403
"""Create a backup of the storage contents.
401404
402405
:param dest: The path to the destination folder.
@@ -440,15 +443,15 @@ def backup(
440443
STORAGE_LOGGER.report(f'Overwriting the `{DEFAULT_CONFIG_FILE_NAME} file.')
441444
self._write_backup_config(backup_manager)
442445

443-
def get_info(self, detailed: bool = False) -> dict:
446+
def get_info(self, detailed: bool = False) -> dict[str, Any]:
444447
"""Return general information on the storage.
445448
446449
:param detailed: flag to request more detailed information about the content of the storage.
447450
:returns: a nested dict with the relevant information.
448451
"""
449452
return {'entities': self.get_orm_entities(detailed=detailed)}
450453

451-
def get_orm_entities(self, detailed: bool = False) -> dict:
454+
def get_orm_entities(self, detailed: bool = False) -> dict[str, Any]:
452455
"""Return a mapping with an overview of the storage contents regarding ORM entities.
453456
454457
:param detailed: flag to request more detailed information about the content of the storage.

0 commit comments

Comments
 (0)