Skip to content

Commit 01c51b9

Browse files
authored
Offload literals (#2872) (#2950)
* wip - Implement offloading of literals * Fix use of metadata bucket prefix * Fix repeated use of uri * Add temporary representation for offloaded literal * Add one unit test * Add another test * Stylistic changes to the two tests * Add test for min offloading threshold set to 1MB * Pick a unique engine-dir for tests * s/new_outputs/literal_map_copy/ * Remove unused constant * Use output_prefix in definition of offloaded literals * Add initial version of pbhash.py * Add tests to verify that overriding the hash is carried over to offloaded literals * Add a few more tests * Always import ParamSpec from `typing_extensions` * Fix lint warnings * Set inferred_type using the task type interface * Add comment about offloaded literals files and how they are uploaded to the metadata bucket * Add offloading_enabled * Add more unit tests including a negative test * Fix bad merge * Incorporate feedback. * Fix image name (unrelated to this PR - just a nice-to-have to decrease flakiness) * Add `is_map_task` to `_dispatch_execute` --------- Signed-off-by: Eduardo Apolinario <[email protected]> Co-authored-by: Eduardo Apolinario <[email protected]>
1 parent 61c066c commit 01c51b9

File tree

9 files changed

+679
-18
lines changed

9 files changed

+679
-18
lines changed

flytekit/bin/entrypoint.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import traceback
1212
import warnings
1313
from sys import exit
14-
from typing import Callable, List, Optional
14+
from typing import Callable, Dict, List, Optional
1515

1616
import click
1717
from flyteidl.core import literals_pb2 as _literals_pb2
@@ -48,6 +48,7 @@
4848
from flytekit.models.core import identifier as _identifier
4949
from flytekit.tools.fast_registration import download_distribution as _download_distribution
5050
from flytekit.tools.module_loader import load_object_from_module
51+
from flytekit.utils.pbhash import compute_hash_string
5152

5253

5354
def get_version_message():
@@ -93,6 +94,7 @@ def _dispatch_execute(
9394
load_task: Callable[[], PythonTask],
9495
inputs_path: str,
9596
output_prefix: str,
97+
is_map_task: bool = False,
9698
):
9799
"""
98100
Dispatches execute to PythonTask
@@ -102,6 +104,12 @@ def _dispatch_execute(
102104
a: [Optional] Record outputs to output_prefix
103105
b: OR if IgnoreOutputs is raised, then ignore uploading outputs
104106
c: OR if an unhandled exception is retrieved - record it as an errors.pb
107+
108+
:param ctx: FlyteContext
109+
:param load_task: Callable[[], PythonTask]
110+
:param inputs: Where to read inputs
111+
:param output_prefix: Where to write primitive outputs
112+
:param is_map_task: Whether this task is executing as part of a map task
105113
"""
106114
output_file_dict = {}
107115

@@ -134,7 +142,59 @@ def _dispatch_execute(
134142
logger.warning("Task produces no outputs")
135143
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
136144
elif isinstance(outputs, _literal_models.LiteralMap):
137-
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
145+
# The keys in this map hold the filenames to the offloaded proto literals.
146+
offloaded_literals: Dict[str, _literal_models.Literal] = {}
147+
literal_map_copy = {}
148+
149+
offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None
150+
min_offloaded_size = -1
151+
max_offloaded_size = -1
152+
if offloading_enabled:
153+
min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024
154+
max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024
155+
156+
# Go over each output and create a separate offloaded in case its size is too large
157+
for k, v in outputs.literals.items():
158+
literal_map_copy[k] = v
159+
160+
if not offloading_enabled:
161+
continue
162+
163+
lit = v.to_flyte_idl()
164+
if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size:
165+
raise ValueError(
166+
f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes"
167+
)
168+
169+
if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size:
170+
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")
171+
inferred_type = task_def.interface.outputs[k].type
172+
173+
# In the case of map tasks we need to use the type of the collection as inferred type as the task
174+
# typed interface of the offloaded literal. This is done because the map task interface present in
175+
# the task template contains the (correct) type for the entire map task, not the single node execution.
176+
# For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal.
177+
if is_map_task:
178+
inferred_type = inferred_type.collection_type
179+
180+
# This file will hold the offloaded literal and will be written to the output prefix
181+
# alongside the regular outputs.pb, deck.pb, etc.
182+
# N.B.: by construction `offloaded_filename` is guaranteed to be unique
183+
offloaded_filename = f"{k}_offloaded_metadata.pb"
184+
offloaded_literal = _literal_models.Literal(
185+
offloaded_metadata=_literal_models.LiteralOffloadedMetadata(
186+
uri=f"{output_prefix}/{offloaded_filename}",
187+
size_bytes=lit.ByteSize(),
188+
# TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged
189+
inferred_type=inferred_type,
190+
),
191+
hash=v.hash if v.hash is not None else compute_hash_string(lit),
192+
)
193+
literal_map_copy[k] = offloaded_literal
194+
offloaded_literals[offloaded_filename] = v
195+
outputs = _literal_models.LiteralMap(literals=literal_map_copy)
196+
197+
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals}
138198
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
139199
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
140200
else:
@@ -500,7 +560,7 @@ def load_task():
500560
)
501561
return
502562

