Skip to content

Commit 22f2002

Browse files
committed
fix: fix agent short_term_memory parameter passthrough issue
1 parent 3af70dc commit 22f2002

File tree

2 files changed

+48
-43
lines changed

2 files changed

+48
-43
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,6 @@ cython_debug/
199199
**./output
200200

201201
*.mp3
202-
*.pcm
202+
*.pcm
203+
204+
.trae

veadk/runner.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,20 @@ def intercept_new_message(process_func):
122122
def decorator(func):
123123
@functools.wraps(func)
124124
async def wrapper(
125-
self,
126-
*,
127-
user_id: str,
128-
session_id: str,
129-
new_message: types.Content,
130-
**kwargs,
125+
self,
126+
*,
127+
user_id: str,
128+
session_id: str,
129+
new_message: types.Content,
130+
**kwargs,
131131
):
132132
await pre_run_process(self, process_func, new_message, user_id, session_id)
133133

134134
async for event in func(
135-
user_id=user_id,
136-
session_id=session_id,
137-
new_message=new_message,
138-
**kwargs,
135+
user_id=user_id,
136+
session_id=session_id,
137+
new_message=new_message,
138+
**kwargs,
139139
):
140140
if event is None:
141141
logger.error(f"Event is None with new_message: {new_message}")
@@ -172,10 +172,10 @@ async def wrapper(
172172

173173

174174
def _convert_messages(
175-
messages: RunnerMessage,
176-
app_name: str,
177-
user_id: str,
178-
session_id: str,
175+
messages: RunnerMessage,
176+
app_name: str,
177+
user_id: str,
178+
session_id: str,
179179
) -> list:
180180
"""Convert a VeADK ``RunnerMessage`` into a list of Google ADK messages.
181181
@@ -251,7 +251,7 @@ def _convert_messages(
251251

252252

253253
async def _upload_image_to_tos(
254-
part: genai.types.Part, app_name: str, user_id: str, session_id: str
254+
part: genai.types.Part, app_name: str, user_id: str, session_id: str
255255
) -> None:
256256
"""Upload inline media data in a message part to TOS and rewrite its URL.
257257
@@ -316,15 +316,15 @@ class Runner(ADKRunner):
316316
"""
317317

318318
def __init__(
319-
self,
320-
agent: BaseAgent | Agent | None = None,
321-
short_term_memory: ShortTermMemory | None = None,
322-
app_name: str | None = None,
323-
user_id: str = "veadk_default_user",
324-
upload_inline_data_to_tos: bool = False,
325-
run_processor: "BaseRunProcessor | None" = None,
326-
*args,
327-
**kwargs,
319+
self,
320+
agent: BaseAgent | Agent | None = None,
321+
short_term_memory: ShortTermMemory | None = None,
322+
app_name: str | None = None,
323+
user_id: str = "veadk_default_user",
324+
upload_inline_data_to_tos: bool = False,
325+
run_processor: "BaseRunProcessor | None" = None,
326+
*args,
327+
**kwargs,
328328
) -> None:
329329
"""Initialize a Runner instance.
330330
@@ -355,13 +355,16 @@ def __init__(
355355
Raises:
356356
None
357357
"""
358+
358359
self.user_id = user_id
359360
self.long_term_memory = None
360-
self.short_term_memory = short_term_memory
361361
self.upload_inline_data_to_tos = upload_inline_data_to_tos
362362
credential_service = kwargs.pop("credential_service", None)
363363
session_service = kwargs.pop("session_service", None)
364364
memory_service = kwargs.pop("memory_service", None)
365+
if not short_term_memory:
366+
short_term_memory = agent.short_term_memory
367+
self.short_term_memory = short_term_memory
365368

366369
# Handle run_processor: priority is runner arg > agent.run_processor > NoOpRunProcessor
367370
if run_processor is not None:
@@ -426,14 +429,14 @@ def __init__(
426429
)
427430

428431
async def run(
429-
self,
430-
messages: RunnerMessage,
431-
user_id: str = "",
432-
session_id: str = f"tmp-session-{formatted_timestamp()}",
433-
run_config: RunConfig | None = None,
434-
save_tracing_data: bool = False,
435-
upload_inline_data_to_tos: bool = False,
436-
run_processor: "BaseRunProcessor | None" = None,
432+
self,
433+
messages: RunnerMessage,
434+
user_id: str = "",
435+
session_id: str = f"tmp-session-{formatted_timestamp()}",
436+
run_config: RunConfig | None = None,
437+
save_tracing_data: bool = False,
438+
upload_inline_data_to_tos: bool = False,
439+
run_processor: "BaseRunProcessor | None" = None,
437440
):
438441
"""Run a conversation with multi-turn text and multimodal inputs.
439442
@@ -503,20 +506,20 @@ async def run(
503506
)
504507
async def event_generator():
505508
async for event in self.run_async(
506-
user_id=user_id,
507-
session_id=session_id,
508-
new_message=converted_message,
509-
run_config=run_config,
509+
user_id=user_id,
510+
session_id=session_id,
511+
new_message=converted_message,
512+
run_config=run_config,
510513
):
511514
yield event
512515

513516
async for event in event_generator():
514517
if event.content is not None and event.content.parts:
515518
for part in event.content.parts:
516519
if (
517-
not part.thought
518-
and part.text
519-
and len(part.text.strip()) > 0
520+
not part.thought
521+
and part.text
522+
and len(part.text.strip()) > 0
520523
):
521524
final_output = part.text
522525
break
@@ -630,7 +633,7 @@ def save_tracing_file(self, session_id: str) -> str:
630633
```
631634
"""
632635
if not isinstance(
633-
self.agent, (Agent, SequentialAgent, ParallelAgent, LoopAgent)
636+
self.agent, (Agent, SequentialAgent, ParallelAgent, LoopAgent)
634637
):
635638
logger.warning(
636639
(
@@ -689,7 +692,7 @@ async def save_eval_set(self, session_id: str, eval_set_id: str = "default") ->
689692
return eval_set_path
690693

691694
async def save_session_to_long_term_memory(
692-
self, session_id: str, user_id: str = "", app_name: str = "", **kwargs
695+
self, session_id: str, user_id: str = "", app_name: str = "", **kwargs
693696
) -> None:
694697
"""Save the specified session to long-term memory.
695698

0 commit comments

Comments
 (0)