|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +__all__ = ("create_sandbox", "download_sandbox") |
| 4 | + |
| 5 | +import hashlib |
| 6 | +import logging |
| 7 | +import os |
| 8 | +import tarfile |
| 9 | +import tempfile |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | +import httpx |
| 13 | + |
| 14 | +from diracx.client.aio import DiracClient |
| 15 | +from diracx.client.models import SandboxInfo |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | +SANDBOX_CHECKSUM_ALGORITHM = "sha256" |
| 20 | +SANDBOX_COMPRESSION = "bz2" |
| 21 | + |
| 22 | + |
| 23 | +async def create_sandbox(client: DiracClient, paths: list[Path]) -> str: |
| 24 | + """Create a sandbox from the given paths and upload it to the storage backend. |
| 25 | +
|
| 26 | + Any paths that are directories will be added recursively. |
| 27 | + The returned value is the PFN of the sandbox in the storage backend and can |
| 28 | + be used to submit jobs. |
| 29 | + """ |
| 30 | + with tempfile.TemporaryFile(mode="w+b") as tar_fh: |
| 31 | + with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf: |
| 32 | + for path in paths: |
| 33 | + logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name) |
| 34 | + tf.add(path.resolve(), path.name, recursive=True) |
| 35 | + tar_fh.seek(0) |
| 36 | + |
| 37 | + hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)() |
| 38 | + while data := tar_fh.read(512 * 1024): |
| 39 | + hasher.update(data) |
| 40 | + checksum = hasher.hexdigest() |
| 41 | + tar_fh.seek(0) |
| 42 | + logger.debug("Sandbox checksum is %s", checksum) |
| 43 | + |
| 44 | + sandbox_info = SandboxInfo( |
| 45 | + checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM, |
| 46 | + checksum=checksum, |
| 47 | + size=os.stat(tar_fh.fileno()).st_size, |
| 48 | + format=f"tar.{SANDBOX_COMPRESSION}", |
| 49 | + ) |
| 50 | + |
| 51 | + res = await client.jobs.initiate_sandbox_upload(sandbox_info) |
| 52 | + if res.url: |
| 53 | + logger.debug("Uploading sandbox for %s", res.pfn) |
| 54 | + files = {"file": ("file", tar_fh)} |
| 55 | + response = httpx.post(res.url, data=res.fields, files=files) |
| 56 | + # TODO: Handle this error better |
| 57 | + response.raise_for_status() |
| 58 | + logger.debug( |
| 59 | + "Sandbox uploaded for %s with status code %s", |
| 60 | + res.pfn, |
| 61 | + response.status_code, |
| 62 | + ) |
| 63 | + else: |
| 64 | + logger.debug("%s already exists in storage backend", res.pfn) |
| 65 | + return res.pfn |
| 66 | + |
| 67 | + |
| 68 | +async def download_sandbox(client: DiracClient, pfn: str, destination: Path): |
| 69 | + """Download a sandbox from the storage backend to the given destination.""" |
| 70 | + res = await client.jobs.get_sandbox_file(pfn) |
| 71 | + logger.debug("Downloading sandbox for %s", pfn) |
| 72 | + with tempfile.TemporaryFile(mode="w+b") as fh: |
| 73 | + async with httpx.AsyncClient() as http_client: |
| 74 | + response = await http_client.get(res.url) |
| 75 | + # TODO: Handle this error better |
| 76 | + response.raise_for_status() |
| 77 | + async for chunk in response.aiter_bytes(): |
| 78 | + fh.write(chunk) |
| 79 | + logger.debug("Sandbox downloaded for %s", pfn) |
| 80 | + |
| 81 | + with tarfile.open(fileobj=fh) as tf: |
| 82 | + tf.extractall(path=destination, filter="data") |
| 83 | + logger.debug("Extracted %s to %s", pfn, destination) |
0 commit comments