Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Archiving) Modify according to documentation #2692

Merged
merged 13 commits into from
Nov 27, 2023
2 changes: 1 addition & 1 deletion cg/constants/archiving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


class ArchiveLocations(StrEnum):
"""Demultiplexing related directories and files."""
"""Archive locations for the different customers' Spring files."""

KAROLINSKA_BUCKET: str = "karolinska_bucket"
77 changes: 25 additions & 52 deletions cg/meta/archive/ddn_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import urljoin

from housekeeper.store.models import File
from pydantic import BaseModel
from pydantic import BaseModel, Field
from requests.models import Response

from cg.constants.constants import APIMethods
Expand Down Expand Up @@ -37,20 +37,24 @@ class DataflowEndpoints(StrEnum):
GET_AUTH_TOKEN = "auth/token"
REFRESH_AUTH_TOKEN = "auth/token/refresh"
RETRIEVE_FILES = "files/retrieve"
GET_JOB_STATUS = "getJobStatus"
GET_JOB_STATUS = "activity/jobs/"


class JobDescription(StrEnum):
class JobStatus(StrEnum):
"""Enum for the different job statuses which can be returned via Miria."""

CANCELED = "Canceled"
COMPLETED = "Completed"
CREATION = "Creation"
DENIED = "Denied"
CREATION_IN_PROGRESS = "Creation in progress"
IN_QUEUE = "In Queue"
INVALID_LICENSE = "Invalid license"
ON_VALIDATION = "On validation"
REFUSED = "Refused"
RUNNING = "Running"
SUSPENDED = "Suspended"
TERMINATED_ON_ERROR = "Terminated on Error"
TERMINATED_ON_ERROR = "Terminated on error"
TERMINATED_ON_WARNING = "Terminated on warning"


class MiriaObject(FileTransferData):
Expand Down Expand Up @@ -100,6 +104,7 @@ class TransferPayload(BaseModel):
files_to_transfer: list[MiriaObject]
osType: str = OSTYPE
createFolder: bool = False
settings: list[dict] = []

def trim_paths(self, attribute_to_trim: str):
"""Trims the source path from its root directory for all objects in the transfer."""
Expand Down Expand Up @@ -128,7 +133,7 @@ def post_request(self, url: str, headers: dict) -> "TransferJob":
url: URL to which the POST goes to.
headers: Headers which are set in the request
Raises:
HTTPError if the response status is not okay.
HTTPError if the response status is not successful.
ValidationError if the response does not conform to the expected response structure.
Returns:
The job ID of the launched transfer task.
Expand Down Expand Up @@ -171,60 +176,28 @@ class TransferJob(BaseModel):
"""Model representing th response fields of an archive or retrieve reqeust to the Dataflow
API."""

job_id: int


class SubJob(BaseModel):
"""Model representing the response fields in a subjob returned in a get_job_status post."""

subjob_id: int
subjob_type: str
status: int
description: str
progress: float
total_rate: int
throughput: int
estimated_end: datetime
estimated_left: int
job_id: int = Field(alias="jobId")


class GetJobStatusResponse(BaseModel):
"""Model representing the response fields from a get_job_status post."""

request_date: datetime | None = None
operation: str | None = None
job_id: int
type: str | None = None
status: int | None = None
description: str
start_date: datetime | None = None
end_date: datetime | None = None
durationTime: int | None = None
priority: int | None = None
progress: float | None = None
subjobs: list[SubJob] | None = None
job_id: int = Field(alias="id")
status: str


class GetJobStatusPayload(BaseModel):
"""Model representing the payload for a get_job_status request."""

job_id: int
subjob_id: int | None = None
related_jobs: bool | None = None
main_subjob: bool | None = None
debug: bool | None = None
id: int

def post_request(self, url: str, headers: dict) -> GetJobStatusResponse:
"""Sends a request to the given url with the given headers, and its own content as
payload. Returns the job ID of the launched transfer task.
def get_job_status(self, url: str, headers: dict) -> GetJobStatusResponse:
"""Sends a get request to the given URL with the given headers.
Returns the parsed status response of the task specified in the URL.
Raises:
HTTPError if the response code is not ok.
"""
HTTPError if the response code is not successful."""
response: Response = APIRequest.api_request_from_content(
api_method=APIMethods.POST,
url=url,
headers=headers,
json=self.model_dump(),
api_method=APIMethods.GET, url=url, headers=headers, json={}
)
response.raise_for_status()
return GetJobStatusResponse.model_validate(response.json())
Expand Down Expand Up @@ -355,15 +328,15 @@ def convert_into_transfer_data(
]

