Skip to content

Commit

Permalink
[BUG] Blob uri isn't converted to str when source path is used…
Browse files Browse the repository at this point in the history
… as `uri` (#2881)

* fix: Convert blob uri to string type

Signed-off-by: JiaWei Jiang <[email protected]>

* fix: Make sure source path is of type str

Signed-off-by: JiaWei Jiang <[email protected]>

* Add test on source path with different types

Signed-off-by: JiaWei Jiang <[email protected]>

* Add test on source path with different types for FlyteDirectory

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove main function

Signed-off-by: JiaWei Jiang <[email protected]>

* update by han-ru

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* delete-files

Signed-off-by: Future-Outlier <[email protected]>

* Cleanup tmp files and add comments on fix

Signed-off-by: JiaWei Jiang <[email protected]>

---------

Signed-off-by: JiaWei Jiang <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
  • Loading branch information
JiangJiaWei1103 and Future-Outlier authored Nov 4, 2024
1 parent 804dae1 commit 7c08d50
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ async def async_to_literal(
meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type)))

if isinstance(python_val, FlyteFile):
source_path = python_val.path
# Cast the source path to str type to avoid error raised when the source path is used as the blob uri,
# please refer to this issue: https://github.com/flyteorg/flyte/issues/5872.
source_path = str(python_val.path)
self.validate_file_type(python_type, source_path)

# If the object has a remote source, then we just convert it back. This means that if someone is just
Expand Down
52 changes: 52 additions & 0 deletions tests/flytekit/unit/types/directory/test_dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path
from typing import Optional

import flytekit
from flytekit import task, workflow
from flytekit.types.directory import FlyteDirectory


def test_src_path_with_different_types() -> None:
N_FILES = 3

@task
def write_fidx_task(
use_str_src_path: bool, remote_dir: Optional[str] = None
) -> FlyteDirectory:
"""Write file indices to text files in a source path."""
source_path = Path(flytekit.current_context().working_directory) / "txt_files"
source_path.mkdir(exist_ok=True)

for file_idx in range(N_FILES):
file_path = source_path / f"{file_idx}.txt"
with file_path.open(mode="w") as f:
f.write(str(file_idx))

if use_str_src_path:
source_path = str(source_path)
fd = FlyteDirectory(path=source_path, remote_directory=remote_dir)

return fd

@workflow
def wf(use_str_src_path: bool, remote_dir: Optional[str] = None) -> FlyteDirectory:
return write_fidx_task(use_str_src_path=use_str_src_path, remote_dir=remote_dir)

def _verify_files(fd: FlyteDirectory) -> None:
for file_idx in range(N_FILES):
with open(fd / f"{file_idx}.txt", "r") as f:
assert f.read() == str(file_idx)

# Source path is of type str
ff_1 = wf(use_str_src_path=True, remote_dir=None)
_verify_files(ff_1)

ff_2 = wf(use_str_src_path=True, remote_dir="./my_txt_files")
_verify_files(ff_2)

# Source path is of type pathlib.PosixPath
ff_3 = wf(use_str_src_path=False, remote_dir=None)
_verify_files(ff_3)

ff_4 = wf(use_str_src_path=False, remote_dir="./my_txt_files2")
_verify_files(ff_4)
72 changes: 72 additions & 0 deletions tests/flytekit/unit/types/file/test_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tempfile
from pathlib import Path
from typing import Optional

import pytest
from flytekit import task, workflow
from flytekit.types.file import FlyteFile


@pytest.fixture
def local_tmp_txt_files():
# Create a source file
with tempfile.NamedTemporaryFile(delete=False, mode="w+", suffix=".txt") as src_file:
src_file.write("Hello World!")
src_file.flush()
src_path = src_file.name

# Create an empty file as the destination
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as dst_file:
dst_path = dst_file.name

yield src_path, dst_path

# Cleanup
Path(src_path).unlink(missing_ok=True)
Path(dst_path).unlink(missing_ok=True)


def test_src_path_with_different_types(local_tmp_txt_files) -> None:

@task
def create_flytefile(
source_path: str,
use_pathlike_src_path: bool,
remote_path: Optional[str] = None
) -> FlyteFile:
if use_pathlike_src_path:
source_path = Path(source_path)
ff = FlyteFile(path=source_path, remote_path=remote_path)

return ff

@workflow
def wf(
source_path: str,
use_pathlike_src_path: bool,
remote_path: Optional[str] = None
) -> FlyteFile:
return create_flytefile(
source_path=source_path, use_pathlike_src_path=use_pathlike_src_path, remote_path=remote_path
)

def _verify_msg(ff: FlyteFile) -> None:
with open(ff, "r") as f:
assert f.read() == "Hello World!"


source_path, remote_path = local_tmp_txt_files

# Source path is of type str
ff_1 = wf(source_path=source_path, use_pathlike_src_path=False, remote_path=None)
_verify_msg(ff_1)

ff_2 = wf(source_path=source_path, use_pathlike_src_path=False, remote_path=remote_path)
_verify_msg(ff_2)

# Source path is of type pathlib.PosixPath
ff_3 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=None)
_verify_msg(ff_3)

ff_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=remote_path)
_verify_msg(ff_4)

0 comments on commit 7c08d50

Please sign in to comment.