diff --git a/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java b/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java index 5efe2b85..e4871c66 100644 --- a/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java @@ -43,4 +43,8 @@ public class AgentConfigOptions { /** The config parameter specifies the replication factor for the Kafka action state topic. */ public static final ConfigOption KAFKA_ACTION_STATE_TOPIC_REPLICATION_FACTOR = new ConfigOption<>("kafkaActionStateTopicReplicationFactor", Integer.class, 1); + + /** The config parameter specifies the unique identifier of job. */ + public static final ConfigOption JOB_IDENTIFIER = + new ConfigOption<>("job-identifier", String.class, null); } diff --git a/python/flink_agents/api/core_options.py b/python/flink_agents/api/core_options.py index 5599e24c..d9ee4561 100644 --- a/python/flink_agents/api/core_options.py +++ b/python/flink_agents/api/core_options.py @@ -49,7 +49,10 @@ def covert_j_option_to_python_option(j_option: Any) -> ConfigOption: class AgentConfigOptionsMeta(type): """Metaclass for FlinkAgentsCoreOptions.""" - def __init__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> None: + + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: """Initialize the metaclass for FlinkAgentsCoreOptions.""" super().__init__(name, bases, attrs) @@ -68,3 +71,9 @@ def __getattr__(cls, item: str) -> ConfigOption: class AgentConfigOptions(metaclass=AgentConfigOptionsMeta): """CoreOptions to manage core configuration parameters for Flink Agents.""" + + JOB_IDENTIFIER = ConfigOption( + key="job-identifier", + config_type=str, + default=None, + ) diff --git a/python/flink_agents/api/memory/long_term_memory.py b/python/flink_agents/api/memory/long_term_memory.py index 6373499d..7d611822 100644 --- a/python/flink_agents/api/memory/long_term_memory.py +++ b/python/flink_agents/api/memory/long_term_memory.py @@ -30,6 +30,7 @@ from typing_extensions import override from flink_agents.api.chat_message import ChatMessage +from flink_agents.api.configuration import ConfigOption from flink_agents.api.prompts.prompt import Prompt ItemType = str | ChatMessage @@ -76,6 +77,28 @@ class LongTermMemoryBackend(Enum): EXTERNAL_VECTOR_STORE = "external_vector_store" +class LongTermMemoryOptions: + """Config options for ReActAgent.""" + + BACKEND = ConfigOption( + key="long-term-memory.", + config_type=LongTermMemoryBackend, + default=None, + ) + + EXTERNAL_VECTOR_STORE_NAME = ConfigOption( + key="long-term-memory.external-vector-store-name", + config_type=str, + default=None, + ) + + ASYNC_COMPACTION = ConfigOption( + key="long-term-memory.async-compaction", + config_type=bool, + default=False, + ) + + class DatetimeRange(BaseModel): """Represents a datetime range.""" @@ -159,7 +182,7 @@ def size(self) -> int: def add( self, items: ItemType | List[ItemType], ids: str | List[str] | None = None - ) -> None: + ) -> List[str]: """Add a memory item to the set, currently only support item with type str or ChatMessage. @@ -169,8 +192,11 @@ def add( Args: items: The items to be inserted to this set. ids: The ids of the items to be inserted. Optional. + + Returns: + The IDs of the items added. """ - self.ltm.add(memory_set=self, memory_items=items, ids=ids) + return self.ltm.add(memory_set=self, memory_items=items, ids=ids) def get( self, ids: str | List[str] | None = None @@ -203,7 +229,7 @@ class BaseLongTermMemory(ABC, BaseModel): def get_or_create_memory_set( self, name: str, - item_type: str | Type[ChatMessage], + item_type: type[str] | Type[ChatMessage], capacity: int, compaction_strategy: CompactionStrategy, ) -> MemorySet: @@ -257,7 +283,7 @@ def add( memory_items: ItemType | List[ItemType], ids: str | List[str] | None = None, metadatas: Dict[str, Any] | List[Dict[str, Any]] | None = None, - ) -> None: + ) -> List[str]: """Add items to the memory set, currently only support items with type str or ChatMessage. @@ -269,6 +295,9 @@ def add( ids: The IDs of items. Will be automatically generated if not provided. Optional. metadatas: The metadata for items. Optional. + + Returns: + The IDs of added items. """ @abstractmethod diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 8f3f7a75..4def0ac7 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -16,10 +16,11 @@ # limitations under the License. ################################################################################# from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict from flink_agents.api.configuration import ReadableConfiguration from flink_agents.api.events.event import Event +from flink_agents.api.memory.long_term_memory import BaseLongTermMemory from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType @@ -107,6 +108,17 @@ def short_term_memory(self) -> "MemoryObject": The root object of the short-term memory. """ + @property + @abstractmethod + def long_term_memory(self) -> BaseLongTermMemory: + """Get the long-term memory. + + Returns: + ------- + BaseLongTermMemory + The long-term memory instance. + """ + @property @abstractmethod def agent_metric_group(self) -> MetricGroup: @@ -133,8 +145,8 @@ def action_metric_group(self) -> MetricGroup: def execute_async( self, func: Callable[[Any], Any], - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> Any: """Asynchronously execute the provided function. Access to memory is prohibited within the function. @@ -143,9 +155,9 @@ def execute_async( ---------- func : Callable The function need to be asynchronously processing. - *args : tuple + *args : Any Positional arguments to pass to the function. - **kwargs : dict + **kwargs : Any Keyword arguments to pass to the function. Returns: diff --git a/python/flink_agents/e2e_tests/long_term_memory_test.py b/python/flink_agents/e2e_tests/long_term_memory_test.py new file mode 100644 index 00000000..5a24181b --- /dev/null +++ b/python/flink_agents/e2e_tests/long_term_memory_test.py @@ -0,0 +1,302 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import sysconfig +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, List + +from pydantic import BaseModel +from pyflink.common import Encoder, Types, WatermarkStrategy +from pyflink.datastream import ( + KeySelector, + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.agent import Agent +from flink_agents.api.core_options import AgentConfigOptions +from flink_agents.api.decorators import ( + action, + chat_model_connection, + chat_model_setup, + embedding_model_connection, + embedding_model_setup, + vector_store, +) +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.memory.long_term_memory import ( + LongTermMemoryBackend, + LongTermMemoryOptions, + MemorySetItem, + SummarizationStrategy, +) +from flink_agents.api.resource import ResourceDescriptor +from flink_agents.api.runner_context import RunnerContext +from flink_agents.e2e_tests.test_utils import pull_model +from flink_agents.integrations.chat_models.ollama_chat_model import ( + OllamaChatModelConnection, + OllamaChatModelSetup, +) +from flink_agents.integrations.embedding_models.local.ollama_embedding_model import ( + OllamaEmbeddingModelConnection, + OllamaEmbeddingModelSetup, +) +from flink_agents.integrations.vector_stores.chroma.chroma_vector_store import ( + ChromaVectorStore, +) + +current_dir = Path(__file__).parent + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + +chromadb_path = tempfile.mkdtemp() + +OLLAMA_CHAT_MODEL = "qwen3:8b" +OLLAMA_EMBEDDING_MODEL = "nomic-embed-text" +pull_model(OLLAMA_CHAT_MODEL) +pull_model(OLLAMA_EMBEDDING_MODEL) + + +class ItemData(BaseModel): + """Data model for storing item information. + + Attributes: + ---------- + id : int + Unique identifier of the item + review : str + The user review of the item + review_score: float + The review_score of the item + """ + + id: int + review: str + review_score: float + memory_info: dict | None = None + + +class Record(BaseModel): # noqa: D101 + id: int + count: int + timestamp_before_add: str + timestamp_after_add: str + timestamp_second_action: str | None = None + items: List[MemorySetItem] | None = None + + +class MyEvent(Event): # noqa D101 + value: Any + + +class MyKeySelector(KeySelector): + """KeySelector for extracting key.""" + + def get_key(self, value: ItemData) -> int: + """Extract key from ItemData.""" + return value.id + + +class LongTermMemoryAgent(Agent): + """Agent used for testing long term memory .""" + + @chat_model_connection + @staticmethod + def ollama_connection() -> ResourceDescriptor: + """ChatModelConnection responsible for ollama model service connection.""" + return ResourceDescriptor( + clazz=OllamaChatModelConnection, request_timeout=240.0 + ) + + @chat_model_setup + @staticmethod + def ollama_qwen3() -> ResourceDescriptor: + """ChatModel which focus on math, and reuse ChatModelConnection.""" + return ResourceDescriptor( + clazz=OllamaChatModelSetup, + connection="ollama_connection", + model=OLLAMA_CHAT_MODEL, + extract_reasoning=True, + ) + + @embedding_model_connection + @staticmethod + def ollama_embedding_connection() -> ResourceDescriptor: # noqa D102 + return ResourceDescriptor( + clazz=OllamaEmbeddingModelConnection, request_timeout=240.0 + ) + + @embedding_model_setup + @staticmethod + def ollama_nomic_embed_text() -> ResourceDescriptor: # noqa D102 + return ResourceDescriptor( + clazz=OllamaEmbeddingModelSetup, + connection="ollama_embedding_connection", + model=OLLAMA_EMBEDDING_MODEL, + ) + + @vector_store + @staticmethod + def chroma_vector_store() -> ResourceDescriptor: + """Vector store setup for knowledge base.""" + return ResourceDescriptor( + clazz=ChromaVectorStore, + embedding_model="ollama_nomic_embed_text", + persist_directory=chromadb_path, + ) + + @action(InputEvent) + @staticmethod + def add_items(event: Event, ctx: RunnerContext): # noqa D102 + input_data = event.input + ltm = ctx.long_term_memory + + timestamp_before_add = datetime.now(timezone.utc).isoformat() + memory_set = ltm.get_or_create_memory_set( + name="test_ltm", + item_type=str, + capacity=5, + compaction_strategy=SummarizationStrategy(model="ollama_qwen3"), + ) + yield from ctx.execute_async(memory_set.add, items=input_data.review) + timestamp_after_add = datetime.now(timezone.utc).isoformat() + + stm = ctx.short_term_memory + count = stm.get("count") or 1 + stm.set("count", count + 1) + + ctx.send_event( + MyEvent( + value=Record( + id=input_data.id, + count=count, + timestamp_before_add=timestamp_before_add, + timestamp_after_add=timestamp_after_add, + ) + ) + ) + + @action(MyEvent) + @staticmethod + def retrieve_items(event: Event, ctx: RunnerContext): # noqa D102 + record: Record = event.value + record.timestamp_second_action = datetime.now(timezone.utc).isoformat() + memory_set = ctx.long_term_memory.get_memory_set(name="test_ltm") + items = yield from ctx.execute_async(memory_set.get) + if ( + (record.id == 1 and record.count == 3) + or (record.id == 2 and record.count == 5) + or (record.id == 3 and record.count == 2) + ): + record.items = items + ctx.send_event(OutputEvent(output=record)) + + +def test_long_term_memory_async_execution_in_action(tmp_path: Path) -> None: # noqa: D103 + env = StreamExecutionEnvironment.get_execution_environment() + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + # currently, bounded source is not supported due to runtime implementation, so + # we use continuous file source here. + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), f"file:///{current_dir}/resources/input" + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="streaming_agent_example", + ) + + deserialize_datastream = input_datastream.map( + lambda x: ItemData.model_validate_json(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + agents_config = agents_env.get_config() + agents_config.set(AgentConfigOptions.JOB_IDENTIFIER, "LTM_TEST_JOB") + agents_config.set( + LongTermMemoryOptions.BACKEND, LongTermMemoryBackend.EXTERNAL_VECTOR_STORE + ) + agents_config.set( + LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME, "chroma_vector_store" + ) + agents_config.set(LongTermMemoryOptions.ASYNC_COMPACTION, True) + + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=MyKeySelector() + ) + .apply(LongTermMemoryAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + output_datastream.map(lambda x: x.model_dump_json(), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + check_result(result_dir=result_dir) + + +def check_result(*, result_dir: Path) -> None: # noqa: D103 + actual_result = [] + for file in result_dir.iterdir(): + if file.is_dir(): + for child in file.iterdir(): + with child.open() as f: + actual_result.extend( + [Record.model_validate_json(line) for line in f] + ) + + records = {} + for record in actual_result: + records[f"{record.id}.{record.count}"] = record + + # verify async add doesn't block process other key + assert datetime.fromisoformat( + records["2.1"].timestamp_before_add + ) < datetime.fromisoformat(records["1.1"].timestamp_after_add) + assert datetime.fromisoformat( + records["3.1"].timestamp_before_add + ) < datetime.fromisoformat(records["1.1"].timestamp_after_add) + + # verify async compaction doesn't block any operation + assert not records["2.5"].items[0].compacted + store = ChromaVectorStore( + persist_directory=chromadb_path, embedding_model="ollama_nomic_embed_text" + ) + doc = store.get(collection_name="LTM_TEST_JOB--89360337-test_ltm") + print(f"Retrieved items: {doc}") + assert len(doc) == 1 + doc = doc[0] + assert doc.metadata.get("compacted") diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 69334cc8..43190743 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -17,19 +17,27 @@ ################################################################################# import os from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict import cloudpickle from typing_extensions import override from flink_agents.api.configuration import ReadableConfiguration from flink_agents.api.events.event import Event +from flink_agents.api.memory.long_term_memory import ( + BaseLongTermMemory, + LongTermMemoryBackend, + LongTermMemoryOptions, +) from flink_agents.api.memory_object import MemoryType from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.agent_plan import AgentPlan from flink_agents.runtime.flink_memory_object import FlinkMemoryObject from flink_agents.runtime.flink_metric_group import FlinkMetricGroup +from flink_agents.runtime.memory.vector_store_long_term_memory import ( + VectorStoreLongTermMemory, +) class FlinkRunnerContext(RunnerContext): @@ -39,9 +47,14 @@ class FlinkRunnerContext(RunnerContext): """ __agent_plan: AgentPlan + __ltm: BaseLongTermMemory = None def __init__( - self, j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor, j_resource_adapter: Any + self, + j_runner_context: Any, + agent_plan_json: str, + executor: ThreadPoolExecutor, + j_resource_adapter: Any, ) -> None: """Initialize a flink runner context with the given java runner context. @@ -55,6 +68,16 @@ def __init__( self.__agent_plan.set_java_resource_adapter(j_resource_adapter) self.executor = executor + def set_long_term_memory(self, ltm: BaseLongTermMemory) -> None: + """Set long term memory instance to this context. + + Parameters + ---------- + ltm : BaseLongTermMemory + The long term memory to keep. + """ + self.__ltm = ltm + @override def send_event(self, event: Event) -> None: """Send an event to the agent for processing. @@ -104,7 +127,9 @@ def sensory_memory(self) -> FlinkMemoryObject: temporary state data. """ try: - return FlinkMemoryObject(MemoryType.SENSORY, self._j_runner_context.getSensoryMemory()) + return FlinkMemoryObject( + MemoryType.SENSORY, self._j_runner_context.getSensoryMemory() + ) except Exception as e: err_msg = "Failed to get sensory memory of runner context" raise RuntimeError(err_msg) from e @@ -121,11 +146,18 @@ def short_term_memory(self) -> FlinkMemoryObject: temporary state data. """ try: - return FlinkMemoryObject(MemoryType.SHORT_TERM, self._j_runner_context.getShortTermMemory()) + return FlinkMemoryObject( + MemoryType.SHORT_TERM, self._j_runner_context.getShortTermMemory() + ) except Exception as e: err_msg = "Failed to get short-term memory of runner context" raise RuntimeError(err_msg) from e + @property + @override + def long_term_memory(self) -> BaseLongTermMemory: + return self.__ltm + @property @override def agent_metric_group(self) -> FlinkMetricGroup: @@ -154,8 +186,8 @@ def action_metric_group(self) -> FlinkMetricGroup: def execute_async( self, func: Callable[[Any], Any], - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> Any: """Asynchronously execute the provided function. Access to memory is prohibited within the function. @@ -183,10 +215,44 @@ def config(self) -> ReadableConfiguration: def create_flink_runner_context( - j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor, j_resource_adapter: Any + j_runner_context: Any, + agent_plan_json: str, + executor: ThreadPoolExecutor, + j_resource_adapter: Any, + job_identifier: str, + key: int, +) -> FlinkRunnerContext: + """Used to create a FlinkRunnerContext Python object in Pemja environment.""" + ctx = FlinkRunnerContext( + j_runner_context, agent_plan_json, executor, j_resource_adapter + ) + backend = ctx.config.get(LongTermMemoryOptions.BACKEND) + # use external vector store based long term memory + if backend == LongTermMemoryBackend.EXTERNAL_VECTOR_STORE: + vector_store_name = ctx.config.get( + LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME + ) + ctx.set_long_term_memory( + VectorStoreLongTermMemory( + ctx=ctx, + vector_store=vector_store_name, + job_id=job_identifier, + key=str(key), + ) + ) + return ctx + + +def create_long_term_memory( + j_runner_context: Any, + agent_plan_json: str, + executor: ThreadPoolExecutor, + j_resource_adapter: Any, ) -> FlinkRunnerContext: """Used to create a FlinkRunnerContext Python object in Pemja environment.""" - return FlinkRunnerContext(j_runner_context, agent_plan_json, executor, j_resource_adapter) + return FlinkRunnerContext( + j_runner_context, agent_plan_json, executor, j_resource_adapter + ) def create_async_thread_pool() -> ThreadPoolExecutor: diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 6b5f50bd..b54eb962 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -18,12 +18,13 @@ import logging import uuid from collections import deque -from typing import Any, Callable, Dict, Generator, List, Tuple +from typing import Any, Callable, Dict, Generator, List from typing_extensions import override from flink_agents.api.agent import Agent from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.memory.long_term_memory import BaseLongTermMemory from flink_agents.api.memory_object import MemoryObject, MemoryType from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType @@ -64,7 +65,9 @@ class LocalRunnerContext(RunnerContext): _short_term_memory: MemoryObject _config: AgentConfiguration - def __init__(self, agent_plan: AgentPlan, key: Any, config: AgentConfiguration) -> None: + def __init__( + self, agent_plan: AgentPlan, key: Any, config: AgentConfiguration + ) -> None: """Initialize a new context with the given agent and key. Parameters @@ -84,7 +87,9 @@ def __init__(self, agent_plan: AgentPlan, key: Any, config: AgentConfiguration) MemoryType.SENSORY, self._sensory_mem_store, LocalMemoryObject.ROOT_KEY ) self._short_term_memory = LocalMemoryObject( - MemoryType.SHORT_TERM, self._short_term_mem_store, LocalMemoryObject.ROOT_KEY + MemoryType.SHORT_TERM, + self._short_term_mem_store, + LocalMemoryObject.ROOT_KEY, ) self._config = config @@ -152,6 +157,12 @@ def short_term_memory(self) -> MemoryObject: """ return self._short_term_memory + @property + @override + def long_term_memory(self) -> BaseLongTermMemory: + err_msg = "Long-Term Memory is not supported for local agent execution yet." + raise NotImplementedError(err_msg) + @property @override def agent_metric_group(self) -> MetricGroup: @@ -169,8 +180,8 @@ def action_metric_group(self) -> MetricGroup: def execute_async( self, func: Callable[[Any], Any], - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> Any: """Asynchronously execute the provided function. Access to memory is prohibited within the function. @@ -248,7 +259,9 @@ def run(self, **data: Dict[str, Any]) -> Any: key = uuid.uuid4() if key not in self.__keyed_contexts: - self.__keyed_contexts[key] = LocalRunnerContext(self.__agent_plan, key, self.__config) + self.__keyed_contexts[key] = LocalRunnerContext( + self.__agent_plan, key, self.__config + ) context = self.__keyed_contexts[key] context.clear_sensory_memory() diff --git a/python/flink_agents/runtime/memory/compaction_functions.py b/python/flink_agents/runtime/memory/compaction_functions.py index e3a9416a..e895900d 100644 --- a/python/flink_agents/runtime/memory/compaction_functions.py +++ b/python/flink_agents/runtime/memory/compaction_functions.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# import json +import logging from typing import TYPE_CHECKING, List, Type, cast from flink_agents.api.chat_message import ChatMessage, MessageRole @@ -42,8 +43,9 @@ -You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you must extract important topics first. The extracted topics -must no more than {limit}. Afterwards, you should generate summarization for each topic, and and record which messages the summary was derived from. +You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice, +**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and and record which messages the summary was derived from. +The message index start from 0. @@ -85,6 +87,8 @@ def summarize( items, memory_set.item_type, strategy, ctx ) + logging.debug(f"Items to be summarized: {items}\nSummarization: {response.content}") + for topic in cast("dict", json.loads(response.content)).values(): summarization = topic["summarization"] indices = topic["messages"] @@ -126,7 +130,7 @@ def summarize( ) -#TODO: Currently, we feed all items to the LLM at once, which may exceed the LLM's +# TODO: Currently, we feed all items to the LLM at once, which may exceed the LLM's # context window. We need to support batched summary generation. def _generate_summarization( memory_set_items: List[MemorySetItem], diff --git a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py index e4aabda0..a4cf203d 100644 --- a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py +++ b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py @@ -15,7 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import functools import uuid +from concurrent.futures import Future from datetime import datetime, timezone from typing import Any, Dict, List, Type, cast @@ -29,6 +31,7 @@ CompactionStrategyType, DatetimeRange, ItemType, + LongTermMemoryOptions, MemorySet, MemorySetItem, ) @@ -61,6 +64,10 @@ class VectorStoreLongTermMemory(BaseLongTermMemory): key: str = Field(description="Unique identifier for the keyed partition.") + async_compaction: bool = Field( + default=False, description="Whether to execute compact asynchronously." + ) + def __init__( self, *, @@ -76,6 +83,7 @@ def __init__( vector_store=vector_store, job_id=job_id, key=key, + async_compaction=ctx.config.get(LongTermMemoryOptions.ASYNC_COMPACTION), **kwargs, ) @@ -137,7 +145,7 @@ def add( memory_items: ItemType | List[ItemType], ids: str | List[str] | None = None, metadatas: Dict[str, Any] | List[Dict[str, Any]] | None = None, - ) -> None: + ) -> List[str]: memory_items = _maybe_cast_to_list(memory_items) ids = _maybe_cast_to_list(ids) metadatas = _maybe_cast_to_list(metadatas) @@ -171,13 +179,23 @@ def add( ) ) - self.store.add( + ids = self.store.add( documents=documents, collection_name=self._name_mangling(memory_set.name) ) if memory_set.size >= memory_set.capacity: # trigger compaction - self._compact(memory_set) + if self.async_compaction: + future = self.ctx.executor.submit(self._compact, memory_set=memory_set) + future.add_done_callback( + functools.partial( + self._handle_exception, self.job_id, self.key, memory_set + ) + ) + else: + self._compact(memory_set=memory_set) + + return ids @override def get( @@ -220,6 +238,17 @@ def _compact(self, memory_set: MemorySet) -> None: msg = f"Unknown compaction strategy: {compaction_strategy.type}" raise RuntimeError(msg) + @staticmethod + def _handle_exception( + job_id: str, key: str, memory_set: MemorySet, future: Future + ) -> None: + exception = future.exception() + if exception is not None: + err_msg = f"Compaction for {job_id}-{key}-{memory_set.name} failed." + # TODO: Currently, this exception will appear in the log of TaskManager, + # but will not cause the Flink job to fail. + raise RuntimeError(err_msg) from exception + @staticmethod def _convert_to_items( memory_set: MemorySet, diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 5f664408..98ff1446 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -90,6 +90,7 @@ import java.util.Optional; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; @@ -175,6 +176,13 @@ public class ActionExecutionOperator extends AbstractStreamOperator actionTaskRunnerContexts; + // Each job can only have one identifier and this identifier must be consistent across restarts. + // We cannot use job id as the identifier here because user may change job id by + // creating a savepoint, stop the job and then resume from savepoint. + // We use this identifier to control the visibility for long-term memory. + // Inspired by Apache Paimon. + private transient String jobIdentifier; + public ActionExecutionOperator( AgentPlan agentPlan, Boolean inputIsJava, @@ -559,7 +567,8 @@ private void initPythonActionExecutor() throws Exception { new PythonActionExecutor( pythonInterpreter, new ObjectMapper().writeValueAsString(agentPlan), - javaResourceAdapter); + javaResourceAdapter, + jobIdentifier); pythonActionExecutor.open(); } @@ -635,6 +644,16 @@ public void initializeState(StateInitializationContext context) throws Exception } actionStateStore.rebuildState(markers); } + + // Get job identifier from user configuration. + // If not configured, get from state. + jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); + if (jobIdentifier == null) { + String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); + jobIdentifier = + StateUtils.getSingleValueFromState( + context, "identifier_state", String.class, initialJobIdentifier); + } } @Override diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/StateUtils.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/StateUtils.java new file mode 100644 index 00000000..55a9d701 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/StateUtils.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; + +/** Utility class for job identifier state manipulation. Copy from Apache Paimon. */ +public class StateUtils { + + public static @Nullable T getSingleValueFromState( + StateInitializationContext context, + String stateName, + Class valueClass, + T defaultValue) + throws Exception { + ListState state = + context.getOperatorStateStore() + .getUnionListState(new ListStateDescriptor<>(stateName, valueClass)); + + List values = new ArrayList<>(); + state.get().forEach(values::add); + + if (context.isRestored()) { + // For union list state, on restore/redistribution, each operator gets the complete + // list of state elements. As we're storing the same value for each subtask, we hereby + // check if all elements are equal. + for (int i = 1; i < values.size(); i++) { + Preconditions.checkState( + values.get(i).equals(values.get(i - 1)), + "Values in list state are not the same. This is unexpected."); + } + } else { + Preconditions.checkState( + values.isEmpty(), + "Expecting 0 value for a fresh state but found " + + values.size() + + ". This is unexpected."); + } + + if (values.isEmpty()) { + values.add(defaultValue); + } + + T value = values.get(0); + state.update(Lists.newArrayList(value)); + + return value; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java index 89b39c15..a03c4c80 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java @@ -54,7 +54,10 @@ public ActionTaskResult invoke(ClassLoader userCodeClassLoader) throws Exception PythonActionExecutor pythonActionExecutor = getPythonActionExecutor(); String pythonGeneratorRef = pythonActionExecutor.executePythonFunction( - (PythonFunction) action.getExec(), (PythonEvent) event, runnerContext); + (PythonFunction) action.getExec(), + (PythonEvent) event, + runnerContext, + key.hashCode()); // If a user-defined action uses an interface to submit asynchronous tasks, it will return a // Python generator object instance upon its first execution. Otherwise, it means that no // asynchronous tasks were submitted and the action has already completed. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index 2cef2886..4c087891 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -60,15 +60,18 @@ public class PythonActionExecutor { private final PythonInterpreter interpreter; private final String agentPlanJson; private final JavaResourceAdapter javaResourceAdapter; + private final String jobIdentifier; private Object pythonAsyncThreadPool; public PythonActionExecutor( PythonInterpreter interpreter, String agentPlanJson, - JavaResourceAdapter javaResourceAdapter) { + JavaResourceAdapter javaResourceAdapter, + String jobIdentifier) { this.interpreter = interpreter; this.agentPlanJson = agentPlanJson; this.javaResourceAdapter = javaResourceAdapter; + this.jobIdentifier = jobIdentifier; } public void open() throws Exception { @@ -88,7 +91,10 @@ public void open() throws Exception { * not return a generator. */ public String executePythonFunction( - PythonFunction function, PythonEvent event, RunnerContextImpl runnerContext) + PythonFunction function, + PythonEvent event, + RunnerContextImpl runnerContext, + int hashOfKey) throws Exception { runnerContext.checkNoPendingEvents(); function.setInterpreter(interpreter); @@ -99,7 +105,9 @@ public String executePythonFunction( runnerContext, agentPlanJson, pythonAsyncThreadPool, - javaResourceAdapter); + javaResourceAdapter, + jobIdentifier, + hashOfKey); Object pythonEventObject = interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());