Skip to content

Commit

Permalink
Mod: Update list_dict_to_dict_list function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 17, 2024
1 parent 3e5e9bf commit 561c96d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
8 changes: 5 additions & 3 deletions src/aac_datasets/utils/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Union

import torch

from torch import Tensor
from torch.nn import functional as F

Expand All @@ -18,7 +17,7 @@ class BasicCollate:
"""

def __call__(self, batch_lst: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
return list_dict_to_dict_list(batch_lst)
return list_dict_to_dict_list(batch_lst, key_mode="intersect")


class AdvancedCollate:
Expand All @@ -42,7 +41,10 @@ def __init__(self, fill_values: Dict[str, Union[float, int]]) -> None:
self.fill_values = fill_values

def __call__(self, batch_lst: List[Dict[str, Any]]) -> Dict[str, Any]:
batch_dic: Dict[str, Any] = list_dict_to_dict_list(batch_lst)
batch_dic: Dict[str, Any] = list_dict_to_dict_list(
batch_lst,
key_mode="intersect",
)
keys = list(batch_dic.keys())

for key in keys:
Expand Down
60 changes: 47 additions & 13 deletions src/aac_datasets/utils/collections.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Dict, Iterable, List, Mapping, Sequence, TypeVar
from typing import (
Dict,
Iterable,
List,
Literal,
Mapping,
Sequence,
TypeVar,
Union,
overload,
)

K = TypeVar("K")
T = TypeVar("T")
V = TypeVar("V")
W = TypeVar("W")

KEY_MODES = ("same", "intersect", "union")
KeyMode = Literal["intersect", "same", "union"]


@overload
def list_dict_to_dict_list(
lst: Sequence[Mapping[K, V]],
key_mode: Literal["intersect", "same"],
default_val: W = None,
) -> Dict[K, List[V]]:
...


@overload
def list_dict_to_dict_list(
lst: Sequence[Mapping[K, V]],
key_mode: Literal["union"] = "union",
default_val: W = None,
) -> Dict[K, List[Union[V, W]]]:
...


def list_dict_to_dict_list(
lst: Sequence[Mapping[str, T]],
key_mode: str = "intersect",
default: Any = None,
) -> Dict[str, List[T]]:
lst: Sequence[Mapping[K, V]],
key_mode: KeyMode = "union",
default_val: W = None,
) -> Dict[K, List[Union[V, W]]]:
"""Convert list of dicts to dict of lists.
:param lst: The list of dict to merge.
:param key_mode: Can be "same" or "intersect".
If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised.
If "intersect", only the intersection of all keys will be used in output.
If "union", the output dict will contains the union of all keys, and the missing value will use the argument default.
:returns: The dictionary of lists.
Args:
lst: The list of dict to merge.
key_mode: Can be "same" or "intersect".
If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised.
If "intersect", only the intersection of all keys will be used in output.
If "union", the output dict will contains the union of all keys, and the missing value will use the argument default_val.
default_val: Default value of an element when key_mode is "union". defaults to None.
"""
if len(lst) <= 0:
return {}
Expand All @@ -32,12 +67,11 @@ def list_dict_to_dict_list(
elif key_mode == "union":
keys = union_lists([item.keys() for item in lst])
else:
KEY_MODES = ("same", "intersect", "union")
raise ValueError(
f"Invalid argument key_mode={key_mode}. (expected one of {KEY_MODES})"
)

return {key: [item.get(key, default) for item in lst] for key in keys}
return {key: [item.get(key, default_val) for item in lst] for key in keys}


def intersect_lists(lst_of_lst: Sequence[Iterable[T]]) -> List[T]:
Expand Down
7 changes: 2 additions & 5 deletions src/aac_datasets/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import hashlib
import os
import os.path as osp

from pathlib import Path
from typing import List, Union
from typing import List, Literal, Union

import tqdm

from torch.hub import download_url_to_file


HASH_TYPES = ("sha256", "md5")
DEFAULT_CHUNK_SIZE = 256 * 1024**2 # 256 MiB

Expand Down Expand Up @@ -70,7 +67,7 @@ def safe_rmdir(

def hash_file(
fpath: Union[str, Path],
hash_type: str,
hash_type: Literal["sha256", "md5"],
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> str:
"""Return the hash value for a file.
Expand Down

0 comments on commit 561c96d

Please sign in to comment.