-
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #494 from c-bata/artifact-prefixer
Add AppendPrefix middleware for ArtifactBackend
- Loading branch information
Showing
3 changed files
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TYPE_CHECKING | ||
|
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
if TYPE_CHECKING: | ||
from typing import BinaryIO | ||
|
||
from optuna_dashboard.artifact.protocol import ArtifactBackend | ||
|
||
|
||
class AppendPrefix: | ||
"""An artifact backend middleware that appends a prefix string to artifact ids. | ||
Example: | ||
.. code-block:: python | ||
import optuna | ||
from optuna_dashboard.artifact import upload_artifact | ||
from optuna_dashboard.artifact.boto3 import Boto3Backend | ||
from optuna_dashboard.artifact.prefix import AppendPrefix | ||
artifact_backend = AppendPrefix( | ||
Boto3Backend("my-bucket"), | ||
prefix="my-folder/" | ||
) | ||
def objective(trial: optuna.Trial) -> float: | ||
... = trial.suggest_float("x", -10, 10) | ||
file_path = generate_example_png(...) | ||
upload_artifact(artifact_backend, trial, file_path) | ||
return ... | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backend: ArtifactBackend, | ||
prefix: str, | ||
) -> None: | ||
self._backend = backend | ||
self._prefix = prefix | ||
|
||
def _with_prefix(self, artifact_id: str) -> str: | ||
return self._prefix + artifact_id | ||
|
||
def open(self, artifact_id: str) -> BinaryIO: | ||
return self._backend.open(self._with_prefix(artifact_id)) | ||
|
||
def write(self, artifact_id: str, content_body: BinaryIO) -> None: | ||
return self._backend.write(self._with_prefix(artifact_id), content_body) | ||
|
||
def remove(self, artifact_id: str) -> None: | ||
return self._backend.remove(self._with_prefix(artifact_id)) | ||
|
||
|
||
if TYPE_CHECKING: | ||
# A mypy-runtime assertion to ensure that SCSBackend | ||
# implements all abstract methods in ArtifactBackendProtocol. | ||
from optuna_dashboard.artifact.file_system import FileSystemBackend | ||
|
||
_: ArtifactBackend = AppendPrefix(FileSystemBackend("."), "prefix-") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import io | ||
import uuid | ||
|
||
from optuna_dashboard.artifact.prefix import AppendPrefix | ||
|
||
from .stubs import InMemoryBackend | ||
|
||
|
||
def test_read_and_write() -> None: | ||
artifact_id = str(uuid.uuid4()) | ||
dummy_content = b"Hello World" | ||
|
||
backend = AppendPrefix(backend=InMemoryBackend(), prefix="my-") | ||
backend.write(artifact_id, io.BytesIO(dummy_content)) | ||
with backend.open(artifact_id) as f: | ||
actual = f.read() | ||
assert actual == dummy_content | ||
backend.remove(artifact_id) | ||
|
||
|
||
def test_check_prefix() -> None: | ||
artifact_id = str(uuid.uuid4()) | ||
dummy_content = b"Hello World" | ||
|
||
in_memory = InMemoryBackend() | ||
backend = AppendPrefix(backend=in_memory, prefix="my-") | ||
backend.write(artifact_id, io.BytesIO(dummy_content)) | ||
|
||
assert len(in_memory._data) == 1 | ||
key = list(in_memory._data.keys())[0] | ||
assert key.startswith("my-") |