|
18 | 18 | import sys |
19 | 19 | from collections.abc import Iterable, Iterator |
20 | 20 | from datetime import datetime, timedelta |
21 | | -from typing import TYPE_CHECKING, Any, Callable |
| 21 | +from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload |
22 | 22 | from uuid import UUID |
23 | 23 |
|
24 | 24 | from aiida.common.typing import Self |
|
32 | 32 | except ImportError: |
33 | 33 | from typing_extensions import TypeAlias |
34 | 34 |
|
| 35 | +T = TypeVar('T') |
| 36 | +R = TypeVar('R') |
| 37 | + |
35 | 38 |
|
36 | 39 | def get_new_uuid() -> str: |
37 | 40 | """Return a new UUID (typically to be used for new nodes).""" |
@@ -608,3 +611,55 @@ def format_directory_size(size_in_bytes: int) -> str: |
608 | 611 |
|
609 | 612 | # Format the size to two decimal places |
610 | 613 | return f'{converted_size:.2f} {prefixes[index]}' |
| 614 | + |
| 615 | + |
| 616 | +@overload |
| 617 | +def batch_iter(iterable: Iterable[T], size: int, transform: None = None) -> Iterable[tuple[int, list[T]]]: ... |
| 618 | + |
| 619 | + |
| 620 | +@overload |
| 621 | +def batch_iter(iterable: Iterable[T], size: int, transform: Callable[[T], R]) -> Iterable[tuple[int, list[R]]]: ... |
| 622 | + |
| 623 | + |
| 624 | +def batch_iter( |
| 625 | + iterable: Iterable[T], size: int, transform: Callable[[T], Any] | None = None |
| 626 | +) -> Iterable[tuple[int, list[Any]]]: |
| 627 | + """Yield an iterable in batches of a set number of items. |
| 628 | +
|
| 629 | + Note, the final yield may be less than this size. |
| 630 | +
|
| 631 | + :param transform: a transform to apply to each item |
| 632 | + :returns: (number of items, list of items) |
| 633 | + """ |
| 634 | + transform = transform or (lambda x: x) |
| 635 | + current = [] |
| 636 | + length = 0 |
| 637 | + for item in iterable: |
| 638 | + current.append(transform(item)) |
| 639 | + length += 1 |
| 640 | + if length >= size: |
| 641 | + yield length, current |
| 642 | + current = [] |
| 643 | + length = 0 |
| 644 | + if current: |
| 645 | + yield length, current |
| 646 | + |
| 647 | + |
| 648 | +# NOTE: `sqlite` has an `SQLITE_MAX_VARIABLE_NUMBER` compile-time flag. |
| 649 | +# On older `sqlite` versions, this was set to 999 by default, |
| 650 | +# while for newer versions it is generally higher, see: |
| 651 | +# https://www.sqlite.org/limits.html |
| 652 | +# If `DEFAULT_FILTER_SIZE` is set too high, the limit can be hit when large `IN` queries are |
| 653 | +# constructed through AiiDA, leading to SQLAlchemy `OperationalError`s. |
| 654 | +# On modern systems, the limit might be in the hundreds of thousands, however, as it is OS- |
| 655 | +# and/or Python version dependent and we don't know its size, we set the value to 999 for safety. |
| 656 | +# From manual benchmarking, this value for batching also seems to give reasonable performance. |
| 657 | +DEFAULT_FILTER_SIZE: int = 999 |
| 658 | + |
| 659 | +# NOTE: `DEFAULT_BATCH_SIZE` controls how many database rows are fetched and processed at once during |
| 660 | +# streaming operations (e.g., `QueryBuilder.iterall()`, `QueryBuilder.iterdict()`). This prevents |
| 661 | +# loading entire large result sets into memory at once, which could cause memory exhaustion when |
| 662 | +# working with datasets containing thousands or millions of records. The value of 1000 provides a |
| 663 | +# balance between memory efficiency and database round-trip overhead. Setting it too low increases |
| 664 | +# the number of database queries needed, while setting it too high increases memory consumption. |
| 665 | +DEFAULT_BATCH_SIZE: int = 1000 |
0 commit comments