diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index eb0aa5544d..ef76ee1642 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -489,7 +489,9 @@ async def async_to_literal( meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type))) if isinstance(python_val, FlyteFile): - source_path = python_val.path + # Cast the source path to str type to avoid error raised when the source path is used as the blob uri, + # please refer to this issue: https://github.com/flyteorg/flyte/issues/5872. + source_path = str(python_val.path) self.validate_file_type(python_type, source_path) # If the object has a remote source, then we just convert it back. This means that if someone is just diff --git a/tests/flytekit/unit/types/directory/test_dir.py b/tests/flytekit/unit/types/directory/test_dir.py new file mode 100644 index 0000000000..285162d1e9 --- /dev/null +++ b/tests/flytekit/unit/types/directory/test_dir.py @@ -0,0 +1,52 @@ +from pathlib import Path +from typing import Optional + +import flytekit +from flytekit import task, workflow +from flytekit.types.directory import FlyteDirectory + + +def test_src_path_with_different_types() -> None: + N_FILES = 3 + + @task + def write_fidx_task( + use_str_src_path: bool, remote_dir: Optional[str] = None + ) -> FlyteDirectory: + """Write file indices to text files in a source path.""" + source_path = Path(flytekit.current_context().working_directory) / "txt_files" + source_path.mkdir(exist_ok=True) + + for file_idx in range(N_FILES): + file_path = source_path / f"{file_idx}.txt" + with file_path.open(mode="w") as f: + f.write(str(file_idx)) + + if use_str_src_path: + source_path = str(source_path) + fd = FlyteDirectory(path=source_path, remote_directory=remote_dir) + + return fd + + @workflow + def wf(use_str_src_path: bool, remote_dir: Optional[str] = None) -> FlyteDirectory: + return write_fidx_task(use_str_src_path=use_str_src_path, remote_dir=remote_dir) + + def _verify_files(fd: FlyteDirectory) -> None: + for file_idx in range(N_FILES): + with open(fd / f"{file_idx}.txt", "r") as f: + assert f.read() == str(file_idx) + + # Source path is of type str + ff_1 = wf(use_str_src_path=True, remote_dir=None) + _verify_files(ff_1) + + ff_2 = wf(use_str_src_path=True, remote_dir="./my_txt_files") + _verify_files(ff_2) + + # Source path is of type pathlib.PosixPath + ff_3 = wf(use_str_src_path=False, remote_dir=None) + _verify_files(ff_3) + + ff_4 = wf(use_str_src_path=False, remote_dir="./my_txt_files2") + _verify_files(ff_4) diff --git a/tests/flytekit/unit/types/file/test_file.py b/tests/flytekit/unit/types/file/test_file.py new file mode 100644 index 0000000000..5187aa061a --- /dev/null +++ b/tests/flytekit/unit/types/file/test_file.py @@ -0,0 +1,72 @@ +import tempfile +from pathlib import Path +from typing import Optional + +import pytest +from flytekit import task, workflow +from flytekit.types.file import FlyteFile + + +@pytest.fixture +def local_tmp_txt_files(): + # Create a source file + with tempfile.NamedTemporaryFile(delete=False, mode="w+", suffix=".txt") as src_file: + src_file.write("Hello World!") + src_file.flush() + src_path = src_file.name + + # Create an empty file as the destination + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as dst_file: + dst_path = dst_file.name + + yield src_path, dst_path + + # Cleanup + Path(src_path).unlink(missing_ok=True) + Path(dst_path).unlink(missing_ok=True) + + +def test_src_path_with_different_types(local_tmp_txt_files) -> None: + + @task + def create_flytefile( + source_path: str, + use_pathlike_src_path: bool, + remote_path: Optional[str] = None + ) -> FlyteFile: + if use_pathlike_src_path: + source_path = Path(source_path) + ff = FlyteFile(path=source_path, remote_path=remote_path) + + return ff + + @workflow + def wf( + source_path: str, + use_pathlike_src_path: bool, + remote_path: Optional[str] = None + ) -> FlyteFile: + return create_flytefile( + source_path=source_path, use_pathlike_src_path=use_pathlike_src_path, remote_path=remote_path + ) + + def _verify_msg(ff: FlyteFile) -> None: + with open(ff, "r") as f: + assert f.read() == "Hello World!" + + + source_path, remote_path = local_tmp_txt_files + + # Source path is of type str + ff_1 = wf(source_path=source_path, use_pathlike_src_path=False, remote_path=None) + _verify_msg(ff_1) + + ff_2 = wf(source_path=source_path, use_pathlike_src_path=False, remote_path=remote_path) + _verify_msg(ff_2) + + # Source path is of type pathlib.PosixPath + ff_3 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=None) + _verify_msg(ff_3) + + ff_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=remote_path) + _verify_msg(ff_4)