Skip to content

Commit

Permalink
fetch: tweak collect
Browse files Browse the repository at this point in the history
Collect functionality is going to be useful for push/gc/etc in the next PRs.
  • Loading branch information
efiop committed Jul 9, 2023
1 parent f684b56 commit efa543a
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 147 deletions.
157 changes: 157 additions & 0 deletions src/dvc_data/index/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from dvc_objects.fs.callbacks import DEFAULT_CALLBACK

from .index import (
DataIndex,
DataIndexEntry,
FileStorage,
ObjectStorage,
StorageInfo,
)

if TYPE_CHECKING:
from dvc_objects.fs.callbacks import Callback

from .index import Storage

logger = logging.getLogger(__name__)


def _collect_from_index(
cache,
cache_prefix,
index,
prefix,
storage,
callback: "Callback" = DEFAULT_CALLBACK,
):
entries = {}

try:
for _, entry in index.iteritems(prefix):
callback.relative_update()
try:
storage_key = storage.get_key(entry)
except ValueError:
continue

loaded = False
if entry.meta and entry.meta.isdir:
# NOTE: at this point it might not be loaded yet, so we can't
# rely on entry.loaded
loaded = True

meta = entry.meta
hash_info = entry.hash_info
if (
isinstance(storage, FileStorage)
and storage.fs.version_aware
and entry.meta
and not entry.meta.isdir
and entry.meta.version_id is None
):
meta.md5 = None
hash_info = None

# NOTE: avoiding modifying cache right away, because you might
# run into a locked database if idx and cache are using the same
# table.
entries[storage_key] = DataIndexEntry(
key=storage_key,
meta=meta,
hash_info=hash_info,
loaded=loaded,
)

except KeyError:
return

for key, entry in entries.items():
cache[(*cache_prefix, *key)] = entry


def collect( # noqa: C901
idxs,
storage,
callback: "Callback" = DEFAULT_CALLBACK,
cache_index=None,
cache_key=None,
) -> List["DataIndex"]:
from fsspec.utils import tokenize

storage_by_fs: Dict[Tuple[str, str], StorageInfo] = {}
skip = set()

if cache_index is None:
cache_index = DataIndex()
cache_key = ()

for idx in idxs:
for prefix, storage_info in idx.storage_map.items():
data = getattr(storage_info, storage)
cache = storage_info.cache if storage != "cache" else None
remote = storage_info.remote if storage != "remote" else None

if not data:
continue

# FIXME should use fsid instead of protocol
key = (data.fs.protocol, tokenize(data.path))
if key not in storage_by_fs:
if cache_index.has_node((*cache_key, *key)):
skip.add(key)

if key not in skip:
_collect_from_index(
cache_index,
(*cache_key, *key),
idx,
prefix,
data,
callback=callback,
)
cache_index.commit()

if key not in storage_by_fs:
fs_data: "Storage"
fs_cache: Optional["Storage"]
fs_remote: Optional["Storage"]

if isinstance(data, ObjectStorage):
fs_data = ObjectStorage(key=(), odb=data.odb)
else:
fs_data = FileStorage(key=(), fs=data.fs, path=data.path)

if not cache:
fs_cache = None
elif isinstance(cache, ObjectStorage):
fs_cache = ObjectStorage(key=(), odb=cache.odb)
else:
fs_cache = FileStorage(
key=(), fs=cache.fs, path=cache.path
)

if not remote:
fs_remote = None
elif isinstance(remote, ObjectStorage):
fs_remote = ObjectStorage(key=(), odb=remote.odb)
else:
fs_remote = FileStorage(
key=(),
fs=remote.fs,
path=remote.path,
)

storage_by_fs[key] = StorageInfo(
data=fs_data, cache=fs_cache, remote=fs_remote
)

storage_indexes = []
for key, storage_info in storage_by_fs.items():
idx = cache_index.view((*cache_key, *key))
idx.storage_map[()] = storage_info
storage_indexes.append(idx)

return storage_indexes
160 changes: 14 additions & 146 deletions src/dvc_data/index/fetch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Optional

from dvc_objects.fs.callbacks import DEFAULT_CALLBACK

Expand All @@ -8,22 +8,15 @@

from .build import build
from .checkout import apply, compare
from .index import (
DataIndex,
DataIndexEntry,
FileStorage,
ObjectStorage,
StorageInfo,
)
from .collect import collect # noqa: F401, pylint: disable=unused-import
from .index import ObjectStorage
from .save import md5, save

if TYPE_CHECKING:
from dvc_objects.fs.callbacks import Callback

from dvc_data.hashfile.status import CompareStatusResult

from .index import Storage

logger = logging.getLogger(__name__)


Expand All @@ -41,174 +34,49 @@ def _log_missing(status: "CompareStatusResult"):
)


def _collect_from_index(
cache,
cache_prefix,
index,
prefix,
remote,
callback: "Callback" = DEFAULT_CALLBACK,
):
entries = {}

