Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions python/flink_agents/runtime/flink_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,23 @@ def create_flink_runner_context(
agent_plan_json: str,
executor: ThreadPoolExecutor,
j_resource_adapter: Any,
job_identifier: str,
key: int,
) -> FlinkRunnerContext:
"""Used to create a FlinkRunnerContext Python object in Pemja environment."""
ctx = FlinkRunnerContext(
return FlinkRunnerContext(
j_runner_context, agent_plan_json, executor, j_resource_adapter
)


def flink_runner_context_switch_action_context(
ctx: FlinkRunnerContext,
job_identifier: str,
key: int,
) -> None:
"""Switch the context of the flink runner context.

The ctx is reused across keyed partitions, the context related to
specific key should be switched when process new action.
"""
backend = ctx.config.get(LongTermMemoryOptions.BACKEND)
# use external vector store based long term memory
if backend == LongTermMemoryBackend.EXTERNAL_VECTOR_STORE:
Expand All @@ -240,19 +250,6 @@ def create_flink_runner_context(
key=str(key),
)
)
return ctx


def create_long_term_memory(
j_runner_context: Any,
agent_plan_json: str,
executor: ThreadPoolExecutor,
j_resource_adapter: Any,
) -> FlinkRunnerContext:
"""Used to create a FlinkRunnerContext Python object in Pemja environment."""
return FlinkRunnerContext(
j_runner_context, agent_plan_json, executor, j_resource_adapter
)


