Skip to content

Commit ee87cec

Browse files
authored
Improve typing of QueryBuilder.all() (#6966)
* Provide an overloaded method definitions for `flat=True` and `flat=False` arguments. * Bump mypy to 1.17 * Cast away the casts
1 parent ec7bde8 commit ee87cec

File tree

8 files changed

+83
-75
lines changed

8 files changed

+83
-75
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ notebook = [
231231
]
232232
pre-commit = [
233233
'aiida-core[atomic_tools,rest,tests,tui]',
234-
'mypy~=1.16.0',
234+
'mypy~=1.17.0',
235235
'packaging~=23.0',
236236
'pre-commit~=3.5',
237237
'sqlalchemy[mypy]~=2.0',

src/aiida/orm/entities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pathlib
1515
from enum import Enum
1616
from functools import lru_cache
17-
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, Union, cast
17+
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, Union
1818

1919
from plumpy.base.utils import call_with_super_check, super_check
2020
from pydantic import BaseModel
@@ -156,14 +156,14 @@ def find(
156156
:return: a list of resulting matches
157157
"""
158158
query = self.query(filters=filters, order_by=order_by, limit=limit)
159-
return cast(List[EntityType], query.all(flat=True))
159+
return query.all(flat=True)
160160

161161
def all(self) -> List[EntityType]:
162162
"""Get all entities in this collection.
163163
164164
:return: A list of all entities
165165
"""
166-
return cast(List[EntityType], self.query().all(flat=True))
166+
return self.query().all(flat=True)
167167

168168
def count(self, filters: Optional['FilterType'] = None) -> int:
169169
"""Count entities in this collection according to criteria.

src/aiida/orm/nodes/data/code/legacy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_code_helper(cls, label, machinename=None, backend=None):
220220
elif query.count() > 1:
221221
codes = query.all(flat=True)
222222
retstr = f"There are multiple codes with label '{label}', having IDs: "
223-
retstr += f"{', '.join(sorted([str(c.pk) for c in codes]))}.\n" # type: ignore[union-attr]
223+
retstr += f"{', '.join(sorted([str(c.pk) for c in codes]))}.\n"
224224
retstr += 'Relabel them (using their ID), or refer to them with their ID.'
225225
raise MultipleObjectsError(retstr)
226226
else:
@@ -320,9 +320,9 @@ def list_for_plugin(cls, plugin, labels=True, backend=None):
320320
valid_codes = query.all(flat=True)
321321

322322
if labels:
323-
return [c.label for c in valid_codes] # type: ignore[union-attr]
323+
return [c.label for c in valid_codes]
324324

325-
return [c.pk for c in valid_codes] # type: ignore[union-attr]
325+
return [c.pk for c in valid_codes]
326326

327327
def _validate(self):
328328
super()._validate()

src/aiida/orm/querybuilder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,12 +1033,12 @@ def _get_aiida_entity_res(value) -> Any:
10331033
return value
10341034

10351035
@overload
1036-
def first(self, flat: Literal[False] = False) -> Optional[list[Any]]: ...
1036+
def first(self, flat: Literal[False] = False) -> list[Any] | None: ...
10371037

10381038
@overload
1039-
def first(self, flat: Literal[True]) -> Optional[Any]: ...
1039+
def first(self, flat: Literal[True]) -> Any | None: ...
10401040

1041-
def first(self, flat: bool = False) -> Optional[list[Any] | Any]:
1041+
def first(self, flat: bool = False) -> list[Any] | Any | None:
10421042
"""Return the first result of the query.
10431043
10441044
Calling ``first`` results in an execution of the underlying query.
@@ -1105,7 +1105,13 @@ def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict[str, Dict[s
11051105

11061106
yield item
11071107

1108-
def all(self, batch_size: Optional[int] = None, flat: bool = False) -> Union[List[List[Any]], List[Any]]:
1108+
@overload
1109+
def all(self, batch_size: int | None = None, flat: Literal[False] = False) -> list[list[Any]]: ...
1110+
1111+
@overload
1112+
def all(self, batch_size: int | None = None, flat: Literal[True] = True) -> list[Any]: ...
1113+
1114+
def all(self, batch_size: int | None = None, flat: bool = False) -> list[list[Any]] | list[Any]:
11091115
"""Executes the full query with the order of the rows as returned by the backend.
11101116
11111117
The order inside each row is given by the order of the vertices in the path and the order of the projections for

src/aiida/tools/_dumping/detect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
from datetime import datetime, timedelta
14-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, Union, cast
14+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
1515

1616
from aiida import orm
1717
from aiida.common import AIIDA_LOGGER
@@ -171,7 +171,7 @@ def _query_single_type(
171171
if self.NODE_TAG not in qb._projections:
172172
qb.add_projection(self.NODE_TAG, '*')
173173

174-
return cast(list[orm.ProcessNode], qb.all(flat=True))
174+
return qb.all(flat=True)
175175

176176
def _exclude_tracked_nodes(self, nodes: list[orm.ProcessNode], store_type: str) -> list[orm.ProcessNode]:
177177
"""Exclude nodes that are already tracked in the dump tracker.
@@ -289,7 +289,7 @@ def _detect_deleted_nodes(self) -> set[str]:
289289
qb = orm.QueryBuilder()
290290
orm_type = REGISTRY_TO_ORM_TYPE[registry_name]
291291
qb.append(orm_type, project=['uuid'])
292-
all_db_uuids = cast(Set[str], set(qb.all(flat=True)))
292+
all_db_uuids = set(qb.all(flat=True))
293293

294294
# Find missing UUIDs
295295
missing_uuids = dumped_uuids - all_db_uuids

src/aiida/tools/_dumping/executors/profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
from pathlib import Path
14-
from typing import TYPE_CHECKING, List, cast
14+
from typing import TYPE_CHECKING, cast
1515

1616
from aiida import orm
1717
from aiida.common import NotExistent
@@ -91,7 +91,7 @@ def _determine_groups_to_process(self) -> list[orm.Group]:
9191
"""Determine which groups to process based on config."""
9292
if self.config.all_entries:
9393
qb_groups = orm.QueryBuilder().append(orm.Group)
94-
return cast(List[orm.Group], qb_groups.all(flat=True))
94+
return qb_groups.all(flat=True)
9595

9696
if not self.config.groups:
9797
return []

src/aiida/tools/archive/create.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import tempfile
1717
from datetime import datetime
1818
from pathlib import Path
19-
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
19+
from typing import Callable, Iterable, Optional, Sequence, Union
2020

2121
from tabulate import tabulate
2222

@@ -192,7 +192,7 @@ def querybuilder():
192192
EXPORT_LOGGER.report(initial_summary)
193193

194194
# Store starting UUIDs, to write to metadata
195-
starting_uuids: Dict[EntityTypes, Set[str]] = {
195+
starting_uuids: dict[EntityTypes, set[str]] = {
196196
EntityTypes.USER: set(),
197197
EntityTypes.COMPUTER: set(),
198198
EntityTypes.GROUP: set(),
@@ -201,7 +201,7 @@ def querybuilder():
201201

202202
# Store all entity IDs to be written to the archive
203203
# Note, this is the order they will be written to the archive
204-
entity_ids: Dict[EntityTypes, Set[int]] = {
204+
entity_ids: dict[EntityTypes, set[int]] = {
205205
ent: set()
206206
for ent in [
207207
EntityTypes.USER,
@@ -376,12 +376,12 @@ def transform(d):
376376

377377
def _collect_all_entities(
378378
querybuilder: QbType,
379-
entity_ids: Dict[EntityTypes, Set[int]],
379+
entity_ids: dict[EntityTypes, set[int]],
380380
include_authinfos: bool,
381381
include_comments: bool,
382382
include_logs: bool,
383383
batch_size: int,
384-
) -> Tuple[List[Tuple[int, int]], Set[LinkQuadruple]]:
384+
) -> tuple[list[tuple[int, int]], set[LinkQuadruple]]:
385385
"""Collect all entities.
386386
387387
:returns: (group_id_to_node_id, link_data) and updates entity_ids
@@ -393,11 +393,7 @@ def progress_str(name):
393393
with get_progress_reporter()(desc=progress_str(''), total=9) as progress:
394394
progress.set_description_str(progress_str('Nodes'))
395395
entity_ids[EntityTypes.NODE].update(
396-
querybuilder()
397-
.append(orm.Node, project='id')
398-
.all( # type: ignore[arg-type]
399-
batch_size=batch_size, flat=True
400-
)
396+
querybuilder().append(orm.Node, project='id').all(batch_size=batch_size, flat=True)
401397
)
402398
progress.update()
403399

@@ -417,7 +413,7 @@ def progress_str(name):
417413
querybuilder()
418414
.append(
419415
orm.Group,
420-
project='id', # type: ignore[arg-type]
416+
project='id',
421417
)
422418
.all(batch_size=batch_size, flat=True)
423419
)
@@ -429,15 +425,15 @@ def progress_str(name):
429425
.append(orm.Node, with_group='group', project='id')
430426
.distinct()
431427
)
432-
group_nodes: List[Tuple[int, int]] = qbuilder.all(batch_size=batch_size) # type: ignore[assignment]
428+
group_nodes: list[tuple[int, int]] = qbuilder.all(batch_size=batch_size) # type: ignore[assignment]
433429

434430
progress.set_description_str(progress_str('Computers'))
435431
progress.update()
436432
entity_ids[EntityTypes.COMPUTER].update(
437433
querybuilder()
438434
.append(
439435
orm.Computer,
440-
project='id', # type: ignore[arg-type]
436+
project='id',
441437
)
442438
.all(batch_size=batch_size, flat=True)
443439
)
@@ -449,7 +445,7 @@ def progress_str(name):
449445
querybuilder()
450446
.append(
451447
orm.AuthInfo,
452-
project='id', # type: ignore[arg-type]
448+
project='id',
453449
)
454450
.all(batch_size=batch_size, flat=True)
455451
)
@@ -461,7 +457,7 @@ def progress_str(name):
461457
querybuilder()
462458
.append(
463459
orm.Log,
464-
project='id', # type: ignore[arg-type]
460+
project='id',
465461
)
466462
.all(batch_size=batch_size, flat=True)
467463
)
@@ -473,7 +469,7 @@ def progress_str(name):
473469
querybuilder()
474470
.append(
475471
orm.Comment,
476-
project='id', # type: ignore[arg-type]
472+
project='id',
477473
)
478474
.all(batch_size=batch_size, flat=True)
479475
)
@@ -484,7 +480,7 @@ def progress_str(name):
484480
querybuilder()
485481
.append(
486482
orm.User,
487-
project='id', # type: ignore[arg-type]
483+
project='id',
488484
)
489485
.all(batch_size=batch_size, flat=True)
490486
)
@@ -494,14 +490,14 @@ def progress_str(name):
494490

