From 5bf0acd22623f02293a75d41952dd4ec26e025ae Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:51:40 -0700 Subject: [PATCH] Restrict Python Version Mismatch between Pickled Object and Remote Envrionment (#2848) * Restrict version mismatch Signed-off-by: Mecoli1219 * Update unit test Signed-off-by: Mecoli1219 * revert formatter change Signed-off-by: Mecoli1219 * revert formatter change - test_remote Signed-off-by: Mecoli1219 * Create dataclass definition for pickled object Signed-off-by: Mecoli1219 * Update error message Signed-off-by: Mecoli1219 --------- Signed-off-by: Mecoli1219 --- flytekit/core/python_auto_container.py | 62 +++++++++++++++++++++-- flytekit/remote/executions.py | 5 +- flytekit/remote/remote.py | 19 +++++-- tests/flytekit/unit/core/test_resolver.py | 28 +++++++++- tests/flytekit/unit/remote/test_remote.py | 25 ++++++--- 5 files changed, 120 insertions(+), 19 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 1466c351ac..f0163cdbf1 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,6 +3,7 @@ import importlib import re from abc import ABC +from dataclasses import dataclass from typing import Callable, Dict, List, Optional, TypeVar, Union from flyteidl.core import tasks_pb2 @@ -282,6 +283,32 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore default_task_resolver = DefaultTaskResolver() +@dataclass +class PickledEntityMetadata: + """ + Metadata for a pickled entity containing version information. + + Attributes: + python_version: The Python version string (e.g. "3.12.0") used to create the pickle + """ + + python_version: str + + +@dataclass +class PickledEntity: + """ + Represents the structure of the pickled object stored in the .pkl file for interactive mode. + + Attributes: + metadata: Metadata about the pickled entities including Python version + entities: Dictionary mapping entity names to their PythonAutoContainerTask instances + """ + + metadata: PickledEntityMetadata + entities: Dict[str, PythonAutoContainerTask] + + class DefaultNotebookTaskResolver(TrackedInstance, TaskResolverMixin): """ This resolved is used when the task is defined in a notebook. It is used to load the task from the notebook. @@ -294,12 +321,41 @@ def name(self) -> str: def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, entity_name, *_ = loader_args import gzip + import sys import cloudpickle - with gzip.open(PICKLE_FILE_PATH, "r") as f: - entity_dict = cloudpickle.load(f) - return entity_dict[entity_name] + try: + with gzip.open(PICKLE_FILE_PATH, "r") as f: + loaded_data = cloudpickle.load(f) + except TypeError: + raise RuntimeError( + "The Python version is different from the version used to create the pickle file. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + + # verify the loaded_data is of the correct type + if not isinstance(loaded_data, PickledEntity): + raise RuntimeError( + "The loaded data is not of the correct type. Please ensure that the pickle file is not corrupted." + ) + pickled_object: PickledEntity = loaded_data + + pickled_version = pickled_object.metadata.python_version.split(".") + if sys.version_info.major != int(pickled_version[0]) or sys.version_info.minor != int(pickled_version[1]): + raise RuntimeError( + "The Python version used to create the pickle file is different from the current Python version. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + f"Python version used to create the pickle file: {pickled_object.metadata.python_version}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + + if entity_name not in pickled_object.entities: + raise ValueError(f"Entity {entity_name} not found in the pickled object") + return pickled_object.entities[entity_name] def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore n, _, _, _ = extract_task_module(task) diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 530582975a..5095504784 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -44,7 +44,10 @@ def outputs(self) -> Optional[LiteralsResolver]: "Please wait until the execution has completed before requesting the outputs." ) if self.error: - raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + raise user_exceptions.FlyteAssertion( + "Outputs could not be found because the execution ended in failure. Error message: " + f"{self.error.message}" + ) return self._outputs diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 33771eb578..7eda76ddfa 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -48,6 +48,8 @@ from flytekit.core.node import Node as CoreNode from flytekit.core.python_auto_container import ( PICKLE_FILE_PATH, + PickledEntity, + PickledEntityMetadata, PythonAutoContainerTask, default_notebook_task_resolver, ) @@ -202,14 +204,21 @@ def _get_git_repo_url(source_path: str): def _get_pickled_target_dict( root_entity: typing.Union[WorkflowBase, PythonTask], -) -> typing.Tuple[bytes, typing.Dict[str, PythonAutoContainerTask]]: +) -> typing.Tuple[bytes, PickledEntity]: """ Get the pickled target dictionary for the entity. :param root_entity: The entity to get the pickled target for. :return: hashed bytes and the pickled target dictionary. """ + import sys + queue: typing.List[typing.Union[WorkflowBase, PythonTask, CoreNode]] = [root_entity] - pickled_target_dict = {} + pickled_target_dict = PickledEntity( + metadata=PickledEntityMetadata( + python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ), + entities={}, + ) while queue: entity = queue.pop() if isinstance(entity, PythonFunctionTask): @@ -226,10 +235,10 @@ def _get_pickled_target_dict( if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): if isinstance(entity, ArrayNodeMapTask): entity._run_task.set_resolver(default_notebook_task_resolver) - pickled_target_dict[entity._run_task.name] = entity._run_task + pickled_target_dict.entities[entity._run_task.name] = entity._run_task else: entity.set_resolver(default_notebook_task_resolver) - pickled_target_dict[entity.name] = entity + pickled_target_dict.entities[entity.name] = entity elif isinstance(entity, WorkflowBase): for task in entity.nodes: queue.append(task) @@ -2655,7 +2664,7 @@ def download( def _pickle_and_upload_entity( self, entity: typing.Union[PythonTask, WorkflowBase], - pickled_dict: typing.Optional[typing.Dict[str, PythonAutoContainerTask]] = None, + pickled_dict: typing.Optional[PickledEntity] = None, ) -> FastSerializationSettings: """ Pickle the entity to the specified location. This is useful for debugging and for sharing entities across diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 116b1251ae..af04db2de7 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -4,12 +4,13 @@ import cloudpickle import mock import pytest +import sys import flytekit.configuration from flytekit.configuration import Image, ImageConfig from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PICKLE_FILE_PATH +from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PickledEntity, PickledEntityMetadata from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -123,10 +124,33 @@ def t1(a: str, b: str) -> str: assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"] - pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1} + pickled_dict = PickledEntity( + metadata=PickledEntityMetadata( + python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ), + entities={ + "tests.flytekit.unit.core.test_resolver.t1": t1, + }, + ) + custom_pickled_object = cloudpickle.dumps(pickled_dict) mock_gzip_open.return_value.read.return_value = custom_pickled_object mock_cloudpickle.return_value = pickled_dict t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) assert t == t1 + + mismatched_pickled_dict = PickledEntity( + metadata=PickledEntityMetadata( + python_version=f"{sys.version_info.major}.{sys.version_info.minor - 1}.{sys.version_info.micro}" + ), + entities={ + "tests.flytekit.unit.core.test_resolver.t1": t1, + }, + ) + mismatched_custom_pickled_object = cloudpickle.dumps(mismatched_pickled_dict) + mock_gzip_open.return_value.read.return_value = mismatched_custom_pickled_object + mock_cloudpickle.return_value = mismatched_pickled_dict + + with pytest.raises(RuntimeError): + c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index abb49f3317..bfc925aff8 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -2,6 +2,7 @@ import pathlib import shutil import subprocess +import sys import tempfile import typing import uuid @@ -708,11 +709,15 @@ def w() -> int: return t2(a=t1()) _, target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 2 - assert t1.name in target_dict - assert t2.name in target_dict - assert target_dict[t1.name] == t1 - assert target_dict[t2.name] == t2 + assert ( + target_dict.metadata.python_version + == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + assert len(target_dict.entities) == 2 + assert t1.name in target_dict.entities + assert t2.name in target_dict.entities + assert target_dict.entities[t1.name] == t1 + assert target_dict.entities[t2.name] == t2 def test_get_pickled_target_dict_with_map_task(): @task @@ -724,9 +729,13 @@ def w() -> int: return map_task(partial(t1, y=2))(x=[1, 2, 3]) _, target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 1 - assert t1.name in target_dict - assert target_dict[t1.name] == t1 + assert ( + target_dict.metadata.python_version + == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + assert len(target_dict.entities) == 1 + assert t1.name in target_dict.entities + assert target_dict.entities[t1.name] == t1 def test_get_pickled_target_dict_with_dynamic(): @task