Skip to content

Commit

Permalink
Make datachain queries atomic when exception occurs (#494)
Browse files Browse the repository at this point in the history
* Make datachain queries atomic when exception occurs

With this change, whenever any error or exception is raised when running
the script, this will revert all datachain version and datasets created
during the script.

Studio PR: iterative/studio#10740
Studio Issue: iterative/studio#9875

* Use uuid to make it work with postgres

* Check only current dv are created
  • Loading branch information
amritghimire authored Oct 7, 2024
1 parent 644dc9a commit 414872b
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,14 @@ def create_new_dataset_version(
schema = {
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
}

job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
if not job_id:
from datachain.query.session import Session

session = Session.get(catalog=self)
job_id = session.job_id

dataset = self.metastore.create_dataset_version(
dataset,
version,
Expand Down
21 changes: 20 additions & 1 deletion src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from datachain.data_storage import AbstractIDGenerator, schema
from datachain.data_storage.db_engine import DatabaseEngine


logger = logging.getLogger("datachain")


Expand Down Expand Up @@ -384,6 +383,11 @@ def set_job_and_dataset_status(
) -> None:
"""Set the status of the given job and dataset."""

@abstractmethod
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
"""Returns dataset names and versions for the job."""
raise NotImplementedError


class AbstractDBMetastore(AbstractMetastore):
"""
Expand Down Expand Up @@ -1519,3 +1523,18 @@ def set_job_and_dataset_status(
.values(status=dataset_status)
)
self.db.execute(query, conn=conn) # type: ignore[attr-defined]

def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
"""Returns dataset names and versions for the job."""
dv = self._datasets_versions
ds = self._datasets

join_condition = dv.c.dataset_id == ds.c.id

query = (
self._datasets_versions_select(ds.c.name, dv.c.version)
.select_from(dv.join(ds, join_condition))
.where(dv.c.job_id == job_id)
)

return list(self.db.execute(query))
42 changes: 42 additions & 0 deletions src/datachain/query/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import atexit
import logging
import os
import re
import sys
from typing import TYPE_CHECKING, Optional
from uuid import uuid4

Expand All @@ -9,6 +12,8 @@
if TYPE_CHECKING:
from datachain.catalog import Catalog

logger = logging.getLogger("datachain")


class Session:
"""
Expand All @@ -35,6 +40,7 @@ class Session:

GLOBAL_SESSION_CTX: Optional["Session"] = None
GLOBAL_SESSION: Optional["Session"] = None
ORIGINAL_EXCEPT_HOOK = None

DATASET_PREFIX = "session_"
GLOBAL_SESSION_NAME = "global"
Expand All @@ -58,6 +64,7 @@ def __init__(

session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
self.name = f"{name}_{session_uuid}"
self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
self.is_new_catalog = not catalog
self.catalog = catalog or get_catalog(
client_config=client_config, in_memory=in_memory
Expand All @@ -67,6 +74,9 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._cleanup_created_versions(self.name)

self._cleanup_temp_datasets()
if self.is_new_catalog:
self.catalog.metastore.close_on_exit()
Expand All @@ -88,6 +98,21 @@ def _cleanup_temp_datasets(self) -> None:
except TableMissingError:
pass

def _cleanup_created_versions(self, job_id: str) -> None:
versions = self.catalog.metastore.get_job_dataset_versions(job_id)
if not versions:
return

datasets = {}
for dataset_name, version in versions:
if dataset_name not in datasets:
datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
dataset = datasets[dataset_name]
logger.info(
"Removing dataset version %s@%s due to exception", dataset_name, version
)
self.catalog.remove_dataset_version(dataset, version)

@classmethod
def get(
cls,
Expand All @@ -114,9 +139,23 @@ def get(
in_memory=in_memory,
)
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()

atexit.register(cls._global_cleanup)
cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
sys.excepthook = cls.except_hook

return cls.GLOBAL_SESSION

@staticmethod
def except_hook(exc_type, exc_value, exc_traceback):
Session._global_cleanup()
if Session.GLOBAL_SESSION_CTX is not None:
job_id = Session.GLOBAL_SESSION_CTX.job_id
Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)

if Session.ORIGINAL_EXCEPT_HOOK:
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)

@classmethod
def cleanup_for_tests(cls):
if cls.GLOBAL_SESSION_CTX is not None:
Expand All @@ -125,6 +164,9 @@ def cleanup_for_tests(cls):
cls.GLOBAL_SESSION_CTX = None
atexit.unregister(cls._global_cleanup)

if cls.ORIGINAL_EXCEPT_HOOK:
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK

@staticmethod
def _global_cleanup():
if Session.GLOBAL_SESSION_CTX is not None:
Expand Down
24 changes: 24 additions & 0 deletions tests/scripts/feature_class_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Set logger to debug level
import logging

from pydantic import BaseModel

from datachain.lib.dc import C, DataChain

logging.basicConfig(level=logging.INFO)


class Embedding(BaseModel):
value: float


ds_name = "feature_class_error"
ds = (
DataChain.from_storage("gs://dvcx-datalakes/dogs-and-cats/")
.filter(C("file.path").glob("*cat*.jpg"))
.limit(5)
.map(emd=lambda file: Embedding(value=512), output=Embedding)
)
ds.select("file.path", "emd.value").show(limit=5, flatten=True)
ds.save(ds_name)
raise Exception("This is a test exception")
58 changes: 58 additions & 0 deletions tests/test_atomicity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import subprocess
import sys

import pytest
import sqlalchemy as sa

from datachain.sql.types import Float32

tests_dir = os.path.dirname(os.path.abspath(__file__))

python_exc = sys.executable or "python3"

E2E_STEP_TIMEOUT_SEC = 90


@pytest.mark.e2e
@pytest.mark.xdist_group(name="tmpfile")
def test_atomicity_feature_file(tmp_dir, catalog_tmpfile):
command = (
python_exc,
os.path.join(tests_dir, "scripts", "feature_class_exception.py"),
)
if sys.platform == "win32":
# Windows has a different mechanism of creating a process group.
popen_args = {"creationflags": subprocess.CREATE_NEW_PROCESS_GROUP}
# This is STATUS_CONTROL_C_EXIT which is equivalent to 0xC000013A
else:
popen_args = {"start_new_session": True}

existing_dataset = catalog_tmpfile.create_dataset(
"existing_dataset",
query_script="script",
columns=[sa.Column("similarity", Float32)],
create_rows=True,
)

process = subprocess.Popen( # noqa: S603
command,
shell=False,
encoding="utf-8",
env={
**os.environ,
"DATACHAIN__ID_GENERATOR": catalog_tmpfile.id_generator.serialize(),
"DATACHAIN__METASTORE": catalog_tmpfile.metastore.serialize(),
"DATACHAIN__WAREHOUSE": catalog_tmpfile.warehouse.serialize(),
},
**popen_args,
)

process.communicate(timeout=E2E_STEP_TIMEOUT_SEC)

assert process.returncode == 1

# No datasets should be created in the catalog, but old should not be removed.
dataset_versions = list(catalog_tmpfile.list_datasets_versions())
assert len(dataset_versions) == 1
assert dataset_versions[0][0].name == existing_dataset.name

0 comments on commit 414872b

Please sign in to comment.