503-
_dispatch_execute(ctx, load_task, inputs, output_prefix)
563+
_dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True)
504564

505565

506566
def normalize_inputs(

flytekit/core/task.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
from functools import update_wrapper
66
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload
77

8-
from flytekit.core.utils import str2bool
9-
10-
try:
11-
from typing import ParamSpec
12-
except ImportError:
13-
from typing_extensions import ParamSpec # type: ignore
8+
from typing_extensions import ParamSpec # type: ignore
149

1510
from flytekit.core import launch_plan as _annotated_launchplan
1611
from flytekit.core import workflow as _annotated_workflow
@@ -20,6 +15,7 @@
2015
from flytekit.core.python_function_task import PythonFunctionTask
2116
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
2217
from flytekit.core.resources import Resources
18+
from flytekit.core.utils import str2bool
2319
from flytekit.deck import DeckField
2420
from flytekit.extras.accelerators import BaseAccelerator
2521
from flytekit.image_spec.image_spec import ImageSpec

flytekit/core/workflow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
from functools import update_wrapper
99
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload
1010

11+
from typing_extensions import ParamSpec # type: ignore
1112
from typing_inspect import is_optional_type
1213

13-
try:
14-
from typing import ParamSpec
15-
except ImportError:
16-
from typing_extensions import ParamSpec # type: ignore
17-
1814
from flytekit.core import constants as _common_constants
1915
from flytekit.core import launch_plan as _annotated_launch_plan
2016
from flytekit.core.base_task import PythonTask, Task

flytekit/interaction/string_literals.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any:
6161
return [literal_string_repr(i) for i in lit.collection.literals]
6262
if lit.map:
6363
return {k: literal_string_repr(v) for k, v in lit.map.literals.items()}
64+
if lit.offloaded_metadata:
65+
# TODO: load literal from offloaded literal?
66+
return f"Offloaded literal metadata: {lit.offloaded_metadata}"
6467
raise ValueError(f"Unknown literal type {lit}")
6568

6669

flytekit/models/literals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def to_flyte_idl(self):
991991
map=self.map.to_flyte_idl() if self.map is not None else None,
992992
hash=self.hash,
993993
metadata=self.metadata,
994-
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
994+
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None,
995995
)
996996

997997
@classmethod

flytekit/utils/pbhash.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# This is a module that provides hashing utilities for Protobuf objects.
2+
import base64
3+
import hashlib
4+
import json
5+
6+
from google.protobuf import json_format
7+
from google.protobuf.message import Message
8+
9+
10+
def compute_hash(pb: Message) -> bytes:
11+
"""
12+
Computes a deterministic hash in bytes for the Protobuf object.
13+
"""
14+
try:
15+
pb_dict = json_format.MessageToDict(pb)
16+
# json.dumps with sorted keys to ensure stability
17+
stable_json_str = json.dumps(
18+
pb_dict, sort_keys=True, separators=(",", ":")
19+
) # separators to ensure no extra spaces
20+
except Exception as e:
21+
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")
22+
23+
try:
24+
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
25+
# assuming it provides a consistent hash output.
26+
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
27+
except Exception as e:
28+
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")
29+
30+
# The digest is guaranteed to be 32 bytes long
31+
return hash_obj.digest()
32+
33+
34+
def compute_hash_string(pb: Message) -> str:
35+
"""
36+
Computes a deterministic hash in base64 encoded string for the Protobuf object
37+
"""
38+
hash_bytes = compute_hash(pb)
39+
return base64.b64encode(hash_bytes).decode("utf-8")

0 commit comments

Comments
 (0)