def is_job_done(self, job_id: int) -> bool:
get_job_status_payload = GetJobStatusPayload(job_id=job_id)
get_job_status_response: GetJobStatusResponse = get_job_status_payload.post_request(
url=urljoin(self.url, DataflowEndpoints.GET_JOB_STATUS),
get_job_status_payload = GetJobStatusPayload(id=job_id)
get_job_status_response: GetJobStatusResponse = get_job_status_payload.get_job_status(
url=urljoin(self.url, DataflowEndpoints.GET_JOB_STATUS + str(job_id)),
headers=dict(self.headers, **self.auth_header),
)
if get_job_status_response.description == JobDescription.COMPLETED:
if get_job_status_response.status == JobStatus.COMPLETED:
return True
LOG.info(
f"Job with id {job_id} has not been completed. "
f"Current job description is {get_job_status_response.description}"
f"Current job description is {get_job_status_response.status}"
)
return False
21 changes: 7 additions & 14 deletions tests/meta/archive/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def ddn_dataflow_config(


@pytest.fixture
def ok_ddn_response(ok_response: Response):
ok_response._content = b'{"job_id": "123"}'
def ok_miria_response(ok_response: Response):
ok_response._content = b'{"jobId": "123"}'
return ok_response


@pytest.fixture
def ok_ddn_job_status_response(ok_response: Response):
ok_response._content = b'{"job_id": "123", "description": "Completed"}'
def ok_miria_job_status_response(ok_response: Response):
ok_response._content = b'{"jobId": "123", "status": "Completed"}'
return ok_response


Expand All @@ -69,15 +69,7 @@ def archive_request_json(
}
],
"metadataList": [],
}


@pytest.fixture
def get_job_status_request_json(
remote_storage_repository: str, local_storage_repository: str, trimmed_local_path: str
) -> dict:
return {
"job_id": 123,
"settings": [],
}


Expand All @@ -97,6 +89,7 @@ def retrieve_request_json(
}
],
"metadataList": [],
"settings": [],
}


Expand All @@ -110,7 +103,7 @@ def header_with_test_auth_token() -> dict:


@pytest.fixture
def ddn_auth_token_response(ok_response: Response):
def miria_auth_token_response(ok_response: Response):
ok_response._content = b'{"access": "test_auth_token", "expire":15, "test_refresh_token":""}'
return ok_response

