diff --git a/src/aac_datasets/utils/collate.py b/src/aac_datasets/utils/collate.py index 0adbe50..c50f42a 100644 --- a/src/aac_datasets/utils/collate.py +++ b/src/aac_datasets/utils/collate.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Union import torch - from torch import Tensor from torch.nn import functional as F @@ -18,7 +17,7 @@ class BasicCollate: """ def __call__(self, batch_lst: List[Dict[str, Any]]) -> Dict[str, List[Any]]: - return list_dict_to_dict_list(batch_lst) + return list_dict_to_dict_list(batch_lst, key_mode="intersect") class AdvancedCollate: @@ -42,7 +41,10 @@ def __init__(self, fill_values: Dict[str, Union[float, int]]) -> None: self.fill_values = fill_values def __call__(self, batch_lst: List[Dict[str, Any]]) -> Dict[str, Any]: - batch_dic: Dict[str, Any] = list_dict_to_dict_list(batch_lst) + batch_dic: Dict[str, Any] = list_dict_to_dict_list( + batch_lst, + key_mode="intersect", + ) keys = list(batch_dic.keys()) for key in keys: diff --git a/src/aac_datasets/utils/collections.py b/src/aac_datasets/utils/collections.py index fd8300b..4a0dcaf 100644 --- a/src/aac_datasets/utils/collections.py +++ b/src/aac_datasets/utils/collections.py @@ -1,24 +1,59 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Dict, Iterable, List, Mapping, Sequence, TypeVar +from typing import ( + Dict, + Iterable, + List, + Literal, + Mapping, + Sequence, + TypeVar, + Union, + overload, +) +K = TypeVar("K") T = TypeVar("T") +V = TypeVar("V") +W = TypeVar("W") + +KEY_MODES = ("same", "intersect", "union") +KeyMode = Literal["intersect", "same", "union"] + + +@overload +def list_dict_to_dict_list( + lst: Sequence[Mapping[K, V]], + key_mode: Literal["intersect", "same"], + default_val: W = None, +) -> Dict[K, List[V]]: + ... + + +@overload +def list_dict_to_dict_list( + lst: Sequence[Mapping[K, V]], + key_mode: Literal["union"] = "union", + default_val: W = None, +) -> Dict[K, List[Union[V, W]]]: + ... def list_dict_to_dict_list( - lst: Sequence[Mapping[str, T]], - key_mode: str = "intersect", - default: Any = None, -) -> Dict[str, List[T]]: + lst: Sequence[Mapping[K, V]], + key_mode: KeyMode = "union", + default_val: W = None, +) -> Dict[K, List[Union[V, W]]]: """Convert list of dicts to dict of lists. - :param lst: The list of dict to merge. - :param key_mode: Can be "same" or "intersect". - If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised. - If "intersect", only the intersection of all keys will be used in output. - If "union", the output dict will contains the union of all keys, and the missing value will use the argument default. - :returns: The dictionary of lists. + Args: + lst: The list of dict to merge. + key_mode: Can be "same" or "intersect". + If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised. + If "intersect", only the intersection of all keys will be used in output. + If "union", the output dict will contains the union of all keys, and the missing value will use the argument default_val. + default_val: Default value of an element when key_mode is "union". defaults to None. """ if len(lst) <= 0: return {} @@ -32,12 +67,11 @@ def list_dict_to_dict_list( elif key_mode == "union": keys = union_lists([item.keys() for item in lst]) else: - KEY_MODES = ("same", "intersect", "union") raise ValueError( f"Invalid argument key_mode={key_mode}. (expected one of {KEY_MODES})" ) - return {key: [item.get(key, default) for item in lst] for key in keys} + return {key: [item.get(key, default_val) for item in lst] for key in keys} def intersect_lists(lst_of_lst: Sequence[Iterable[T]]) -> List[T]: diff --git a/src/aac_datasets/utils/download.py b/src/aac_datasets/utils/download.py index 9458a57..536bab8 100644 --- a/src/aac_datasets/utils/download.py +++ b/src/aac_datasets/utils/download.py @@ -4,15 +4,12 @@ import hashlib import os import os.path as osp - from pathlib import Path -from typing import List, Union +from typing import List, Literal, Union import tqdm - from torch.hub import download_url_to_file - HASH_TYPES = ("sha256", "md5") DEFAULT_CHUNK_SIZE = 256 * 1024**2 # 256 MiB @@ -70,7 +67,7 @@ def safe_rmdir( def hash_file( fpath: Union[str, Path], - hash_type: str, + hash_type: Literal["sha256", "md5"], chunk_size: int = DEFAULT_CHUNK_SIZE, ) -> str: """Return the hash value for a file.