Skip to content

Commit acf5c3f

Browse files
committed
Add API for uploading/downloading sandboxes
1 parent 903ced6 commit acf5c3f

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/diracx/api/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
__all__ = ("jobs",)
4+
5+
from . import jobs

src/diracx/api/jobs.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
__all__ = ("create_sandbox", "download_sandbox")
4+
5+
import hashlib
6+
import logging
7+
import os
8+
import tarfile
9+
import tempfile
10+
from pathlib import Path
11+
12+
import httpx
13+
14+
from diracx.client.aio import DiracClient
15+
from diracx.client.models import SandboxInfo
16+
17+
logger = logging.getLogger(__name__)
18+
19+
SANDBOX_CHECKSUM_ALGORITHM = "sha256"
20+
SANDBOX_COMPRESSION = "bz2"
21+
22+
23+
async def create_sandbox(client: DiracClient, paths: list[Path]) -> str:
24+
"""Create a sandbox from the given paths and upload it to the storage backend.
25+
26+
Any paths that are directories will be added recursively.
27+
The returned value is the PFN of the sandbox in the storage backend and can
28+
be used to submit jobs.
29+
"""
30+
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
31+
with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf:
32+
for path in paths:
33+
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name)
34+
tf.add(path.resolve(), path.name, recursive=True)
35+
tar_fh.seek(0)
36+
37+
hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)()
38+
while data := tar_fh.read(512 * 1024):
39+
hasher.update(data)
40+
checksum = hasher.hexdigest()
41+
tar_fh.seek(0)
42+
logger.debug("Sandbox checksum is %s", checksum)
43+
44+
sandbox_info = SandboxInfo(
45+
checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM,
46+
checksum=checksum,
47+
size=os.stat(tar_fh.fileno()).st_size,
48+
format=f"tar.{SANDBOX_COMPRESSION}",
49+
)
50+
51+
res = await client.jobs.initiate_sandbox_upload(sandbox_info)
52+
if res.url:
53+
logger.debug("Uploading sandbox for %s", res.pfn)
54+
files = {"file": ("file", tar_fh)}
55+
response = httpx.post(res.url, data=res.fields, files=files)
56+
# TODO: Handle this error better
57+
response.raise_for_status()
58+
logger.debug(
59+
"Sandbox uploaded for %s with status code %s",
60+
res.pfn,
61+
response.status_code,
62+
)
63+
else:
64+
logger.debug("%s already exists in storage backend", res.pfn)
65+
return res.pfn
66+
67+
68+
async def download_sandbox(client: DiracClient, pfn: str, destination: Path):
69+
"""Download a sandbox from the storage backend to the given destination."""
70+
res = await client.jobs.get_sandbox_file(pfn)
71+
logger.debug("Downloading sandbox for %s", pfn)
72+
with tempfile.TemporaryFile(mode="w+b") as fh:
73+
async with httpx.AsyncClient() as http_client:
74+
response = await http_client.get(res.url)
75+
# TODO: Handle this error better
76+
response.raise_for_status()
77+
async for chunk in response.aiter_bytes():
78+
fh.write(chunk)
79+
logger.debug("Sandbox downloaded for %s", pfn)
80+
81+
with tarfile.open(fileobj=fh) as tf:
82+
tf.extractall(path=destination, filter="data")
83+
logger.debug("Extracted %s to %s", pfn, destination)

src/diracx/core/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class SandboxChecksum(StrEnum):
119119

120120

121121
class SandboxFormat(StrEnum):
122-
TAR_GZ = "tar.gz"
122+
TAR_BZ2 = "tar.bz2"
123123

124124

125125
class SandboxInfo(BaseModel):

0 commit comments

Comments
 (0)