Skip to content

Commit

Permalink
[CVAT][Exchange Oracle] Fix track completed tasks (#2645)
Browse files Browse the repository at this point in the history
* Fix complete_tasks_with_completed_jobs return type

* Prevent swallowing exceptions in state trackers while testing

* Fix track completed escrows tests
  • Loading branch information
Bobronium authored Oct 16, 2024
1 parent 4b5a2e1 commit 259115d
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 13 deletions.
1 change: 1 addition & 0 deletions packages/examples/cvat/exchange-oracle/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def validate(cls) -> None:


class Config:
debug = to_bool(os.environ.get("DEBUG", "false"))
port = int(os.environ.get("PORT", 8000))
environment = os.environ.get("ENVIRONMENT", "development")
workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1))
Expand Down
3 changes: 3 additions & 0 deletions packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlalchemy.orm import Session

from src import Config
from src.db import SessionLocal
from src.log import get_logger_name

Expand Down Expand Up @@ -65,6 +66,8 @@ def wrapper():
return fn(logger, session)
except Exception:
logger.exception(f"Exception while running {cron_spec.repr} cron")
if Config.debug:
raise
finally:
logger.debug(f"Cron {cron_spec.repr} finished")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ def track_completed_tasks(logger: logging.Logger, session: Session) -> None:
updated_tasks = cvat_service.complete_tasks_with_completed_jobs(session)

if updated_tasks:
session.commit()
cvat_service.touch(
session,
cvat_models.Task,
[t[0] for t in updated_tasks],
[t.id for t in updated_tasks],
)
session.commit()
logger.info(
f"Found new completed tasks: {format_sequence([t.cvat_id for t in updated_tasks])}"
)

logger.info(f"Found new completed projects: {format_sequence(updated_tasks)}")


@cron_job
Expand Down
11 changes: 8 additions & 3 deletions packages/examples/cvat/exchange-oracle/src/services/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Iterable, Sequence
from datetime import datetime
from itertools import islice
from typing import Any
from typing import Any, NamedTuple

from sqlalchemy import case, delete, func, literal, select, update
from sqlalchemy.dialects.postgresql import insert
Expand Down Expand Up @@ -584,7 +584,12 @@ def finish_data_uploads(session: Session, uploads: list[DataUpload]) -> None:
session.execute(statement)


def complete_tasks_with_completed_jobs(session: Session) -> list[tuple[str, int]]:
class TaskResult(NamedTuple):
id: str
cvat_id: int


def complete_tasks_with_completed_jobs(session: Session) -> list[TaskResult]:
incomplete_jobs_exist = (
select(1)
.where(Job.cvat_task_id == Task.cvat_id, Job.status != JobStatuses.completed)
Expand All @@ -604,7 +609,7 @@ def complete_tasks_with_completed_jobs(session: Session) -> list[tuple[str, int]
)

result = session.execute(stmt)
return [row.cvat_id for row in result.all()]
return [TaskResult(row.id, row.cvat_id) for row in result.all()]


# Job
Expand Down
3 changes: 3 additions & 0 deletions packages/examples/cvat/exchange-oracle/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from collections.abc import Generator

os.environ["DEBUG"] = "1"

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest.mock import Mock, patch

import datumaro as dm
import pytest
from sqlalchemy import select

from src.core.types import (
Expand All @@ -32,6 +33,9 @@
from tests.utils.db_helper import create_project_task_and_job


class TestException(RuntimeError): ...


class ServiceIntegrationTest(unittest.TestCase):
def setUp(self):
self.session = SessionLocal()
Expand Down Expand Up @@ -237,7 +241,7 @@ def test_retrieve_annotations_error_getting_annotations(self):
patch("src.handlers.completed_escrows.get_escrow_manifest") as mock_get_manifest,
patch("src.handlers.completed_escrows.validate_escrow"),
patch(
"src.handlers.completed_escrows.cvat_api.get_job_annotations"
"src.handlers.completed_escrows.cvat_api.request_job_annotations"
) as mock_annotations,
patch("src.handlers.completed_escrows.cloud_service") as mock_cloud_service,
):
Expand All @@ -249,9 +253,9 @@ def test_retrieve_annotations_error_getting_annotations(self):
mock_storage_client.create_file = mock_create_file
mock_cloud_service.make_client = Mock(return_value=mock_storage_client)

mock_annotations.side_effect = Exception("Connection error")

track_escrow_validations()
mock_annotations.side_effect = TestException()
with pytest.raises(TestException):
track_escrow_validations()

webhook = (
self.session.query(Webhook)
Expand Down Expand Up @@ -321,6 +325,15 @@ def test_retrieve_annotations_error_uploading_files(self):
cvat_job_id=cvat_job.cvat_id,
expires_at=datetime.now() + timedelta(days=1),
)
project_images = ["sample1.jpg", "sample2.png"]

for image_filename in project_images:
self.session.add(
Image(
id=str(uuid.uuid4()), cvat_project_id=cvat_project_id, filename=image_filename
)
)

self.session.add(assignment)
self.session.commit()

Expand All @@ -339,12 +352,34 @@ def test_retrieve_annotations_error_uploading_files(self):
open("tests/utils/manifest.json") as data,
patch("src.handlers.completed_escrows.get_escrow_manifest") as mock_get_manifest,
patch("src.handlers.completed_escrows.validate_escrow"),
patch("src.handlers.completed_escrows.cvat_api"),
patch("src.handlers.completed_escrows.cvat_api") as mock_cvat_api,
patch("src.handlers.completed_escrows.cloud_service") as mock_cloud_service,
patch("src.services.cloud.make_client"),
):
manifest = json.load(data)
mock_get_manifest.return_value = manifest
dummy_zip_file = io.BytesIO()
with zipfile.ZipFile(dummy_zip_file, "w") as archive, TemporaryDirectory() as tempdir:
mock_dataset = dm.Dataset(
media_type=dm.Image,
categories={
dm.AnnotationType.label: dm.LabelCategories.from_iterable(["cat", "dog"])
},
)
for image_filename in project_images:
mock_dataset.put(dm.DatasetItem(id=os.path.splitext(image_filename)[0]))
mock_dataset.export(tempdir, format="coco_instances")

track_escrow_validations()
for filename in list(glob(os.path.join(tempdir, "**/*"), recursive=True)):
archive.write(filename, os.path.relpath(filename, tempdir))
dummy_zip_file.seek(0)

mock_cvat_api.get_job_annotations.return_value = dummy_zip_file
mock_cvat_api.get_project_annotations.return_value = dummy_zip_file
mock_cloud_service.make_client.return_value.create_file.side_effect = TestException()

with pytest.raises(TestException):
track_escrow_validations()

webhook = (
self.session.query(Webhook)
Expand Down

0 comments on commit 259115d

Please sign in to comment.