Skip to content

Commit

Permalink
Place input files in /opt/ml/input/data/{pk}-input and link to them…
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn authored Aug 12, 2024
1 parent fb08004 commit 837d45f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 5 deletions.
54 changes: 49 additions & 5 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _get_group_id(id_or_name: str) -> int | None:

def clean_path(path: Path) -> None:
for f in path.glob("*"):
if f.is_file():
if f.is_symlink() or f.is_file():
f.chmod(0o700)
f.unlink()
elif f.is_dir():
Expand Down Expand Up @@ -493,6 +493,19 @@ def input_path(self) -> Path:
logger.debug(f"{input_path=}")
return input_path

@property
def linked_input_path(self) -> Path:
"""Local path where the input files will be placed and linked to"""
linked_input_parent = Path(
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT",
"/opt/ml/input/data/",
)
)
linked_input_path = linked_input_parent / f"{self.pk}-input"
logger.debug(f"{linked_input_path=}")
return linked_input_path

@property
def output_path(self) -> Path:
"""Local path where the subprocess is expected to write its files"""
Expand Down Expand Up @@ -590,7 +603,7 @@ async def _invoke(self) -> InferenceResult:
logger.info(f"Invoking {self.pk=}")

try:
self.clean_io()
self.reset_io()

try:
self.download_input()
Expand All @@ -609,12 +622,43 @@ async def _invoke(self) -> InferenceResult:
pk=self.pk, return_code=return_code, outputs=outputs
)
finally:
self.clean_io()
self.reset_io()

def clean_io(self) -> None:
"""Clean all contents of input and output folders"""
def reset_io(self) -> None:
"""Resets the input and output directories"""
clean_path(path=self.input_path)
clean_path(path=self.output_path)
self.reset_linked_input()

def reset_linked_input(self) -> None:
"""Resets the symlink from the input to the linked directory"""
if (
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "True"
).lower()
== "true"
):
logger.info(
f"Setting up linked input from {self.input_path} "
f"to {self.linked_input_path}"
)

if self.input_path.exists():
if self.input_path.is_symlink():
self.input_path.unlink()
elif self.input_path.is_dir():
self.input_path.rmdir()

if self.linked_input_path.exists():
self.linked_input_path.rmdir()

self.linked_input_path.mkdir(parents=True)
self.linked_input_path.chmod(0o755)

self.input_path.symlink_to(
self.linked_input_path, target_is_directory=True
)
self.input_path.chmod(0o755)

def download_input(self) -> None:
"""Download all the inputs to the input path"""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_inference_from_task_list(
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -144,6 +145,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code):
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

definition_key = f"{uuid4()}/invocations.json"

Expand Down Expand Up @@ -186,6 +188,7 @@ def test_logging_setup(minio, monkeypatch):
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -216,6 +219,7 @@ def test_logging_stderr_setup(minio, monkeypatch):
),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down
1 change: 1 addition & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ async def test_inference_result_upload(
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

direct_invocation = await task.invoke()

Expand Down
47 changes: 47 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,50 @@ def test_ensure_directories_are_writable(tmp_path, monkeypatch):
assert model.stat().st_mode == 0o40777
assert checkpoints.stat().st_mode == 0o40777
assert tmp.stat().st_mode == 0o40777


def test_linked_input_path_default():
t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.linked_input_path == Path("/opt/ml/input/data/test-input")


def test_linked_input_path_setting(monkeypatch):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", "/foo")

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.linked_input_path == Path("/foo/test-input")


def test_reset_linked_input(tmp_path, monkeypatch):
input_path = tmp_path / "input"
linked_input_parent = tmp_path / "linked-input"

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_INPUT_PATH", input_path.absolute()
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", linked_input_parent
)

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)
t.reset_io()

expected_input_directory = linked_input_parent / "test-input"

assert input_path.exists()
assert input_path.is_symlink()
assert expected_input_directory.exists()
assert expected_input_directory.is_dir()
assert input_path.resolve(strict=True) == expected_input_directory

# Ensure 0o755 permissions
assert os.stat(input_path).st_mode & 0o777 == 0o755
assert os.stat(expected_input_directory).st_mode & 0o777 == 0o755

0 comments on commit 837d45f

Please sign in to comment.