diff --git a/sdk/python/agentfield/agent_server.py b/sdk/python/agentfield/agent_server.py index 9aa3c0f2..abef348c 100644 --- a/sdk/python/agentfield/agent_server.py +++ b/sdk/python/agentfield/agent_server.py @@ -1,3 +1,6 @@ +from contextlib import AsyncExitStack +from contextlib import asynccontextmanager +from fastapi import FastAPI import asyncio import importlib.util import os @@ -48,7 +51,10 @@ async def agentfield_process_logs(request: Request): if not node_logs.logs_enabled(): return JSONResponse( status_code=404, - content={"error": "logs_disabled", "message": "Process logs API is disabled"}, + content={ + "error": "logs_disabled", + "message": "Process logs API is disabled", + }, ) auth = request.headers.get("authorization") or request.headers.get( "Authorization" @@ -56,7 +62,10 @@ async def agentfield_process_logs(request: Request): if not node_logs.verify_internal_bearer(auth): return JSONResponse( status_code=401, - content={"error": "unauthorized", "message": "Valid Authorization Bearer required"}, + content={ + "error": "unauthorized", + "message": "Valid Authorization Bearer required", + }, ) qp = request.query_params try: @@ -107,7 +116,9 @@ async def debug_tasks(): name = t.get_name() except Exception: name = "?" - buf.write(f"=== Task {name} done={t.done()} cancelled={t.cancelled()} ===\n") + buf.write( + f"=== Task {name} done={t.done()} cancelled={t.cancelled()} ===\n" + ) try: coro = t.get_coro() buf.write(f"coro: {coro!r}\n") @@ -121,7 +132,9 @@ async def debug_tasks(): f" {frame.f_code.co_filename}:{frame.f_lineno} in {frame.f_code.co_name}\n" ) else: - buf.write(" \n") + buf.write( + " \n" + ) except Exception as e: buf.write(f" \n") out.append(buf.getvalue()) @@ -315,7 +328,10 @@ async def approval_webhook(request: Request): approval_request_id = body.get("approval_request_id", "") if not execution_id or not decision: - return {"error": "execution_id and decision are required", "status": 400} + return { + "error": "execution_id and decision are required", + "status": 400, + } # Parse the raw response field (may be a JSON string or dict) raw_response = None @@ -340,9 +356,13 @@ async def approval_webhook(request: Request): # Try to resolve by approval_request_id first, then by execution_id resolved = False if approval_request_id: - resolved = await self.agent._pause_manager.resolve(approval_request_id, result) + resolved = await self.agent._pause_manager.resolve( + approval_request_id, result + ) if not resolved and execution_id: - resolved = await self.agent._pause_manager.resolve_by_execution_id(execution_id, result) + resolved = await self.agent._pause_manager.resolve_by_execution_id( + execution_id, result + ) if self.agent.dev_mode: log_debug( @@ -766,8 +786,27 @@ def serve( # Setup fast lifecycle signal handlers self.agent.agentfield_handler.setup_fast_lifecycle_signal_handlers() - # Add startup event handler for resilient lifecycle - @self.agent.on_event("startup") + @asynccontextmanager + async def internal_lifespan(app: FastAPI): + # Add startup event handler for resilient lifecycle + await startup_resilient_lifecycle() + try: + yield + finally: + # Add shutdown event handler for cleanup + await shutdown_cleanup() + + existing_lifespan = self.agent.router.lifespan_context + + @asynccontextmanager + async def merged_lifespan(app: FastAPI): + async with AsyncExitStack() as stack: + await stack.enter_async_context(internal_lifespan(app)) + await stack.enter_async_context(existing_lifespan(app)) + yield + + self.agent.router.lifespan_context = merged_lifespan + async def startup_resilient_lifecycle(): """Resilient lifecycle startup: connection manager handles AgentField server connectivity""" @@ -848,8 +887,6 @@ def on_disconnected(): "Agent started in local mode - will connect to AgentField server when available" ) - # Add shutdown event handler for cleanup - @self.agent.on_event("shutdown") async def shutdown_cleanup(): """Cleanup all resources when FastAPI shuts down"""