diff --git a/src/dvc_data/index/collect.py b/src/dvc_data/index/collect.py new file mode 100644 index 00000000..5d54bea0 --- /dev/null +++ b/src/dvc_data/index/collect.py @@ -0,0 +1,157 @@ +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from dvc_objects.fs.callbacks import DEFAULT_CALLBACK + +from .index import ( + DataIndex, + DataIndexEntry, + FileStorage, + ObjectStorage, + StorageInfo, +) + +if TYPE_CHECKING: + from dvc_objects.fs.callbacks import Callback + + from .index import Storage + +logger = logging.getLogger(__name__) + + +def _collect_from_index( + cache, + cache_prefix, + index, + prefix, + storage, + callback: "Callback" = DEFAULT_CALLBACK, +): + entries = {} + + try: + for _, entry in index.iteritems(prefix): + callback.relative_update() + try: + storage_key = storage.get_key(entry) + except ValueError: + continue + + loaded = False + if entry.meta and entry.meta.isdir: + # NOTE: at this point it might not be loaded yet, so we can't + # rely on entry.loaded + loaded = True + + meta = entry.meta + hash_info = entry.hash_info + if ( + isinstance(storage, FileStorage) + and storage.fs.version_aware + and entry.meta + and not entry.meta.isdir + and entry.meta.version_id is None + ): + meta.md5 = None + hash_info = None + + # NOTE: avoiding modifying cache right away, because you might + # run into a locked database if idx and cache are using the same + # table. + entries[storage_key] = DataIndexEntry( + key=storage_key, + meta=meta, + hash_info=hash_info, + loaded=loaded, + ) + + except KeyError: + return + + for key, entry in entries.items(): + cache[(*cache_prefix, *key)] = entry + + +def collect( # noqa: C901 + idxs, + storage, + callback: "Callback" = DEFAULT_CALLBACK, + cache_index=None, + cache_key=None, +) -> List["DataIndex"]: + from fsspec.utils import tokenize + + storage_by_fs: Dict[Tuple[str, str], StorageInfo] = {} + skip = set() + + if cache_index is None: + cache_index = DataIndex() + cache_key = () + + for idx in idxs: + for prefix, storage_info in idx.storage_map.items(): + data = getattr(storage_info, storage) + cache = storage_info.cache if storage != "cache" else None + remote = storage_info.remote if storage != "remote" else None + + if not data: + continue + + # FIXME should use fsid instead of protocol + key = (data.fs.protocol, tokenize(data.path)) + if key not in storage_by_fs: + if cache_index.has_node((*cache_key, *key)): + skip.add(key) + + if key not in skip: + _collect_from_index( + cache_index, + (*cache_key, *key), + idx, + prefix, + data, + callback=callback, + ) + cache_index.commit() + + if key not in storage_by_fs: + fs_data: "Storage" + fs_cache: Optional["Storage"] + fs_remote: Optional["Storage"] + + if isinstance(data, ObjectStorage): + fs_data = ObjectStorage(key=(), odb=data.odb) + else: + fs_data = FileStorage(key=(), fs=data.fs, path=data.path) + + if not cache: + fs_cache = None + elif isinstance(cache, ObjectStorage): + fs_cache = ObjectStorage(key=(), odb=cache.odb) + else: + fs_cache = FileStorage( + key=(), fs=cache.fs, path=cache.path + ) + + if not remote: + fs_remote = None + elif isinstance(remote, ObjectStorage): + fs_remote = ObjectStorage(key=(), odb=remote.odb) + else: + fs_remote = FileStorage( + key=(), + fs=remote.fs, + path=remote.path, + ) + + storage_by_fs[key] = StorageInfo( + data=fs_data, cache=fs_cache, remote=fs_remote + ) + + storage_indexes = [] + for key, storage_info in storage_by_fs.items(): + idx = cache_index.view((*cache_key, *key)) + idx.storage_map[()] = storage_info + storage_indexes.append(idx) + + return storage_indexes diff --git a/src/dvc_data/index/fetch.py b/src/dvc_data/index/fetch.py index 5dc0a089..2a1db089 100644 --- a/src/dvc_data/index/fetch.py +++ b/src/dvc_data/index/fetch.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Optional from dvc_objects.fs.callbacks import DEFAULT_CALLBACK @@ -8,13 +8,8 @@ from .build import build from .checkout import apply, compare -from .index import ( - DataIndex, - DataIndexEntry, - FileStorage, - ObjectStorage, - StorageInfo, -) +from .collect import collect # noqa: F401, pylint: disable=unused-import +from .index import ObjectStorage from .save import md5, save if TYPE_CHECKING: @@ -22,8 +17,6 @@ from dvc_data.hashfile.status import CompareStatusResult - from .index import Storage - logger = logging.getLogger(__name__) @@ -41,154 +34,31 @@ def _log_missing(status: "CompareStatusResult"): ) -def _collect_from_index( - cache, - cache_prefix, - index, - prefix, - remote, - callback: "Callback" = DEFAULT_CALLBACK, -): - entries = {} - - try: - for _, entry in index.iteritems(prefix): - callback.relative_update() - try: - storage_key = remote.get_key(entry) - except ValueError: - continue - - loaded = False - if entry.meta and entry.meta.isdir: - # NOTE: at this point it might not be loaded yet, so we can't - # rely on entry.loaded - loaded = True - - meta = entry.meta - hash_info = entry.hash_info - if ( - isinstance(remote, FileStorage) - and remote.fs.version_aware - and entry.meta - and not entry.meta.isdir - and entry.meta.version_id is None - ): - meta.md5 = None - hash_info = None - # NOTE: avoiding modifying cache right away, because you might - # run into a locked database if idx and cache are using the same - # table. - entries[storage_key] = DataIndexEntry( - key=storage_key, - meta=meta, - hash_info=hash_info, - loaded=loaded, - ) - - except KeyError: - return - - for key, entry in entries.items(): - cache[(*cache_prefix, *key)] = entry - - -def collect( # noqa: C901 - idxs, - callback: "Callback" = DEFAULT_CALLBACK, - cache_index=None, - cache_key=None, -): - from fsspec.utils import tokenize - - storage_by_fs: Dict[Tuple[str, str], StorageInfo] = {} - skip = set() - - if cache_index is None: - cache_index = DataIndex() - cache_key = () - - for idx in idxs: - for prefix, storage_info in idx.storage_map.items(): - remote = storage_info.remote - cache = storage_info.cache - if not remote or not cache: - continue - - # FIXME should use fsid instead of protocol - key = (remote.fs.protocol, tokenize(remote.path)) - if key not in storage_by_fs: - if cache_index.has_node((*cache_key, *key)): - skip.add(key) - - if key not in skip: - _collect_from_index( - cache_index, - (*cache_key, *key), - idx, - prefix, - remote, - callback=callback, - ) - cache_index.commit() - - if key not in storage_by_fs: - fs_cache: "Storage" - fs_remote: "Storage" - - if isinstance(cache, ObjectStorage): - fs_cache = ObjectStorage(key=(), odb=cache.odb) - else: - fs_cache = FileStorage( - key=(), fs=cache.fs, path=cache.path - ) - - if isinstance(remote, ObjectStorage): - fs_remote = ObjectStorage(key=(), odb=remote.odb) - else: - fs_remote = FileStorage( - key=(), - fs=remote.fs, - path=remote.path, - ) - - storage_by_fs[key] = StorageInfo( - cache=fs_cache, remote=fs_remote - ) - - by_fs: Dict[Tuple[str, str], DataIndex] = {} - for key, storage in storage_by_fs.items(): - by_fs[key] = cache_index.view((*cache_key, *key)) - by_fs[key].storage_map[()] = storage - - return by_fs - - def fetch( - data, + idxs, callback: "Callback" = DEFAULT_CALLBACK, jobs: Optional[int] = None, ): fetched, failed = 0, 0 - for (fs_protocol, _), fs_index in data.items(): + for fs_index in idxs: + data = fs_index.storage_map[()].data cache = fs_index.storage_map[()].cache - remote = fs_index.storage_map[()].remote if callback != DEFAULT_CALLBACK: cb = callback.as_tqdm_callback( unit="file", total=len(fs_index), - desc=f"Fetching from {fs_protocol}", + desc=f"Fetching from {data.fs.protocol}", ) else: cb = callback with cb: if isinstance(cache, ObjectStorage) and isinstance( - remote, ObjectStorage + data, ObjectStorage ): result = transfer( - remote.odb, + data.odb, cache.odb, [ entry.hash_info @@ -196,19 +66,17 @@ def fetch( if entry.hash_info ], jobs=jobs, - src_index=get_index(remote.odb), + src_index=get_index(data.odb), cache_odb=cache.odb, - verify=remote.odb.verify, + verify=data.odb.verify, validate_status=_log_missing, callback=cb, ) fetched += len(result.transferred) failed += len(result.failed) elif isinstance(cache, ObjectStorage): - md5(fs_index, storage="remote", check_meta=False) - fetched += save( - fs_index, storage="remote", jobs=jobs, callback=cb - ) + md5(fs_index, check_meta=False) + fetched += save(fs_index, jobs=jobs, callback=cb) else: old = build(cache.path, cache.fs) diff = compare(old, fs_index) @@ -220,7 +88,7 @@ def fetch( cache.path, cache.fs, update_meta=False, - storage="remote", + storage="data", jobs=jobs, callback=cb, ) diff --git a/tests/index/test_index.py b/tests/index/test_index.py index f6a4f604..75950175 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -210,7 +210,7 @@ def test_fetch(tmp_upath, make_odb, odb, as_filesystem): index.storage_map.add_remote(ObjectStorage((), odb)) (tmp_upath / "fetched").mkdir() - data = collect([index]) + data = collect([index], "remote") fetch(data) diff = checkout.compare(None, index) checkout.apply(