Skip to content

Commit

Permalink
Restrict Python Version Mismatch between Pickled Object and Remote En…
Browse files Browse the repository at this point in the history
…vrionment (#2848)

* Restrict version mismatch

Signed-off-by: Mecoli1219 <[email protected]>

* Update unit test

Signed-off-by: Mecoli1219 <[email protected]>

* revert formatter change

Signed-off-by: Mecoli1219 <[email protected]>

* revert formatter change - test_remote

Signed-off-by: Mecoli1219 <[email protected]>

* Create dataclass definition for pickled object

Signed-off-by: Mecoli1219 <[email protected]>

* Update error message

Signed-off-by: Mecoli1219 <[email protected]>

---------

Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 authored Oct 30, 2024
1 parent bc0c162 commit 5bf0acd
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 19 deletions.
62 changes: 59 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import re
from abc import ABC
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2
Expand Down Expand Up @@ -282,6 +283,32 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
default_task_resolver = DefaultTaskResolver()


@dataclass
class PickledEntityMetadata:
"""
Metadata for a pickled entity containing version information.
Attributes:
python_version: The Python version string (e.g. "3.12.0") used to create the pickle
"""

python_version: str


@dataclass
class PickledEntity:
"""
Represents the structure of the pickled object stored in the .pkl file for interactive mode.
Attributes:
metadata: Metadata about the pickled entities including Python version
entities: Dictionary mapping entity names to their PythonAutoContainerTask instances
"""

metadata: PickledEntityMetadata
entities: Dict[str, PythonAutoContainerTask]


class DefaultNotebookTaskResolver(TrackedInstance, TaskResolverMixin):
"""
This resolved is used when the task is defined in a notebook. It is used to load the task from the notebook.
Expand All @@ -294,12 +321,41 @@ def name(self) -> str:
def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
_, entity_name, *_ = loader_args
import gzip
import sys

import cloudpickle

with gzip.open(PICKLE_FILE_PATH, "r") as f:
entity_dict = cloudpickle.load(f)
return entity_dict[entity_name]
try:
with gzip.open(PICKLE_FILE_PATH, "r") as f:
loaded_data = cloudpickle.load(f)
except TypeError:
raise RuntimeError(
"The Python version is different from the version used to create the pickle file. "
f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. "
"Please try using the same Python version to create the pickle file or use another "
"container image with a matching version."
)

# verify the loaded_data is of the correct type
if not isinstance(loaded_data, PickledEntity):
raise RuntimeError(
"The loaded data is not of the correct type. Please ensure that the pickle file is not corrupted."
)
pickled_object: PickledEntity = loaded_data

pickled_version = pickled_object.metadata.python_version.split(".")
if sys.version_info.major != int(pickled_version[0]) or sys.version_info.minor != int(pickled_version[1]):
raise RuntimeError(
"The Python version used to create the pickle file is different from the current Python version. "
f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. "
f"Python version used to create the pickle file: {pickled_object.metadata.python_version}. "
"Please try using the same Python version to create the pickle file or use another "
"container image with a matching version."
)

if entity_name not in pickled_object.entities:
raise ValueError(f"Entity {entity_name} not found in the pickled object")
return pickled_object.entities[entity_name]

def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore
n, _, _, _ = extract_task_module(task)
Expand Down
5 changes: 4 additions & 1 deletion flytekit/remote/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def outputs(self) -> Optional[LiteralsResolver]:
"Please wait until the execution has completed before requesting the outputs."
)
if self.error:
raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
raise user_exceptions.FlyteAssertion(
"Outputs could not be found because the execution ended in failure. Error message: "
f"{self.error.message}"
)

return self._outputs

Expand Down
19 changes: 14 additions & 5 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from flytekit.core.node import Node as CoreNode
from flytekit.core.python_auto_container import (
PICKLE_FILE_PATH,
PickledEntity,
PickledEntityMetadata,
PythonAutoContainerTask,
default_notebook_task_resolver,
)
Expand Down Expand Up @@ -202,14 +204,21 @@ def _get_git_repo_url(source_path: str):

def _get_pickled_target_dict(
root_entity: typing.Union[WorkflowBase, PythonTask],
) -> typing.Tuple[bytes, typing.Dict[str, PythonAutoContainerTask]]:
) -> typing.Tuple[bytes, PickledEntity]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: hashed bytes and the pickled target dictionary.
"""
import sys

queue: typing.List[typing.Union[WorkflowBase, PythonTask, CoreNode]] = [root_entity]
pickled_target_dict = {}
pickled_target_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
),
entities={},
)
while queue:
entity = queue.pop()
if isinstance(entity, PythonFunctionTask):
Expand All @@ -226,10 +235,10 @@ def _get_pickled_target_dict(
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
pickled_target_dict.entities[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
pickled_target_dict.entities[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
Expand Down Expand Up @@ -2655,7 +2664,7 @@ def download(
def _pickle_and_upload_entity(
self,
entity: typing.Union[PythonTask, WorkflowBase],
pickled_dict: typing.Optional[typing.Dict[str, PythonAutoContainerTask]] = None,
pickled_dict: typing.Optional[PickledEntity] = None,
) -> FastSerializationSettings:
"""
Pickle the entity to the specified location. This is useful for debugging and for sharing entities across
Expand Down
28 changes: 26 additions & 2 deletions tests/flytekit/unit/core/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import cloudpickle
import mock
import pytest
import sys

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PICKLE_FILE_PATH
from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PickledEntity, PickledEntityMetadata
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -123,10 +124,33 @@ def t1(a: str, b: str) -> str:

assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]

pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1}
pickled_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
),
entities={
"tests.flytekit.unit.core.test_resolver.t1": t1,
},
)

custom_pickled_object = cloudpickle.dumps(pickled_dict)
mock_gzip_open.return_value.read.return_value = custom_pickled_object
mock_cloudpickle.return_value = pickled_dict

t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"])
assert t == t1

mismatched_pickled_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor - 1}.{sys.version_info.micro}"
),
entities={
"tests.flytekit.unit.core.test_resolver.t1": t1,
},
)
mismatched_custom_pickled_object = cloudpickle.dumps(mismatched_pickled_dict)
mock_gzip_open.return_value.read.return_value = mismatched_custom_pickled_object
mock_cloudpickle.return_value = mismatched_pickled_dict

with pytest.raises(RuntimeError):
c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"])
25 changes: 17 additions & 8 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import shutil
import subprocess
import sys
import tempfile
import typing
import uuid
Expand Down Expand Up @@ -708,11 +709,15 @@ def w() -> int:
return t2(a=t1())

_, target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 2
assert t1.name in target_dict
assert t2.name in target_dict
assert target_dict[t1.name] == t1
assert target_dict[t2.name] == t2
assert (
target_dict.metadata.python_version
== f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
assert len(target_dict.entities) == 2
assert t1.name in target_dict.entities
assert t2.name in target_dict.entities
assert target_dict.entities[t1.name] == t1
assert target_dict.entities[t2.name] == t2

def test_get_pickled_target_dict_with_map_task():
@task
Expand All @@ -724,9 +729,13 @@ def w() -> int:
return map_task(partial(t1, y=2))(x=[1, 2, 3])

_, target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 1
assert t1.name in target_dict
assert target_dict[t1.name] == t1
assert (
target_dict.metadata.python_version
== f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
assert len(target_dict.entities) == 1
assert t1.name in target_dict.entities
assert target_dict.entities[t1.name] == t1

def test_get_pickled_target_dict_with_dynamic():
@task
Expand Down

0 comments on commit 5bf0acd

Please sign in to comment.