diff --git a/src/dvc_data/index/index.py b/src/dvc_data/index/index.py index 283066ad..f207ef9f 100644 --- a/src/dvc_data/index/index.py +++ b/src/dvc_data/index/index.py @@ -2,10 +2,13 @@ import logging import os from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Iterator, MutableMapping from typing import TYPE_CHECKING, Any, Callable, Optional, cast import attrs +from fsspec import Callback +from fsspec.callbacks import DEFAULT_CALLBACK from sqltrie import JSONTrie, PyGTrie, ShortKeyError, SQLiteTrie from dvc_data.compat import cached_property @@ -156,6 +159,20 @@ def exists(self, entry: "DataIndexEntry") -> bool: fs, path = self.get(entry) return fs.exists(path) + def bulk_exists( + self, + entries: list["DataIndexEntry"], + refresh: bool = False, + max_workers: int | None = None, + callback: "Callback" = DEFAULT_CALLBACK, + cached_info: dict[str, Any] | None = None, + ) -> dict["DataIndexEntry", bool]: + results = {} + for entry in callback.wrap(entries): + results[entry] = self.exists(entry) + + return results + class ObjectStorage(Storage): def __init__( @@ -224,6 +241,79 @@ def exists(self, entry: "DataIndexEntry", refresh: bool = False) -> bool: finally: self.index.commit() + def bulk_exists( + self, + entries: list["DataIndexEntry"], + refresh: bool = False, + max_workers: int | None = None, + callback: "Callback" = DEFAULT_CALLBACK, + cached_info: dict[str, Any] | None = None, + ) -> dict["DataIndexEntry", bool]: + results = {} + + if not entries: + return results + + entries_with_hash = [e for e in entries if e.hash_info] + entries_without_hash = [e for e in entries if not e.hash_info] + + for entry in callback.wrap(entries_without_hash): + results[entry] = False + + if self.index is None: + for entry in callback.wrap(entries_with_hash): + assert entry.hash_info + value = cast("str", entry.hash_info.value) + results[entry] = self.odb.exists(value) + return results + + if not refresh: + for entry in callback.wrap(entries_with_hash): + assert entry.hash_info + value = cast("str", entry.hash_info.value) + key = self.odb._oid_parts(value) + results[entry] = key in self.index + return results + + entry_map: dict[str, DataIndexEntry] = { + self.get(entry)[1]: entry for entry in entries_with_hash + } + if cached_info is not None: + # Instead of doing the network call, we use the pre-computed info. + info_results = [ + cached_info.get(path) for path in callback.wrap(entry_map.keys()) + ] + else: + info_results = self.fs.info( + list(entry_map.keys()), + batch_size=max_workers, + return_exceptions=True, + callback=callback, + ) + + results = {} + for (path, entry), info in zip(entry_map.items(), info_results): + assert entry.hash_info # built from entries_with_hash + value = cast("str", entry.hash_info.value) + key = self.odb._oid_parts(value) + + if isinstance(info, FileNotFoundError) or info is None: + self.index.pop(key, None) + results[entry] = False + elif isinstance(info, Exception): + raise info + else: + from .build import build_entry + + built_entry = build_entry(path, self.fs, info=info) + self.index[key] = built_entry + results[entry] = True + + if self.index is not None: + self.index.commit() + + return results + class FileStorage(Storage): def __init__( @@ -442,6 +532,102 @@ def remote_exists(self, entry: "DataIndexEntry", **kwargs) -> bool: return storage.remote.exists(entry, **kwargs) + def _bulk_storage_exists( + self, + entries: list[DataIndexEntry], + storage_selector: Callable[["StorageInfo"], Optional["Storage"]], + callback: Callback = DEFAULT_CALLBACK, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + by_storage: dict[Optional[Storage], list[DataIndexEntry]] = defaultdict(list) + for entry in entries: + storage_info = self[entry.key] + storage = storage_selector(storage_info) if storage_info else None + by_storage[storage].append(entry) + + results = {} + + # Unify batches per actual underlying ODB path. + # Maps from (storage_type, odb_path) to [(StorageInstance, entries)] + odb_batches: dict[ + tuple[type, str | None], list[tuple[ObjectStorage, list[DataIndexEntry]]] + ] = defaultdict(list) + + for storage, storage_entries in by_storage.items(): + if storage is None: + for entry in storage_entries: + raise StorageKeyError(entry.key) + continue + + if not isinstance(storage, ObjectStorage): + # We won't optimize this and run it normally. + storage_results = storage.bulk_exists( + storage_entries, callback=callback, **kwargs + ) + results.update(storage_results) + continue + + key = (type(storage), storage.path) + odb_batches[key].append((storage, storage_entries)) + + # Actually process batches + for storage_groups in odb_batches.values(): + all_paths = [ + storage.get(entry)[1] + for storage, entries in storage_groups + for entry in entries + ] + + # Any storage is representative for this batch + batch_info = storage_groups[0][0].fs.info( + all_paths, + return_exceptions=True, + callback=callback, + ) + + # Maps from path to info + cached_info: dict[str, Any] = { + p: info if not isinstance(info, Exception) else None + for p, info in zip(all_paths, batch_info) + } + + # Finally, distribute results back to original storages + for storage, storage_entries in storage_groups: + storage_results = storage.bulk_exists( + storage_entries, + cached_info=cached_info, + **kwargs, + ) + results.update(storage_results) + + return results + + def bulk_cache_exists( + self, + entries: list[DataIndexEntry], + callback: Callback = DEFAULT_CALLBACK, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + return self._bulk_storage_exists( + entries, + lambda info: info.cache, + callback=callback, + **kwargs, + ) + + def bulk_remote_exists( + self, + entries: list[DataIndexEntry], + callback: Callback = DEFAULT_CALLBACK, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + return self._bulk_storage_exists( + entries, + lambda info: info.remote, + callback=callback, + **kwargs, + ) + class BaseDataIndex(ABC, MutableMapping[DataIndexKey, DataIndexEntry]): storage_map: StorageMapping