Skip to content

Commit

Permalink
Use one instantiation (#2850) (patch)
Browse files Browse the repository at this point in the history
### Fixed

- The archive handlers are instantiated once per batch instead of once per file.
  • Loading branch information
islean authored Jan 18, 2024
1 parent 214bdee commit b013069
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 38 deletions.
56 changes: 29 additions & 27 deletions cg/meta/archive/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def __init__(
self.data_flow_config: DataFlowConfig = data_flow_config

def archive_file_to_location(
self, file_and_sample: FileAndSample, archive_location: ArchiveLocations
self, file_and_sample: FileAndSample, archive_handler: ArchiveHandler
) -> int:
archive_handler: ArchiveHandler = ARCHIVE_HANDLERS[archive_location](self.data_flow_config)
return archive_handler.archive_file(file_and_sample=file_and_sample)

def archive_spring_files_and_add_archives_to_housekeeper(
Expand All @@ -67,21 +66,17 @@ def archive_files_to_location(self, archive_location: str, file_limit: int | Non
)
if files_to_archive:
files_and_samples_for_location = self.add_samples_to_files(files_to_archive)
archive_handler: ArchiveHandler = ARCHIVE_HANDLERS[archive_location](
self.data_flow_config
)
for file_and_sample in files_and_samples_for_location:
self.archive_file(
file_and_sample=file_and_sample, archive_location=archive_location
)
self.archive_file(file_and_sample=file_and_sample, archive_handler=archive_handler)

else:
LOG.info(f"No files to archive for location {archive_location}.")

def archive_file(
self, file_and_sample: FileAndSample, archive_location: ArchiveLocations
) -> None:
job_id: int = self.archive_file_to_location(
file_and_sample=file_and_sample, archive_location=archive_location
)
LOG.info(f"File submitted to {archive_location} with archival task id {job_id}.")
def archive_file(self, file_and_sample: FileAndSample, archive_handler: ArchiveHandler) -> None:
job_id: int = archive_handler.archive_file(file_and_sample)
self.housekeeper_api.add_archives(
files=[file_and_sample.file],
archive_task_id=job_id,
Expand Down Expand Up @@ -174,45 +169,52 @@ def update_ongoing_archivals(self) -> None:
ArchiveLocations, list[int]
] = self.sort_archival_ids_on_archive_location(ongoing_archivals)
for archive_location in ArchiveLocations:
self.update_archival_jobs_for_archive_location(
archive_location=archive_location,
job_ids=archival_ids_per_location.get(archive_location),
)
if archival_ids := archival_ids_per_location.get(archive_location):
archive_handler: ArchiveHandler = ARCHIVE_HANDLERS[archive_location](
self.data_flow_config
)
self.update_archival_jobs_for_archive_location(
archive_handler=archive_handler,
job_ids=archival_ids,
)

def update_ongoing_retrievals(self) -> None:
ongoing_retrievals: list[Archive] = self.housekeeper_api.get_ongoing_retrievals()
retrieval_ids_per_location: dict[
ArchiveLocations, list[int]
] = self.sort_retrieval_ids_on_archive_location(ongoing_retrievals)
for archive_location in ArchiveLocations:
self.update_retrieval_jobs_for_archive_location(
archive_location=archive_location,
job_ids=retrieval_ids_per_location.get(archive_location),
)
if retrieval_ids := retrieval_ids_per_location.get(archive_location):
archive_handler: ArchiveHandler = ARCHIVE_HANDLERS[archive_location](
self.data_flow_config
)
self.update_retrieval_jobs_for_archive_location(
archive_handler=archive_handler,
job_ids=retrieval_ids,
)

def update_archival_jobs_for_archive_location(
self, archive_location: ArchiveLocations, job_ids: list[int]
self, archive_handler: ArchiveHandler, job_ids: list[int]
) -> None:
for job_id in job_ids:
self.update_ongoing_task(
task_id=job_id, archive_location=archive_location, is_archival=True
task_id=job_id, archive_handler=archive_handler, is_archival=True
)

def update_retrieval_jobs_for_archive_location(
self, archive_location: ArchiveLocations, job_ids: list[int]
self, archive_handler: ArchiveHandler, job_ids: list[int]
) -> None:
for job_id in job_ids:
self.update_ongoing_task(
task_id=job_id, archive_location=archive_location, is_archival=False
task_id=job_id, archive_handler=archive_handler, is_archival=False
)

def update_ongoing_task(
self, task_id: int, archive_location: ArchiveLocations, is_archival: bool
self, task_id: int, archive_handler: ArchiveHandler, is_archival: bool
) -> None:
"""Fetches info on an ongoing job and updates the Archive entry in Housekeeper."""
archive_handler: ArchiveHandler = ARCHIVE_HANDLERS[archive_location](self.data_flow_config)
try:
LOG.info(f"Fetching status for job with id {task_id} from {archive_location}")
LOG.info(f"Fetching status for job with id {task_id}")
is_job_done: bool = archive_handler.is_job_done(task_id)
if is_job_done:
LOG.info(f"Job with id {task_id} has finished, updating Archive entries.")
Expand Down
4 changes: 3 additions & 1 deletion cg/meta/archive/ddn/ddn_data_flow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ def archive_file(self, file_and_sample: FileAndSample) -> int:
archival_request: TransferPayload = self.create_transfer_request(
miria_file_data=miria_file_data, is_archiving_request=True, metadata=metadata
)
return archival_request.post_request(
job_id: int = archival_request.post_request(
headers=dict(self.headers, **self.auth_header),
url=urljoin(base=self.url, url=DataflowEndpoints.ARCHIVE_FILES),
).job_id
LOG.info(f"File submitted to Miria with archival task id {job_id}.")
return job_id

def retrieve_files(self, files_and_samples: list[FileAndSample]) -> int:
"""Retrieves the provided files and stores them in the corresponding sample bundle in
Expand Down
10 changes: 5 additions & 5 deletions cg/meta/archive/ddn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
LOG = logging.getLogger(__name__)


def get_request_log(headers: dict, body: dict):
return "Sending request with headers: \n" + f"{headers} \n" + "and body: \n" + f"{body}"
def get_request_log(body: dict):
return "Sending request with body: \n" + f"{body}"


class MiriaObject(FileTransferData):
Expand Down Expand Up @@ -91,7 +91,7 @@ def post_request(self, url: str, headers: dict) -> "TransferResponse":
The job ID of the launched transfer task.
"""

LOG.info(get_request_log(headers=headers, body=self.model_dump()))
LOG.info(get_request_log(self.model_dump()))

response: Response = APIRequest.api_request_from_content(
api_method=APIMethods.POST,
Expand Down Expand Up @@ -153,7 +153,7 @@ def get_job_status(self, url: str, headers: dict) -> GetJobStatusResponse:
HTTPError if the response code is not ok.
"""

LOG.info(get_request_log(headers=headers, body=self.model_dump()))
LOG.info(get_request_log(self.model_dump()))

response: Response = APIRequest.api_request_from_content(
api_method=APIMethods.GET,
Expand All @@ -179,7 +179,7 @@ def delete_file(self, url: str, headers: dict) -> DeleteFileResponse:
Raises:
HTTPError if the response code is not ok.
"""
LOG.info(get_request_log(headers=headers, body=self.model_dump()))
LOG.info(get_request_log(self.model_dump()))

response: Response = APIRequest.api_request_from_content(
api_method=APIMethods.POST,
Expand Down
14 changes: 9 additions & 5 deletions tests/meta/archive/test_archive_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def test_convert_into_transfer_data(
assert isinstance(transferdata[0], MiriaObject)


def test_call_corresponding_archiving_method(spring_archive_api: SpringArchiveAPI, sample_id: str):
def test_call_corresponding_archiving_method(
spring_archive_api: SpringArchiveAPI, sample_id: str, ddn_dataflow_client: DDNDataFlowClient
):
"""Tests so that the correct archiving function is used when providing a Karolinska customer."""
# GIVEN a file to be transferred
# GIVEN a spring_archive_api with a mocked archive function
Expand All @@ -148,7 +150,7 @@ def test_call_corresponding_archiving_method(spring_archive_api: SpringArchiveAP
) as mock_request_submitter:
# WHEN calling the corresponding archive method
spring_archive_api.archive_file_to_location(
file_and_sample=file_and_sample, archive_location=ArchiveLocations.KAROLINSKA_BUCKET
file_and_sample=file_and_sample, archive_handler=ddn_dataflow_client
)

# THEN the correct archive function should have been called once
Expand Down Expand Up @@ -199,7 +201,7 @@ def test_archive_all_non_archived_spring_files(

# THEN all spring files for Karolinska should have an entry in the Archive table in Housekeeper while no other
# files should have an entry
files: list[File] = spring_archive_api.housekeeper_api.files()
files: list[File] = spring_archive_api.housekeeper_api.files().all()
for file in files:
if SequencingFileTag.SPRING in [tag.name for tag in file.tags]:
sample: Sample = spring_archive_api.status_db.get_sample_by_internal_id(
Expand All @@ -221,6 +223,7 @@ def test_archive_all_non_archived_spring_files(
)
def test_get_archival_status(
spring_archive_api: SpringArchiveAPI,
ddn_dataflow_client: DDNDataFlowClient,
caplog,
ok_miria_job_status_response,
archive_request_json,
Expand Down Expand Up @@ -250,7 +253,7 @@ def test_get_archival_status(
):
spring_archive_api.update_ongoing_task(
task_id=archival_job_id,
archive_location=ArchiveLocations.KAROLINSKA_BUCKET,
archive_handler=ddn_dataflow_client,
is_archival=True,
)

Expand All @@ -271,6 +274,7 @@ def test_get_archival_status(
)
def test_get_retrieval_status(
spring_archive_api: SpringArchiveAPI,
ddn_dataflow_client: DDNDataFlowClient,
caplog,
ok_miria_job_status_response,
archive_request_json,
Expand Down Expand Up @@ -307,7 +311,7 @@ def test_get_retrieval_status(
):
spring_archive_api.update_ongoing_task(
task_id=retrieval_job_id,
archive_location=ArchiveLocations.KAROLINSKA_BUCKET,
archive_handler=ddn_dataflow_client,
is_archival=False,
)

Expand Down

0 comments on commit b013069

Please sign in to comment.