try:
for _, entry in index.iteritems(prefix):
callback.relative_update()
try:
storage_key = remote.get_key(entry)
except ValueError:
continue

loaded = False
if entry.meta and entry.meta.isdir:
# NOTE: at this point it might not be loaded yet, so we can't
# rely on entry.loaded
loaded = True

meta = entry.meta
hash_info = entry.hash_info
if (
isinstance(remote, FileStorage)
and remote.fs.version_aware
and entry.meta
and not entry.meta.isdir
and entry.meta.version_id is None
):
meta.md5 = None
hash_info = None
# NOTE: avoiding modifying cache right away, because you might
# run into a locked database if idx and cache are using the same
# table.
entries[storage_key] = DataIndexEntry(
key=storage_key,
meta=meta,
hash_info=hash_info,
loaded=loaded,
)

except KeyError:
return

for key, entry in entries.items():
cache[(*cache_prefix, *key)] = entry


def collect( # noqa: C901
idxs,
callback: "Callback" = DEFAULT_CALLBACK,
cache_index=None,
cache_key=None,
):
from fsspec.utils import tokenize

storage_by_fs: Dict[Tuple[str, str], StorageInfo] = {}
skip = set()

if cache_index is None:
cache_index = DataIndex()
cache_key = ()

for idx in idxs:
for prefix, storage_info in idx.storage_map.items():
remote = storage_info.remote
cache = storage_info.cache
if not remote or not cache:
continue

# FIXME should use fsid instead of protocol
key = (remote.fs.protocol, tokenize(remote.path))
if key not in storage_by_fs:
if cache_index.has_node((*cache_key, *key)):
skip.add(key)

if key not in skip:
_collect_from_index(
cache_index,
(*cache_key, *key),
idx,
prefix,
remote,
callback=callback,
)
cache_index.commit()

if key not in storage_by_fs:
fs_cache: "Storage"
fs_remote: "Storage"

if isinstance(cache, ObjectStorage):
fs_cache = ObjectStorage(key=(), odb=cache.odb)
else:
fs_cache = FileStorage(
key=(), fs=cache.fs, path=cache.path
)

if isinstance(remote, ObjectStorage):
fs_remote = ObjectStorage(key=(), odb=remote.odb)
else:
fs_remote = FileStorage(
key=(),
fs=remote.fs,
path=remote.path,
)

storage_by_fs[key] = StorageInfo(
cache=fs_cache, remote=fs_remote
)

by_fs: Dict[Tuple[str, str], DataIndex] = {}
for key, storage in storage_by_fs.items():
by_fs[key] = cache_index.view((*cache_key, *key))
by_fs[key].storage_map[()] = storage

return by_fs


def fetch(
data,
idxs,
callback: "Callback" = DEFAULT_CALLBACK,
jobs: Optional[int] = None,
):
fetched, failed = 0, 0
for (fs_protocol, _), fs_index in data.items():
for fs_index in idxs:
data = fs_index.storage_map[()].data
cache = fs_index.storage_map[()].cache
remote = fs_index.storage_map[()].remote

if callback != DEFAULT_CALLBACK:
cb = callback.as_tqdm_callback(
unit="file",
total=len(fs_index),
desc=f"Fetching from {fs_protocol}",
desc=f"Fetching from {data.fs.protocol}",
)
else:
cb = callback

with cb:
if isinstance(cache, ObjectStorage) and isinstance(
remote, ObjectStorage
data, ObjectStorage
):
result = transfer(
remote.odb,
data.odb,
cache.odb,
[
entry.hash_info
for _, entry in fs_index.iteritems()
if entry.hash_info
],
jobs=jobs,
src_index=get_index(remote.odb),
src_index=get_index(data.odb),
cache_odb=cache.odb,
verify=remote.odb.verify,
verify=data.odb.verify,
validate_status=_log_missing,
callback=cb,
)
fetched += len(result.transferred)
failed += len(result.failed)
elif isinstance(cache, ObjectStorage):
md5(fs_index, storage="remote", check_meta=False)
fetched += save(
fs_index, storage="remote", jobs=jobs, callback=cb
)
md5(fs_index, check_meta=False)
fetched += save(fs_index, jobs=jobs, callback=cb)
else:
old = build(cache.path, cache.fs)
diff = compare(old, fs_index)
Expand All @@ -220,7 +88,7 @@ def fetch(
cache.path,
cache.fs,
update_meta=False,
storage="remote",
storage="data",
jobs=jobs,
callback=cb,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/index/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_fetch(tmp_upath, make_odb, odb, as_filesystem):
index.storage_map.add_remote(ObjectStorage((), odb))

(tmp_upath / "fetched").mkdir()
data = collect([index])
data = collect([index], "remote")
fetch(data)
diff = checkout.compare(None, index)
checkout.apply(
Expand Down

0 comments on commit efa543a

Please sign in to comment.