495491
def _collect_required_entities(
496492
querybuilder: QbType,
497-
entity_ids: Dict[EntityTypes, Set[int]],
498-
traversal_rules: Dict[str, bool],
493+
entity_ids: dict[EntityTypes, set[int]],
494+
traversal_rules: dict[str, bool],
499495
include_authinfos: bool,
500496
include_comments: bool,
501497
include_logs: bool,
502498
backend: StorageBackend,
503499
batch_size: int,
504-
) -> Tuple[List[Tuple[int, int]], Set[LinkQuadruple]]:
500+
) -> tuple[list[tuple[int, int]], set[LinkQuadruple]]:
505501
"""Collect required entities, given a set of starting entities and provenance graph traversal rules.
506502
507503
:returns: (group_id_to_node_id, link_data) and updates entity_ids
@@ -513,7 +509,7 @@ def progress_str(name):
513509
with get_progress_reporter()(desc=progress_str(''), total=7) as progress:
514510
# get all nodes from groups
515511
progress.set_description_str(progress_str('Nodes (groups)'))
516-
group_nodes: List[Tuple[int, int]] = []
512+
group_nodes: list[tuple[int, int]] = []
517513
if entity_ids[EntityTypes.GROUP]:
518514
qbuilder = querybuilder()
519515
qbuilder.append(
@@ -632,7 +628,7 @@ def progress_str(name):
632628

633629

634630
def _stream_repo_files(
635-
key_format: str, writer: ArchiveWriterAbstract, node_ids: Set[int], backend: StorageBackend, batch_size: int
631+
key_format: str, writer: ArchiveWriterAbstract, node_ids: set[int], backend: StorageBackend, batch_size: int
636632
) -> None:
637633
"""Collect all repository object keys from the nodes, then stream the files to the archive."""
638634
keys = set(
@@ -652,7 +648,7 @@ def _stream_repo_files(
652648
progress.update()
653649

654650

655-
def _check_unsealed_nodes(querybuilder: QbType, node_ids: Set[int], batch_size: int) -> None:
651+
def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size: int) -> None:
656652
"""Check no process nodes are unsealed, i.e. all processes have completed."""
657653
qbuilder = (
658654
querybuilder()
@@ -678,7 +674,7 @@ def _check_unsealed_nodes(querybuilder: QbType, node_ids: Set[int], batch_size:
678674

679675
def _check_node_licenses(
680676
querybuilder: QbType,
681-
node_ids: Set[int],
677+
node_ids: set[int],
682678
allowed_licenses: Union[None, Sequence[str], Callable],
683679
forbidden_licenses: Union[None, Sequence[str], Callable],
684680
batch_size: int,

0 commit comments

Comments
 (0)