Expand Down
32 changes: 16 additions & 16 deletions tests/meta/archive/test_archive_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DDNDataFlowClient,
GetJobStatusPayload,
GetJobStatusResponse,
JobDescription,
JobStatus,
MiriaObject,
)
from cg.meta.archive.models import ArchiveHandler, FileTransferData
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_call_corresponding_archiving_method(spring_archive_api: SpringArchiveAP
def test_archive_all_non_archived_spring_files(
spring_archive_api: SpringArchiveAPI,
caplog,
ok_ddn_response,
ok_miria_response,
archive_request_json,
header_with_test_auth_token,
test_auth_token: AuthToken,
Expand All @@ -196,7 +196,7 @@ def test_archive_all_non_archived_spring_files(
), mock.patch.object(
APIRequest,
"api_request_from_content",
return_value=ok_ddn_response,
return_value=ok_miria_response,
) as mock_request_submitter:
spring_archive_api.archive_all_non_archived_spring_files()

Expand All @@ -222,17 +222,17 @@ def test_archive_all_non_archived_spring_files(

@pytest.mark.parametrize(
"job_status, should_date_be_set",
[(JobDescription.COMPLETED, True), (JobDescription.RUNNING, False)],
[(JobStatus.COMPLETED, True), (JobStatus.RUNNING, False)],
)
def test_get_archival_status(
spring_archive_api: SpringArchiveAPI,
caplog,
ok_ddn_job_status_response,
ok_miria_job_status_response,
archive_request_json,
header_with_test_auth_token,
test_auth_token: AuthToken,
archival_job_id: int,
job_status: JobDescription,
job_status: JobStatus,
should_date_be_set: bool,
):
# GIVEN a file with an ongoing archival
Expand All @@ -247,11 +247,11 @@ def test_get_archival_status(
), mock.patch.object(
APIRequest,
"api_request_from_content",
return_value=ok_ddn_job_status_response,
return_value=ok_miria_job_status_response,
), mock.patch.object(
GetJobStatusPayload,
"post_request",
return_value=GetJobStatusResponse(job_id=archival_job_id, description=job_status),
"get_job_status",
return_value=GetJobStatusResponse(job_id=archival_job_id, status=job_status),
):
spring_archive_api.update_ongoing_task(
task_id=archival_job_id,
Expand All @@ -265,12 +265,12 @@ def test_get_archival_status(

@pytest.mark.parametrize(
"job_status, should_date_be_set",
[(JobDescription.COMPLETED, True), (JobDescription.RUNNING, False)],
[(JobStatus.COMPLETED, True), (JobStatus.RUNNING, False)],
)
def test_get_retrieval_status(
spring_archive_api: SpringArchiveAPI,
caplog,
ok_ddn_job_status_response,
ok_miria_job_status_response,
archive_request_json,
header_with_test_auth_token,
retrieval_job_id: int,
Expand All @@ -293,11 +293,11 @@ def test_get_retrieval_status(
), mock.patch.object(
APIRequest,
"api_request_from_content",
return_value=ok_ddn_job_status_response,
return_value=ok_miria_job_status_response,
), mock.patch.object(
GetJobStatusPayload,
"post_request",
return_value=GetJobStatusResponse(job_id=retrieval_job_id, description=job_status),
"get_job_status",
return_value=GetJobStatusResponse(job_id=retrieval_job_id, status=job_status),
):
spring_archive_api.update_ongoing_task(
task_id=retrieval_job_id,
Expand All @@ -312,7 +312,7 @@ def test_get_retrieval_status(
def test_retrieve_samples(
spring_archive_api: SpringArchiveAPI,
caplog,
ok_ddn_response,
ok_miria_response,
trimmed_local_path,
local_storage_repository,
retrieve_request_json,
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_retrieve_samples(
), mock.patch.object(MiriaObject, "trim_path", return_value=True), mock.patch.object(
APIRequest,
"api_request_from_content",
return_value=ok_ddn_response,
return_value=ok_miria_response,
) as mock_request_submitter:
spring_archive_api.retrieve_samples([sample_with_spring_file])

Expand Down
10 changes: 6 additions & 4 deletions tests/meta/archive/test_archiving.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_archive_folders(
local_storage_repository: str,
file_and_sample: FileAndSample,
trimmed_local_path: str,
ok_ddn_response: Response,
ok_miria_response,
):
"""Tests that the archiving function correctly formats the input and sends API request."""

Expand All @@ -281,7 +281,7 @@ def test_archive_folders(
with mock.patch.object(
APIRequest,
"api_request_from_content",
return_value=ok_ddn_response,
return_value=ok_miria_response,
) as mock_request_submitter:
job_id: int = ddn_dataflow_client.archive_files([file_and_sample])

Expand All @@ -303,6 +303,7 @@ def test_archive_folders(
"osType": OSTYPE,
"createFolder": False,
"metadataList": [],
"settings": [],
},
)

Expand All @@ -314,7 +315,7 @@ def test_retrieve_samples(
archive_store: Store,
trimmed_local_path: str,
sample_id: str,
ok_ddn_response: Response,
ok_miria_response,
):
"""Tests that the retrieve function correctly formats the input and sends API request."""

Expand All @@ -327,7 +328,7 @@ def test_retrieve_samples(

# WHEN running the retrieve method and providing a SampleAndDestination object
with mock.patch.object(
APIRequest, "api_request_from_content", return_value=ok_ddn_response
APIRequest, "api_request_from_content", return_value=ok_miria_response
) as mock_request_submitter:
job_id: int = ddn_dataflow_client.retrieve_samples([sample_and_destination])

Expand All @@ -349,6 +350,7 @@ def test_retrieve_samples(
"osType": OSTYPE,
"createFolder": False,
"metadataList": [],
"settings": [],
},
)

Expand Down
Loading