def create_async_thread_pool() -> ThreadPoolExecutor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,61 @@
* actions.
*/
public class RunnerContextImpl implements RunnerContext {
public static class MemoryContext {
private final CachedMemoryStore sensoryMemStore;
private final CachedMemoryStore shortTermMemStore;
private final List<MemoryUpdate> sensoryMemoryUpdates;
private final List<MemoryUpdate> shortTermMemoryUpdates;

public MemoryContext(
CachedMemoryStore sensoryMemStore, CachedMemoryStore shortTermMemStore) {
this.sensoryMemStore = sensoryMemStore;
this.shortTermMemStore = shortTermMemStore;
this.sensoryMemoryUpdates = new LinkedList<>();
this.shortTermMemoryUpdates = new LinkedList<>();
}

public List<MemoryUpdate> getShortTermMemoryUpdates() {
return shortTermMemoryUpdates;
}

public List<MemoryUpdate> getSensoryMemoryUpdates() {
return sensoryMemoryUpdates;
}

public CachedMemoryStore getShortTermMemStore() {
return shortTermMemStore;
}

public CachedMemoryStore getSensoryMemStore() {
return sensoryMemStore;
}
}

protected final List<Event> pendingEvents = new ArrayList<>();
protected final CachedMemoryStore sensoryMemStore;
protected final CachedMemoryStore shortTermMemStore;
protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
protected final Runnable mailboxThreadChecker;
protected final AgentPlan agentPlan;
protected final List<MemoryUpdate> sensoryMemoryUpdates;
protected final List<MemoryUpdate> shortTermMemoryUpdates;

protected MemoryContext memoryContext;
protected String actionName;

public RunnerContextImpl(
CachedMemoryStore sensoryMemStore,
CachedMemoryStore shortTermMemStore,
FlinkAgentsMetricGroupImpl agentMetricGroup,
Runnable mailboxThreadChecker,
AgentPlan agentPlan) {
this.sensoryMemStore = sensoryMemStore;
this.shortTermMemStore = shortTermMemStore;
this.agentMetricGroup = agentMetricGroup;
this.mailboxThreadChecker = mailboxThreadChecker;
this.agentPlan = agentPlan;
this.sensoryMemoryUpdates = new LinkedList<>();
this.shortTermMemoryUpdates = new LinkedList<>();
}

public void setActionName(String actionName) {
public void switchActionContext(String actionName, MemoryContext memoryContext) {
this.actionName = actionName;
this.memoryContext = memoryContext;
}

public MemoryContext getMemoryContext() {
return memoryContext;
}

@Override
Expand Down Expand Up @@ -112,7 +139,7 @@ public void checkNoPendingEvents() {

public List<MemoryUpdate> getSensoryMemoryUpdates() {
mailboxThreadChecker.run();
return List.copyOf(sensoryMemoryUpdates);
return List.copyOf(memoryContext.getSensoryMemoryUpdates());
}

/**
Expand All @@ -124,29 +151,29 @@ public List<MemoryUpdate> getSensoryMemoryUpdates() {
*/
public List<MemoryUpdate> getShortTermMemoryUpdates() {
mailboxThreadChecker.run();
return List.copyOf(shortTermMemoryUpdates);
return List.copyOf(memoryContext.getShortTermMemoryUpdates());
}

@Override
public MemoryObject getSensoryMemory() throws Exception {
mailboxThreadChecker.run();
return new MemoryObjectImpl(
MemoryObject.MemoryType.SENSORY,
sensoryMemStore,
memoryContext.getSensoryMemStore(),
MemoryObjectImpl.ROOT_KEY,
mailboxThreadChecker,
sensoryMemoryUpdates);
memoryContext.getSensoryMemoryUpdates());
}

@Override
public MemoryObject getShortTermMemory() throws Exception {
mailboxThreadChecker.run();
return new MemoryObjectImpl(
MemoryObject.MemoryType.SHORT_TERM,
shortTermMemStore,
memoryContext.getShortTermMemStore(),
MemoryObjectImpl.ROOT_KEY,
mailboxThreadChecker,
shortTermMemoryUpdates);
memoryContext.getShortTermMemoryUpdates());
}

@Override
Expand Down Expand Up @@ -177,11 +204,11 @@ public String getActionName() {
}

public void persistMemory() throws Exception {
sensoryMemStore.persistCache();
shortTermMemStore.persistCache();
memoryContext.getSensoryMemStore().persistCache();
memoryContext.getShortTermMemStore().persistCache();
}

public void clearSensoryMemory() throws Exception {
sensoryMemStore.clear();
memoryContext.getSensoryMemStore().clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT
// PythonActionExecutor for Python actions
private transient PythonActionExecutor pythonActionExecutor;

// RunnerContext for Python actions
private transient PythonRunnerContextImpl pythonRunnerContext;

// PythonResourceAdapter for Python resources in Java actions
private transient PythonResourceAdapterImpl pythonResourceAdapter;

Expand All @@ -144,6 +147,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT

private final transient MailboxExecutor mailboxExecutor;

// RunnerContext for Java Actions
private transient RunnerContextImpl runnerContext;

// We need to check whether the current thread is the mailbox thread using the mailbox
// processor.
// TODO: This is a temporary workaround. In the future, we should add an interface in
Expand Down Expand Up @@ -174,7 +180,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT

// This in memory map keep track of the runner context for the async action task that having
// been finished
private final transient Map<ActionTask, RunnerContextImpl> actionTaskRunnerContexts;
private final transient Map<ActionTask, RunnerContextImpl.MemoryContext>
actionTaskMemoryContexts;

// Each job can only have one identifier and this identifier must be consistent across restarts.
// We cannot use job id as the identifier here because user may change job id by
Expand All @@ -198,7 +205,7 @@ public ActionExecutionOperator(
this.eventListeners = new ArrayList<>();
this.actionStateStore = actionStateStore;
this.checkpointIdToSeqNums = new HashMap<>();
this.actionTaskRunnerContexts = new HashMap<>();
this.actionTaskMemoryContexts = new HashMap<>();
}

@Override
Expand Down Expand Up @@ -443,12 +450,14 @@ private void processActionTaskForKey(Object key) throws Exception {
} else {
maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event);
ActionTask.ActionTaskResult actionTaskResult =
actionTask.invoke(getRuntimeContext().getUserCodeClassLoader());
actionTask.invoke(
getRuntimeContext().getUserCodeClassLoader(),
this.pythonActionExecutor);

// We remove the RunnerContext of the action task from the map after it is finished. The
// RunnerContext will be added later if the action task has a generated action task,
// meaning it is not finished.
actionTaskRunnerContexts.remove(actionTask);
actionTaskMemoryContexts.remove(actionTask);
maybePersistTaskResult(
key,
sequenceNumber,
Expand Down Expand Up @@ -483,7 +492,8 @@ private void processActionTaskForKey(Object key) throws Exception {

// If the action task is not finished, we keep the runner context in the memory for the
// next generated ActionTask to be invoked.
actionTaskRunnerContexts.put(generatedActionTask, actionTask.getRunnerContext());
actionTaskMemoryContexts.put(
generatedActionTask, actionTask.getRunnerContext().getMemoryContext());

actionTasksKState.add(generatedActionTask);
}
Expand Down Expand Up @@ -552,6 +562,9 @@ private void initPythonEnvironment() throws Exception {
pythonEnvironmentManager.open();
EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment();
pythonInterpreter = env.getInterpreter();
pythonRunnerContext =
new PythonRunnerContextImpl(
this.metricGroup, this::checkMailboxThread, this.agentPlan);
if (containPythonAction) {
initPythonActionExecutor();
} else {
Expand All @@ -568,6 +581,7 @@ private void initPythonActionExecutor() throws Exception {
pythonInterpreter,
new ObjectMapper().writeValueAsString(agentPlan),
javaResourceAdapter,
pythonRunnerContext,
jobIdentifier);
pythonActionExecutor.open();
}
Expand Down Expand Up @@ -752,31 +766,28 @@ private void createAndSetRunnerContext(ActionTask actionTask) {
}

RunnerContextImpl runnerContext;
if (actionTaskRunnerContexts.containsKey(actionTask)) {
runnerContext = actionTaskRunnerContexts.get(actionTask);
} else if (actionTask.action.getExec() instanceof JavaFunction) {
runnerContext =
new RunnerContextImpl(
new CachedMemoryStore(sensoryMemState),
new CachedMemoryStore(shortTermMemState),
metricGroup,
this::checkMailboxThread,
agentPlan);
if (actionTask.action.getExec() instanceof JavaFunction) {
runnerContext = createOrGetRunnerContext(true);
} else if (actionTask.action.getExec() instanceof PythonFunction) {
runnerContext =
new PythonRunnerContextImpl(
new CachedMemoryStore(sensoryMemState),
new CachedMemoryStore(shortTermMemState),
metricGroup,
this::checkMailboxThread,
agentPlan,
pythonActionExecutor);
runnerContext = createOrGetRunnerContext(false);
} else {
throw new IllegalStateException(
"Unsupported action type: " + actionTask.action.getExec().getClass());
}

runnerContext.setActionName(actionTask.action.getName());
RunnerContextImpl.MemoryContext memoryContext;
if (actionTaskMemoryContexts.containsKey(actionTask)) {
// action task for async execution action, should retrieve intermediate results from
// map.
memoryContext = actionTaskMemoryContexts.get(actionTask);
} else {
memoryContext =
new RunnerContextImpl.MemoryContext(
new CachedMemoryStore(sensoryMemState),
new CachedMemoryStore(shortTermMemState));
}

runnerContext.switchActionContext(actionTask.action.getName(), memoryContext);
actionTask.setRunnerContext(runnerContext);
}

Expand Down Expand Up @@ -883,6 +894,24 @@ private void processEligibleWatermarks() throws Exception {
}
}

private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) {
if (isJava) {
if (runnerContext == null) {
runnerContext =
new RunnerContextImpl(
this.metricGroup, this::checkMailboxThread, this.agentPlan);
}
return runnerContext;
} else {
if (pythonRunnerContext == null) {
pythonRunnerContext =
new PythonRunnerContextImpl(
this.metricGroup, this::checkMailboxThread, this.agentPlan);
}
return pythonRunnerContext;
}
}

/** Failed to execute Action task. */
public static class ActionTaskExecutionException extends Exception {
public ActionTaskExecutionException(String message, Throwable cause) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.plan.actions.Action;
import org.apache.flink.agents.runtime.context.RunnerContextImpl;
import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -87,7 +88,8 @@ public int hashCode() {
}

/** Invokes the action task. */
public abstract ActionTaskResult invoke(ClassLoader userCodeClassLoader) throws Exception;
public abstract ActionTaskResult invoke(
ClassLoader userCodeClassLoader, PythonActionExecutor executor) throws Exception;

public class ActionTaskResult {
private final boolean finished;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.plan.JavaFunction;
import org.apache.flink.agents.plan.actions.Action;
import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;

import static org.apache.flink.util.Preconditions.checkState;

Expand All @@ -36,7 +37,8 @@ public JavaActionTask(Object key, Event event, Action action) {
}

@Override
public ActionTaskResult invoke(ClassLoader userCodeClassLoader) throws Exception {
public ActionTaskResult invoke(ClassLoader userCodeClassLoader, PythonActionExecutor executor)
throws Exception {
LOG.debug(
"Try execute java action {} for event {} with key {}.",
action.getName(),
Expand Down
Loading
Loading