diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 12bc15e704..20f5443761 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -283,6 +283,7 @@ def validate(cls) -> None: class Config: + debug = to_bool(os.environ.get("DEBUG", "false")) port = int(os.environ.get("PORT", 8000)) environment = os.environ.get("ENVIRONMENT", "development") workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1)) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py b/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py index 951580fcdc..2ed2b83b2c 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session +from src import Config from src.db import SessionLocal from src.log import get_logger_name @@ -65,6 +66,8 @@ def wrapper(): return fn(logger, session) except Exception: logger.exception(f"Exception while running {cron_spec.repr} cron") + if Config.debug: + raise finally: logger.debug(f"Cron {cron_spec.repr} finished") diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py index 6d235ac23a..a01b718861 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py @@ -32,14 +32,15 @@ def track_completed_tasks(logger: logging.Logger, session: Session) -> None: updated_tasks = cvat_service.complete_tasks_with_completed_jobs(session) if updated_tasks: - session.commit() cvat_service.touch( session, cvat_models.Task, - [t[0] for t in updated_tasks], + [t.id for t in updated_tasks], + ) + session.commit() + logger.info( + f"Found new completed tasks: {format_sequence([t.cvat_id for t in updated_tasks])}" ) - - logger.info(f"Found new completed projects: {format_sequence(updated_tasks)}") @cron_job diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 834d0ddfd0..bca185286a 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -7,7 +7,7 @@ from collections.abc import Iterable, Sequence from datetime import datetime from itertools import islice -from typing import Any +from typing import Any, NamedTuple from sqlalchemy import case, delete, func, literal, select, update from sqlalchemy.dialects.postgresql import insert @@ -584,7 +584,12 @@ def finish_data_uploads(session: Session, uploads: list[DataUpload]) -> None: session.execute(statement) -def complete_tasks_with_completed_jobs(session: Session) -> list[tuple[str, int]]: +class TaskResult(NamedTuple): + id: str + cvat_id: int + + +def complete_tasks_with_completed_jobs(session: Session) -> list[TaskResult]: incomplete_jobs_exist = ( select(1) .where(Job.cvat_task_id == Task.cvat_id, Job.status != JobStatuses.completed) @@ -604,7 +609,7 @@ def complete_tasks_with_completed_jobs(session: Session) -> list[tuple[str, int] ) result = session.execute(stmt) - return [row.cvat_id for row in result.all()] + return [TaskResult(row.id, row.cvat_id) for row in result.all()] # Job diff --git a/packages/examples/cvat/exchange-oracle/tests/conftest.py b/packages/examples/cvat/exchange-oracle/tests/conftest.py index 306d4f3368..5e0a51bbac 100644 --- a/packages/examples/cvat/exchange-oracle/tests/conftest.py +++ b/packages/examples/cvat/exchange-oracle/tests/conftest.py @@ -1,5 +1,8 @@ +import os from collections.abc import Generator +os.environ["DEBUG"] = "1" + import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py index eced5fa3a4..7ccbfa9dc0 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py @@ -10,6 +10,7 @@ from unittest.mock import Mock, patch import datumaro as dm +import pytest from sqlalchemy import select from src.core.types import ( @@ -32,6 +33,9 @@ from tests.utils.db_helper import create_project_task_and_job +class TestException(RuntimeError): ... + + class ServiceIntegrationTest(unittest.TestCase): def setUp(self): self.session = SessionLocal() @@ -237,7 +241,7 @@ def test_retrieve_annotations_error_getting_annotations(self): patch("src.handlers.completed_escrows.get_escrow_manifest") as mock_get_manifest, patch("src.handlers.completed_escrows.validate_escrow"), patch( - "src.handlers.completed_escrows.cvat_api.get_job_annotations" + "src.handlers.completed_escrows.cvat_api.request_job_annotations" ) as mock_annotations, patch("src.handlers.completed_escrows.cloud_service") as mock_cloud_service, ): @@ -249,9 +253,9 @@ def test_retrieve_annotations_error_getting_annotations(self): mock_storage_client.create_file = mock_create_file mock_cloud_service.make_client = Mock(return_value=mock_storage_client) - mock_annotations.side_effect = Exception("Connection error") - - track_escrow_validations() + mock_annotations.side_effect = TestException() + with pytest.raises(TestException): + track_escrow_validations() webhook = ( self.session.query(Webhook) @@ -321,6 +325,15 @@ def test_retrieve_annotations_error_uploading_files(self): cvat_job_id=cvat_job.cvat_id, expires_at=datetime.now() + timedelta(days=1), ) + project_images = ["sample1.jpg", "sample2.png"] + + for image_filename in project_images: + self.session.add( + Image( + id=str(uuid.uuid4()), cvat_project_id=cvat_project_id, filename=image_filename + ) + ) + self.session.add(assignment) self.session.commit() @@ -339,12 +352,34 @@ def test_retrieve_annotations_error_uploading_files(self): open("tests/utils/manifest.json") as data, patch("src.handlers.completed_escrows.get_escrow_manifest") as mock_get_manifest, patch("src.handlers.completed_escrows.validate_escrow"), - patch("src.handlers.completed_escrows.cvat_api"), + patch("src.handlers.completed_escrows.cvat_api") as mock_cvat_api, + patch("src.handlers.completed_escrows.cloud_service") as mock_cloud_service, + patch("src.services.cloud.make_client"), ): manifest = json.load(data) mock_get_manifest.return_value = manifest + dummy_zip_file = io.BytesIO() + with zipfile.ZipFile(dummy_zip_file, "w") as archive, TemporaryDirectory() as tempdir: + mock_dataset = dm.Dataset( + media_type=dm.Image, + categories={ + dm.AnnotationType.label: dm.LabelCategories.from_iterable(["cat", "dog"]) + }, + ) + for image_filename in project_images: + mock_dataset.put(dm.DatasetItem(id=os.path.splitext(image_filename)[0])) + mock_dataset.export(tempdir, format="coco_instances") - track_escrow_validations() + for filename in list(glob(os.path.join(tempdir, "**/*"), recursive=True)): + archive.write(filename, os.path.relpath(filename, tempdir)) + dummy_zip_file.seek(0) + + mock_cvat_api.get_job_annotations.return_value = dummy_zip_file + mock_cvat_api.get_project_annotations.return_value = dummy_zip_file + mock_cloud_service.make_client.return_value.create_file.side_effect = TestException() + + with pytest.raises(TestException): + track_escrow_validations() webhook = ( self.session.query(Webhook)