diff --git a/.agents/skills/tool-microservice/SKILL.md b/.agents/skills/modai-tool/SKILL.md similarity index 52% rename from .agents/skills/tool-microservice/SKILL.md rename to .agents/skills/modai-tool/SKILL.md index bf729af..96f474b 100644 --- a/.agents/skills/tool-microservice/SKILL.md +++ b/.agents/skills/modai-tool/SKILL.md @@ -1,5 +1,5 @@ --- -name: tool-microservice +name: modai-tool description: How to create a new tool microservice for modAI-chat. Tools are independent HTTP microservices that expose an OpenAPI spec and a trigger endpoint. They are registered in modAI's tool registry via config.yaml. --- @@ -53,7 +53,7 @@ Start the service and check that `/openapi.json` contains: - `operationId` — unique name for the tool (e.g. `"roll_dice"`) - `summary` or `description` — what the tool does (shown to the LLM) -- `requestBody.content.application/json.schema` — input parameters +- `requestBody.content.application/json.schema` — input parameters (optional if all inputs come from path/header) ```bash curl http://localhost:8001/openapi.json | jq '.paths' @@ -82,14 +82,71 @@ The dice roller produces this structure: } ``` +#### Path Parameters + +If the trigger URL contains path variables (e.g. `/users/{user_id}/orders/{order_id}`), declare them as `"in": "path"` in the `parameters` array. The registry merges them into the tool's parameter schema so the LLM knows to supply them. At invocation time, modAI substitutes their values directly into the URL — they are **not** sent in the request body. + +```json +{ + "/users/{user_id}/orders/{order_id}": { + "get": { + "summary": "Get a user order", + "operationId": "get_user_order", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": true, + "description": "The user's ID", + "schema": { "type": "string" } + }, + { + "name": "order_id", + "in": "path", + "required": true, + "description": "The order's ID", + "schema": { "type": "integer" } + } + ] + } + } +} +``` + +#### Header Parameters + +Parameters your tool expects as **HTTP request headers** (e.g. `X-Session-Id`) must be declared as `"in": "header"` in the `parameters` array. The registry includes them in the tool's parameter schema; at invocation time modAI forwards their values as HTTP headers — they are **not** sent in the request body. + +```json +{ + "/data": { + "get": { + "summary": "Fetch session data", + "operationId": "fetch_data", + "parameters": [ + { + "name": "X-Session-Id", + "in": "header", + "required": true, + "description": "Active session identifier", + "schema": { "type": "string" } + } + ] + } + } +} +``` + ### 3. Register in modAI config.yaml -Add the tool to the `tool_registry` module's `tools` list in `config.yaml` (and `default_config.yaml` if it should be a default): +Add the tool to the `openapi_tool_registry` module's `tools` list in `config.yaml` (and `default_config.yaml` if it should be a default): ```yaml modules: - tool_registry: - class: modai.modules.tools.tool_registry.HttpToolRegistryModule + openapi_tool_registry: + class: modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule + module_dependencies: + http_client: "http_client" config: tools: - url: http://localhost:8001/roll @@ -97,11 +154,31 @@ modules: ``` Each entry has: -- **`url`**: The full trigger endpoint URL (not the base URL) +- **`url`**: The full trigger endpoint URL (including any path-parameter placeholders, e.g. `http://svc:8000/users/{user_id}/orders/{order_id}`) - **`method`**: The HTTP method to invoke the tool (POST, PUT, GET, etc.) The registry derives the base URL from `url` (strips the path) and appends `/openapi.json` to fetch the spec. +#### Hiding known variables with PredefinedVariablesToolRegistryModule + +The default `tool_registry` in `config.yaml` is a `PredefinedVariablesToolRegistryModule` that wraps the OpenAPI registry. When the caller already has a value for a tool parameter (e.g. a session ID that comes from the auth headers), that parameter can be hidden from the LLM's tool definition so the LLM is never asked to supply it. + +- **Direct match**: if a tool has a body/path parameter named `session_id` and the predefined params dict contains `_session_id`, `session_id` is stripped automatically. +- **Configured mapping**: if a tool uses a header parameter named `X-Session-Id` (which differs from the predefined variable name `session_id`), add a `variable_mappings` entry: + +```yaml +modules: + tool_registry: + class: modai.modules.tools.tool_registry_predefined_vars.PredefinedVariablesToolRegistryModule + module_dependencies: + delegate_registry: "openapi_tool_registry" + config: + variable_mappings: + X-Session-Id: session_id # _session_id predefined value → X-Session-Id header +``` + +At invocation time, modAI translates `_session_id` back to `X-Session-Id` and forwards it as an HTTP header — the LLM never sees it. + ### 4. Test the Integration 1. Start the tool microservice @@ -133,7 +210,9 @@ Expected: | OpenAPI spec location | `/openapi.json` at service root | | Tool name | `operationId` from the OpenAPI spec | | Tool description | `summary` (preferred) or `description` from the operation | -| Parameters | `requestBody.content.application/json.schema` | +| Body parameters | `requestBody.content.application/json.schema` | +| Path parameters | `"in": "path"` in `parameters` array — substituted into the URL at invocation | +| Header parameters | `"in": "header"` in `parameters` array — forwarded as HTTP headers at invocation | | HTTP method | Choose what's idiomatic (POST for actions, GET for queries, etc.) | | Error handling | Return appropriate HTTP status codes; modAI logs warnings for unreachable tools | @@ -141,6 +220,9 @@ Expected: - **Missing `operationId`**: The tool will be silently skipped. Always set `operationId` on your trigger operation. - **Wrong URL in config**: The `url` must be the full trigger endpoint (e.g. `http://localhost:8001/roll`), not just the base URL. The registry strips the path to derive the base for fetching `/openapi.json`. +- **Path variables in URL but not in spec**: If the configured `url` contains `{param}` placeholders, the corresponding `"in": "path"` parameters must be declared in the spec. Otherwise the LLM won't know to supply them and the URL won't be substituted correctly. +- **Header params missing from `parameters` array**: Header parameters must be declared with `"in": "header"` in the spec — they are not inferred from the request body schema. Undeclared header params will never be forwarded. +- **Header param name mismatch with predefined variables**: If your header param is named `X-Session-Id` but the predefined variable is `_session_id`, the value won't be injected automatically. Add a `variable_mappings` entry in the `tool_registry` config to bridge the naming difference. - **Multiple operations**: The registry uses the **first** operation with an `operationId` it finds. Keep one trigger operation per tool service. - **Non-JSON responses**: The LLM expects JSON results. Always return `application/json`. diff --git a/.agents/skills/tool-microservice/references/README.md b/.agents/skills/modai-tool/references/README.md similarity index 100% rename from .agents/skills/tool-microservice/references/README.md rename to .agents/skills/modai-tool/references/README.md diff --git a/.agents/skills/tool-microservice/references/main.py b/.agents/skills/modai-tool/references/main.py similarity index 100% rename from .agents/skills/tool-microservice/references/main.py rename to .agents/skills/modai-tool/references/main.py diff --git a/.agents/skills/tool-microservice/references/pyproject.toml b/.agents/skills/modai-tool/references/pyproject.toml similarity index 100% rename from .agents/skills/tool-microservice/references/pyproject.toml rename to .agents/skills/modai-tool/references/pyproject.toml diff --git a/.github/dependabot.yml b/.github/dependabot.yml index f969913..5b60428 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -2,13 +2,36 @@ version: 2 updates: - package-ecosystem: uv - directory: '/' + directories: + - '/backend/omni' + - '/backend/tools/dice-roller' + schedule: + interval: weekly + target-branch: 'main' + groups: + all-uv: + patterns: + - "*" + + - package-ecosystem: npm + directories: + - '/frontend_omni' + - '/e2e_tests/tests_omni_full' + - '/e2e_tests/tests_omni_light' schedule: interval: weekly target-branch: 'main' + groups: + all-pnpm: + patterns: + - "*" - package-ecosystem: 'github-actions' directory: '/' schedule: interval: weekly target-branch: 'main' + groups: + all-gha: + patterns: + - "*" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..78b9a4e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,3 @@ +# Workspace root – prek discovers nested project configs automatically. +# See https://prek.j178.dev/workspace/ +repos: [] diff --git a/AGENTS.md b/AGENTS.md index f7429e4..000da65 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -188,6 +188,36 @@ pnpm check # Run linter For comprehensive e2e testing best practices and patterns, refer to `e2e_tests/BEST_PRACTICES.md`. +## Git Hooks (prek) + +Pre-commit hooks use prek's **workspace mode**. Each sub-project has its own +`.pre-commit-config.yaml` (with `orphan: true`), discovered automatically from +the root `.pre-commit-config.yaml`. + +Layout: +- `.pre-commit-config.yaml` — workspace root (empty, enables discovery) +- `backend/omni/.pre-commit-config.yaml` — ruff format + ruff check +- `backend/tools/dice-roller/.pre-commit-config.yaml` — ruff format + ruff check +- `frontend_omni/.pre-commit-config.yaml` — biome check +- `e2e_tests/tests_omni_full/.pre-commit-config.yaml` — biome check +- `e2e_tests/tests_omni_light/.pre-commit-config.yaml` — biome check + +Hooks check (but do not auto-fix) on every commit and fail if issues remain. + +**One-time setup** after cloning: +```bash +uv tool install prek # Install prek binary (skip if already installed) +prek install # Wire hooks into .git/hooks/pre-commit +``` + +**Manual run:** +```bash +prek run # Run on staged files only +prek run --all-files # Run on all files +prek run backend/omni/ # Run hooks for a specific project +prek run ruff-check # Run a single hook by id +``` + ## Development Workflow 1. **Read Architecture**: Always read relevant architecture docs first diff --git a/backend/omni/.pre-commit-config.yaml b/backend/omni/.pre-commit-config.yaml new file mode 100644 index 0000000..aaa6699 --- /dev/null +++ b/backend/omni/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +orphan: true + +repos: + - repo: local + hooks: + - id: ruff-format + name: ruff format check + language: system + entry: uv run ruff format --check src + always_run: true + pass_filenames: false + + - id: ruff-check + name: ruff check + language: system + entry: uv run ruff check src + always_run: true + pass_filenames: false diff --git a/backend/omni/config.yaml b/backend/omni/config.yaml index f8173e1..4a091aa 100644 --- a/backend/omni/config.yaml +++ b/backend/omni/config.yaml @@ -62,8 +62,13 @@ modules: session: "session" user_settings_store: "user_settings_store" - tool_registry: - class: modai.modules.tools.tool_registry.HttpToolRegistryModule + http_client: + class: modai.modules.http_client.httpx_http_client_module.HttpxHttpClientModule + + openapi_tool_registry: + class: modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule + module_dependencies: + http_client: "http_client" config: tools: - url: http://localhost:8001/roll @@ -75,6 +80,17 @@ modules: # - url: http://web-search-service:8000/search # method: PUT + tool_registry: + class: modai.modules.tools.tool_registry_predefined_vars.PredefinedVariablesToolRegistryModule + module_dependencies: + delegate_registry: "openapi_tool_registry" + config: + # Map tool parameter names that differ from the predefined variable name. + # Format: : + # Example: the predefined _session_id fills the tool's X-Session-Id header param. + # variable_mappings: + # X-Session-Id: session_id + tools_web: class: modai.modules.tools.tools_web_module.OpenAIToolsWebModule module_dependencies: diff --git a/backend/omni/docs/architecture/tools.md b/backend/omni/docs/architecture/tools.md index 1439041..4052369 100644 --- a/backend/omni/docs/architecture/tools.md +++ b/backend/omni/docs/architecture/tools.md @@ -1,13 +1,13 @@ # Tools Architecture ## 1. Overview -- **Architecture Style**: Microservice-based tool system with a registry and a web layer that serves tools in OpenAI format +- **Architecture Style**: Microservice-based tool system with a generic tool abstraction and a web layer that serves tools in OpenAI format - **Design Principles**: - - Tools are independent microservices — each tool is a standalone service with its own OpenAPI spec - - OpenAPI as contract — the tool's definition (parameters, description, endpoints) is read from its OpenAPI spec - - Registry as aggregator — the Tool Registry fetches and holds OpenAPI specs + invocation metadata (url, method) without transformation - - Web layer transforms — the Tools Web Module converts OpenAPI specs into OpenAI function-calling format for the frontend - - Chat Agent resolves via registry — when the LLM emits a tool call, the Chat Agent looks up the tool's url and method from the registry to make the HTTP call + - Tools are LLM-agnostic — a tool has a definition (name, description, parameters) and a run capability; neither is tied to any LLM API + - OpenAPI as one registry implementation — the OpenAPI registry fetches specs from microservices and creates tool instances that handle HTTP invocation internally + - Registry encapsulates invocation — callers do not need to know about URLs, methods, or HTTP; they just run a tool with parameters + - Web layer transforms definitions — the Tools Web Module converts tool definitions to OpenAI function-calling format + - Extensible — new registry implementations can plug in any tool backend (HTTP, gRPC, in-process functions, etc.) without changing callers - **Quality Attributes**: Decoupled, language-agnostic, independently deployable, discoverable ## 2. Tool Microservice Convention @@ -16,7 +16,8 @@ Each tool is a standalone microservice that follows these conventions: 1. **HTTP endpoint**: The tool is triggered via an HTTP request. Each tool chooses the HTTP method (PUT, POST, GET, etc.) that is most idiomatic for its use case. The method is configured in the tool registry. 2. **OpenAPI spec**: The microservice exposes its OpenAPI specification (typically at `/openapi.json`). This spec documents all endpoints, including the trigger endpoint with its parameters, description, and response schema. -3. **Independence**: Tools have no dependency on modAI. They are plain HTTP microservices that can be developed, deployed, and tested independently in any language/framework. +3. **Path parameters**: If the trigger URL contains path parameters (e.g. `/users/{user_id}/orders/{order_id}`), they must be declared in the OpenAPI spec as `parameters` with `"in": "path"`. The registry includes them in the tool definition so the LLM knows to supply them; at invocation time they are substituted into the URL and are not forwarded in the request body. +4. **Independence**: Tools have no dependency on modAI. They are plain HTTP microservices that can be developed, deployed, and tested independently in any language/framework. ### Example Tool Microservice (OpenAPI spec) ```json @@ -70,6 +71,49 @@ Each tool is a standalone microservice that follows these conventions: } ``` +### Example Tool Microservice with Path Parameters (OpenAPI spec) +```json +{ + "openapi": "3.1.0", + "info": { + "title": "Order Tool", + "version": "1.0.0", + "description": "Retrieve a specific user order" + }, + "paths": { + "/users/{user_id}/orders/{order_id}": { + "get": { + "summary": "Get a user order", + "operationId": "get_user_order", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": true, + "description": "The user's ID", + "schema": { "type": "string" } + }, + { + "name": "order_id", + "in": "path", + "required": true, + "description": "The order's ID", + "schema": { "type": "integer" } + } + ], + "responses": { + "200": { + "description": "Order details" + } + } + } + } + } +} +``` + +The registry will build a tool definition that includes `user_id` and `order_id` as required parameters. When the LLM calls the tool with `{"user_id": "alice", "order_id": 42}`, the registry substitutes those values into the URL (`/users/alice/orders/42`) and sends an empty JSON body. + ## 3. System Context ```mermaid @@ -80,54 +124,96 @@ flowchart TD TR -->|GET /openapi.json| TS2[Tool Service B] FE -->|POST /api/responses with tool names| CR[Chat Router] CR --> CA[Chat Agent Module] - CA -->|lookup tool by name| TR - CA -->|HTTP trigger| TS1 - CA -->|HTTP trigger| TS2 + CA -->|get_tool_by_name| TR + CA -->|tool.run params | TS1 + CA -->|tool.run params | TS2 ``` **Flow**: 1. Frontend calls `GET /api/tools` to discover all available tools -2. Tools Web Module asks the Tool Registry for all tools (OpenAPI specs + url + method) -3. Tools Web Module transforms the OpenAPI specs into **OpenAI function-calling format** and returns them to the frontend -4. User selects which tools to enable for a chat session -5. Frontend sends `POST /api/responses` with tool names (as received from `GET /api/tools`) -6. When the LLM emits a `tool_call` with a function name, the Chat Agent **looks up** that name in the Tool Registry to get the tool's url and method -7. The Chat Agent sends an HTTP request to the tool's microservice endpoint and returns the result to the LLM +2. Tools Web Module asks the Tool Registry for all tools and converts their definitions to OpenAI format +3. User selects which tools to enable for a chat session +4. Frontend sends `POST /api/responses` with tool names (as received from `GET /api/tools`) +5. When the LLM emits a `tool_call`, the Chat Agent looks up the tool by name in the registry +6. The Chat Agent runs the tool with the LLM-supplied parameters — the tool directly invokes the microservice; no registry involvement at invocation time ## 4. Module Architecture -### 4.1 Tool Registry Module (Plain Module) +### 4.1 Core Abstractions + +**Tool Definition** — a value object with three fields: +- `name` — unique identifier derived from the OpenAPI `operationId` +- `description` — human-readable text describing what the tool does +- `parameters` — a fully-resolved JSON Schema (all `$ref` pointers inlined) describing the input + +A tool definition contains enough information to construct an LLM tool call but is not tied to any specific LLM API. + +**Tool** — pairs a definition with execution capability: +- Exposes its `definition` (read-only) +- Provides a `run(params)` operation that executes the tool with the given parameters and returns the result + +#### Reserved `_`-prefixed keys in `params` + +Callers may inject caller-supplied metadata into the `params` dict using keys prefixed with `_`. These keys are **never** forwarded to the tool microservice's JSON body — tool implementations must extract and consume them before sending the request. + +Currently defined reserved keys: -**Purpose**: Aggregates OpenAPI specs from all configured tool microservices and provides tool lookup for invocation. +| Key | Type | Description | +|---|---|---| +| `_bearer_token` | `str \| None` | Forwarded as `Authorization: Bearer ` HTTP header | + +This convention keeps the `Tool.run` interface stable while allowing callers to pass through transport-level concerns (auth, tracing, etc.) without requiring interface changes. + +### 4.2 Tool Registry Module (Plain Module) + +**Purpose**: Aggregates tools from all configured sources and provides lookup by name. **Responsibilities**: -- Maintain a list of configured tool microservice URLs and their HTTP methods -- Fetch the OpenAPI spec from each tool microservice -- Return all tools with their OpenAPI specs, urls, and methods (unmodified) -- Provide lookup by tool name — given a function name (derived from `operationId`), return the tool's url, method, and parameters -- Handle unavailable tool services gracefully (skip with warning, don't fail the whole request) +- Return all available tools via `get_tools` +- Look up a tool by name via `get_tool_by_name` +- Handle unavailable tool services gracefully (skip with warning, don't fail) + +**No module dependencies**: The registry does not depend on other modAI modules. -**No module dependencies**: The registry does not depend on other modAI modules. Tool microservices are external HTTP services configured via the module's config. +### 4.3 OpenAPI Tool Registry (concrete implementation) -### 4.2 Tools Web Module (Web Module) +**Purpose**: Concrete registry implementation that harvests OpenAPI specs from configured HTTP microservices. + +**How it works**: +- On each call to `get_tools`, fetches `/openapi.json` from each configured service +- Extracts the tool definition from the spec: + - `operationId` → name + - `summary`/`description` → description + - Request body schema → parameters (all `$ref` resolved inline) + - Path parameters (`in: path`) from the `parameters` array are merged into the schema's `properties` and `required` lists so the LLM is told to supply them +- Each resulting tool's `run` operation: + 1. Resolves `{param_name}` placeholders in the configured URL by substituting values from the supplied `params` dict + 2. Sends the remaining parameters as the JSON request body + 3. Makes an HTTP call to the resolved URL using the configured method + +**Configuration** — each tool entry specifies: +- `url`: The full trigger endpoint URL of the tool microservice +- `method`: The HTTP method to use when invoking the tool (e.g. POST, PUT, GET) -**Purpose**: Exposes `GET /api/tools` endpoint. Transforms tool definitions from OpenAPI format into OpenAI function-calling format so the frontend can use them directly. +### 4.4 Tools Web Module (Web Module) -**Dependencies**: Tool Registry Module (injected via `module_dependencies`) +**Purpose**: Exposes `GET /api/tools` endpoint. Transforms tool definitions into OpenAI function-calling format. + +**Dependencies**: Tool Registry Module **Responsibilities**: - Expose `GET /api/tools` endpoint -- Call the Tool Registry to get all available tools with their OpenAPI specs -- Transform each tool's OpenAPI spec into OpenAI function-calling format (see section 5.1) -- Return the transformed tools to the frontend +- Call the Tool Registry to get all available tools +- Convert each tool definition to OpenAI function-calling format +- Return the transformed tool definitions to the frontend -### 4.3 Chat Agent Module (existing, updated dependency) +### 4.5 Chat Agent Module -The Chat Agent Module receives a `tool_registry` dependency. When the LLM emits a `tool_call`: +The Chat Agent Module receives a tool registry dependency. When the LLM emits a `tool_call`: 1. Extract the function name from the tool call -2. Look up the function name in the Tool Registry to get url + method -3. Send the HTTP request with the tool call arguments to the tool's endpoint -4. Return the response to the LLM +2. Look up the tool by name in the registry +3. Run the tool with the LLM-supplied parameters — no HTTP knowledge needed in the chat module +4. Return the result to the LLM ## 5. API Endpoints @@ -137,14 +223,12 @@ The Chat Agent Module receives a `tool_registry` dependency. When the LLM emits **Endpoint**: `GET /api/tools` -**Purpose**: Returns all available tools in OpenAI function-calling format. The frontend can pass these tool definitions directly when calling `/api/responses`. - -The Tools Web Module fetches tool data from the registry (OpenAPI specs + metadata) and transforms each tool into OpenAI format. +**Purpose**: Returns all available tools in OpenAI function-calling format. -**OpenAPI → OpenAI Transformation**: -- `operationId` → `function.name` -- `summary` (or `description`) → `function.description` -- Request body `schema` → `function.parameters` +**Tool Definition → OpenAI Transformation**: +- `name` → `function.name` +- `description` → `function.description` +- `parameters` → `function.parameters` (already resolved, no `$ref`) **Response Format (200 OK)**: ```json @@ -166,73 +250,19 @@ The Tools Web Module fetches tool data from the registry (OpenAPI specs + metada "required": ["expression"] } } - }, - { - "type": "function", - "function": { - "name": "web_search", - "description": "Search the web for current information", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query" - } - }, - "required": ["query"] - } - } } ] } ``` -If a tool service is unreachable, it is omitted from the response and a warning is logged. The endpoint never fails due to a single unavailable tool. +If a tool service is unreachable, it is omitted from the response and a warning is logged. ## 6. Configuration -```yaml -modules: - tool_registry: - class: modai.modules.tools.tool_registry.HttpToolRegistryModule - config: - tools: - - url: http://calculator-service:8000/calculate - method: POST - - url: http://web-search-service:8000/search - method: PUT - - tools_web: - class: modai.modules.tools.tools_web_module.ToolsWebModule - module_dependencies: - tool_registry: tool_registry - - chat_openai: - class: modai.modules.chat.openai_agent_chat.StrandsAgentChatModule - module_dependencies: - llm_provider_module: openai_model_provider - tool_registry: tool_registry -``` - -Each entry in `tools` (on the registry) has: +The tool registry is configured with a list of tool microservice endpoints. Each entry has: - `url`: The full trigger endpoint URL of the tool microservice - `method`: The HTTP method used to invoke the tool (e.g. PUT, POST, GET) -The registry derives the base URL from `url` to fetch the OpenAPI spec (appending `/openapi.json` to the base). - -## 7. Design Decisions - -- **Decision 1**: Tools are independent microservices, not modAI modules. - - **Rationale**: Maximum decoupling — tools can be written in any language, deployed independently, and reused across systems. - - **Trade-off**: Network overhead for spec fetching and tool invocation vs. in-process calls. - -- **Decision 2**: OpenAPI spec is fetched at request time, not cached. - - **Rationale**: Simplicity — no cache invalidation needed. Tool services can update their specs and changes are immediately visible. - - **Trade-off**: Higher latency on `GET /api/tools`. Can be optimized with caching later if needed. - -- **Decision 3**: The Tool Registry stores OpenAPI specs unmodified. The Tools Web Module transforms them. - - **Rationale**: Separation of concerns — the registry is a pure aggregator, the web module handles format conversion. This keeps each module focused on one job. +The registry derives the base URL from `url` (strips the path) and appends `/openapi.json` to fetch the spec. -- **Decision 4**: Tool name for lookup is derived from `operationId` in the OpenAPI spec. - - **Rationale**: `operationId` is a standard OpenAPI field designed to uniquely identify an operation, making it a natural tool name. +See `config.yaml` and `default_config.yaml` for concrete configuration examples. diff --git a/backend/omni/docs/learnings/INSTRUCTION_UPDATES.md b/backend/omni/docs/learnings/INSTRUCTION_UPDATES.md index 5ffb87a..64415e3 100644 --- a/backend/omni/docs/learnings/INSTRUCTION_UPDATES.md +++ b/backend/omni/docs/learnings/INSTRUCTION_UPDATES.md @@ -30,7 +30,10 @@ This file tracks corrections provided by the user to improve future performance. - **Correction**: All tests must exercise only the public interface. If internal logic needs coverage, improve public-API tests, not private-function tests. - **New Rule**: NO WHITEBOX TESTING. Never test `_prefixed` functions or assert on private object state. A test that does so is incorrect by definition and must be rewritten to go through the public API. Updated `AGENTS.md`. -### 2026-03-04 - llmock v2: no `/v1` path prefix, use trailing slash in base_url +### 2026-03-13 - Injected metadata in tool params uses `_` prefix +- **Convention**: When the caller needs to pass transport-level metadata (e.g. bearer token) into `Tool.run`, inject it as a `_`-prefixed key in the `params` dict (e.g. `_bearer_token`). The implementation pops those keys before building the request body; they are never forwarded as JSON payload. +- **New Rule**: Any caller-injected, non-payload property passed via `params` MUST use a `_`-prefixed key. Document new keys in `docs/architecture/tools.md` under "Reserved `_`-prefixed keys". + - **Mistake**: Passed `base_url = f"{root_url}/v1"` — the updated llmock no longer mounts routes under `/v1`. - **Correction**: All llmock endpoints are now at the root (`/chat/completions`, `/models`, `/health`). Pass `base_url = f"{root_url}/"` (trailing slash) so the OpenAI SDK does not append `/v1`. - **New Rule**: Always use `base_url = "http://:/"` (trailing slash) when connecting to llmock. The SKILL.md has been updated. diff --git a/backend/omni/src/modai/__tests__/test_module_loader.py b/backend/omni/src/modai/__tests__/test_module_loader.py index 9a4543e..4869cbd 100644 --- a/backend/omni/src/modai/__tests__/test_module_loader.py +++ b/backend/omni/src/modai/__tests__/test_module_loader.py @@ -58,11 +58,28 @@ def test_load_module_disabled(caplog): def test_load_module_none_config(): - """Test loading a module with None config.""" - startup_config = {"modules": {}} + """Test loading a module whose config key is None (e.g. YAML 'config:' with only comments). + + When a YAML config block has all its children commented out, the key is + present in the parsed dict but its value is None - not missing. The + module loader must treat None the same as an empty dict so that modules + that call config.get(...) in their __init__ don't raise AttributeError. + """ + startup_config = { + "modules": { + "foo": { + "class": "modai.__tests__.test_module_loader.DummyModule", + "config": None, # YAML: config: (with only commented-out children) + } + } + } loader = ModuleLoader(startup_config) + loader.load_modules() - assert loader.loaded_modules == {} + dummy_module = loader.get_module("foo") + assert dummy_module is not None + assert isinstance(dummy_module, DummyModule) + assert dummy_module.config == {} def test_load_module_import_error(): diff --git a/backend/omni/src/modai/default_config.yaml b/backend/omni/src/modai/default_config.yaml index b4e066a..70ab966 100644 --- a/backend/omni/src/modai/default_config.yaml +++ b/backend/omni/src/modai/default_config.yaml @@ -59,11 +59,21 @@ modules: session: "session" user_settings_store: "user_settings_store" - tool_registry: - class: modai.modules.tools.tool_registry.HttpToolRegistryModule + http_client: + class: modai.modules.http_client.httpx_http_client_module.HttpxHttpClientModule + + openapi_tool_registry: + class: modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule + module_dependencies: + http_client: "http_client" config: tools: [] + tool_registry: + class: modai.modules.tools.tool_registry_predefined_vars.PredefinedVariablesToolRegistryModule + module_dependencies: + delegate_registry: "openapi_tool_registry" + tools_web: class: modai.modules.tools.tools_web_module.OpenAIToolsWebModule module_dependencies: diff --git a/backend/omni/src/modai/module_loader.py b/backend/omni/src/modai/module_loader.py index 8ef6ccd..94d2797 100644 --- a/backend/omni/src/modai/module_loader.py +++ b/backend/omni/src/modai/module_loader.py @@ -69,7 +69,7 @@ def _load_modules_with_dependencies( continue module_class_path = full_module_config.get("class") - nested_config = full_module_config.get("config", {}) + nested_config = full_module_config.get("config") or {} self._load_module( module_name, module_class_path, module_dependencies, nested_config ) diff --git a/backend/omni/src/modai/modules/chat/__tests__/test_strands_agent_chat.py b/backend/omni/src/modai/modules/chat/__tests__/test_strands_agent_chat.py index 41daa43..a6fd82b 100644 --- a/backend/omni/src/modai/modules/chat/__tests__/test_strands_agent_chat.py +++ b/backend/omni/src/modai/modules/chat/__tests__/test_strands_agent_chat.py @@ -32,12 +32,43 @@ ModelProviderResponse, ModelProvidersListResponse, ) -from modai.modules.tools.module import ToolDefinition +from modai.modules.tools.module import Tool, ToolDefinition working_dir = Path.cwd() load_dotenv(find_dotenv(str(working_dir / ".env"))) +def _make_tool( + definition: ToolDefinition, run_url: str = "", run_method: str = "POST" +) -> Tool: + """Create a Tool stub for testing. + + If run_url is provided the tool will make a real HTTP call to that URL + when run() is called; otherwise run() returns an empty string. + """ + url = run_url + method = run_method + + class _TestTool(Tool): + @property + def definition(self) -> ToolDefinition: + return definition + + async def run(self, params: dict[str, Any]) -> Any: + if url: + import httpx + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.request( + method=method.upper(), url=url, json=params + ) + response.raise_for_status() + return response.text + return "" + + return _TestTool() + + # --------------------------------------------------------------------------- # llmock container # --------------------------------------------------------------------------- @@ -295,10 +326,10 @@ async def test_stream_completed_response_is_valid(self, llmock_base_url): class TestToolCallingHappyPath: """Tools are resolved from the registry and forwarded to the agent.""" - def _make_tool_registry(self, tool_def: ToolDefinition | None = None) -> Mock: + def _make_tool_registry(self, tool: Tool | None = None) -> Mock: registry = Mock() - if tool_def: - registry.get_tool_by_name = AsyncMock(return_value=tool_def) + if tool: + registry.get_tool_by_name = AsyncMock(return_value=tool) else: registry.get_tool_by_name = AsyncMock(return_value=None) return registry @@ -332,12 +363,19 @@ def _capture(request): "/calculate", method="POST" ).respond_with_handler(_capture) - tool_def = ToolDefinition( - url=httpserver.url_for("/calculate"), - method="POST", - openapi_spec=SAMPLE_TOOL_OPENAPI_SPEC, + definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, ) - registry = self._make_tool_registry(tool_def) + tool = _make_tool( + definition, run_url=httpserver.url_for("/calculate"), run_method="POST" + ) + registry = self._make_tool_registry(tool) module = _llmock_module(llmock_base_url, tool_registry=registry) body = { @@ -378,12 +416,19 @@ def _capture(request): "/calculate", method="POST" ).respond_with_handler(_capture) - tool_def = ToolDefinition( - url=httpserver.url_for("/calculate"), - method="POST", - openapi_spec=SAMPLE_TOOL_OPENAPI_SPEC, + definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, ) - registry = self._make_tool_registry(tool_def) + tool = _make_tool( + definition, run_url=httpserver.url_for("/calculate"), run_method="POST" + ) + registry = self._make_tool_registry(tool) module = _llmock_module(llmock_base_url, tool_registry=registry) body = { @@ -589,15 +634,23 @@ async def test_unknown_tool_is_silently_skipped(self, llmock_base_url): assert result.status == "completed" @pytest.mark.asyncio - async def test_tool_with_invalid_openapi_spec_is_skipped(self, llmock_base_url): - """A tool whose OpenAPI spec has no valid operation is skipped.""" - bad_tool_def = ToolDefinition( - url="http://broken:8000/noop", - method="POST", - openapi_spec={"paths": {}}, # no operations - ) + async def test_tool_run_error_is_handled_gracefully(self, llmock_base_url): + """A tool whose run() raises an error does not crash the agent.""" + + class _FailingTool(Tool): + @property + def definition(self) -> ToolDefinition: + return ToolDefinition( + name="broken_tool", + description="Broken", + parameters={"type": "object", "properties": {}}, + ) + + async def run(self, params: dict[str, Any]) -> Any: + raise RuntimeError("tool exploded") + registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=bad_tool_def) + registry.get_tool_by_name = AsyncMock(return_value=_FailingTool()) module = _llmock_module(llmock_base_url, tool_registry=registry) body = { @@ -613,6 +666,84 @@ async def test_tool_with_invalid_openapi_spec_is_skipped(self, llmock_base_url): assert isinstance(result, openai.types.responses.Response) assert result.status == "completed" + @pytest.mark.asyncio + async def test_additional_tool_properties_from_request_are_injected_into_params( + self, llmock_base_url + ): + """Properties extracted from the request are injected as _-prefixed keys in tool params.""" + captured_params: list[dict] = [] + + class _CapturingTool(Tool): + @property + def definition(self) -> ToolDefinition: + return ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, + ) + + async def run(self, params: dict[str, Any]) -> Any: + captured_params.append(dict(params)) + return "42" + + registry = Mock() + registry.get_tool_by_name = AsyncMock(return_value=_CapturingTool()) + + module = _llmock_module(llmock_base_url, tool_registry=registry) + body = { + "model": "myprovider/gpt-4o", + "input": "call tool 'calculate' with '{\"expression\": \"1+1\"}'", + "tools": [{"type": "function", "function": {"name": "calculate"}}], + } + + await module.generate_response( + _make_request(authorization="Bearer mytoken"), body + ) + + assert len(captured_params) >= 1 + assert captured_params[0].get("_bearer_token") == "mytoken" + + @pytest.mark.asyncio + async def test_no_additional_tool_properties_injected_when_absent_in_request( + self, llmock_base_url + ): + """When the request carries no extractable properties, no _-prefixed keys are injected.""" + captured_params: list[dict] = [] + + class _CapturingTool(Tool): + @property + def definition(self) -> ToolDefinition: + return ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, + ) + + async def run(self, params: dict[str, Any]) -> Any: + captured_params.append(dict(params)) + return "42" + + registry = Mock() + registry.get_tool_by_name = AsyncMock(return_value=_CapturingTool()) + + module = _llmock_module(llmock_base_url, tool_registry=registry) + body = { + "model": "myprovider/gpt-4o", + "input": "call tool 'calculate' with '{\"expression\": \"1+1\"}'", + "tools": [{"type": "function", "function": {"name": "calculate"}}], + } + + await module.generate_response(_make_request(), body) + + assert len(captured_params) >= 1 + assert not any(k.startswith("_") for k in captured_params[0]) + @pytest.mark.asyncio async def test_tool_registry_error_propagates(self): """If the tool registry raises, the error propagates.""" @@ -645,13 +776,20 @@ async def test_tool_invocation_http_error_agent_handles_gracefully( message is a user message). On the next turn the last message is the tool result, so MirrorStrategy takes over and the agent completes. """ - tool_def = ToolDefinition( - url="http://localhost:1/calculate", # unreachable - method="POST", - openapi_spec=SAMPLE_TOOL_OPENAPI_SPEC, + definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, + ) + # run_url points to an unreachable port — tool.run() will raise + tool = _make_tool( + definition, run_url="http://localhost:1/calculate", run_method="POST" ) registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=tool_def) + registry.get_tool_by_name = AsyncMock(return_value=tool) module = _llmock_module(llmock_base_url, tool_registry=registry) body = { @@ -676,13 +814,19 @@ async def test_tool_invocation_success_request_sent_to_tool( """ httpserver.expect_oneshot_request("/calculate").respond_with_json({"result": 4}) - tool_def = ToolDefinition( - url=httpserver.url_for("/calculate"), - method="POST", - openapi_spec=SAMPLE_TOOL_OPENAPI_SPEC, + definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, + ) + tool = _make_tool( + definition, run_url=httpserver.url_for("/calculate"), run_method="POST" ) registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=tool_def) + registry.get_tool_by_name = AsyncMock(return_value=tool) module = _llmock_module(llmock_base_url, tool_registry=registry) body = { @@ -698,14 +842,18 @@ async def test_tool_invocation_success_request_sent_to_tool( @pytest.mark.asyncio async def test_partial_tools_resolved_when_some_missing(self, llmock_base_url): """When some tools are found and others not, only found tools are used.""" - calc_def = ToolDefinition( - url="http://calc:8000/calculate", - method="POST", - openapi_spec=SAMPLE_TOOL_OPENAPI_SPEC, + calc_definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, ) + calc_tool = _make_tool(calc_definition) registry = Mock() registry.get_tool_by_name = AsyncMock( - side_effect=lambda name: calc_def if name == "calculate" else None + side_effect=lambda name, **_: calc_tool if name == "calculate" else None ) module = _llmock_module(llmock_base_url, tool_registry=registry) @@ -818,8 +966,12 @@ def _make_dependencies( return ModuleDependencies(modules) -def _make_request() -> Request: - return Mock(spec=Request) +def _make_request(authorization: str | None = None) -> Request: + mock = Mock(spec=Request) + mock.headers.get.side_effect = lambda name, default="": ( + authorization if name == "Authorization" and authorization else default + ) + return mock def _llmock_module( diff --git a/backend/omni/src/modai/modules/chat/openai_agent_chat.py b/backend/omni/src/modai/modules/chat/openai_agent_chat.py index 5653e5d..6650bdf 100644 --- a/backend/omni/src/modai/modules/chat/openai_agent_chat.py +++ b/backend/omni/src/modai/modules/chat/openai_agent_chat.py @@ -13,7 +13,6 @@ from datetime import datetime, timezone from typing import Any, AsyncGenerator -import httpx from fastapi import Request from openai.types.responses import ( Response as OpenAIResponse, @@ -37,12 +36,11 @@ ModelProviderModule, ModelProviderResponse, ) -from modai.modules.tools.module import ToolDefinition, ToolRegistryModule +from modai.modules.tools.module import Tool, ToolRegistryModule logger = logging.getLogger(__name__) DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant." -TOOL_HTTP_TIMEOUT_SECONDS = 30.0 class StrandsAgentChatModule(ChatLLMModule): @@ -73,7 +71,12 @@ async def generate_response( ) -> OpenAIResponse | AsyncGenerator[OpenAIResponseStreamEvent, None]: provider_name, actual_model = _parse_model(body_json.get("model", "")) provider = await self._resolve_provider(request, provider_name) - tools = await _resolve_request_tools(body_json, self.tool_registry) + additional_tool_properties = _extract_additional_tool_properties(request) + tools = await _resolve_request_tools( + body_json, + self.tool_registry, + additional_tool_properties=additional_tool_properties, + ) agent = _create_agent(provider, actual_model, body_json, tools) user_message = _extract_last_user_message(body_json) @@ -112,6 +115,23 @@ def _parse_model(model: str) -> tuple[str, str]: return parts[0], parts[1] +def _extract_additional_tool_properties(request: Request) -> dict[str, Any]: + """Extract caller-supplied metadata from the request to inject into tool calls. + + Returns a dict of ``_``-prefixed keys that are merged into every tool + invocation's ``params`` dict. Tool implementations consume these reserved + keys (e.g. for HTTP headers) without forwarding them to the payload. + + Currently extracted properties: + - ``_bearer_token``: raw token from the ``Authorization: Bearer`` header. + """ + properties: dict[str, Any] = {} + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + properties["_bearer_token"] = auth_header[len("Bearer ") :] + return properties + + def _create_agent( provider: ModelProviderResponse, model_id: str, @@ -215,6 +235,7 @@ def _extract_tool_names(body_json: OpenAICreateResponse) -> list[str]: async def _resolve_request_tools( body_json: OpenAICreateResponse, tool_registry: ToolRegistryModule | None, + additional_tool_properties: dict[str, Any] | None = None, ) -> list[PythonAgentTool]: """Resolve requested tools from the request body into Strands agent tools. @@ -222,6 +243,10 @@ async def _resolve_request_tools( is looked up in the registry and wrapped as a ``PythonAgentTool`` that invokes the tool microservice over HTTP. + ``additional_tool_properties`` is a dict of ``_``-prefixed keys extracted + from the request (see ``_extract_additional_tool_properties``) that are + merged into every tool invocation's params dict. + Returns an empty list when no registry is configured or no tools are requested. """ @@ -234,66 +259,56 @@ async def _resolve_request_tools( strands_tools: list[PythonAgentTool] = [] for name in tool_names: - tool_def = await tool_registry.get_tool_by_name(name) - if tool_def is None: + tool = await tool_registry.get_tool_by_name( + name, predefined_params=additional_tool_properties + ) + if tool is None: logger.warning("Tool '%s' not found in registry, skipping", name) continue - strands_tool = _create_http_tool(tool_def) - if strands_tool: - strands_tools.append(strands_tool) + strands_tools.append( + _create_strands_tool( + tool, additional_tool_properties=additional_tool_properties + ) + ) return strands_tools -def _create_http_tool(tool_def: ToolDefinition) -> PythonAgentTool | None: - """Create a Strands ``PythonAgentTool`` that invokes a tool via HTTP. +def _create_strands_tool( + tool: Tool, additional_tool_properties: dict[str, Any] | None = None +) -> PythonAgentTool: + """Wrap a Tool as a Strands ``PythonAgentTool``. - The tool spec (name, description, input schema) is derived from the - tool's OpenAPI spec. The handler makes an HTTP request to the tool's - endpoint and returns the response body to the LLM. - """ - operation = _extract_operation(tool_def.openapi_spec) - if not operation: - logger.warning( - "No operation found in OpenAPI spec for tool at %s", tool_def.url - ) - return None - - operation_id = operation.get("operationId", "") - description = operation.get("summary") or operation.get("description", "") + The tool spec (name, description, input schema) comes from the tool's + definition. The handler delegates execution to ``tool.run``. - request_body = operation.get("requestBody", {}) - content = request_body.get("content", {}) - json_content = content.get("application/json", {}) - parameters_schema = json_content.get("schema", {"type": "object", "properties": {}}) + ``additional_tool_properties`` (a dict of ``_``-prefixed keys) is merged + into every invocation's params dict so that tool implementations can pick + up transport-level concerns (auth, tracing, etc.) without the interface + carrying extra args. + """ + definition = tool.definition tool_spec: ToolSpec = { - "name": operation_id, - "description": description, - "inputSchema": {"json": parameters_schema}, + "name": definition.name, + "description": definition.description, + "inputSchema": {"json": definition.parameters}, } - url = tool_def.url - method = tool_def.method - - def _handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: # noqa: ARG001 - """Invoke the tool microservice over HTTP.""" - params = tool_use["input"] + async def _handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: # noqa: ARG001 + """Invoke the tool and wrap the result for Strands.""" + params: dict[str, Any] = dict(tool_use["input"]) + if additional_tool_properties: + params.update(additional_tool_properties) try: - with httpx.Client(timeout=TOOL_HTTP_TIMEOUT_SECONDS) as client: - response = client.request( - method=method.upper(), - url=url, - json=params, - ) - response.raise_for_status() - return { - "toolUseId": tool_use["toolUseId"], - "status": "success", - "content": [{"text": response.text}], - } + result = await tool.run(params) + return { + "toolUseId": tool_use["toolUseId"], + "status": "success", + "content": [{"text": str(result)}], + } except Exception as exc: - logger.error("Tool '%s' invocation failed: %s", operation_id, exc) + logger.error("Tool '%s' invocation failed: %s", definition.name, exc) return { "toolUseId": tool_use["toolUseId"], "status": "error", @@ -301,22 +316,12 @@ def _handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: # noqa: ARG001 } return PythonAgentTool( - tool_name=operation_id, + tool_name=definition.name, tool_spec=tool_spec, tool_func=_handler, ) -def _extract_operation(spec: dict[str, Any]) -> dict[str, Any] | None: - """Extract the first operation from an OpenAPI spec.""" - paths = spec.get("paths", {}) - for _path, methods in paths.items(): - for _method, operation in methods.items(): - if isinstance(operation, dict) and "operationId" in operation: - return operation - return None - - # --------------------------------------------------------------------------- # Response builders # --------------------------------------------------------------------------- diff --git a/backend/omni/src/modai/modules/http_client/__init__.py b/backend/omni/src/modai/modules/http_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/omni/src/modai/modules/http_client/httpx_http_client_module.py b/backend/omni/src/modai/modules/http_client/httpx_http_client_module.py new file mode 100644 index 0000000..aba5d48 --- /dev/null +++ b/backend/omni/src/modai/modules/http_client/httpx_http_client_module.py @@ -0,0 +1,24 @@ +"""httpx-backed implementation of :class:`HttpClientModule`.""" + +from contextlib import asynccontextmanager +from typing import Any, AsyncContextManager + +import httpx + +from modai.module import ModuleDependencies +from modai.modules.http_client.module import HttpClientModule + + +class HttpxHttpClientModule(HttpClientModule): + """HTTP client factory backed by :class:`httpx.AsyncClient`.""" + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + + def new(self, timeout: float) -> AsyncContextManager[httpx.AsyncClient]: + @asynccontextmanager + async def _create(): + async with httpx.AsyncClient(timeout=timeout) as client: + yield client + + return _create() diff --git a/backend/omni/src/modai/modules/http_client/module.py b/backend/omni/src/modai/modules/http_client/module.py new file mode 100644 index 0000000..ad4e707 --- /dev/null +++ b/backend/omni/src/modai/modules/http_client/module.py @@ -0,0 +1,35 @@ +"""HTTP client module interfaces. + +Modules that need to make outbound HTTP requests should declare a dependency +on :class:`HttpClientModule` and use the factory to obtain a scoped client: + + async with self._http_client.new(timeout=10.0) as client: + response = await client.request("GET", url) + +This keeps HTTP-client instantiation out of business code and makes +units trivially testable without patching ``httpx.AsyncClient``. +""" + +from abc import ABC, abstractmethod +from typing import Any, AsyncContextManager + +import httpx + +from modai.module import ModaiModule, ModuleDependencies + + +class HttpClientModule(ModaiModule, ABC): + """Factory module for scoped HTTP clients. + + Usage:: + + async with self._http_client.new(timeout=10.0) as client: + response = await client.request("POST", url, json=payload) + """ + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + + @abstractmethod + def new(self, timeout: float) -> AsyncContextManager[httpx.AsyncClient]: + """Return an async context manager that yields a connected :class:`httpx.AsyncClient`.""" diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry.py deleted file mode 100644 index 5348fe2..0000000 --- a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry.py +++ /dev/null @@ -1,352 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import httpx -import pytest - -from modai.module import ModuleDependencies -from modai.modules.tools.module import ToolDefinition -from modai.modules.tools.tool_registry import ( - HttpToolRegistryModule, - _derive_base_url, - _extract_operation_id, - _fetch_openapi_spec, -) - - -SAMPLE_OPENAPI_SPEC = { - "openapi": "3.1.0", - "info": {"title": "Calculator Tool", "version": "1.0.0"}, - "paths": { - "/calculate": { - "post": { - "summary": "Evaluate a math expression", - "operationId": "calculate", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - } - } - }, - }, - } - } - }, -} - - -class TestHttpToolRegistryModule: - def _make_module(self, tools: list[dict]) -> HttpToolRegistryModule: - deps = ModuleDependencies() - config = {"tools": tools} - return HttpToolRegistryModule(deps, config) - - @pytest.mark.asyncio - async def test_get_tools_empty_config(self): - module = self._make_module([]) - result = await module.get_tools() - assert result == [] - - @pytest.mark.asyncio - async def test_get_tools_returns_specs_from_all_services(self): - module = self._make_module( - [ - {"url": "http://calc:8000/calculate", "method": "POST"}, - {"url": "http://search:8000/search", "method": "PUT"}, - ] - ) - - spec_a = {**SAMPLE_OPENAPI_SPEC, "info": {"title": "Calc", "version": "1.0.0"}} - spec_b = { - **SAMPLE_OPENAPI_SPEC, - "info": {"title": "Search", "version": "1.0.0"}, - } - - mock_response_a = MagicMock() - mock_response_a.status_code = 200 - mock_response_a.raise_for_status = lambda: None - mock_response_a.json.return_value = spec_a - - mock_response_b = MagicMock() - mock_response_b.status_code = 200 - mock_response_b.raise_for_status = lambda: None - mock_response_b.json.return_value = spec_b - - async def mock_get(url, **kwargs): - if "calc" in url: - return mock_response_a - return mock_response_b - - with patch( - "modai.modules.tools.tool_registry.httpx.AsyncClient" - ) as mock_client_cls: - mock_client = AsyncMock() - mock_client.get = mock_get - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client_cls.return_value = mock_client - - result = await module.get_tools() - - assert len(result) == 2 - assert result[0].url == "http://calc:8000/calculate" - assert result[0].method == "POST" - assert result[0].openapi_spec["info"]["title"] == "Calc" - assert result[1].url == "http://search:8000/search" - assert result[1].method == "PUT" - assert result[1].openapi_spec["info"]["title"] == "Search" - - @pytest.mark.asyncio - async def test_get_tools_skips_unavailable_service(self): - module = self._make_module( - [ - {"url": "http://good:8000/run", "method": "POST"}, - {"url": "http://bad:8000/run", "method": "POST"}, - ] - ) - - mock_response_good = MagicMock() - mock_response_good.status_code = 200 - mock_response_good.raise_for_status = lambda: None - mock_response_good.json.return_value = SAMPLE_OPENAPI_SPEC - - async def mock_get(url, **kwargs): - if "bad" in url: - raise httpx.ConnectError("Connection refused") - return mock_response_good - - with patch( - "modai.modules.tools.tool_registry.httpx.AsyncClient" - ) as mock_client_cls: - mock_client = AsyncMock() - mock_client.get = mock_get - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client_cls.return_value = mock_client - - result = await module.get_tools() - - assert len(result) == 1 - assert result[0].url == "http://good:8000/run" - - @pytest.mark.asyncio - async def test_specs_are_returned_unmodified(self): - module = self._make_module([{"url": "http://tool:8000/run", "method": "PUT"}]) - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = lambda: None - mock_response.json.return_value = SAMPLE_OPENAPI_SPEC - - async def mock_get(url, **kwargs): - return mock_response - - with patch( - "modai.modules.tools.tool_registry.httpx.AsyncClient" - ) as mock_client_cls: - mock_client = AsyncMock() - mock_client.get = mock_get - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client_cls.return_value = mock_client - - result = await module.get_tools() - - assert result[0].openapi_spec == SAMPLE_OPENAPI_SPEC - - def test_has_no_router(self): - module = self._make_module([]) - assert not hasattr(module, "router") - - def test_stores_tool_services_from_config(self): - tools = [ - {"url": "http://a:8000/run", "method": "POST"}, - {"url": "http://b:9000/exec", "method": "PUT"}, - ] - module = self._make_module(tools) - assert module.tool_services == tools - - def test_defaults_to_empty_tools_list(self): - deps = ModuleDependencies() - module = HttpToolRegistryModule(deps, {}) - assert module.tool_services == [] - - -class TestFetchOpenapiSpec: - @pytest.mark.asyncio - async def test_success(self): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = lambda: None - mock_response.json.return_value = SAMPLE_OPENAPI_SPEC - - client = AsyncMock() - client.get = AsyncMock(return_value=mock_response) - - result = await _fetch_openapi_spec(client, "http://tool:8000") - assert result == SAMPLE_OPENAPI_SPEC - client.get.assert_called_once_with("http://tool:8000/openapi.json") - - @pytest.mark.asyncio - async def test_strips_trailing_slash(self): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = lambda: None - mock_response.json.return_value = SAMPLE_OPENAPI_SPEC - - client = AsyncMock() - client.get = AsyncMock(return_value=mock_response) - - await _fetch_openapi_spec(client, "http://tool:8000/") - client.get.assert_called_once_with("http://tool:8000/openapi.json") - - @pytest.mark.asyncio - async def test_http_error_returns_none(self): - mock_response = MagicMock() - mock_response.status_code = 500 - mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Server Error", - request=httpx.Request("GET", "http://tool:8000/openapi.json"), - response=mock_response, - ) - - client = AsyncMock() - client.get = AsyncMock(return_value=mock_response) - - result = await _fetch_openapi_spec(client, "http://tool:8000") - assert result is None - - @pytest.mark.asyncio - async def test_connection_error_returns_none(self): - client = AsyncMock() - client.get = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) - - result = await _fetch_openapi_spec(client, "http://tool:8000") - assert result is None - - @pytest.mark.asyncio - async def test_unexpected_error_returns_none(self): - client = AsyncMock() - client.get = AsyncMock(side_effect=RuntimeError("something went wrong")) - - result = await _fetch_openapi_spec(client, "http://tool:8000") - assert result is None - - -class TestDeriveBaseUrl: - def test_strips_path(self): - assert _derive_base_url("http://calc:8000/calculate") == "http://calc:8000" - - def test_strips_nested_path(self): - assert _derive_base_url("http://host:9000/api/v1/run") == "http://host:9000" - - def test_no_path(self): - assert _derive_base_url("http://tool:8000") == "http://tool:8000" - - def test_trailing_slash(self): - assert _derive_base_url("http://tool:8000/") == "http://tool:8000" - - -class TestExtractOperationId: - def test_extracts_from_valid_spec(self): - assert _extract_operation_id(SAMPLE_OPENAPI_SPEC) == "calculate" - - def test_returns_none_for_empty_paths(self): - assert _extract_operation_id({"paths": {}}) is None - - def test_returns_none_for_missing_paths(self): - assert _extract_operation_id({}) is None - - def test_returns_none_for_no_operation_id(self): - spec = {"paths": {"/run": {"post": {"summary": "No operationId here"}}}} - assert _extract_operation_id(spec) is None - - def test_skips_non_dict_entries(self): - spec = { - "paths": { - "/run": { - "parameters": [{"name": "x"}], - "post": {"operationId": "run_it", "summary": "Run"}, - } - } - } - assert _extract_operation_id(spec) == "run_it" - - -class TestGetToolByName: - def _make_module(self, tools: list[dict]) -> HttpToolRegistryModule: - deps = ModuleDependencies() - config = {"tools": tools} - return HttpToolRegistryModule(deps, config) - - def _mock_httpx(self, spec_map: dict[str, dict]): - """Return a context manager that patches httpx.AsyncClient. - - spec_map: domain substring -> openapi spec to return - """ - mock_responses = {} - for key, spec in spec_map.items(): - resp = MagicMock() - resp.status_code = 200 - resp.raise_for_status = lambda: None - resp.json.return_value = spec - mock_responses[key] = resp - - async def mock_get(url, **kwargs): - for key, resp in mock_responses.items(): - if key in url: - return resp - raise httpx.ConnectError("No mock for " + url) - - mock_client_cls = patch("modai.modules.tools.tool_registry.httpx.AsyncClient") - return mock_client_cls, mock_get - - @pytest.mark.asyncio - async def test_finds_tool_by_operation_id(self): - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}] - ) - - mock_client_cls, mock_get = self._mock_httpx({"calc": SAMPLE_OPENAPI_SPEC}) - with mock_client_cls as cls: - mock_client = AsyncMock() - mock_client.get = mock_get - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - cls.return_value = mock_client - - result = await module.get_tool_by_name("calculate") - - assert result == ToolDefinition( - url="http://calc:8000/calculate", - method="POST", - openapi_spec=SAMPLE_OPENAPI_SPEC, - ) - - @pytest.mark.asyncio - async def test_returns_none_for_unknown_name(self): - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}] - ) - - mock_client_cls, mock_get = self._mock_httpx({"calc": SAMPLE_OPENAPI_SPEC}) - with mock_client_cls as cls: - mock_client = AsyncMock() - mock_client.get = mock_get - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - cls.return_value = mock_client - - result = await module.get_tool_by_name("nonexistent") - - assert result is None - - @pytest.mark.asyncio - async def test_returns_none_for_empty_registry(self): - module = self._make_module([]) - result = await module.get_tool_by_name("calculate") - assert result is None diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py new file mode 100644 index 0000000..ba0fa29 --- /dev/null +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py @@ -0,0 +1,789 @@ +from contextlib import asynccontextmanager +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from modai.module import ModuleDependencies +from modai.modules.http_client.module import HttpClientModule +from modai.modules.tools.module import Tool, ToolDefinition +from modai.modules.tools.tool_registry_openapi import ( + OpenAPIToolRegistryModule, + _build_tool_definition, + _derive_base_url, + _fetch_openapi_spec, +) + + +class _StubHttpClientFactory(HttpClientModule): + """Test factory that yields clients in sequence; reuses the last one when exhausted.""" + + def __init__(self, *clients: httpx.AsyncClient): + super().__init__(ModuleDependencies(), {}) + self._clients = list(clients) + self._index = 0 + + def new(self, timeout: float) -> Any: + @asynccontextmanager + async def _ctx(): + idx = min(self._index, len(self._clients) - 1) + self._index += 1 + yield self._clients[idx] + + return _ctx() + + +def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: + """Build a minimal mock httpx response.""" + resp = MagicMock() + resp.raise_for_status = MagicMock() + if spec is not None: + resp.json.return_value = spec + resp.text = text + return resp + + +SAMPLE_OPENAPI_SPEC = { + "openapi": "3.1.0", + "info": {"title": "Calculator Tool", "version": "1.0.0"}, + "paths": { + "/calculate": { + "post": { + "summary": "Evaluate a math expression", + "operationId": "calculate", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + } + } + }, + }, + } + } + }, +} + +PATH_PARAMS_SPEC = { + "openapi": "3.1.0", + "info": {"title": "User Tool", "version": "1.0.0"}, + "paths": { + "/users/{user_id}/orders/{order_id}": { + "get": { + "summary": "Get a specific user order", + "operationId": "get_user_order", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": True, + "description": "The user's ID", + "schema": {"type": "string"}, + }, + { + "name": "order_id", + "in": "path", + "required": True, + "description": "The order's ID", + "schema": {"type": "integer"}, + }, + ], + } + } + }, +} + +HEADER_PARAMS_SPEC = { + "openapi": "3.1.0", + "info": {"title": "Session Tool", "version": "1.0.0"}, + "paths": { + "/data": { + "get": { + "summary": "Fetch session data", + "operationId": "fetch_data", + "parameters": [ + { + "name": "X-Session-Id", + "in": "header", + "required": True, + "description": "Active session identifier", + "schema": {"type": "string"}, + }, + { + "name": "X-Tenant", + "in": "header", + "required": False, + "description": "Optional tenant override", + "schema": {"type": "string"}, + }, + ], + } + } + }, +} + +HEADER_AND_BODY_SPEC = { + "openapi": "3.1.0", + "info": {"title": "Submit Tool", "version": "1.0.0"}, + "paths": { + "/submit": { + "post": { + "summary": "Submit a payload", + "operationId": "submit", + "parameters": [ + { + "name": "X-Request-Id", + "in": "header", + "required": True, + "description": "Idempotency key", + "schema": {"type": "string"}, + } + ], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"payload": {"type": "string"}}, + "required": ["payload"], + } + } + }, + }, + } + } + }, +} + +PATH_PARAMS_WITH_BODY_SPEC = { + "openapi": "3.1.0", + "info": {"title": "Update Tool", "version": "1.0.0"}, + "paths": { + "/items/{item_id}": { + "put": { + "summary": "Update an item", + "operationId": "update_item", + "parameters": [ + { + "name": "item_id", + "in": "path", + "required": True, + "description": "The item's ID", + "schema": {"type": "integer"}, + }, + ], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + } + }, + }, + } + } + }, +} + +DICE_ROLLER_SPEC = { + "openapi": "3.1.0", + "info": {"title": "Dice Roller Tool", "version": "1.0.0"}, + "paths": { + "/roll": { + "post": { + "summary": "Roll dice and return the results", + "operationId": "roll_dice", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/DiceRequest"} + } + }, + }, + } + } + }, + "components": { + "schemas": { + "DiceRequest": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "default": 1, + "description": "Number of dice to roll", + }, + "sides": { + "type": "integer", + "default": 6, + "description": "Number of sides per die", + }, + }, + } + } + }, +} + + +class TestBuildToolDefinition: + def test_openapi_with_inline_schema(self): + definition, header_names = _build_tool_definition(SAMPLE_OPENAPI_SPEC) + assert definition == ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + ) + assert header_names == frozenset() + + def test_openapi_with_ref_schema(self): + definition, header_names = _build_tool_definition(DICE_ROLLER_SPEC) + assert definition == ToolDefinition( + name="roll_dice", + description="Roll dice and return the results", + parameters={ + "type": "object", + "properties": { + "count": { + "type": "integer", + "default": 1, + "description": "Number of dice to roll", + }, + "sides": { + "type": "integer", + "default": 6, + "description": "Number of sides per die", + }, + }, + }, + ) + assert header_names == frozenset() + + def test_path_parameters_only(self): + definition, header_names = _build_tool_definition(PATH_PARAMS_SPEC) + assert definition is not None + assert definition.name == "get_user_order" + assert definition.description == "Get a specific user order" + params = definition.parameters + assert params["type"] == "object" + assert "user_id" in params["properties"] + assert "order_id" in params["properties"] + assert params["properties"]["user_id"]["type"] == "string" + assert params["properties"]["user_id"]["description"] == "The user's ID" + assert params["properties"]["order_id"]["type"] == "integer" + assert set(params["required"]) == {"user_id", "order_id"} + assert header_names == frozenset() + + def test_path_parameters_merged_with_request_body(self): + definition, header_names = _build_tool_definition(PATH_PARAMS_WITH_BODY_SPEC) + assert definition is not None + params = definition.parameters + assert "item_id" in params["properties"] + assert "name" in params["properties"] + assert "item_id" in params["required"] + assert "name" in params["required"] + assert header_names == frozenset() + + def test_header_parameters_in_definition(self): + definition, header_names = _build_tool_definition(HEADER_PARAMS_SPEC) + assert definition is not None + params = definition.parameters + assert "X-Session-Id" in params["properties"] + assert "X-Tenant" in params["properties"] + assert params["properties"]["X-Session-Id"]["type"] == "string" + assert ( + params["properties"]["X-Session-Id"]["description"] + == "Active session identifier" + ) + assert "X-Session-Id" in params["required"] + assert "X-Tenant" not in params.get("required", []) + assert header_names == {"X-Session-Id", "X-Tenant"} + + def test_header_parameters_merged_with_request_body(self): + definition, header_names = _build_tool_definition(HEADER_AND_BODY_SPEC) + assert definition is not None + params = definition.parameters + assert "X-Request-Id" in params["properties"] + assert "payload" in params["properties"] + assert "X-Request-Id" in params["required"] + assert "payload" in params["required"] + assert header_names == {"X-Request-Id"} + + def test_no_operation_id_returns_none(self): + spec = {"paths": {"/run": {"post": {"summary": "no id"}}}} + definition, header_names = _build_tool_definition(spec) + assert definition is None + assert header_names == frozenset() + + def _make_module( + self, tools: list[dict], factory=None + ) -> OpenAPIToolRegistryModule: + if factory is None: + # Provide a factory that yields a no-op async client by default + factory = _StubHttpClientFactory(AsyncMock()) + deps = ModuleDependencies({"http_client": factory}) + return OpenAPIToolRegistryModule(deps, {"tools": tools}) + + @pytest.mark.asyncio + async def test_get_tools_empty_config(self): + module = self._make_module([]) + result = await module.get_tools() + assert result == [] + + @pytest.mark.asyncio + async def test_get_tools_returns_tools_from_all_services(self): + search_spec = { + **SAMPLE_OPENAPI_SPEC, + "paths": { + "/search": { + "put": { + "summary": "Search the web", + "operationId": "web_search", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + } + } + } + }, + } + } + }, + } + + async def mock_request(method, url, **kwargs): + if "calc" in url: + return _mock_response(spec=SAMPLE_OPENAPI_SPEC) + return _mock_response(spec=search_spec) + + mock_client = AsyncMock() + mock_client.request = mock_request + module = self._make_module( + [ + {"url": "http://calc:8000/calculate", "method": "POST"}, + {"url": "http://search:8000/search", "method": "PUT"}, + ], + factory=_StubHttpClientFactory(mock_client), + ) + + result = await module.get_tools() + + assert len(result) == 2 + assert isinstance(result[0], Tool) + assert isinstance(result[1], Tool) + names = {tool.definition.name for tool in result} + assert names == {"calculate", "web_search"} + + @pytest.mark.asyncio + async def test_tool_definition_extracted_from_spec(self): + mock_client = AsyncMock() + mock_client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) + module = self._make_module( + [{"url": "http://calc:8000/calculate", "method": "POST"}], + factory=_StubHttpClientFactory(mock_client), + ) + + result = await module.get_tools() + + assert len(result) == 1 + definition = result[0].definition + assert definition.name == "calculate" + assert definition.description == "Evaluate a math expression" + assert "expression" in definition.parameters["properties"] + + @pytest.mark.asyncio + async def test_get_tools_skips_unavailable_service(self): + async def mock_request(method, url, **kwargs): + if "bad" in url: + raise httpx.ConnectError("Connection refused") + return _mock_response(spec=SAMPLE_OPENAPI_SPEC) + + mock_client = AsyncMock() + mock_client.request = mock_request + module = self._make_module( + [ + {"url": "http://good:8000/run", "method": "POST"}, + {"url": "http://bad:8000/run", "method": "POST"}, + ], + factory=_StubHttpClientFactory(mock_client), + ) + + result = await module.get_tools() + + assert len(result) == 1 + assert result[0].definition.name == "calculate" + + @pytest.mark.asyncio + async def test_get_tools_skips_spec_without_operation_id(self): + no_op_spec = {"paths": {"/run": {"post": {"summary": "No operationId"}}}} + mock_client = AsyncMock() + mock_client.request = AsyncMock(return_value=_mock_response(spec=no_op_spec)) + module = self._make_module( + [{"url": "http://tool:8000/run", "method": "POST"}], + factory=_StubHttpClientFactory(mock_client), + ) + + result = await module.get_tools() + + assert result == [] + + def test_has_no_router(self): + module = self._make_module([]) + assert not hasattr(module, "router") + + def test_stores_tool_services_from_config(self): + tools = [ + {"url": "http://a:8000/run", "method": "POST"}, + {"url": "http://b:9000/exec", "method": "PUT"}, + ] + module = self._make_module(tools) + assert module.tool_services == tools + + def test_defaults_to_empty_tools_list(self): + deps = ModuleDependencies() + module = OpenAPIToolRegistryModule(deps, {}) + assert module.tool_services == [] + + +class TestToolRun: + """Tool.run invokes the tool microservice over HTTP.""" + + def _make_module( + self, tools: list[dict], factory=None + ) -> OpenAPIToolRegistryModule: + if factory is None: + factory = _StubHttpClientFactory(AsyncMock()) + deps = ModuleDependencies({"http_client": factory}) + return OpenAPIToolRegistryModule(deps, {"tools": tools}) + + @pytest.mark.asyncio + async def test_run_makes_http_request_to_tool_endpoint(self): + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) + + run_response = _mock_response(text='{"result": 42}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + # factory yields spec_client on first new() call, run_client on second + module = self._make_module( + [{"url": "http://calc:8000/calculate", "method": "POST"}], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + assert len(tools) == 1 + + result = await tools[0].run({"expression": "6*7"}) + + run_client.request.assert_called_once_with( + method="POST", + url="http://calc:8000/calculate", + json={"expression": "6*7"}, + headers={}, + ) + assert result == '{"result": 42}' + + @pytest.mark.asyncio + async def test_run_forwards_bearer_token_as_authorization_header(self): + """When _bearer_token is in params it becomes Authorization: Bearer .""" + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) + + run_response = _mock_response(text='{"result": 42}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + module = self._make_module( + [{"url": "http://calc:8000/calculate", "method": "POST"}], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + await tools[0].run({"expression": "2+2", "_bearer_token": "secret"}) + + run_client.request.assert_called_once_with( + method="POST", + url="http://calc:8000/calculate", + json={"expression": "2+2"}, + headers={"Authorization": "Bearer secret"}, + ) + + @pytest.mark.asyncio + async def test_run_substitutes_path_parameters_into_url(self): + """Path parameters are substituted into the URL template, not sent in the body.""" + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=PATH_PARAMS_SPEC) + ) + + run_response = _mock_response(text='{"order": "details"}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + module = self._make_module( + [ + { + "url": "http://users:8000/users/{user_id}/orders/{order_id}", + "method": "GET", + } + ], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + assert len(tools) == 1 + + result = await tools[0].run({"user_id": "alice", "order_id": 42}) + + run_client.request.assert_called_once_with( + method="GET", + url="http://users:8000/users/alice/orders/42", + json={}, + headers={}, + ) + assert result == '{"order": "details"}' + + @pytest.mark.asyncio + async def test_run_substitutes_path_parameters_leaving_body_params(self): + """Path params are substituted into URL; remaining params go in the request body.""" + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=PATH_PARAMS_WITH_BODY_SPEC) + ) + + run_response = _mock_response(text='{"updated": true}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + module = self._make_module( + [ + { + "url": "http://items:8000/items/{item_id}", + "method": "PUT", + } + ], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + assert len(tools) == 1 + + await tools[0].run({"item_id": 7, "name": "Widget"}) + + run_client.request.assert_called_once_with( + method="PUT", + url="http://items:8000/items/7", + json={"name": "Widget"}, + headers={}, + ) + + @pytest.mark.asyncio + async def test_run_forwards_header_parameters_as_http_headers(self): + """Header parameters declared in the spec are forwarded as HTTP headers, not in the body.""" + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=HEADER_AND_BODY_SPEC) + ) + + run_response = _mock_response(text='{"ok": true}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + module = self._make_module( + [{"url": "http://submit:8000/submit", "method": "POST"}], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + assert len(tools) == 1 + + await tools[0].run({"payload": "hello", "X-Request-Id": "req-abc"}) + + run_client.request.assert_called_once_with( + method="POST", + url="http://submit:8000/submit", + json={"payload": "hello"}, + headers={"X-Request-Id": "req-abc"}, + ) + + @pytest.mark.asyncio + async def test_run_combines_bearer_token_and_header_parameters(self): + """Both _bearer_token and header params end up in the headers dict.""" + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=HEADER_AND_BODY_SPEC) + ) + + run_response = _mock_response(text='{"ok": true}') + run_client = AsyncMock() + run_client.request = AsyncMock(return_value=run_response) + + module = self._make_module( + [{"url": "http://submit:8000/submit", "method": "POST"}], + factory=_StubHttpClientFactory(spec_client, run_client), + ) + + tools = await module.get_tools() + await tools[0].run( + {"payload": "hello", "X-Request-Id": "req-abc", "_bearer_token": "tok"} + ) + + run_client.request.assert_called_once_with( + method="POST", + url="http://submit:8000/submit", + json={"payload": "hello"}, + headers={"Authorization": "Bearer tok", "X-Request-Id": "req-abc"}, + ) + + +class TestFetchOpenapiSpec: + @pytest.mark.asyncio + async def test_success(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = lambda: None + mock_response.json.return_value = SAMPLE_OPENAPI_SPEC + + client = AsyncMock() + client.request = AsyncMock(return_value=mock_response) + + result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result == SAMPLE_OPENAPI_SPEC + client.request.assert_called_once_with("GET", "http://tool:8000/openapi.json") + + @pytest.mark.asyncio + async def test_strips_trailing_slash(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = lambda: None + mock_response.json.return_value = SAMPLE_OPENAPI_SPEC + + client = AsyncMock() + client.request = AsyncMock(return_value=mock_response) + + await _fetch_openapi_spec(client, "http://tool:8000/") + client.request.assert_called_once_with("GET", "http://tool:8000/openapi.json") + + @pytest.mark.asyncio + async def test_http_error_returns_none(self): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server Error", + request=httpx.Request("GET", "http://tool:8000/openapi.json"), + response=mock_response, + ) + + client = AsyncMock() + client.request = AsyncMock(return_value=mock_response) + + result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result is None + + @pytest.mark.asyncio + async def test_connection_error_returns_none(self): + client = AsyncMock() + client.request = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + + result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result is None + + @pytest.mark.asyncio + async def test_unexpected_error_returns_none(self): + client = AsyncMock() + client.request = AsyncMock(side_effect=RuntimeError("something went wrong")) + + result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result is None + + +class TestDeriveBaseUrl: + def test_strips_path(self): + assert _derive_base_url("http://calc:8000/calculate") == "http://calc:8000" + + def test_strips_nested_path(self): + assert _derive_base_url("http://host:9000/api/v1/run") == "http://host:9000" + + def test_no_path(self): + assert _derive_base_url("http://tool:8000") == "http://tool:8000" + + def test_trailing_slash(self): + assert _derive_base_url("http://tool:8000/") == "http://tool:8000" + + +class TestGetToolByName: + def _make_module( + self, tools: list[dict], factory=None + ) -> OpenAPIToolRegistryModule: + if factory is None: + factory = _StubHttpClientFactory(AsyncMock()) + deps = ModuleDependencies({"http_client": factory}) + return OpenAPIToolRegistryModule(deps, {"tools": tools}) + + def _make_spec_factory(self, spec_map: dict[str, dict]): + """Build an HttpClientFactory whose client dispatches by URL key.""" + + async def mock_request(method, url, **kwargs): + for key, spec in spec_map.items(): + if key in url: + return _mock_response(spec=spec) + raise httpx.ConnectError("No mock for " + url) + + mock_client = AsyncMock() + mock_client.request = mock_request + return _StubHttpClientFactory(mock_client) + + @pytest.mark.asyncio + async def test_finds_tool_by_name(self): + module = self._make_module( + [{"url": "http://calc:8000/calculate", "method": "POST"}], + factory=self._make_spec_factory({"calc": SAMPLE_OPENAPI_SPEC}), + ) + + result = await module.get_tool_by_name("calculate") + + assert result is not None + assert isinstance(result, Tool) + assert result.definition.name == "calculate" + assert result.definition.description == "Evaluate a math expression" + + @pytest.mark.asyncio + async def test_returns_none_for_unknown_name(self): + module = self._make_module( + [{"url": "http://calc:8000/calculate", "method": "POST"}], + factory=self._make_spec_factory({"calc": SAMPLE_OPENAPI_SPEC}), + ) + + result = await module.get_tool_by_name("nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_for_empty_registry(self): + module = self._make_module([]) + result = await module.get_tool_by_name("calculate") + assert result is None diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py new file mode 100644 index 0000000..7c16dc7 --- /dev/null +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py @@ -0,0 +1,465 @@ +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from modai.module import ModuleDependencies +from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule +from modai.modules.tools.tool_registry_predefined_vars import ( + PredefinedVariablesToolRegistryModule, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tool(definition: ToolDefinition) -> Tool: + class _StubTool(Tool): + @property + def definition(self) -> ToolDefinition: + return definition + + async def run(self, params: dict[str, Any]) -> Any: + return params + + return _StubTool() + + +def _make_capturing_tool(definition: ToolDefinition) -> tuple[Tool, list[dict]]: + """Return a tool and a list that receives each run() params call.""" + calls: list[dict] = [] + + class _CapturingTool(Tool): + @property + def definition(self) -> ToolDefinition: + return definition + + async def run(self, params: dict[str, Any]) -> Any: + calls.append(dict(params)) + return "ok" + + return _CapturingTool(), calls + + +def _stub_registry(*tools: Tool) -> ToolRegistryModule: + """Build a mock ToolRegistryModule that returns the given tools.""" + registry = MagicMock(spec=ToolRegistryModule) + registry.get_tools = AsyncMock(return_value=list(tools)) + + async def _get_by_name(name: str, predefined_params=None) -> Tool | None: + return next((t for t in tools if t.definition.name == name), None) + + registry.get_tool_by_name = _get_by_name + return registry + + +def _make_module( + inner: ToolRegistryModule, + variable_mappings: dict[str, str] | None = None, +) -> PredefinedVariablesToolRegistryModule: + deps = ModuleDependencies({"delegate_registry": inner}) + config: dict = {} + if variable_mappings: + config["variable_mappings"] = variable_mappings + return PredefinedVariablesToolRegistryModule(deps, config) + + +FULL_DEFINITION = ToolDefinition( + name="get_user_order", + description="Retrieve an order", + parameters={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "The user's ID"}, + "order_id": {"type": "integer", "description": "The order's ID"}, + "session_id": {"type": "string", "description": "Active session"}, + }, + "required": ["user_id", "order_id", "session_id"], + }, +) + +SIMPLE_DEFINITION = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, +) + +HEADER_DEFINITION = ToolDefinition( + name="fetch_data", + description="Fetch session data", + parameters={ + "type": "object", + "properties": { + "X-Session-Id": {"type": "string", "description": "Active session"}, + "filter": {"type": "string", "description": "Optional filter"}, + }, + "required": ["X-Session-Id"], + }, +) + + +# --------------------------------------------------------------------------- +# get_tools: definition filtering +# --------------------------------------------------------------------------- + + +class TestGetToolsDefinitionFiltering: + @pytest.mark.asyncio + async def test_no_predefined_params_returns_full_definition(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools() + + assert len(result) == 1 + assert result[0].definition == FULL_DEFINITION + + @pytest.mark.asyncio + async def test_predefined_param_stripped_from_properties(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools(predefined_params={"_session_id": "abc"}) + + assert len(result) == 1 + params = result[0].definition.parameters + assert "session_id" not in params["properties"] + assert "user_id" in params["properties"] + assert "order_id" in params["properties"] + + @pytest.mark.asyncio + async def test_predefined_param_stripped_from_required(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools(predefined_params={"_session_id": "abc"}) + + required = result[0].definition.parameters["required"] + assert "session_id" not in required + assert "user_id" in required + assert "order_id" in required + + @pytest.mark.asyncio + async def test_multiple_predefined_params_stripped(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools( + predefined_params={"_session_id": "s1", "_user_id": "u1"} + ) + + params = result[0].definition.parameters + assert "session_id" not in params["properties"] + assert "user_id" not in params["properties"] + assert "order_id" in params["properties"] + + @pytest.mark.asyncio + async def test_predefined_param_not_in_schema_leaves_definition_unchanged(self): + tool = _make_tool(SIMPLE_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools(predefined_params={"_session_id": "abc"}) + + # session_id doesn't exist in the schema → tool is returned as-is (no wrapper) + assert result[0].definition == SIMPLE_DEFINITION + + @pytest.mark.asyncio + async def test_non_prefixed_predefined_key_is_ignored(self): + """Keys without a leading _ are not treated as predefined variables.""" + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools(predefined_params={"session_id": "abc"}) + + params = result[0].definition.parameters + assert "session_id" in params["properties"] + + @pytest.mark.asyncio + async def test_empty_predefined_params_returns_full_definition(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tools(predefined_params={}) + + assert result[0].definition == FULL_DEFINITION + + @pytest.mark.asyncio + async def test_multiple_tools_each_filtered(self): + tool_a = _make_tool(FULL_DEFINITION) + tool_b = _make_tool(SIMPLE_DEFINITION) + module = _make_module(_stub_registry(tool_a, tool_b)) + + result = await module.get_tools(predefined_params={"_session_id": "s1"}) + + assert len(result) == 2 + # tool_a had session_id → stripped + assert "session_id" not in result[0].definition.parameters["properties"] + # tool_b had no session_id → unchanged + assert result[1].definition == SIMPLE_DEFINITION + + +# --------------------------------------------------------------------------- +# get_tool_by_name +# --------------------------------------------------------------------------- + + +class TestGetToolByName: + @pytest.mark.asyncio + async def test_returns_filtered_tool_when_found(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tool_by_name( + "get_user_order", predefined_params={"_session_id": "s1"} + ) + + assert result is not None + assert "session_id" not in result.definition.parameters["properties"] + + @pytest.mark.asyncio + async def test_returns_none_when_not_found(self): + module = _make_module(_stub_registry()) + + result = await module.get_tool_by_name( + "nonexistent", predefined_params={"_session_id": "s1"} + ) + + assert result is None + + @pytest.mark.asyncio + async def test_no_predefined_params_returns_full_definition(self): + tool = _make_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tool_by_name("get_user_order") + + assert result is not None + assert result.definition == FULL_DEFINITION + + +# --------------------------------------------------------------------------- +# run() — predefined variable translation +# --------------------------------------------------------------------------- + + +class TestRunTranslation: + @pytest.mark.asyncio + async def test_predefined_key_translated_to_unprefixed_before_inner_run(self): + inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(inner_tool)) + + wrapped = await module.get_tool_by_name( + "get_user_order", predefined_params={"_session_id": "session-xyz"} + ) + assert wrapped is not None + + await wrapped.run( + {"user_id": "alice", "order_id": 7, "_session_id": "session-xyz"} + ) + + assert len(calls) == 1 + assert calls[0]["session_id"] == "session-xyz" + assert "_session_id" not in calls[0] + + @pytest.mark.asyncio + async def test_non_predefined_params_passed_through_unchanged(self): + inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(inner_tool)) + + wrapped = await module.get_tool_by_name( + "get_user_order", predefined_params={"_session_id": "s1"} + ) + assert wrapped is not None + + await wrapped.run({"user_id": "alice", "order_id": 7, "_session_id": "s1"}) + + assert calls[0]["user_id"] == "alice" + assert calls[0]["order_id"] == 7 + + @pytest.mark.asyncio + async def test_bearer_token_not_in_schema_stays_prefixed(self): + """_bearer_token is a reserved key not found in the schema — it must + remain as _bearer_token so the inner tool (e.g. _OpenAPITool) can + handle it for the Authorization header.""" + inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) + module = _make_module(_stub_registry(inner_tool)) + + # _bearer_token is NOT in the schema, so it is NOT a hidden property + wrapped = await module.get_tool_by_name( + "get_user_order", + predefined_params={"_session_id": "s1", "_bearer_token": "tok"}, + ) + assert wrapped is not None + + await wrapped.run( + { + "user_id": "alice", + "order_id": 7, + "_session_id": "s1", + "_bearer_token": "tok", + } + ) + + # _bearer_token was not in schema so it is not translated + assert "_bearer_token" in calls[0] + assert "bearer_token" not in calls[0] + + @pytest.mark.asyncio + async def test_tool_not_requiring_wrapping_is_returned_directly(self): + """When predefined params have no overlap with the schema, the original + tool object is returned without a wrapper.""" + tool = _make_tool(SIMPLE_DEFINITION) + module = _make_module(_stub_registry(tool)) + + result = await module.get_tool_by_name( + "calculate", predefined_params={"_session_id": "s1"} + ) + + # Same object — no wrapper was needed + assert result is tool + + +# --------------------------------------------------------------------------- +# variable_mappings config +# --------------------------------------------------------------------------- + + +class TestVariableMappings: + @pytest.mark.asyncio + async def test_mapped_tool_param_hidden_from_definition(self): + """X-Session-Id is stripped when _session_id is predefined and mapping is configured.""" + tool = _make_tool(HEADER_DEFINITION) + module = _make_module( + _stub_registry(tool), + variable_mappings={"X-Session-Id": "session_id"}, + ) + + result = await module.get_tools(predefined_params={"_session_id": "sess-abc"}) + + assert len(result) == 1 + params = result[0].definition.parameters + assert "X-Session-Id" not in params["properties"] + assert "filter" in params["properties"] + assert "X-Session-Id" not in params.get("required", []) + + @pytest.mark.asyncio + async def test_mapped_param_not_hidden_when_predefined_var_absent(self): + """If _session_id is not in predefined_params, X-Session-Id stays visible.""" + tool = _make_tool(HEADER_DEFINITION) + module = _make_module( + _stub_registry(tool), + variable_mappings={"X-Session-Id": "session_id"}, + ) + + result = await module.get_tools(predefined_params={}) + + assert result[0].definition == HEADER_DEFINITION + + @pytest.mark.asyncio + async def test_run_translates_predefined_key_to_mapped_tool_param(self): + """_session_id is translated to X-Session-Id (not session_id) per the mapping.""" + inner_tool, calls = _make_capturing_tool(HEADER_DEFINITION) + module = _make_module( + _stub_registry(inner_tool), + variable_mappings={"X-Session-Id": "session_id"}, + ) + + wrapped = await module.get_tool_by_name( + "fetch_data", predefined_params={"_session_id": "sess-xyz"} + ) + assert wrapped is not None + + await wrapped.run({"filter": "recent", "_session_id": "sess-xyz"}) + + assert len(calls) == 1 + assert calls[0]["X-Session-Id"] == "sess-xyz" + assert "session_id" not in calls[0] + assert "_session_id" not in calls[0] + + @pytest.mark.asyncio + async def test_direct_and_configured_mappings_coexist(self): + """A direct-mapped param (session_id) and a configured mapping (X-Session-Id) + for different predefined vars can both be active at the same time.""" + definition = ToolDefinition( + name="multi_param_tool", + description="Tool with both direct and mapped params", + parameters={ + "type": "object", + "properties": { + "session_id": {"type": "string"}, + "X-Tenant": {"type": "string"}, + "query": {"type": "string"}, + }, + "required": ["session_id", "X-Tenant", "query"], + }, + ) + inner_tool, calls = _make_capturing_tool(definition) + module = _make_module( + _stub_registry(inner_tool), + variable_mappings={"X-Tenant": "tenant_id"}, + ) + + wrapped = await module.get_tool_by_name( + "multi_param_tool", + predefined_params={"_session_id": "s1", "_tenant_id": "acme"}, + ) + assert wrapped is not None + + # Both session_id and X-Tenant should be hidden from the definition + params = wrapped.definition.parameters + assert "session_id" not in params["properties"] + assert "X-Tenant" not in params["properties"] + assert "query" in params["properties"] + + await wrapped.run({"query": "hello", "_session_id": "s1", "_tenant_id": "acme"}) + + assert calls[0]["session_id"] == "s1" # direct mapping + assert calls[0]["X-Tenant"] == "acme" # configured mapping + assert "_session_id" not in calls[0] + assert "_tenant_id" not in calls[0] + + @pytest.mark.asyncio + async def test_configured_mapping_overrides_direct_for_same_var(self): + """When a mapping routes _session_id to X-Session-Id, the default + session_id → _session_id direct mapping must NOT also be applied.""" + definition = ToolDefinition( + name="override_tool", + description="Test override", + parameters={ + "type": "object", + "properties": { + "session_id": {"type": "string"}, + "X-Session-Id": {"type": "string"}, + }, + "required": ["session_id", "X-Session-Id"], + }, + ) + inner_tool, calls = _make_capturing_tool(definition) + # Map _session_id to X-Session-Id only — session_id in schema remains unaffected + module = _make_module( + _stub_registry(inner_tool), + variable_mappings={"X-Session-Id": "session_id"}, + ) + + wrapped = await module.get_tool_by_name( + "override_tool", + predefined_params={"_session_id": "s1"}, + ) + assert wrapped is not None + + # Only X-Session-Id should be hidden; session_id (different schema prop) stays + params = wrapped.definition.parameters + assert "X-Session-Id" not in params["properties"] + assert "session_id" in params["properties"] + + await wrapped.run({"session_id": "manual", "_session_id": "s1"}) + + assert calls[0]["X-Session-Id"] == "s1" + assert calls[0]["session_id"] == "manual" + assert "_session_id" not in calls[0] diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py b/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py index 9b58d3d..26aeb89 100644 --- a/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py @@ -1,14 +1,13 @@ +from typing import Any from unittest.mock import AsyncMock import pytest from modai.module import ModuleDependencies -from modai.modules.tools.module import ToolDefinition +from modai.modules.tools.module import Tool, ToolDefinition from modai.modules.tools.tools_web_module import ( OpenAIToolsWebModule, - _extract_parameters, - _resolve_refs, - _transform_openapi_to_openai, + _to_openai_format, ) @@ -42,171 +41,82 @@ }, } +SAMPLE_DEFINITION = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression to evaluate", + } + }, + "required": ["expression"], + }, +) + + +def _make_tool(definition: ToolDefinition) -> Tool: + """Create a minimal Tool stub for testing.""" -class TestTransformOpenapiToOpenai: - def test_transforms_valid_spec(self): - result = _transform_openapi_to_openai(SAMPLE_OPENAPI_SPEC) + class _StubTool(Tool): + @property + def definition(self) -> ToolDefinition: + return definition + + async def run(self, params: dict[str, Any]) -> Any: + return "" + + return _StubTool() + + +class TestToOpenAIFormat: + def test_formats_valid_definition(self): + result = _to_openai_format(SAMPLE_DEFINITION) assert result == { "type": "function", "function": { "name": "calculate", "description": "Evaluate a math expression", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression to evaluate", - } - }, - "required": ["expression"], - }, + "parameters": SAMPLE_DEFINITION.parameters, "strict": True, }, } - def test_uses_description_when_no_summary(self): - spec = { - "paths": { - "/run": { - "post": { - "description": "Runs something", - "operationId": "run_task", - "requestBody": { - "content": { - "application/json": { - "schema": {"type": "object", "properties": {}} - } - } - }, - } - } - } - } - result = _transform_openapi_to_openai(spec) + def test_uses_provided_description(self): + definition = ToolDefinition( + name="run_task", + description="Runs something", + parameters={"type": "object", "properties": {}}, + ) + result = _to_openai_format(definition) assert result["function"]["description"] == "Runs something" - def test_empty_description_when_none_provided(self): - spec = { - "paths": { - "/run": { - "post": { - "operationId": "run_task", - } - } - } - } - result = _transform_openapi_to_openai(spec) + def test_empty_description_is_preserved(self): + definition = ToolDefinition( + name="run_task", + description="", + parameters={"type": "object", "properties": {}}, + ) + result = _to_openai_format(definition) assert result["function"]["description"] == "" - def test_returns_none_for_no_operation_id(self): - spec = {"paths": {"/run": {"post": {"summary": "No operationId"}}}} - result = _transform_openapi_to_openai(spec) - assert result is None - - def test_returns_none_for_empty_paths(self): - spec = {"paths": {}} - result = _transform_openapi_to_openai(spec) - assert result is None - - def test_returns_none_for_missing_paths(self): - result = _transform_openapi_to_openai({}) - assert result is None + def test_parameters_are_passed_through(self): + custom_params = {"type": "object", "properties": {"x": {"type": "integer"}}} + definition = ToolDefinition( + name="calc", description="desc", parameters=custom_params + ) + result = _to_openai_format(definition) + assert result["function"]["parameters"] == custom_params - def test_default_parameters_when_no_request_body(self): - spec = { - "paths": { - "/status": { - "get": { - "operationId": "get_status", - "summary": "Get status", - } - } - } - } - result = _transform_openapi_to_openai(spec) - assert result["function"]["parameters"] == { - "type": "object", - "properties": {}, - } - - -class TestExtractParameters: - def test_extracts_json_schema(self): - operation = { - "requestBody": { - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": {"x": {"type": "integer"}}, - } - } - } - } - } - result = _extract_parameters(operation, {}) - assert result == { - "type": "object", - "properties": {"x": {"type": "integer"}}, - } - - def test_returns_default_when_no_request_body(self): - result = _extract_parameters({}, {}) - assert result == {"type": "object", "properties": {}} - - def test_returns_default_when_no_json_content(self): - operation = { - "requestBody": {"content": {"text/plain": {"schema": {"type": "string"}}}} - } - result = _extract_parameters(operation, {}) - assert result == {"type": "object", "properties": {}} - - def test_resolves_ref_in_schema(self): - spec = { - "components": { - "schemas": { - "DiceRequest": { - "type": "object", - "properties": { - "count": { - "type": "integer", - "description": "Number of dice", - }, - "sides": { - "type": "integer", - "description": "Sides per die", - }, - }, - "required": ["count", "sides"], - } - } - } - } - operation = { - "requestBody": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/DiceRequest"} - } - } - } - } - result = _extract_parameters(operation, spec) - assert result == { - "type": "object", - "properties": { - "count": {"type": "integer", "description": "Number of dice"}, - "sides": {"type": "integer", "description": "Sides per die"}, - }, - "required": ["count", "sides"], - } + def test_strict_is_always_true(self): + result = _to_openai_format(SAMPLE_DEFINITION) + assert result["function"]["strict"] is True class TestToolsWebModule: - def _make_module( - self, registry_tools: list[ToolDefinition] - ) -> OpenAIToolsWebModule: + def _make_module(self, registry_tools: list[Tool]) -> OpenAIToolsWebModule: mock_registry = AsyncMock() mock_registry.get_tools = AsyncMock(return_value=registry_tools) deps = ModuleDependencies(modules={"tool_registry": mock_registry}) @@ -226,14 +136,21 @@ async def test_returns_empty_tools_when_registry_empty(self): @pytest.mark.asyncio async def test_transforms_registry_tools_to_openai_format(self): - registry_tools = [ - ToolDefinition( - url="http://calc:8000/calculate", - method="POST", - openapi_spec=SAMPLE_OPENAPI_SPEC, - ) - ] - module = self._make_module(registry_tools) + definition = ToolDefinition( + name="calculate", + description="Evaluate a math expression", + parameters={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression to evaluate", + } + }, + "required": ["expression"], + }, + ) + module = self._make_module([_make_tool(definition)]) result = await module.get_tools() assert len(result["tools"]) == 1 @@ -244,174 +161,22 @@ async def test_transforms_registry_tools_to_openai_format(self): assert "expression" in tool["function"]["parameters"]["properties"] @pytest.mark.asyncio - async def test_skips_tools_without_operation_id(self): - bad_spec = {"paths": {"/run": {"post": {"summary": "No operationId"}}}} - registry_tools = [ - ToolDefinition( - url="http://calc:8000/calculate", - method="POST", - openapi_spec=SAMPLE_OPENAPI_SPEC, - ), - ToolDefinition( - url="http://bad:8000/run", - method="POST", - openapi_spec=bad_spec, - ), - ] - module = self._make_module(registry_tools) - result = await module.get_tools() - - assert len(result["tools"]) == 1 - assert result["tools"][0]["function"]["name"] == "calculate" - - @pytest.mark.asyncio - async def test_multiple_tools_transformed(self): - search_spec = { - "openapi": "3.1.0", - "info": {"title": "Search", "version": "1.0.0"}, - "paths": { - "/search": { - "put": { - "summary": "Search the web", - "operationId": "web_search", - "requestBody": { - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": ["query"], - } - } - } - }, - } - } + async def test_multiple_tools_returned(self): + search_def = ToolDefinition( + name="web_search", + description="Search the web", + parameters={ + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], }, - } - registry_tools = [ - ToolDefinition( - url="http://calc:8000/calculate", - method="POST", - openapi_spec=SAMPLE_OPENAPI_SPEC, - ), - ToolDefinition( - url="http://search:8000/search", - method="PUT", - openapi_spec=search_spec, - ), - ] - module = self._make_module(registry_tools) + ) + module = self._make_module( + [_make_tool(SAMPLE_DEFINITION), _make_tool(search_def)] + ) result = await module.get_tools() assert len(result["tools"]) == 2 names = [t["function"]["name"] for t in result["tools"]] assert "calculate" in names assert "web_search" in names - - -class TestResolveRefs: - def test_returns_primitive_as_is(self): - assert _resolve_refs("hello", {}) == "hello" - assert _resolve_refs(42, {}) == 42 - assert _resolve_refs(None, {}) is None - - def test_returns_dict_without_refs_unchanged(self): - node = {"type": "string", "description": "test"} - assert _resolve_refs(node, {}) == node - - def test_resolves_top_level_ref(self): - spec = { - "components": {"schemas": {"Foo": {"type": "object", "properties": {}}}} - } - node = {"$ref": "#/components/schemas/Foo"} - assert _resolve_refs(node, spec) == {"type": "object", "properties": {}} - - def test_resolves_nested_ref(self): - spec = { - "components": { - "schemas": { - "Bar": {"type": "string", "description": "A bar"}, - } - } - } - node = { - "type": "object", - "properties": { - "bar": {"$ref": "#/components/schemas/Bar"}, - }, - } - result = _resolve_refs(node, spec) - assert result == { - "type": "object", - "properties": { - "bar": {"type": "string", "description": "A bar"}, - }, - } - - def test_resolves_refs_in_list(self): - spec = {"components": {"schemas": {"X": {"type": "integer"}}}} - node = [{"$ref": "#/components/schemas/X"}, {"type": "string"}] - result = _resolve_refs(node, spec) - assert result == [{"type": "integer"}, {"type": "string"}] - - def test_returns_empty_dict_for_unresolvable_ref(self): - result = _resolve_refs({"$ref": "#/components/schemas/Missing"}, {}) - assert result == {} - - def test_returns_empty_dict_for_non_local_ref(self): - result = _resolve_refs({"$ref": "https://example.com/schema.json"}, {}) - assert result == {} - - -class TestTransformWithRefs: - """Integration test: full OpenAPI spec with $ref (like FastAPI generates).""" - - DICE_ROLLER_SPEC = { - "openapi": "3.1.0", - "info": {"title": "Dice Roller Tool", "version": "1.0.0"}, - "paths": { - "/roll": { - "post": { - "summary": "Roll dice and return the results", - "operationId": "roll_dice", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/DiceRequest"} - } - }, - }, - } - } - }, - "components": { - "schemas": { - "DiceRequest": { - "type": "object", - "properties": { - "count": { - "type": "integer", - "default": 1, - "description": "Number of dice to roll", - }, - "sides": { - "type": "integer", - "default": 6, - "description": "Number of sides per die", - }, - }, - } - } - }, - } - - def test_transform_resolves_refs(self): - result = _transform_openapi_to_openai(self.DICE_ROLLER_SPEC) - assert result is not None - params = result["function"]["parameters"] - assert params["type"] == "object" - assert "count" in params["properties"] - assert "sides" in params["properties"] - assert "$ref" not in str(params) diff --git a/backend/omni/src/modai/modules/tools/module.py b/backend/omni/src/modai/modules/tools/module.py index bfdfec5..204de90 100644 --- a/backend/omni/src/modai/modules/tools/module.py +++ b/backend/omni/src/modai/modules/tools/module.py @@ -9,32 +9,61 @@ @dataclass(frozen=True) class ToolDefinition: - """A tool's metadata as returned by the Tool Registry.""" + """LLM-agnostic description of a tool. - url: str - method: str - openapi_spec: dict[str, Any] + Contains enough information to construct LLM tool calls but is not tied + to any specific LLM API format. Parameters are fully resolved (no $ref) + so they can be passed directly to any LLM. + """ + name: str + description: str + parameters: dict[str, Any] -class ToolRegistryModule(ModaiModule, ABC): + +class Tool(ABC): + """A tool with its LLM-agnostic definition and run capability. + + Implementations provide both the definition (used by LLMs to understand + and invoke the tool) and the ability to execute the tool with parameters + returned by the LLM. """ - Module Declaration for: Tool Registry (Plain Module) - Aggregates OpenAPI specs from all configured tools. + @property + @abstractmethod + def definition(self) -> ToolDefinition: + """The tool's LLM-agnostic definition (name, description, parameters).""" + pass + + @abstractmethod + async def run(self, params: dict[str, Any]) -> Any: + """Execute the tool with the given parameters. + + Args: + params: Parameters to pass to the tool, typically the arguments + returned by an LLM tool call. Callers may inject + additional transport-level properties using ``_``-prefixed + keys (e.g. ``_bearer_token``). These reserved keys must + be extracted and consumed by the implementation before + building the request payload — they are never forwarded + to the tool microservice as part of the JSON body. + + Returns: + The tool's result (implementation-specific). + """ + pass + - Each tool is an independent microservice that: - - Exposes an HTTP endpoint to trigger the tool (method chosen by the tool) - - Provides an OpenAPI spec describing all its endpoints and parameters +class ToolRegistryModule(ModaiModule, ABC): + """ + Module Declaration for: Tool Registry (Plain Module) - The registry fetches each tool's OpenAPI spec and returns them grouped - together (unmodified). + Aggregates tools from all configured sources and provides lookup by name. Configuration: tools: list of dicts, each with: - "url": the full trigger endpoint URL of the tool microservice - "method": the HTTP method to invoke the tool (e.g. PUT, POST, GET) - The registry derives the base URL from "url" and appends - "/openapi.json" to fetch the spec. Example config: tools: @@ -48,25 +77,37 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) @abstractmethod - async def get_tools(self) -> list[ToolDefinition]: + async def get_tools( + self, predefined_params: dict[str, Any] | None = None + ) -> list[Tool]: """ - Returns all configured tool definitions. - - Each ToolDefinition contains the tool's trigger url, HTTP method, - and its full OpenAPI spec (unmodified). - - Unavailable tool services are omitted from the result with a - warning logged. + Returns all configured tools. + + Each Tool provides its definition and run capability. + Unavailable tool services are omitted with a warning logged. + + Args: + predefined_params: Optional dict of ``_``-prefixed keys whose + values are already known by the caller (e.g. + ``{"_session_id": "abc", "_bearer_token": "xyz"}``). + Implementations may use these to strip the corresponding + properties from tool definitions so the LLM is not asked to + supply values that are already available. """ pass @abstractmethod - async def get_tool_by_name(self, name: str) -> ToolDefinition | None: + async def get_tool_by_name( + self, name: str, predefined_params: dict[str, Any] | None = None + ) -> Tool | None: """ - Look up a tool by its function name (derived from operationId). + Look up a tool by its name. + + Returns the matching Tool if found, or None if not found. - Returns the matching ToolDefinition if found, - or None if the tool name is not found. + Args: + name: The tool's unique name (derived from OpenAPI ``operationId``). + predefined_params: Same semantics as in :meth:`get_tools`. """ pass @@ -75,8 +116,8 @@ class ToolsWebModule(ModaiModule, ABC): """ Module Declaration for: Tools Web Module (Web Module) - Exposes GET /api/tools. Retrieves tool definitions from the Tool Registry - and returns them in a format suitable for the consumer (e.g. frontend, chat agent). + Exposes GET /api/tools. Retrieves tools from the Tool Registry and returns + their definitions in a format suitable for the consumer. """ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): @@ -87,7 +128,7 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): @abstractmethod async def get_tools(self) -> dict[str, Any]: """ - Returns all available tools in a consumer-specific format. + Returns all available tool definitions in a consumer-specific format. The response must contain a "tools" key with a list of tool definitions. The exact structure of each tool definition is determined by the diff --git a/backend/omni/src/modai/modules/tools/tool_registry.py b/backend/omni/src/modai/modules/tools/tool_registry.py deleted file mode 100644 index ae388c2..0000000 --- a/backend/omni/src/modai/modules/tools/tool_registry.py +++ /dev/null @@ -1,92 +0,0 @@ -import logging -from typing import Any -from urllib.parse import urlparse - -import httpx - -from modai.module import ModuleDependencies -from modai.modules.tools.module import ( - ToolDefinition, - ToolRegistryModule, -) - -logger = logging.getLogger(__name__) - -HTTP_TIMEOUT_SECONDS = 10.0 - - -class HttpToolRegistryModule(ToolRegistryModule): - """ - Tool Registry implementation that fetches OpenAPI specs from - configured tool microservices over HTTP. - """ - - def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): - super().__init__(dependencies, config) - self.tool_services: list[dict[str, str]] = config.get("tools", []) - - async def get_tools(self) -> list[ToolDefinition]: - tools: list[ToolDefinition] = [] - - async with httpx.AsyncClient(timeout=HTTP_TIMEOUT_SECONDS) as client: - for service in self.tool_services: - url = service["url"] - method = service["method"] - base_url = _derive_base_url(url) - spec = await _fetch_openapi_spec(client, base_url) - if spec is not None: - tools.append( - ToolDefinition(url=url, method=method, openapi_spec=spec) - ) - - return tools - - async def get_tool_by_name(self, name: str) -> ToolDefinition | None: - tools = await self.get_tools() - for tool in tools: - operation_id = _extract_operation_id(tool.openapi_spec) - if operation_id == name: - return tool - return None - - -def _extract_operation_id(spec: dict[str, Any]) -> str | None: - """Extract the operationId from the first operation in an OpenAPI spec.""" - paths = spec.get("paths", {}) - for _path, methods in paths.items(): - for _method, operation in methods.items(): - if isinstance(operation, dict) and "operationId" in operation: - return operation["operationId"] - return None - - -def _derive_base_url(trigger_url: str) -> str: - """Derive the service base URL from a full trigger endpoint URL. - - E.g. 'http://calc:8000/calculate' -> 'http://calc:8000' - """ - parsed = urlparse(trigger_url) - return f"{parsed.scheme}://{parsed.netloc}" - - -async def _fetch_openapi_spec( - client: httpx.AsyncClient, base_url: str -) -> dict[str, Any] | None: - openapi_url = f"{base_url.rstrip('/')}/openapi.json" - try: - response = await client.get(openapi_url) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - logger.warning( - "Tool service %s returned HTTP %s", base_url, e.response.status_code - ) - return None - except httpx.RequestError as e: - logger.warning("Failed to reach tool service %s: %s", base_url, e) - return None - except Exception: - logger.warning( - "Unexpected error fetching spec from %s", base_url, exc_info=True - ) - return None diff --git a/backend/omni/src/modai/modules/tools/tool_registry_openapi.py b/backend/omni/src/modai/modules/tools/tool_registry_openapi.py new file mode 100644 index 0000000..ecad84d --- /dev/null +++ b/backend/omni/src/modai/modules/tools/tool_registry_openapi.py @@ -0,0 +1,316 @@ +import logging +import re +from typing import Any +from urllib.parse import urlparse + +import httpx + +from modai.module import ModuleDependencies +from modai.modules.http_client.module import HttpClientModule +from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule + +logger = logging.getLogger(__name__) + +HTTP_TIMEOUT_SECONDS = 10.0 +TOOL_HTTP_TIMEOUT_SECONDS = 30.0 + + +class _OpenAPITool(Tool): + """Tool backed by an OpenAPI microservice endpoint. + + Holds the tool's pre-built definition and invokes the microservice + over HTTP when ``run`` is called. + """ + + def __init__( + self, + url: str, + method: str, + definition: ToolDefinition, + http_client_factory: HttpClientModule, + header_param_names: frozenset[str] = frozenset(), + ) -> None: + self._url = url + self._method = method + self._definition_val = definition + self._http_client_factory = http_client_factory + self._header_param_names = header_param_names + + @property + def definition(self) -> ToolDefinition: + return self._definition_val + + async def run(self, params: dict[str, Any]) -> Any: + """Invoke the tool microservice over HTTP with the given parameters. + + Extracts reserved metadata keys from ``params`` before sending the + request. Currently recognised keys: + + * ``_bearer_token`` — forwarded as the ``Authorization: Bearer`` + header; never included in the JSON request body. + + Path parameters present in the URL template (e.g. ``{user_id}``) are + substituted directly into the URL and removed from the request body. + + Header parameters declared in the OpenAPI spec (``in: header``) are + forwarded as-is HTTP headers and removed from the request body. + """ + body = dict(params) + bearer_token = body.pop("_bearer_token", None) + headers: dict[str, str] = {} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + + for header_name in self._header_param_names: + if header_name in body: + headers[header_name] = str(body.pop(header_name)) + + url = self._url + for name in re.findall(r"\{(\w+)\}", url): + if name in body: + url = url.replace(f"{{{name}}}", str(body.pop(name))) + + async with self._http_client_factory.new( + timeout=TOOL_HTTP_TIMEOUT_SECONDS + ) as client: + response = await client.request( + method=self._method.upper(), + url=url, + json=body, + headers=headers, + ) + response.raise_for_status() + return response.text + + +class OpenAPIToolRegistryModule(ToolRegistryModule): + """ + Tool Registry that fetches OpenAPI specs from configured microservices + and creates Tool instances with HTTP-based run capability. + + Each Tool's definition (name, description, parameters) is extracted from + the service's OpenAPI spec. The run method makes an HTTP request to the + configured trigger endpoint. + + Configuration: + tools: list of dicts, each with: + - "url": the full trigger endpoint URL of the tool microservice + - "method": the HTTP method to invoke the tool (e.g. PUT, POST, GET) + The registry derives the base URL from "url" and appends + "/openapi.json" to fetch the spec. + + Module Dependencies: + http_client: an HttpClientModule used for all outbound HTTP requests. + + Example config: + tools: + - url: http://calculator-service:8000/calculate + method: POST + - url: http://web-search-service:8000/search + method: PUT + """ + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + self.tool_services: list[dict[str, str]] = config.get("tools", []) + self._http_client: HttpClientModule = dependencies.get_module("http_client") # type: ignore[assignment] + + async def get_tools( + self, + predefined_params: dict[str, Any] | None = None, # noqa: ARG002 + ) -> list[Tool]: + tools: list[Tool] = [] + + async with self._http_client.new(timeout=HTTP_TIMEOUT_SECONDS) as client: + for service in self.tool_services: + url = service["url"] + method = service["method"] + base_url = _derive_base_url(url) + spec = await _fetch_openapi_spec(client, base_url) + if spec is None: + continue + definition, header_param_names = _build_tool_definition(spec) + if definition is None: + logger.warning( + "No valid operation with operationId found in spec from %s, skipping", + base_url, + ) + continue + tools.append( + _OpenAPITool( + url=url, + method=method, + definition=definition, + http_client_factory=self._http_client, + header_param_names=header_param_names, + ) + ) + + return tools + + async def get_tool_by_name( + self, + name: str, + predefined_params: dict[str, Any] | None = None, # noqa: ARG002 + ) -> Tool | None: + tools = await self.get_tools() + for tool in tools: + if tool.definition.name == name: + return tool + return None + + +def _build_tool_definition( + spec: dict[str, Any], +) -> tuple[ToolDefinition, frozenset[str]] | tuple[None, frozenset[str]]: + """Build a ToolDefinition from an OpenAPI spec. + + Extracts name (operationId), description (summary/description), and + parameters (request body schema with $ref fully resolved, path params and + header params merged in). + + Returns a ``(ToolDefinition, header_param_names)`` tuple so callers know + which parameter names map to HTTP headers at invocation time. Returns + ``(None, frozenset())`` if no operation with an operationId is found. + """ + paths = spec.get("paths", {}) + for _path, methods in paths.items(): + for _method, operation in methods.items(): + if not isinstance(operation, dict) or "operationId" not in operation: + continue + name = operation["operationId"] + description = operation.get("summary") or operation.get("description", "") + parameters = _extract_parameters(operation, spec) + header_params = _extract_header_parameters(operation, spec) + header_param_names = frozenset(p["name"] for p in header_params) + return ToolDefinition( + name=name, description=description, parameters=parameters + ), header_param_names + return None, frozenset() + + +def _extract_parameters( + operation: dict[str, Any], spec: dict[str, Any] +) -> dict[str, Any]: + """Extract and resolve the parameter schema from an OpenAPI operation. + + Combines: + - Request body schema (``application/json``), with ``$ref`` pointers fully + resolved. + - Path parameters (``in: path``) merged as additional properties so the + LLM knows to supply them for URL substitution. + - Header parameters (``in: header``) merged as additional properties so + the LLM knows to supply them; at invocation time they are forwarded as + HTTP headers rather than in the request body. + """ + request_body = operation.get("requestBody", {}) + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema", {"type": "object", "properties": {}}) + resolved = _resolve_refs(schema, spec) + + extra_params = [ + *_extract_path_parameters(operation, spec), + *_extract_header_parameters(operation, spec), + ] + if not extra_params: + return resolved + + properties = dict(resolved.get("properties", {})) + required: list[str] = list(resolved.get("required", [])) + + for param in extra_params: + name = param["name"] + param_schema: dict[str, Any] = dict(param.get("schema", {"type": "string"})) + if param.get("description"): + param_schema["description"] = param["description"] + properties[name] = param_schema + if param.get("required", True) and name not in required: + required.append(name) + + result: dict[str, Any] = {**resolved, "properties": properties} + if required: + result["required"] = required + return result + + +def _extract_path_parameters( + operation: dict[str, Any], spec: dict[str, Any] +) -> list[dict[str, Any]]: + """Return resolved path parameters (``in: path``) from an OpenAPI operation.""" + return [ + _resolve_refs(param, spec) + for param in operation.get("parameters", []) + if _resolve_refs(param, spec).get("in") == "path" + ] + + +def _extract_header_parameters( + operation: dict[str, Any], spec: dict[str, Any] +) -> list[dict[str, Any]]: + """Return resolved header parameters (``in: header``) from an OpenAPI operation.""" + return [ + _resolve_refs(param, spec) + for param in operation.get("parameters", []) + if _resolve_refs(param, spec).get("in") == "header" + ] + + +def _resolve_refs(node: Any, spec: dict[str, Any]) -> Any: + """Recursively resolve all $ref pointers in a JSON Schema against the OpenAPI spec.""" + if isinstance(node, dict): + if "$ref" in node: + resolved = _follow_ref(node["$ref"], spec) + return _resolve_refs(resolved, spec) + return {key: _resolve_refs(value, spec) for key, value in node.items()} + if isinstance(node, list): + return [_resolve_refs(item, spec) for item in node] + return node + + +def _follow_ref(ref: str, spec: dict[str, Any]) -> dict[str, Any]: + """Follow a JSON Pointer reference like '#/components/schemas/Foo'.""" + if not ref.startswith("#/"): + logger.warning("Unsupported $ref format: %s", ref) + return {} + parts = ref.lstrip("#/").split("/") + current: Any = spec + for part in parts: + if isinstance(current, dict): + current = current.get(part) + else: + logger.warning("Could not resolve $ref path: %s", ref) + return {} + return current if isinstance(current, dict) else {} + + +def _derive_base_url(trigger_url: str) -> str: + """Derive the service base URL from a full trigger endpoint URL. + + E.g. 'http://calc:8000/calculate' -> 'http://calc:8000' + """ + parsed = urlparse(trigger_url) + return f"{parsed.scheme}://{parsed.netloc}" + + +async def _fetch_openapi_spec( + client: httpx.AsyncClient, base_url: str +) -> dict[str, Any] | None: + openapi_url = f"{base_url.rstrip('/')}/openapi.json" + try: + response = await client.request("GET", openapi_url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + logger.warning( + "Tool service %s returned HTTP %s", base_url, e.response.status_code + ) + return None + except httpx.RequestError as e: + logger.warning("Failed to reach tool service %s: %s", base_url, e) + return None + except Exception: + logger.warning( + "Unexpected error fetching spec from %s", base_url, exc_info=True + ) + return None diff --git a/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py b/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py new file mode 100644 index 0000000..2eeb367 --- /dev/null +++ b/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py @@ -0,0 +1,216 @@ +"""Composing tool registry that hides pre-known variables from tool definitions. + +A caller of the tool registry may already have values for certain tool +parameters (e.g. ``_session_id``, ``_bearer_token``). By passing these as +``predefined_params`` to :meth:`get_tools` / :meth:`get_tool_by_name`, the +``PredefinedVariablesToolRegistryModule`` removes the corresponding properties +from each tool's definition so the LLM is never asked to supply them. + +By default a predefined key ``_session_id`` maps to the tool parameter named +``session_id`` (underscore stripped). A ``variable_mappings`` config section +lets you override this for any tool parameter whose name differs from the +predefined variable name — for example mapping ``X-Session-Id`` to +``session_id``. + +At invocation time predefined values are re-injected, translating each +prefixed predefined key to its tool parameter name, before delegating to the +inner tool so URL-path substitution and body serialisation work as normal. + +This module is a pure decorator/composite: it has no knowledge of URLs, +HTTP, or OpenAPI — all actual tool work is performed by the inner registry +supplied via the ``delegate_registry`` module dependency. +""" + +import logging +from typing import Any + +from modai.module import ModuleDependencies +from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule + +logger = logging.getLogger(__name__) + + +class _PredefinedVariablesTool(Tool): + """Wraps an inner Tool, hiding known variables from its public definition. + + ``translations`` is a mapping of ``tool_param_name → prefixed_predefined_key`` + (e.g. ``{"X-Session-Id": "_session_id"}``). In :meth:`run` each prefixed + key is popped from ``params`` and re-injected under its tool parameter name + before the call is forwarded to the inner tool. + """ + + def __init__( + self, + inner: Tool, + translations: dict[str, str], + filtered_definition: ToolDefinition, + ) -> None: + self._inner = inner + self._translations = translations + self._filtered_definition = filtered_definition + + @property + def definition(self) -> ToolDefinition: + return self._filtered_definition + + async def run(self, params: dict[str, Any]) -> Any: + """Forward to inner tool, substituting predefined values into their tool param names.""" + translated = dict(params) + for tool_param, prefixed_key in self._translations.items(): + if prefixed_key in translated: + translated[tool_param] = translated.pop(prefixed_key) + return await self._inner.run(translated) + + +class PredefinedVariablesToolRegistryModule(ToolRegistryModule): + """Tool registry decorator that strips pre-supplied variables from definitions. + + Wraps another :class:`~modai.modules.tools.module.ToolRegistryModule` and + filters its tool definitions based on ``predefined_params`` passed by the + caller. By default a predefined key like ``_session_id`` hides the tool + parameter ``session_id`` (leading ``_`` stripped). The optional + ``variable_mappings`` config allows overriding this for tool parameters + whose names differ from the predefined variable name. + + At run time each hidden parameter is supplied from its predefined value + before delegating to the inner tool. + + Configuration: + variable_mappings: optional dict mapping tool parameter names to + predefined variable names (without the leading ``_``). Use this + when a tool parameter name differs from the predefined variable + name. + + Module Dependencies: + delegate_registry: the concrete :class:`ToolRegistryModule` that does + the actual spec-fetching and HTTP invocation. + + Example config.yaml:: + + modules: + openapi_registry: + type: OpenAPIToolRegistryModule + ... + + tool_registry: + type: PredefinedVariablesToolRegistryModule + module_dependencies: + delegate_registry: openapi_registry + config: + variable_mappings: + X-Session-Id: session_id # _session_id fills X-Session-Id header + """ + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + self._inner_registry: ToolRegistryModule = dependencies.get_module( + "delegate_registry" + ) # type: ignore[assignment] + self._variable_mappings: dict[str, str] = config.get("variable_mappings", {}) + + async def get_tools( + self, predefined_params: dict[str, Any] | None = None + ) -> list[Tool]: + tools = await self._inner_registry.get_tools() + return [ + _wrap_tool(tool, predefined_params, self._variable_mappings) + for tool in tools + ] + + async def get_tool_by_name( + self, name: str, predefined_params: dict[str, Any] | None = None + ) -> Tool | None: + tool = await self._inner_registry.get_tool_by_name(name) + if tool is None: + return None + return _wrap_tool(tool, predefined_params, self._variable_mappings) + + +# --------------------------------------------------------------------------- +# Pure helper functions +# --------------------------------------------------------------------------- + + +def _build_translations( + definition: ToolDefinition, + predefined_params: dict[str, Any] | None, + variable_mappings: dict[str, str], +) -> dict[str, str]: + """Build a ``tool_param → prefixed_predefined_key`` map for *definition*. + + Only includes entries where: + - the prefixed predefined key is present in ``predefined_params``, AND + - the target tool parameter exists in the definition's schema properties. + + Direct mappings (``_session_id`` → ``session_id``) are derived + automatically from ``predefined_params``. ``variable_mappings`` entries + (``X-Session-Id: session_id``) override the direct mapping for the same + predefined variable so the value is routed to the correct tool parameter. + """ + if not predefined_params: + return {} + + schema_properties = set(definition.parameters.get("properties", {}).keys()) + translations: dict[str, str] = {} + + # Direct: _session_id → session_id (when session_id is in the schema) + for prefixed_key in predefined_params: + if not prefixed_key.startswith("_"): + continue + var_name = prefixed_key[1:] + if var_name in schema_properties: + translations[var_name] = prefixed_key + + # Configured: X-Session-Id ← _session_id (overrides the direct mapping) + for tool_param, var_name in variable_mappings.items(): + prefixed_key = f"_{var_name}" + if prefixed_key not in predefined_params: + continue + if tool_param not in schema_properties: + continue + # Remove default direct mapping for var_name if it was added above + translations.pop(var_name, None) + translations[tool_param] = prefixed_key + + return translations + + +def _wrap_tool( + tool: Tool, + predefined_params: dict[str, Any] | None, + variable_mappings: dict[str, str], +) -> Tool: + """Return a filtered wrapper around *tool*, or *tool* itself if nothing to hide.""" + translations = _build_translations( + tool.definition, predefined_params, variable_mappings + ) + if not translations: + return tool + hidden = set(translations.keys()) + filtered_definition = _filter_definition(tool.definition, hidden) + return _PredefinedVariablesTool( + inner=tool, + translations=translations, + filtered_definition=filtered_definition, + ) + + +def _filter_definition( + definition: ToolDefinition, hidden_properties: set[str] +) -> ToolDefinition: + """Return a new :class:`ToolDefinition` with *hidden_properties* removed.""" + params = definition.parameters + new_properties = { + k: v + for k, v in params.get("properties", {}).items() + if k not in hidden_properties + } + new_required = [r for r in params.get("required", []) if r not in hidden_properties] + new_params: dict[str, Any] = {**params, "properties": new_properties} + if "required" in params: + new_params["required"] = new_required + return ToolDefinition( + name=definition.name, + description=definition.description, + parameters=new_params, + ) diff --git a/backend/omni/src/modai/modules/tools/tools_web_module.py b/backend/omni/src/modai/modules/tools/tools_web_module.py index ef42537..1b8bee2 100644 --- a/backend/omni/src/modai/modules/tools/tools_web_module.py +++ b/backend/omni/src/modai/modules/tools/tools_web_module.py @@ -2,7 +2,11 @@ from typing import Any from modai.module import ModuleDependencies -from modai.modules.tools.module import ToolRegistryModule, ToolsWebModule +from modai.modules.tools.module import ( + ToolDefinition, + ToolRegistryModule, + ToolsWebModule, +) logger = logging.getLogger(__name__) @@ -12,14 +16,14 @@ class OpenAIToolsWebModule(ToolsWebModule): ToolsWebModule implementation that returns tools in OpenAI function-calling format. - Transforms each tool's OpenAPI spec into the format expected by + Converts each tool's ToolDefinition into the format expected by the OpenAI Chat Completions API: { "type": "function", "function": { - "name": "", - "description": "", - "parameters": { }, + "name": "", + "description": "", + "parameters": { }, "strict": true } } @@ -27,95 +31,24 @@ class OpenAIToolsWebModule(ToolsWebModule): def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) - self.tool_registry: ToolRegistryModule = dependencies.get_module( "tool_registry" ) async def get_tools(self) -> dict[str, Any]: tools = await self.tool_registry.get_tools() - openai_tools = [] - for tool in tools: - openai_tool = _transform_openapi_to_openai(tool.openapi_spec) - if openai_tool is not None: - openai_tools.append(openai_tool) + openai_tools = [_to_openai_format(tool.definition) for tool in tools] return {"tools": openai_tools} -def _transform_openapi_to_openai(spec: dict[str, Any]) -> dict[str, Any] | None: - """ - Transform an OpenAPI spec into OpenAI function-calling format. - - Transformation rules: - - operationId → function.name - - summary (or description) → function.description - - Request body schema → function.parameters - - strict = True → function.strict (enables structured outputs) - """ - paths = spec.get("paths", {}) - for _path, methods in paths.items(): - for _method, operation in methods.items(): - if not isinstance(operation, dict) or "operationId" not in operation: - continue - - name = operation["operationId"] - description = operation.get("summary") or operation.get("description", "") - parameters = _extract_parameters(operation, spec) - - return { - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": parameters, - "strict": True, - }, - } - - logger.warning( - "No operation with operationId found in spec: %s", spec.get("info", {}) - ) - return None - - -def _extract_parameters( - operation: dict[str, Any], spec: dict[str, Any] -) -> dict[str, Any]: - """Extract parameter schema from an OpenAPI operation's request body. - - Resolves any $ref references against the full OpenAPI spec so the - returned schema is fully inlined (OpenAI does not support $ref). - """ - request_body = operation.get("requestBody", {}) - content = request_body.get("content", {}) - json_content = content.get("application/json", {}) - schema = json_content.get("schema", {"type": "object", "properties": {}}) - return _resolve_refs(schema, spec) - - -def _resolve_refs(node: Any, spec: dict[str, Any]) -> Any: - """Recursively resolve all $ref pointers in a JSON Schema against the OpenAPI spec.""" - if isinstance(node, dict): - if "$ref" in node: - resolved = _follow_ref(node["$ref"], spec) - return _resolve_refs(resolved, spec) - return {key: _resolve_refs(value, spec) for key, value in node.items()} - if isinstance(node, list): - return [_resolve_refs(item, spec) for item in node] - return node - - -def _follow_ref(ref: str, spec: dict[str, Any]) -> dict[str, Any]: - """Follow a JSON Pointer reference like '#/components/schemas/Foo'.""" - if not ref.startswith("#/"): - logger.warning("Unsupported $ref format: %s", ref) - return {} - parts = ref.lstrip("#/").split("/") - current: Any = spec - for part in parts: - if isinstance(current, dict): - current = current.get(part) - else: - logger.warning("Could not resolve $ref path: %s", ref) - return {} - return current if isinstance(current, dict) else {} +def _to_openai_format(definition: ToolDefinition) -> dict[str, Any]: + """Convert a ToolDefinition to OpenAI function-calling format.""" + return { + "type": "function", + "function": { + "name": definition.name, + "description": definition.description, + "parameters": definition.parameters, + "strict": True, + }, + } diff --git a/backend/tools/dice-roller/.pre-commit-config.yaml b/backend/tools/dice-roller/.pre-commit-config.yaml new file mode 100644 index 0000000..a98440f --- /dev/null +++ b/backend/tools/dice-roller/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +orphan: true + +repos: + - repo: local + hooks: + - id: ruff-format + name: ruff format check + language: system + entry: uv run ruff format --check . + always_run: true + pass_filenames: false + + - id: ruff-check + name: ruff check + language: system + entry: uv run ruff check . + always_run: true + pass_filenames: false diff --git a/backend/tools/dice-roller/pyproject.toml b/backend/tools/dice-roller/pyproject.toml index cf2ab3c..53f2562 100644 --- a/backend/tools/dice-roller/pyproject.toml +++ b/backend/tools/dice-roller/pyproject.toml @@ -4,3 +4,8 @@ version = "1.0.0" description = "A showcase tool that rolls dice" requires-python = ">=3.12" dependencies = ["fastapi", "uvicorn"] + +[dependency-groups] +dev = [ + "ruff", +] diff --git a/e2e_tests/tests_omni_full/.pre-commit-config.yaml b/e2e_tests/tests_omni_full/.pre-commit-config.yaml new file mode 100644 index 0000000..9ef903d --- /dev/null +++ b/e2e_tests/tests_omni_full/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +orphan: true + +repos: + - repo: local + hooks: + - id: biome-check + name: biome check + language: system + entry: pnpm check + always_run: true + pass_filenames: false diff --git a/e2e_tests/tests_omni_light/.pre-commit-config.yaml b/e2e_tests/tests_omni_light/.pre-commit-config.yaml new file mode 100644 index 0000000..9ef903d --- /dev/null +++ b/e2e_tests/tests_omni_light/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +orphan: true + +repos: + - repo: local + hooks: + - id: biome-check + name: biome check + language: system + entry: pnpm check + always_run: true + pass_filenames: false diff --git a/frontend_omni/.pre-commit-config.yaml b/frontend_omni/.pre-commit-config.yaml new file mode 100644 index 0000000..9ef903d --- /dev/null +++ b/frontend_omni/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +orphan: true + +repos: + - repo: local + hooks: + - id: biome-check + name: biome check + language: system + entry: pnpm check + always_run: true + pass_filenames: false