Skip to content

feat: Allow user to edit agent files/directories and apply changes without reloading anything #1647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"typing-extensions>=4.5, <5",
"tzlocal>=5.3", # Time zone utilities
"uvicorn>=0.34.0", # ASGI server for FastAPI
"watchdog>=6.0.0", # For file change detection and hot reload
"websockets>=15.0.1", # For BaseLlmFlow
# go/keep-sorted end
]
Expand Down
20 changes: 20 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,14 @@ def wrapper(*args, **kwargs):
default="127.0.0.1",
show_default=True,
)
@click.option(
"--watch_agents",
type=bool,
help="Optional. Whether to enable live reload for agents changes.",
is_flag=True,
default=False,
show_default=True,
)
@fast_api_common_options()
@adk_services_options()
@deprecated_adk_services_options()
Expand All @@ -625,6 +633,7 @@ def cli_web(
session_db_url: Optional[str] = None, # Deprecated
artifact_storage_uri: Optional[str] = None, # Deprecated
a2a: bool = False,
watch_agents: bool = False,
):
"""Starts a FastAPI server with Web UI for agents.

Expand Down Expand Up @@ -674,6 +683,7 @@ async def _lifespan(app: FastAPI):
a2a=a2a,
host=host,
port=port,
watch_agents=watch_agents,
)
config = uvicorn.Config(
app,
Expand All @@ -694,6 +704,14 @@ async def _lifespan(app: FastAPI):
default="127.0.0.1",
show_default=True,
)
@click.option(
"--watch_agents",
type=bool,
help="Optional. Whether to enable live reload for agents changes.",
is_flag=True,
default=False,
show_default=True,
)
@fast_api_common_options()
@adk_services_options()
@deprecated_adk_services_options()
Expand Down Expand Up @@ -721,6 +739,7 @@ def cli_api_server(
session_db_url: Optional[str] = None, # Deprecated
artifact_storage_uri: Optional[str] = None, # Deprecated
a2a: bool = False,
watch_agents: bool = False,
):
"""Starts a FastAPI server for agents.

Expand Down Expand Up @@ -748,6 +767,7 @@ def cli_api_server(
a2a=a2a,
host=host,
port=port,
watch_agents=watch_agents,
),
host=host,
port=port,
Expand Down
33 changes: 28 additions & 5 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,20 @@
from .utils import evals
from .utils.agent_loader import AgentLoader

from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler

logger = logging.getLogger("google_adk." + __name__)

_EVAL_SET_FILE_EXTENSION = ".evalset.json"
_should_reload_agents = [False]

class AgentChangeEventHandler(FileSystemEventHandler):
def on_modified(self, event):
if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")):
return
logger.info("Change detected in agents directory: %s", event.src_path)
_should_reload_agents[0] = True

class ApiServerSpanExporter(export.SpanExporter):

Expand Down Expand Up @@ -205,8 +215,16 @@ def get_fast_api_app(
host: str = "127.0.0.1",
port: int = 8000,
trace_to_cloud: bool = False,
watch_agents: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
) -> FastAPI:
# Set up a file system watcher to detect changes in the agents directory.
observer = Observer()
if watch_agents:
event_handler = AgentChangeEventHandler()
observer.schedule(event_handler, agents_dir, recursive=True)
observer.start()

# InMemory tracing dict.
trace_dict: dict[str, Any] = {}
session_trace_dict: dict[str, Any] = {}
Expand Down Expand Up @@ -235,14 +253,16 @@ def get_fast_api_app(

@asynccontextmanager
async def internal_lifespan(app: FastAPI):

try:
if lifespan:
async with lifespan(app) as lifespan_context:
yield lifespan_context
else:
yield
finally:
if watch_agents:
observer.stop()
observer.join()
# Create tasks for all runner closures to run concurrently
await cleanup.close_runners(list(runner_dict.values()))

Expand Down Expand Up @@ -503,7 +523,7 @@ async def add_session_to_eval_set(

# Populate the session with initial session state.
initial_session_state = create_empty_state(
agent_loader.load_agent(app_name)
agent_loader.load_agent(app_name, _should_reload_agents)
)

new_eval_case = EvalCase(
Expand Down Expand Up @@ -617,7 +637,7 @@ async def run_eval(
logger.info("Eval ids to run list is empty. We will run all eval cases.")
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}

root_agent = agent_loader.load_agent(app_name)
root_agent = agent_loader.load_agent(app_name, _should_reload_agents)
run_eval_results = []
eval_case_results = []
try:
Expand Down Expand Up @@ -846,7 +866,7 @@ async def get_event_graph(

function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = agent_loader.load_agent(app_name)
root_agent = agent_loader.load_agent(app_name, _should_reload_agents)
dot_graph = None
if function_calls:
function_call_highlights = []
Expand Down Expand Up @@ -947,10 +967,13 @@ async def process_messages():

async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
if _should_reload_agents[0]:
runner_dict.pop(app_name, None)

envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = agent_loader.load_agent(app_name)
root_agent = agent_loader.load_agent(app_name, _should_reload_agents)
runner = Runner(
app_name=app_name,
agent=root_agent,
Expand Down
15 changes: 14 additions & 1 deletion src/google/adk/cli/utils/agent_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,21 @@ def _perform_load(self, agent_name: str) -> BaseAgent:
" exposed."
)

def load_agent(self, agent_name: str) -> BaseAgent:
def load_agent(self, agent_name: str, should_reload_agents: Optional[list[bool]] = [False]) -> BaseAgent:
"""Load an agent module (with caching & .env) and return its root_agent."""
if should_reload_agents[0]:
# Clear module cache for the agent and its submodules
keys_to_delete = [
module_name
for module_name in sys.modules
if module_name == agent_name or module_name.startswith(f"{agent_name}.")
]
for key in keys_to_delete:
logger.debug("Deleting module %s", key)
del sys.modules[key]
self._agent_cache.clear()
should_reload_agents[0] = False

if agent_name in self._agent_cache:
logger.debug("Returning cached agent for %s (async)", agent_name)
return self._agent_cache[agent_name]
Expand Down
Loading