diff --git a/README.md b/README.md index dd62315..21689e4 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,34 @@ # model-ledger -**The model inventory your regulator actually wants. Auto-discovered, dependency-traced, audit-ready.** +**Know what models you have deployed, where they run, what they depend on, and what changed.** [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![Python](https://img.shields.io/badge/python-3.10+-blue.svg)](https://python.org) +[![PyPI](https://img.shields.io/pypi/v/model-ledger)](https://pypi.org/project/model-ledger/) --- -model-ledger automatically discovers models, rules, pipelines, and queues across your systems — then builds the dependency graph between them. +model-ledger is a model inventory for any company with deployed models. It discovers models across your platforms, maps the dependency graph, and tracks every change as an immutable event. Unlike model registries tied to a single platform (MLflow, SageMaker, W&B), model-ledger discovers across *all* of them — as one connected graph. + +## Quick Start + +**Talk to your inventory** — point Claude (or any MCP-compatible agent) at it: + +```bash +pip install model-ledger[mcp] +claude mcp add model-ledger -- model-ledger mcp --demo +``` + +``` +You: "what models are in my inventory?" +Claude: "7 models across 5 platforms. fraud_scoring was retrained + and deployed this week. Want me to dig into anything?" + +You: "if we deprecate customer_features, what breaks?" +Claude: "3 models consume it directly, 2 more transitively." +``` + +**Or use the Python SDK:** ```python from model_ledger import Ledger, DataNode @@ -34,23 +55,41 @@ graph LR style C fill:#FF9800,color:#fff,stroke:#F57C00 ``` -Unlike model registries that track ML models only, model-ledger tracks the *entire model risk ecosystem* — ETL pipelines, heuristic rules, scoring jobs, alert queues, and ML models — as one connected graph with a full audit trail. - ## Install ```bash -pip install model-ledger # Core + SQLite backend -pip install model-ledger[snowflake] # + Snowflake backend -pip install model-ledger[rest] # + REST API connector -pip install model-ledger[github] # + GitHub connector -pip install model-ledger[all] # Everything +pip install model-ledger # Core — SDK + tools + CLI +pip install model-ledger[mcp] # + MCP server (for Claude Code / AI agents) +pip install model-ledger[rest-api] # + REST API (for frontends / dashboards) +pip install model-ledger[snowflake] # + Snowflake backend +pip install model-ledger[mcp,rest-api,snowflake] # Everything ``` ## How It Works ```mermaid graph TB - subgraph discover ["1. Discover"] + subgraph consumers ["Consumers"] + direction LR + AGENT["Claude / AI Agents
MCP"] + FRONT["Frontends
REST API"] + SCRIPT["Scripts / Notebooks
Python SDK"] + CLI_C["CLI
model-ledger"] + end + + subgraph tools ["Agent Protocol — 6 Consolidated Tools"] + direction LR + DISC["discover"] ~~~ REC["record"] ~~~ INV["investigate"] + QRY["query"] ~~~ TRC["trace"] ~~~ CHG["changelog"] + end + + subgraph sdk ["Ledger SDK"] + direction LR + REG["register()"] ~~~ RECD["record()"] ~~~ GET["get() / list()"] + HIST["history()"] ~~~ TRAC["trace()"] ~~~ CONN["connect()"] + end + + subgraph discover ["Discovery Sources"] direction LR DB["SQL databases"] --> F["sql_connector()"] API["REST APIs"] --> G["rest_connector()"] @@ -58,34 +97,99 @@ graph TB CUSTOM["Your platform"] --> I["SourceConnector protocol"] end - subgraph ledger ["2. Build Graph"] + subgraph backends ["Storage — Pluggable Backends"] direction LR - ADD["ledger.add()"] --> CON["ledger.connect()"] - CON --> |"match output ports
to input ports"| GRAPH["Dependency graph"] + JSON["JSON files
default"] + SQLITE["SQLite"] + SNOW["Snowflake"] + PLUG["Plugin
Postgres, GitHub, ..."] end - subgraph query ["3. Query"] - direction LR - TRACE["trace()"] ~~~ UP["upstream()"] ~~~ DOWN["downstream()"] ~~~ INV["inventory_at()"] - end + consumers --> tools + tools --> sdk + sdk --> discover + sdk --> backends + + style consumers fill:#F3E5F5,stroke:#7B1FA2,color:#4A148C + style tools fill:#E3F2FD,stroke:#1565C0,color:#0D47A1 + style sdk fill:#E1F5FE,stroke:#0277BD,color:#01579B + style discover fill:#E8F5E9,stroke:#2E7D32,color:#1B5E20 + style backends fill:#FFF3E0,stroke:#E65100,color:#BF360C +``` + +Every model is a **DataNode** with typed input and output ports. When an output port name matches an input port name, `connect()` creates the dependency edge automatically. Every mutation is recorded as an immutable **Snapshot** — an append-only event log that gives you full history and point-in-time reconstruction. + +## Agent Protocol + +Six consolidated tools designed for AI agents ([Anthropic's tool design guidance](https://www.anthropic.com/engineering/writing-tools-for-agents)). Each is a plain Python function with Pydantic I/O — usable via MCP, REST, CLI, or direct import. + +| Tool | What it does | Scale | +|------|-------------|-------| +| **discover** | Add models from any source — scan platforms, import files, inline data | Bulk | +| **record** | Register a model or record an event with arbitrary metadata | Single | +| **investigate** | Deep dive — identity, merged metadata, recent events, dependencies | Single | +| **query** | Search and filter the inventory with pagination | Multi | +| **trace** | Dependency graph — upstream, downstream, impact analysis | Graph | +| **changelog** | What changed across the inventory in a time range | Multi | - discover --> ledger --> query +### Using the tools directly - style discover fill:#E3F2FD,stroke:#1565C0,color:#0D47A1 - style ledger fill:#E8F5E9,stroke:#2E7D32,color:#1B5E20 - style query fill:#FFF3E0,stroke:#E65100,color:#BF360C +```python +from model_ledger import Ledger, record, investigate, query +from model_ledger.tools.schemas import RecordInput, InvestigateInput, QueryInput +from model_ledger.graph.models import DataNode + +ledger = Ledger.from_sqlite("./inventory.db") + +# Register a model +record(RecordInput( + model_name="fraud_scoring", event="registered", + owner="risk-team", model_type="ml_model", + purpose="Real-time fraud detection", +), ledger) + +# Record an event with schema-free payload +record(RecordInput( + model_name="fraud_scoring", event="retrained", + payload={"accuracy": 0.94, "features_added": ["velocity_24h"]}, + actor="ml-pipeline", +), ledger) + +# Deep dive +result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) +result.metadata # {"accuracy": 0.94, "features_added": ["velocity_24h"]} +result.total_events # 2 + +# Search +models = query(QueryInput(text="fraud", model_type="ml_model"), ledger) +models.total # 1 +``` + +### MCP server + +```bash +model-ledger mcp # empty inventory +model-ledger mcp --demo # sample data +model-ledger mcp --backend sqlite --path ./inventory.db # SQLite +model-ledger mcp --backend json --path ./my-inventory # JSON files + +# Connect to Claude Code (one time) +claude mcp add model-ledger -- model-ledger mcp ``` -Every model is a **DataNode** with typed input and output ports. When an output port name matches an input port name, `connect()` creates the dependency edge automatically. +### REST API -Every mutation is recorded as an immutable **Snapshot** — an append-only event log. Nothing is deleted. This gives you a complete audit trail and point-in-time inventory reconstruction for any date. +```bash +model-ledger serve # start on port 8000 +model-ledger serve --demo --port 3001 # with sample data +``` + +Auto-generated OpenAPI docs at `/docs`. Endpoints: `POST /record`, `POST /discover`, `GET /query`, `GET /investigate/{name}`, `GET /trace/{name}`, `GET /changelog`, `GET /overview`. ## Discover Models From Your Systems ### SQL databases -Most discovery is "query a table, map rows to models." The `sql_connector` factory handles this without writing classes: - ```python from model_ledger import Ledger, sql_connector @@ -146,7 +250,7 @@ pipelines = github_connector( ### Custom connectors -For anything the factories don't cover, implement the `SourceConnector` protocol: +Implement the `SourceConnector` protocol for anything the factories don't cover: ```python class SageMakerConnector: @@ -162,18 +266,44 @@ class SageMakerConnector: ] ``` -## Persistent Storage +## Storage + +Storage-agnostic. Default is JSON files — human-readable, git-friendly, zero config. Upgrade when you need scale. ```python from model_ledger import Ledger +from model_ledger.backends.json_files import JsonFileLedgerBackend -ledger = Ledger.from_sqlite("./inventory.db") # SQLite — zero infrastructure -ledger = Ledger.from_snowflake(connection, schema="DB.MODEL_LEDGER") # Snowflake — production scale -ledger = Ledger() # In-memory — testing -ledger = Ledger(my_custom_backend) # Custom — LedgerBackend protocol +ledger = Ledger(JsonFileLedgerBackend("./my-inventory")) # JSON files — default +ledger = Ledger.from_sqlite("./inventory.db") # SQLite — zero infrastructure +ledger = Ledger.from_snowflake(connection, schema="DB.MODEL_LEDGER") # Snowflake — production +ledger = Ledger() # In-memory — testing ``` -## Key Capabilities +JSON file layout — inspect, diff, and version-control your inventory: + +``` +my-inventory/ +├── models/ +│ ├── fraud_scoring.json +│ └── churn_predictor.json +├── snapshots/ +│ ├── a1b2c3d4.json +│ └── e5f6g7h8.json +└── tags/ + └── {model_hash}/ + └── v1.json +``` + +Add community backends via entry points: + +```toml +# pyproject.toml +[project.entry-points."model_ledger.backends"] +postgres = "my_package:PostgresBackend" +``` + +## Additional Capabilities ### Dependency tracing @@ -181,74 +311,55 @@ ledger = Ledger(my_custom_backend) # Custom ledger.trace("fraud_alerts") # Full pipeline path ledger.upstream("fraud_alerts") # Everything that feeds this ledger.downstream("segmentation") # Everything that depends on this -ledger.dependencies("fraud_alerts", direction="upstream") # Detailed with relationship info ``` ### Shared table disambiguation -When multiple models write to the same table, `DataPort` handles precision matching: +When multiple models write to the same table, `DataPort` schema matching handles precision: ```python from model_ledger import DataPort, DataNode -# Two models write to the same alert table with different model_name values DataNode("check_rules", outputs=[DataPort("alerts", model_name="checks")]) DataNode("card_rules", outputs=[DataPort("alerts", model_name="cards")]) - -# This reader only connects to check_rules — model_name must match DataNode("check_queue", inputs=[DataPort("alerts", model_name="checks")]) +# check_queue connects to check_rules only — model_name must match ``` ### Point-in-time inventory ```python -from datetime import datetime inventory = ledger.inventory_at(datetime(2025, 12, 31)) # Every model that was active on that date ``` -### Compliance validation - -Built-in profiles for major model risk regulations: +### Compliance validation (plugin) -| Profile | Regulation | Checks | -|---------|-----------|--------| -| `sr_11_7` | US Federal Reserve SR 11-7 | Validator independence, governance docs, validation schedule | -| `eu_ai_act` | EU AI Act (2024/1689) | Risk classification, data governance, human oversight | -| `nist_ai_rmf` | NIST AI RMF 1.0 | GOVERN, MAP, MEASURE, MANAGE functions | +Built-in profiles for SR 11-7, EU AI Act, and NIST AI RMF. Add custom profiles for your organization's policies. See [validation docs](docs/) for details. ### Model introspection -Extract metadata from fitted ML models: - -```python -from model_ledger import introspect - -result = introspect(fitted_model) -result.algorithm # "XGBClassifier" -result.features # [FeatureInfo(name="velocity_30d", ...), ...] -result.hyperparameters # {"n_estimators": 50, "max_depth": 4} -``` - -Ships with sklearn, XGBoost, and LightGBM support. Add your own via the `Introspector` protocol. +Extract metadata from fitted sklearn, XGBoost, and LightGBM models. Add custom introspectors via the `Introspector` protocol. See [introspection docs](docs/) for details. ## Design Principles +- **Agents are the primary interface** — the MCP server is the product. SDK and CLI are still first-class, but the agent experience is what we optimize for. +- **Fundamental, not specialized** — model inventory for any company with deployed models. Not tied to a specific regulatory framework or industry. - **Everything is a DataNode** — ML models, heuristic rules, ETL pipelines, alert queues. One abstraction. - **The graph builds itself** — declare inputs and outputs. Dependencies follow from port matching. -- **Schema-agnostic metadata** — `Snapshot.payload` is `dict[str, Any]`. The framework stores whatever your connectors discover. -- **Append-only audit trail** — every change is an immutable Snapshot. Full history, point-in-time queries. -- **Factory for the 80%, protocol for the 20%** — config-driven factories for common patterns, open protocols for anything custom. -- **Batteries included** — persistence, discovery, graph building, and compliance with zero infrastructure. +- **Schema-free payloads** — record whatever metadata matters. No schema to maintain, no migrations. +- **Change tracking is central** — every mutation is an immutable Snapshot. The inventory is a living event log. +- **Storage-agnostic** — JSON files, SQLite, Snowflake, or bring your own via the `LedgerBackend` protocol. ## For Organizations -model-ledger is designed as a core framework with lightweight organization-specific extensions. The OSS core handles graph building, storage, compliance, and the connector factories. Your internal package provides: +The OSS core handles discovery, graph building, change tracking, storage, and the agent protocol. Your internal package provides: -- **Connector configs** — point `sql_connector()` at your tables, `rest_connector()` at your APIs +- **Connector configs** — point factories at your tables and APIs - **Custom connectors** — for internal platforms the factories don't cover -- **Authentication** — your database/API credentials and auth wrappers -- **Additional compliance profiles** — OSFI E-23, PRA SS1/23, MAS AIRG, or internal policies +- **Authentication** — your credentials and auth wrappers +- **Custom backends** — Postgres, GitHub repos, or any storage via `LedgerBackend` protocol +- **Compliance profiles** — SR 11-7, EU AI Act, or your own internal policies (plugin-based) Your internal repo should be thin config and credentials, not reimplemented logic. diff --git a/docs/plans/2026-04-01-v040-datanode-graph.md b/docs/plans/2026-04-01-v040-datanode-graph.md deleted file mode 100644 index 5c8531a..0000000 --- a/docs/plans/2026-04-01-v040-datanode-graph.md +++ /dev/null @@ -1,713 +0,0 @@ -# v0.4.0 DataNode Graph — Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add `DataNode`, `DataPort`, and graph methods (`add`, `connect`, `trace`, `upstream`, `downstream`) to the Ledger class, enabling auto-discovery of model dependency graphs from port matching. - -**Architecture:** `DataNode` and `DataPort` are new data models in `graph/`. `DataPort` is a smart string — acts like a string by default, carries optional schema for shared-table discriminators. The `Ledger` class gets 5 new methods that decompose into existing `register()`, `record()`, and `link_dependency()` calls. No new storage — everything persists through the existing Snapshot/ModelRef infrastructure. - -**Tech Stack:** Python 3.10+, Pydantic (existing), pytest - ---- - -## File Map - -| Action | Path | Responsibility | -|--------|------|----------------| -| Create | `src/model_ledger/graph/__init__.py` | Package init, re-exports | -| Create | `src/model_ledger/graph/models.py` | `DataNode`, `DataPort` | -| Create | `src/model_ledger/graph/protocol.py` | `SourceConnector` protocol | -| Modify | `src/model_ledger/sdk/ledger.py` | Add `add()`, `connect()`, `trace()`, `upstream()`, `downstream()` | -| Modify | `src/model_ledger/__init__.py` | Export `DataNode`, `DataPort` | -| Create | `tests/test_graph/__init__.py` | Test package | -| Create | `tests/test_graph/test_models.py` | DataNode and DataPort tests | -| Create | `tests/test_graph/test_ledger_graph.py` | Ledger graph method tests | - ---- - -### Task 1: DataPort - -**Files:** -- Create: `src/model_ledger/graph/__init__.py` -- Create: `src/model_ledger/graph/models.py` -- Create: `tests/test_graph/__init__.py` -- Create: `tests/test_graph/test_models.py` - -- [ ] **Step 1: Write the failing test** - -```python -# tests/test_graph/__init__.py -# (empty) -``` - -```python -# tests/test_graph/test_models.py -"""Tests for DataNode and DataPort.""" - -from model_ledger.graph.models import DataPort, DataNode - - -class TestDataPort: - def test_create_from_string_identifier(self): - p = DataPort("my_table") - assert p.identifier == "my_table" - assert p.schema == {} - - def test_lowercases_identifier(self): - p = DataPort("APP_COMPLIANCE.CASH.TABLE") - assert p.identifier == "app_compliance.cash.table" - - def test_equality_same_identifier(self): - assert DataPort("table_a") == DataPort("table_a") - - def test_equality_case_insensitive(self): - assert DataPort("TABLE_A") == DataPort("table_a") - - def test_inequality_different_identifier(self): - assert DataPort("table_a") != DataPort("table_b") - - def test_equality_with_matching_schema(self): - a = DataPort("shared_table", model_name="tm-m2o") - b = DataPort("shared_table", model_name="tm-m2o") - assert a == b - - def test_inequality_with_different_schema(self): - a = DataPort("shared_table", model_name="tm-m2o") - b = DataPort("shared_table", model_name="tm-o2m") - assert a != b - - def test_equality_one_has_no_schema(self): - a = DataPort("shared_table", model_name="tm-m2o") - b = DataPort("shared_table") - assert a == b # no schema on b means "match any" - - def test_schema_like_pattern(self): - a = DataPort("scores", model_name="tm-%") - b = DataPort("scores", model_name="tm-m2o") - assert a == b - - def test_schema_like_pattern_no_match(self): - a = DataPort("scores", model_name="tm-%") - b = DataPort("scores", model_name="uup-gambling") - assert a != b - - def test_hashable(self): - s = {DataPort("a"), DataPort("a"), DataPort("b")} - assert len(s) == 2 - - def test_repr(self): - p = DataPort("table", model_name="x") - assert "table" in repr(p) - assert "model_name" in repr(p) - - def test_repr_simple(self): - p = DataPort("table") - assert "table" in repr(p) - - -class TestDataNode: - def test_create_with_string_inputs(self): - node = DataNode("scorer", inputs=["features", "segments"], outputs=["scores"]) - assert len(node.inputs) == 2 - assert all(isinstance(p, DataPort) for p in node.inputs) - assert node.inputs[0].identifier == "features" - - def test_create_with_dataport_inputs(self): - node = DataNode("scorer", inputs=[DataPort("features")], outputs=["scores"]) - assert node.inputs[0].identifier == "features" - - def test_create_mixed_inputs(self): - node = DataNode("scorer", - inputs=["features", DataPort("scores", model_name="tm-m2o")], - outputs=["alerts"]) - assert node.inputs[0].identifier == "features" - assert node.inputs[1].schema == {"model_name": "tm-m2o"} - - def test_defaults(self): - node = DataNode("simple") - assert node.platform == "" - assert node.inputs == [] - assert node.outputs == [] - assert node.metadata == {} - - def test_metadata(self): - node = DataNode("scorer", platform="ml-serving", - metadata={"owner": "ml-team", "version": "v3"}) - assert node.metadata["owner"] == "ml-team" -``` - -- [ ] **Step 2: Run test to verify it fails** - -Run: `cd ~/Development/model-ledger && uv run pytest tests/test_graph/test_models.py -v` -Expected: FAIL — `ModuleNotFoundError` - -- [ ] **Step 3: Write implementation** - -```python -# src/model_ledger/graph/__init__.py -"""Graph-based model discovery — DataNode, DataPort, SourceConnector.""" - -from model_ledger.graph.models import DataNode, DataPort - -__all__ = ["DataNode", "DataPort"] -``` - -```python -# src/model_ledger/graph/models.py -"""DataNode and DataPort — the core graph primitives.""" - -from __future__ import annotations - -import re -from dataclasses import dataclass, field -from typing import Any - - -class DataPort: - """A connection point. Acts like a string, carries optional schema for precision. - - Simple case — just an identifier: - DataPort("transactions_table") - - With discriminator — for shared tables: - DataPort("batch_scores_archive", model_name="tm-m2o") - - Matching: - - Two ports match if identifiers are equal (case-insensitive) - - If both have a schema key, values must match (supports SQL LIKE %) - - If only one has a schema key, it matches anything - """ - - __slots__ = ("identifier", "schema") - - def __init__(self, identifier: str, **schema: str) -> None: - self.identifier = identifier.lower() - self.schema = schema - - def __eq__(self, other: object) -> bool: - if isinstance(other, str): - return self.identifier == other.lower() - if not isinstance(other, DataPort): - return NotImplemented - if self.identifier != other.identifier: - return False - for key in set(self.schema) & set(other.schema): - if not _value_matches(self.schema[key], other.schema[key]): - return False - return True - - def __hash__(self) -> int: - return hash(self.identifier) - - def __repr__(self) -> str: - if self.schema: - params = ", ".join(f"{k}={v!r}" for k, v in self.schema.items()) - return f"DataPort({self.identifier!r}, {params})" - return f"DataPort({self.identifier!r})" - - -def _value_matches(a: str, b: str) -> bool: - """Match values, supporting SQL LIKE patterns (% wildcard).""" - if "%" in a: - pattern = "^" + re.escape(a).replace("%", ".*") + "$" - return bool(re.match(pattern, b, re.IGNORECASE)) - if "%" in b: - pattern = "^" + re.escape(b).replace("%", ".*") + "$" - return bool(re.match(pattern, a, re.IGNORECASE)) - return a.lower() == b.lower() - - -@dataclass -class DataNode: - """A model, job, rule, or workflow that transforms data. - - Inputs and outputs can be strings (auto-wrapped as DataPort) - or DataPort objects (for shared-table discriminators). - - Example: - >>> node = DataNode("fraud_scorer", platform="ml-serving", - ... inputs=["features", "segments"], - ... outputs=["scores"]) - """ - - name: str - platform: str = "" - inputs: list[DataPort] = field(default_factory=list) - outputs: list[DataPort] = field(default_factory=list) - metadata: dict[str, Any] = field(default_factory=dict) - - def __post_init__(self) -> None: - self.inputs = [DataPort(x) if isinstance(x, str) else x for x in self.inputs] - self.outputs = [DataPort(x) if isinstance(x, str) else x for x in self.outputs] -``` - -- [ ] **Step 4: Run tests** - -Run: `cd ~/Development/model-ledger && uv run pytest tests/test_graph/test_models.py -v` -Expected: PASS (15 tests) - -- [ ] **Step 5: Commit** - -```bash -cd ~/Development/model-ledger -git add src/model_ledger/graph/ tests/test_graph/ -git commit -m "feat: add DataNode and DataPort graph primitives" -``` - ---- - -### Task 2: SourceConnector Protocol - -**Files:** -- Create: `src/model_ledger/graph/protocol.py` - -- [ ] **Step 1: Write implementation** - -```python -# src/model_ledger/graph/protocol.py -"""SourceConnector protocol — the extension point for platform discovery.""" - -from __future__ import annotations - -from typing import Protocol, runtime_checkable - -from model_ledger.graph.models import DataNode - - -@runtime_checkable -class SourceConnector(Protocol): - """Discovers DataNodes from a platform. - - Implement this to connect a new data source (Snowflake, SageMaker, - Airflow, etc.) to the model-ledger graph. - - Example: - class MyConnector: - name = "my_platform" - def discover(self) -> list[DataNode]: - return [DataNode("model_a", inputs=["table_x"], outputs=["scores"])] - """ - - name: str - - def discover(self) -> list[DataNode]: ... -``` - -- [ ] **Step 2: Update graph __init__.py** - -```python -# src/model_ledger/graph/__init__.py -"""Graph-based model discovery — DataNode, DataPort, SourceConnector.""" - -from model_ledger.graph.models import DataNode, DataPort -from model_ledger.graph.protocol import SourceConnector - -__all__ = ["DataNode", "DataPort", "SourceConnector"] -``` - -- [ ] **Step 3: Commit** - -```bash -cd ~/Development/model-ledger -git add src/model_ledger/graph/protocol.py src/model_ledger/graph/__init__.py -git commit -m "feat: add SourceConnector protocol" -``` - ---- - -### Task 3: Ledger Graph Methods - -**Files:** -- Modify: `src/model_ledger/sdk/ledger.py` -- Create: `tests/test_graph/test_ledger_graph.py` - -- [ ] **Step 1: Write the failing test** - -```python -# tests/test_graph/test_ledger_graph.py -"""Tests for Ledger graph methods — add, connect, trace, upstream, downstream.""" - -import pytest - -from model_ledger.graph.models import DataNode, DataPort -from model_ledger.sdk.ledger import Ledger - - -@pytest.fixture -def ledger(): - return Ledger() - - -class TestAdd: - def test_add_single_node(self, ledger): - node = DataNode("scorer", platform="ml-serving", - inputs=["features"], outputs=["scores"]) - ledger.add(node) - model = ledger.get("scorer") - assert model.name == "scorer" - - def test_add_list_of_nodes(self, ledger): - nodes = [ - DataNode("a", platform="x", inputs=[], outputs=["t1"]), - DataNode("b", platform="x", inputs=["t1"], outputs=[]), - ] - ledger.add(nodes) - assert len(ledger.list()) == 2 - - def test_add_creates_discovered_snapshot(self, ledger): - node = DataNode("scorer", platform="ml-serving", - inputs=["features"], outputs=["scores"], - metadata={"owner": "ml-team"}) - ledger.add(node) - snaps = ledger.history("scorer") - assert len(snaps) >= 1 - discovered = [s for s in snaps if s.event_type == "discovered"] - assert len(discovered) == 1 - assert discovered[0].payload["platform"] == "ml-serving" - assert discovered[0].payload["inputs"] == [{"identifier": "features"}] - - def test_add_idempotent(self, ledger): - node = DataNode("scorer", platform="ml-serving", outputs=["scores"]) - ledger.add(node) - ledger.add(node) - assert len(ledger.list()) == 1 - - -class TestConnect: - def test_connect_matching_ports(self, ledger): - ledger.add([ - DataNode("writer", outputs=["shared_table"]), - DataNode("reader", inputs=["shared_table"]), - ]) - result = ledger.connect() - assert result["links_created"] >= 1 - - deps = ledger.dependencies("reader", direction="upstream") - assert any(d["model"].name == "writer" for d in deps) - - def test_connect_no_match(self, ledger): - ledger.add([ - DataNode("a", outputs=["table_x"]), - DataNode("b", inputs=["table_y"]), - ]) - result = ledger.connect() - assert result["links_created"] == 0 - - def test_connect_skips_self_refs(self, ledger): - ledger.add(DataNode("a", inputs=["t"], outputs=["t"])) - result = ledger.connect() - assert result["links_created"] == 0 - - def test_connect_shared_table_with_discriminator(self, ledger): - ledger.add([ - DataNode("writer_a", outputs=[DataPort("shared", model_name="model_a")]), - DataNode("writer_b", outputs=[DataPort("shared", model_name="model_b")]), - DataNode("reader_a", inputs=[DataPort("shared", model_name="model_a")]), - ]) - result = ledger.connect() - deps = ledger.dependencies("reader_a", direction="upstream") - upstream_names = [d["model"].name for d in deps] - assert "writer_a" in upstream_names - assert "writer_b" not in upstream_names - - def test_connect_pipeline(self, ledger): - ledger.add([ - DataNode("segmentation", outputs=["segments"]), - DataNode("scoring", inputs=["segments"], outputs=["scores"]), - DataNode("alerting", inputs=["scores"], outputs=["alerts"]), - ]) - ledger.connect() - # alerting depends on scoring depends on segmentation - deps = ledger.dependencies("alerting", direction="upstream") - assert any(d["model"].name == "scoring" for d in deps) - deps2 = ledger.dependencies("scoring", direction="upstream") - assert any(d["model"].name == "segmentation" for d in deps2) - - -class TestTrace: - def test_trace_returns_ordered_pipeline(self, ledger): - ledger.add([ - DataNode("seg", outputs=["segments"]), - DataNode("score", inputs=["segments"], outputs=["scores"]), - DataNode("alert", inputs=["scores"]), - ]) - ledger.connect() - pipeline = ledger.trace("alert") - assert pipeline == ["seg", "score", "alert"] - - def test_trace_single_node(self, ledger): - ledger.add(DataNode("standalone")) - ledger.connect() - assert ledger.trace("standalone") == ["standalone"] - - def test_trace_not_found(self, ledger): - with pytest.raises(Exception): - ledger.trace("nonexistent") - - -class TestUpstreamDownstream: - def test_upstream(self, ledger): - ledger.add([ - DataNode("a", outputs=["t1"]), - DataNode("b", inputs=["t1"], outputs=["t2"]), - DataNode("c", inputs=["t2"]), - ]) - ledger.connect() - assert "a" in ledger.upstream("b") - assert "a" in ledger.upstream("c") - assert "b" in ledger.upstream("c") - - def test_downstream(self, ledger): - ledger.add([ - DataNode("a", outputs=["t1"]), - DataNode("b", inputs=["t1"], outputs=["t2"]), - DataNode("c", inputs=["t2"]), - ]) - ledger.connect() - assert "b" in ledger.downstream("a") - assert "c" in ledger.downstream("a") - - def test_upstream_empty(self, ledger): - ledger.add(DataNode("root", outputs=["t1"])) - ledger.connect() - assert ledger.upstream("root") == [] -``` - -- [ ] **Step 2: Run test to verify it fails** - -Run: `cd ~/Development/model-ledger && uv run pytest tests/test_graph/test_ledger_graph.py -v` -Expected: FAIL — `AttributeError: 'Ledger' object has no attribute 'add'` - -- [ ] **Step 3: Write implementation** - -Add these methods to `src/model_ledger/sdk/ledger.py`, at the end of the `Ledger` class: - -```python - # --- Graph methods (v0.4.0) --- - - def add(self, nodes: DataNode | list[DataNode]) -> None: - """Register DataNodes. Each becomes a ModelRef + discovered Snapshot.""" - from model_ledger.graph.models import DataNode, DataPort - - if isinstance(nodes, DataNode): - nodes = [nodes] - for node in nodes: - self.register( - name=node.name, - owner=node.metadata.get("owner", "unknown"), - model_type=node.metadata.get("node_type", "unknown"), - tier=node.metadata.get("tier", "unclassified"), - purpose=node.metadata.get("purpose", ""), - model_origin=node.metadata.get("model_origin", "internal"), - actor=f"connector:{node.platform}" if node.platform else "system", - ) - self.record( - node.name, - event="discovered", - payload={ - "platform": node.platform, - "inputs": [ - {"identifier": p.identifier, **p.schema} - for p in node.inputs - ], - "outputs": [ - {"identifier": p.identifier, **p.schema} - for p in node.outputs - ], - **{k: v for k, v in node.metadata.items() - if k not in ("owner", "node_type", "tier", "purpose", "model_origin")}, - }, - actor=f"connector:{node.platform}" if node.platform else "system", - ) - - def connect(self) -> dict[str, int]: - """Match output ports to input ports. Write dependency links.""" - from collections import defaultdict - from model_ledger.graph.models import DataNode, DataPort - - nodes = self._load_discovered_nodes() - output_index: dict[str, list[tuple[DataNode, DataPort]]] = defaultdict(list) - for node in nodes: - for port in node.outputs: - output_index[port.identifier].append((node, port)) - - links_created = 0 - seen: set[tuple[str, str]] = set() - for node in nodes: - for in_port in node.inputs: - for upstream_node, out_port in output_index.get(in_port.identifier, []): - if upstream_node.name == node.name: - continue - if not (out_port == in_port): - continue - key = (upstream_node.name, node.name) - if key in seen: - continue - seen.add(key) - try: - self.link_dependency( - upstream=upstream_node.name, - downstream=node.name, - relationship="data_flow", - actor="graph_builder", - metadata={ - "via": in_port.identifier, - "via_schema": in_port.schema if in_port.schema else None, - }, - ) - links_created += 1 - except ModelNotFoundError: - continue - return {"links_created": links_created} - - def trace(self, name: str) -> list[str]: - """Topological path from sources to this node.""" - self._resolve_model(name) # raises if not found - visited: set[str] = set() - order: list[str] = [] - - def _walk(n: str) -> None: - if n in visited: - return - visited.add(n) - for dep in self.dependencies(n, direction="upstream"): - _walk(dep["model"].name) - order.append(n) - - _walk(name) - return order - - def upstream(self, name: str) -> list[str]: - """All models this one depends on (transitive).""" - path = self.trace(name) - return [n for n in path if n != name] - - def downstream(self, name: str) -> list[str]: - """All models that depend on this one (transitive).""" - self._resolve_model(name) - visited: set[str] = set() - result: list[str] = [] - - def _walk(n: str) -> None: - for dep in self.dependencies(n, direction="downstream"): - child = dep["model"].name - if child not in visited: - visited.add(child) - result.append(child) - _walk(child) - - _walk(name) - return result - - def _load_discovered_nodes(self) -> list[DataNode]: - """Rebuild DataNodes from stored discovery snapshots.""" - from model_ledger.graph.models import DataNode, DataPort - - nodes = [] - for model in self._backend.list_models(): - snaps = self._backend.list_snapshots(model.model_hash) - discovered = [s for s in snaps if s.event_type == "discovered"] - if not discovered: - continue - latest = max(discovered, key=lambda s: s.timestamp) - payload = latest.payload - inputs = [ - DataPort(p["identifier"], **{k: v for k, v in p.items() if k != "identifier"}) - for p in payload.get("inputs", []) - ] - outputs = [ - DataPort(p["identifier"], **{k: v for k, v in p.items() if k != "identifier"}) - for p in payload.get("outputs", []) - ] - nodes.append(DataNode( - name=model.name, - platform=payload.get("platform", ""), - inputs=inputs, - outputs=outputs, - metadata={k: v for k, v in payload.items() - if k not in ("platform", "inputs", "outputs")}, - )) - return nodes -``` - -Also add the import at the top of the file (after existing imports): - -```python -from __future__ import annotations - -from typing import TYPE_CHECKING -# ... existing imports ... - -if TYPE_CHECKING: - from model_ledger.graph.models import DataNode -``` - -- [ ] **Step 4: Run tests** - -Run: `cd ~/Development/model-ledger && uv run pytest tests/test_graph/ -v` -Expected: PASS (all tests) - -- [ ] **Step 5: Run existing tests to verify no regressions** - -Run: `cd ~/Development/model-ledger && uv run pytest tests/ -v` -Expected: ALL PASS - -- [ ] **Step 6: Commit** - -```bash -cd ~/Development/model-ledger -git add src/model_ledger/sdk/ledger.py tests/test_graph/test_ledger_graph.py -git commit -m "feat: add graph methods to Ledger (add, connect, trace, upstream, downstream)" -``` - ---- - -### Task 4: Public API Exports - -**Files:** -- Modify: `src/model_ledger/__init__.py` - -- [ ] **Step 1: Update exports** - -Add to the imports: -```python -from model_ledger.graph.models import DataNode, DataPort -from model_ledger.graph.protocol import SourceConnector -``` - -Add to `__all__`: -```python - # v0.4.0 — graph - "DataNode", - "DataPort", - "SourceConnector", -``` - -Update version: -```python -__version__ = "0.4.0" -``` - -- [ ] **Step 2: Verify imports work** - -Run: `cd ~/Development/model-ledger && uv run python -c "from model_ledger import Ledger, DataNode, DataPort, SourceConnector; print('OK')"` - -- [ ] **Step 3: Commit** - -```bash -cd ~/Development/model-ledger -git add src/model_ledger/__init__.py -git commit -m "feat: export DataNode, DataPort, SourceConnector — model-ledger v0.4.0" -``` - ---- - -## Summary - -| Task | Component | Tests | -|------|-----------|-------| -| 1 | DataPort + DataNode | 15 | -| 2 | SourceConnector protocol | 0 | -| 3 | Ledger graph methods | 14 | -| 4 | Public exports | 0 (verification) | -| **Total** | | **29 tests** | diff --git a/docs/technical-design.md b/docs/technical-design.md deleted file mode 100644 index 64e46b5..0000000 --- a/docs/technical-design.md +++ /dev/null @@ -1,179 +0,0 @@ -# model-ledger — Technical Design - -A typed, event-sourced model inventory with pluggable storage, multi-platform scanning, dependency tracking, and executable compliance profiles. - -**Author**: Vignesh Narayanaswamy -**Version**: 0.3.0 -**Prerequisite**: [What & Why](what-and-why.md) covers the motivation and strategic context. - ---- - -## Overview - -model-ledger has six layers: - -``` -┌──────────────────────────────────────────────────────┐ -│ Export audit packs, gap reports │ -├──────────────────────────────────────────────────────┤ -│ Validation SR 11-7, EU AI Act, NIST AI RMF │ -├──────────────────────────────────────────────────────┤ -│ Scanner Scanner protocol, InventoryScanner, │ -│ ScannerRegistry, DBConnection │ -├──────────────────────────────────────────────────────┤ -│ SDK Ledger (register, record, tag, │ -│ link_dependency, inventory_at) │ -├──────────────────────────────────────────────────────┤ -│ Core ModelRef, Snapshot, Tag, exceptions │ -├──────────────────────────────────────────────────────┤ -│ Storage LedgerBackend protocol, InMemory │ -└──────────────────────────────────────────────────────┘ -``` - -Organization-specific scanners and backends sit alongside, not above — they depend on `model-ledger`, never the reverse. - ---- - -## Core Data Model - -### ModelRef — Regulatory Identity - -The minimum a regulator needs to know about a model's existence. - -```python -class ModelRef(BaseModel): - model_hash: str # sha256(name + owner + created_at)[:32] - name: str # human-readable label - owner: str # accountable team or individual - model_type: str # "ml_model", "heuristic", "signal", "vendor", "llm" - model_origin: str # "internal", "vendor", "api", "open_source" - tier: str # risk tier - purpose: str # what the model does - status: str # "active", "retired", "draft" - created_at: datetime # when first registered -``` - -### Snapshot — Immutable Observation - -Content-addressed, timestamped record of something that happened to or was observed about a model. - -```python -class Snapshot(BaseModel): - snapshot_hash: str # sha256(model_hash + timestamp + payload)[:32] - model_hash: str # which model this is about - parent_hash: str | None # chain snapshots together - timestamp: datetime # when this was recorded - actor: str # who/what created this (human, scanner, CI) - event_type: str # "registered", "discovered", "scan_confirmed", - # "not_found", "enriched", "depends_on", etc. - source: str | None # which scanner/system provided this - payload: dict[str, Any] # schema-free — scanner metadata, enrichment, etc. - tags: dict[str, str] # arbitrary key-value metadata -``` - -### Tag — Mutable Pointer - -Like a git tag or branch — a named pointer to a specific Snapshot. - -```python -class Tag(BaseModel): - name: str # "latest", "v3", "prod" - model_hash: str - snapshot_hash: str - updated_at: datetime -``` - ---- - -## Ledger SDK - -Every method is tool-shaped: clear inputs, JSON-serializable outputs, no side effects beyond the ledger. - -| Method | Purpose | -|---|---| -| `register(name, owner, ...)` | Create a ModelRef + "registered" Snapshot | -| `record(model, event, payload, actor)` | Append an immutable Snapshot | -| `tag(model, name)` | Point a Tag at the latest Snapshot | -| `get(name_or_hash)` | Retrieve a ModelRef | -| `list(**filters)` | Filter models by any field | -| `history(model)` | All Snapshots, newest first | -| `latest(model, tag?)` | Most recent Snapshot (or tagged one) | -| `link_dependency(upstream, downstream)` | Bidirectional dependency Snapshots | -| `dependencies(model, direction)` | Query dependency graph | -| `inventory_at(date, platform?)` | Point-in-time reconstruction | - ---- - -## Scanner Architecture - -### Scanner Protocol - -```python -class Scanner(Protocol): - name: str - platform_type: str - def scan(self) -> list[ModelCandidate]: ... - def has_changed(self, last_scan: datetime) -> bool: ... - -class EnrichableScanner(Scanner, Protocol): - def enrich(self, candidate: ModelCandidate) -> dict: ... -``` - -### ModelCandidate - -```python -class ModelCandidate(BaseModel): - name: str - owner: str | None - model_type: str - platform: str - platform_id: str | None - parent_name: str | None # hierarchy support - external_ids: dict[str, str] # cross-platform dedup - metadata: dict[str, Any] -``` - -### InventoryScanner - -Orchestrates multiple scanners with: -- **filter_fn** — post-scan, pre-registration filtering -- **scan_run_id** — groups all snapshots from one scan run -- **not_found tracking** — records when models disappear from a platform -- **has_changed** — skips scan if platform hasn't changed -- **enrich** — calls EnrichableScanner.enrich() and records results - -### ScannerRegistry - -Discovers scanners via `importlib.metadata.entry_points(group="model_ledger.scanners")`. Install a scanner package, it auto-registers. - -### DBConnection Protocol - -```python -class DBConnection(Protocol): - def execute(self, query: str, params: dict | None = None) -> list[dict]: ... -``` - -Thin abstraction for SQL-based scanners. Any database client implements this. - ---- - -## Extension Points - -All extension points use `@runtime_checkable` Protocol — no abstract base classes. - -| Protocol | Purpose | Entry Point Group | -|---|---|---| -| `Scanner` | Discover models on a platform | `model_ledger.scanners` | -| `Introspector` | Extract metadata from fitted models | `model_ledger.introspectors` | -| `LedgerBackend` | Storage for ModelRef/Snapshot/Tag | — | -| `DBConnection` | SQL access for scanners | — | - ---- - -## Design Principles - -1. **The inventory is an event log, not a table.** Never mutate — always append. -2. **Schema is discovered, not declared.** Scanners observe and record. The ledger stores what was found. -3. **Agents are first-class consumers.** Every SDK function maps to a tool call. -4. **The stable core is tiny.** Identity + snapshot + tag. Everything else is a plugin. -5. **Protocol-first.** No base classes. Implementations own all complexity. diff --git a/docs/what-and-why.md b/docs/what-and-why.md deleted file mode 100644 index 898f589..0000000 --- a/docs/what-and-why.md +++ /dev/null @@ -1,157 +0,0 @@ -# model-ledger - -An open-source model inventory and governance framework. - -**Author**: Vignesh Narayanaswamy -**Date**: March 2026 -**License**: Apache-2.0 - ---- - -## Overview - -model-ledger is a Python library that provides a typed, version-controlled, machine-readable inventory for model risk management. It implements the structural requirements of SR 11-7 and related regulatory frameworks as executable code — not as checklists, spreadsheets, or commercial platforms. - -The library is designed so that both humans and AI agents can consume, traverse, validate, and act on governance metadata. - -``` -pip install model-ledger -``` - ---- - -## Background - -### What a model inventory is - -Every regulated financial institution that uses models — credit risk, fraud detection, transaction monitoring, pricing — is required to maintain an inventory of those models. The Federal Reserve's SR 11-7 guidance states explicitly: *"Banks should maintain a comprehensive set of information for models implemented for use, under development for implementation, or recently retired."* - -This inventory must track model identity, ownership, purpose, risk tier, structural components (inputs, processing logic, outputs), governance documents, validation history, and findings. Examiners expect to see it. Internal audit expects to review it. Model validators need it to do their work. - -### How the industry does it today - -Spreadsheets. At most of the financial industry. A model inventory is typically an Excel workbook or a SharePoint list maintained by the model risk team. It tracks 20-50 models with columns for name, owner, tier, status, last validation date. - -This fails in predictable ways: - -- **Stale data.** The spreadsheet drifts from reality within weeks. Nobody's workflow includes "update the inventory spreadsheet." -- **No audit trail.** When did the tier change? Who approved it? The spreadsheet doesn't know. -- **Flat structure.** SR 11-7 defines a model as having input, processing, and output components. A spreadsheet row can't represent a hierarchical decomposition. -- **No machine consumption.** AI agents can't traverse a spreadsheet to understand model structure. -- **No validation.** There's no way to run compliance checks against a spreadsheet — someone eyeballs it. - -Commercial tools exist — hosted platforms with UIs, dashboards, and workflow engines. They're expensive, proprietary, not developer-friendly, and create vendor lock-in. None of them provide an open standard that the industry can build on. - -### What's different now - -AI agents are starting to do governance work — generating validation reports, checking compliance, assembling documentation. These tools need machine-readable governance data as input, not spreadsheets and PDFs. model-ledger provides that structured data layer. - ---- - -## What model-ledger Provides - -The library is a formal inventory that tracks four first-class entities: - -### Models - -A model is a versioned, hierarchical structure. Each version contains a component tree with three top-level branches — Inputs, Processing, and Outputs — per SR 11-7's three-component definition. This isn't just metadata; it's a structural decomposition that agents can traverse and validators can assess component by component. - -``` -Fraud Detection Model v2.0.0 -├── Inputs/ -│ ├── credit_features [FeatureSet, 150 features from Feature Store] -│ ├── behavioral_signals [FeatureSet, from Signal Pipeline] -│ ├── active_customers [Dataset, SQL query] -│ └── stationarity_assumption [Assumption, "risk patterns stable over 180 days"] -├── Processing/ -│ ├── fillna_imputation [Preprocessing, fillna_value=0] -│ ├── shap_feature_selection [FeatureSelection, 2-stage] -│ └── xgboost_classifier [Algorithm, XGBClassifier, 200 features] -└── Outputs/ - ├── risk_score [ProbabilityScore, 0-1] - ├── batch_score_table [Dataset, analytics.scoring.batch_results] - └── production_deployment [Deployment, daily batch via Orchestrator] -``` - -Each model also carries ownership, risk tier, intended purpose, regulatory jurisdiction, vendor information, and lifecycle status — all typed, all validated. - -### Governance Documents - -Linked evidence — model specifications, validation reports, conceptual soundness documents, approval records. Referenced by URI, not copied, so they stay current. - -### Observations - -Validation findings from any source: human reviewers, AI agents, automated testing tools, or manual entry. Each observation has a source tag identifying who generated it, and a full lifecycle: - -> **Created** → **Triaged** (kept / removed / modified, with reason and rationale) → **Issued** (published in a final validation report) or **Removed** (preserved in history) - -Observations can be grouped into validation runs. Multiple runs can exist for a single model version — the full history is preserved, but only one run is marked `final`, and only `issued` observations appear in the published report. - -### Feedback - -Structured records of what happened to each observation and why. Every triage decision — keep, remove, modify — is captured with a reason code (from a taxonomy like `refuted_by_code`, `justified_by_design`, `wrong_scope`, `consolidated`) and a free-text rationale. This accumulated feedback is queryable: any tool can check "have similar observations been removed before?" before generating new ones. - -### Validation Engine - -Executable compliance profiles — starting with SR 11-7 — that check models against regulatory requirements and return structured results with severity levels and remediation suggestions. Profiles are pluggable; adding `eu-ai-act` or `nist-ai-rmf` means implementing a new profile class, not changing the engine. - -### Export - -Audit packs (examiner-ready bundles), gap reports (missing fields with severity and remediation hints), and agent-consumable configs (structured inputs that AI validation tools can consume directly). - ---- - -## The Feedback Loop - -The most common question about AI-assisted governance: "How does it get better over time?" - -Today, validation observations get triaged in spreadsheets. A reviewer removes an observation because an AI agent cross-contaminated findings between two models, or because it flagged an intentional design choice. That correction is lost. The next validation cycle makes the same mistake. - -model-ledger captures these corrections as structured data — not as spreadsheet edits that nobody will ever read again. Over validation cycles and across models, this feedback accumulates into a dataset of governance judgment: what was flagged, what survived triage, what was removed and why. - -This is valuable for three reasons: - -1. **Agent improvement.** Any validation tool — AI or otherwise — can query the feedback corpus before generating observations. "Have observations like this been removed for `justified_by_design` on similar models?" This is pure computation over accumulated data, not new rules. -2. **Process visibility.** Leadership can see acceptance rates by observation type, model, and pillar. If the same removal reasons keep recurring, the tooling isn't learning. -3. **Regulatory defensibility.** The full triage history — including what was removed and why — is structured, immutable, and auditable. Examiners can see that the process is rigorous even when observations are removed. - -The core schema (I/P/O tree, regulatory fields, structural invariants) is fixed — the auditable floor that regulators expect. The feedback layer improves governance quality with each cycle. - ---- - -## Architecture - -### Core library (`model-ledger`, PyPI) - -The schema, SDK, validation engine, storage backends, feedback system, and export layer. Apache-2.0 licensed. - -### Adapters (organization-specific) - -model-ledger's `InventoryBackend` protocol and adapter pattern are designed so that any organization can write adapters to read from their existing systems of record and normalize data into model-ledger's schema. The core library never depends on any specific external system. - -### Schema Extension Points - -The core schema is designed for stability but not rigidity. An `extra_metadata` field on all major objects (Model, ModelVersion, ComponentNode, Observation) allows any tool to park discovered patterns. If a field consistently appears in `extra_metadata` across many models — meaning agents or users keep finding it useful — it can be promoted to a first-class field in a future schema version. Agents discover what's useful; humans decide when to formalize. - ---- - -## Relationship with Existing Tools - -model-ledger is not a replacement for commercial platforms — it's a different layer. - -**Commercial platforms** offer hosted model governance with dashboards, workflow engines, and compliance reporting. They are proprietary and expensive. model-ledger is not a hosted platform — it's a library. Organizations that need a UI can build one on top of model-ledger's schema and SDK. The value is in the open standard, not the hosting. - -**Existing inventory systems** can serve as data sources. model-ledger's adapter pattern lets you ingest from your current tools, adding the structural decomposition, validation engine, observation tracking, and agent-consumable exports they were not designed for. - -**AI validation agents** produce observations that model-ledger captures with full lifecycle tracking. model-ledger provides the structured model context these agents consume as input. - ---- - -## Roadmap - -| Phase | Scope | Timeline | -|-------|-------|----------| -| v0.1 | Core schema, SDK, SR 11-7 profile, storage backends, observation lifecycle, feedback corpus | Built (95 tests passing) | -| v0.2 | Export layer, CLI tooling, adapter examples | Q2 2026 | -| v0.3 | JSON-LD export, additional compliance profiles (EU AI Act, NIST AI RMF) | Q3 2026 | -| v0.4 | CycloneDX MBOM export, contributor ecosystem | Q4 2026 | diff --git a/pyproject.toml b/pyproject.toml index 37f15ea..f2cf490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "model-ledger" -version = "0.5.0" +version = "0.6.0" description = "Developer-first model inventory and governance framework for SR 11-7, EU AI Act, and NIST AI RMF compliance" readme = "README.md" requires-python = ">=3.10" @@ -41,7 +41,9 @@ dev = [ "mypy>=1.0", ] snowflake = ["snowflake-connector-python>=3.0", "pandas>=2.0"] +mcp = ["mcp[cli]>=1.7.1,<2"] rest = ["httpx>=0.28"] +rest-api = ["fastapi>=0.115", "uvicorn>=0.30"] github = ["httpx>=0.28"] all = ["snowflake-connector-python>=3.0", "pandas>=2.0", "httpx>=0.28"] @@ -85,7 +87,7 @@ select = [ ] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["S101"] +"tests/**" = ["S101", "E402"] [tool.ruff.format] quote-style = "double" diff --git a/src/model_ledger/__init__.py b/src/model_ledger/__init__.py index 27996fa..6f0a7aa 100644 --- a/src/model_ledger/__init__.py +++ b/src/model_ledger/__init__.py @@ -32,6 +32,17 @@ from model_ledger.sdk.ledger import Ledger from model_ledger.connectors import sql_connector, rest_connector, github_connector +# v0.6.0 — agent protocol tools +from model_ledger.tools import ( + changelog, discover, investigate, query, record, trace, +) +from model_ledger.tools.schemas import ( + ChangelogInput, ChangelogOutput, DiscoverInput, DiscoverOutput, + InvestigateInput, InvestigateOutput, ModelSummary, + QueryInput, QueryOutput, RecordInput, RecordOutput, + TraceInput, TraceOutput, +) + if TYPE_CHECKING: from model_ledger.introspect.models import IntrospectionResult from model_ledger.introspect.protocol import Introspector @@ -73,9 +84,15 @@ # Introspection "introspect", "register_introspector", + # v0.6.0 — agent tools + "changelog", "discover", "investigate", "query", "record", "trace", + "ChangelogInput", "ChangelogOutput", "DiscoverInput", "DiscoverOutput", + "InvestigateInput", "InvestigateOutput", "ModelSummary", + "QueryInput", "QueryOutput", "RecordInput", "RecordOutput", + "TraceInput", "TraceOutput", ] -__version__ = "0.5.0" +__version__ = "0.6.0" def introspect(obj: Any, *, introspector: str | None = None) -> IntrospectionResult: diff --git a/src/model_ledger/backends/json_files.py b/src/model_ledger/backends/json_files.py new file mode 100644 index 0000000..3746306 --- /dev/null +++ b/src/model_ledger/backends/json_files.py @@ -0,0 +1,167 @@ +"""JSON file LedgerBackend — human-readable, git-friendly persistence. + +Stores each entity as an indented JSON file in a directory tree: + + root/ + ├── models/ # one file per ModelRef (filename: sanitized name.json) + ├── snapshots/ # one file per Snapshot (filename: snapshot_hash.json) + └── tags/ # organized by model_hash + └── {model_hash}/ + └── {tag_name}.json + + >>> from model_ledger.backends.json_files import JsonFileLedgerBackend + >>> backend = JsonFileLedgerBackend("./ledger-data") +""" + +from __future__ import annotations + +import re +from datetime import datetime +from pathlib import Path + +from model_ledger.core.ledger_models import ModelRef, Snapshot, Tag + +_SANITIZE_RE = re.compile(r"[/\\\s]") + + +def _sanitize(name: str) -> str: + """Replace /, \\, and whitespace with underscores for safe filenames.""" + return _SANITIZE_RE.sub("_", name) + + +class JsonFileLedgerBackend: + """LedgerBackend backed by a directory of JSON files.""" + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._models_dir = self._root / "models" + self._snapshots_dir = self._root / "snapshots" + self._tags_dir = self._root / "tags" + for d in (self._models_dir, self._snapshots_dir, self._tags_dir): + d.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------ + # Models + # ------------------------------------------------------------------ + + def _model_path(self, name: str) -> Path: + return self._models_dir / f"{_sanitize(name)}.json" + + def save_model(self, model: ModelRef) -> None: + path = self._model_path(model.name) + path.write_text(model.model_dump_json(indent=2)) + + def get_model(self, model_hash: str) -> ModelRef | None: + for path in self._models_dir.iterdir(): + if path.suffix != ".json": + continue + m = ModelRef.model_validate_json(path.read_text()) + if m.model_hash == model_hash: + return m + return None + + def get_model_by_name(self, name: str) -> ModelRef | None: + path = self._model_path(name) + if path.exists(): + return ModelRef.model_validate_json(path.read_text()) + return None + + def list_models(self, **filters: str) -> list[ModelRef]: + results: list[ModelRef] = [] + for path in sorted(self._models_dir.iterdir()): + if path.suffix != ".json": + continue + m = ModelRef.model_validate_json(path.read_text()) + results.append(m) + for key, value in filters.items(): + results = [m for m in results if getattr(m, key, None) == value] + return results + + def update_model(self, model: ModelRef) -> None: + self.save_model(model) + + # ------------------------------------------------------------------ + # Snapshots + # ------------------------------------------------------------------ + + def _snapshot_path(self, snapshot_hash: str) -> Path: + return self._snapshots_dir / f"{snapshot_hash}.json" + + def append_snapshot(self, snapshot: Snapshot) -> None: + path = self._snapshot_path(snapshot.snapshot_hash) + path.write_text(snapshot.model_dump_json(indent=2)) + + def get_snapshot(self, snapshot_hash: str) -> Snapshot | None: + path = self._snapshot_path(snapshot_hash) + if path.exists(): + return Snapshot.model_validate_json(path.read_text()) + return None + + def _iter_snapshots(self, model_hash: str) -> list[Snapshot]: + results: list[Snapshot] = [] + for path in self._snapshots_dir.iterdir(): + if path.suffix != ".json": + continue + s = Snapshot.model_validate_json(path.read_text()) + if s.model_hash == model_hash: + results.append(s) + return results + + def list_snapshots(self, model_hash: str, **filters: str) -> list[Snapshot]: + results = self._iter_snapshots(model_hash) + for key, value in filters.items(): + results = [s for s in results if getattr(s, key, None) == value] + return sorted(results, key=lambda s: s.timestamp, reverse=True) + + def latest_snapshot(self, model_hash: str, tag: str | None = None) -> Snapshot | None: + if tag: + t = self.get_tag(model_hash, tag) + if t: + return self.get_snapshot(t.snapshot_hash) + return None + snaps = self.list_snapshots(model_hash) + return snaps[0] if snaps else None + + def list_snapshots_before( + self, + model_hash: str, + before: datetime, + event_type: str | None = None, + ) -> list[Snapshot]: + results = [s for s in self._iter_snapshots(model_hash) if s.timestamp < before] + if event_type is not None: + results = [s for s in results if s.event_type == event_type] + return results + + # ------------------------------------------------------------------ + # Tags + # ------------------------------------------------------------------ + + def _tag_dir(self, model_hash: str) -> Path: + d = self._tags_dir / model_hash + d.mkdir(parents=True, exist_ok=True) + return d + + def _tag_path(self, model_hash: str, name: str) -> Path: + return self._tag_dir(model_hash) / f"{_sanitize(name)}.json" + + def set_tag(self, tag: Tag) -> None: + path = self._tag_path(tag.model_hash, tag.name) + path.write_text(tag.model_dump_json(indent=2)) + + def get_tag(self, model_hash: str, name: str) -> Tag | None: + path = self._tag_path(model_hash, name) + if path.exists(): + return Tag.model_validate_json(path.read_text()) + return None + + def list_tags(self, model_hash: str) -> list[Tag]: + tag_dir = self._tags_dir / model_hash + if not tag_dir.exists(): + return [] + results: list[Tag] = [] + for path in sorted(tag_dir.iterdir()): + if path.suffix != ".json": + continue + results.append(Tag.model_validate_json(path.read_text())) + return results diff --git a/src/model_ledger/cli/app.py b/src/model_ledger/cli/app.py index ae7e0c6..5e2a855 100644 --- a/src/model_ledger/cli/app.py +++ b/src/model_ledger/cli/app.py @@ -20,6 +20,24 @@ console = Console() +def _resolve_backend(backend: str, path: str | None): + """Resolve a backend name to a backend instance.""" + if backend == "sqlite" and path: + from model_ledger.backends.sqlite_ledger import SQLiteLedgerBackend + + return SQLiteLedgerBackend(path) + if backend == "json": + from model_ledger.backends.json_files import JsonFileLedgerBackend + + json_path = path or os.path.expanduser("~/.model-ledger") + return JsonFileLedgerBackend(json_path) + if backend == "memory": + from model_ledger.backends.ledger_memory import InMemoryLedgerBackend + + return InMemoryLedgerBackend() + return None + + def _default_db() -> str: return os.environ.get("MODEL_LEDGER_DB", "inventory.db") @@ -352,3 +370,39 @@ def introspect_cmd( ) except ModelNotFoundError: console.print(f"[yellow]Warning:[/yellow] Model '{model_name}' not found in inventory.") + + +@app.command(name="mcp") +def mcp_cmd( + backend: str = typer.Option("memory", help="Backend: memory, sqlite, json"), + path: str = typer.Option(None, help="Path for sqlite/json backend"), + demo: bool = typer.Option(False, help="Load demo inventory"), +) -> None: + """Start the MCP server for AI agent integration.""" + try: + from model_ledger.mcp.server import create_server + except ImportError: + typer.echo("MCP not installed. Run: pip install model-ledger[mcp]", err=True) + raise typer.Exit(1) + backend_obj = _resolve_backend(backend, path) + server = create_server(backend=backend_obj, demo=demo) + server.run() + + +@app.command(name="serve") +def serve_cmd( + backend: str = typer.Option("memory", help="Backend: memory, sqlite, json"), + path: str = typer.Option(None, help="Path for sqlite/json backend"), + demo: bool = typer.Option(False, help="Load demo inventory"), + port: int = typer.Option(8000, help="Port to serve on"), +) -> None: + """Start the REST API server.""" + try: + from model_ledger.rest.app import create_app + import uvicorn + except ImportError: + typer.echo("REST API not installed. Run: pip install model-ledger[rest-api]", err=True) + raise typer.Exit(1) + backend_obj = _resolve_backend(backend, path) + rest_app = create_app(backend=backend_obj, demo=demo) + uvicorn.run(rest_app, host="0.0.0.0", port=port) diff --git a/src/model_ledger/datasets/__init__.py b/src/model_ledger/datasets/__init__.py index 738db5d..44001f4 100644 --- a/src/model_ledger/datasets/__init__.py +++ b/src/model_ledger/datasets/__init__.py @@ -8,6 +8,7 @@ >>> print(len(inv.list_models())) # 3 models ready to use """ +from model_ledger.datasets.demo import load_demo_inventory from model_ledger.datasets.samples import ( load_sample_inventory, make_credit_model, @@ -16,6 +17,7 @@ ) __all__ = [ + "load_demo_inventory", "load_sample_inventory", "make_credit_model", "make_fraud_detector", diff --git a/src/model_ledger/datasets/demo.py b/src/model_ledger/datasets/demo.py new file mode 100644 index 0000000..fee0e6a --- /dev/null +++ b/src/model_ledger/datasets/demo.py @@ -0,0 +1,134 @@ +"""Demo inventory with sample models, events, and dependency connections. + +Provides a realistic multi-platform inventory for first-time users, +tutorials, and integration testing. Includes lifecycle events (retrained, +deployed, metadata_updated) and data-flow dependencies between nodes. + + >>> from model_ledger import Ledger + >>> from model_ledger.datasets.demo import load_demo_inventory + >>> ledger = Ledger() + >>> load_demo_inventory(ledger) + >>> print(len(ledger.list())) # 7 models +""" + +from __future__ import annotations + +from model_ledger.graph.models import DataNode +from model_ledger.sdk.ledger import Ledger + + +def load_demo_inventory(ledger: Ledger) -> None: + """Populate *ledger* with 7 sample models across different platforms. + + Creates: + - 2 data nodes (feature store, ETL pipeline) + - 3 ML models (fraud scoring, churn predictor, credit risk) + - 1 alerting engine + - 1 rules engine + + Dependency edges are inferred via matching input/output ports, and + lifecycle events are recorded against selected models. + """ + ledger.add( + [ + DataNode( + "customer_features", + platform="database", + outputs=["customer_data"], + metadata={"owner": "data-team", "model_type": "data_source", "description": "Core customer feature store"}, + ), + DataNode( + "transaction_pipeline", + platform="etl", + inputs=["raw_transactions"], + outputs=["processed_transactions", "customer_data"], + metadata={"owner": "data-team", "model_type": "etl_pipeline", "schedule": "hourly"}, + ), + DataNode( + "fraud_scoring", + platform="ml", + inputs=["customer_data", "processed_transactions"], + outputs=["fraud_scores"], + metadata={ + "owner": "risk-team", + "model_type": "ml_model", + "algorithm": "gradient_boosted_trees", + }, + ), + DataNode( + "churn_predictor", + platform="ml", + inputs=["customer_data"], + outputs=["churn_probabilities"], + metadata={"owner": "growth-team", "model_type": "ml_model", "algorithm": "logistic_regression"}, + ), + DataNode( + "alert_engine", + platform="alerting", + inputs=["fraud_scores"], + metadata={"owner": "ops-team", "model_type": "alerting"}, + ), + DataNode( + "credit_risk", + platform="ml", + inputs=["customer_data", "processed_transactions"], + outputs=["credit_scores"], + metadata={"owner": "risk-team", "model_type": "ml_model", "algorithm": "neural_network"}, + ), + DataNode( + "pricing_rules", + platform="rules", + inputs=["credit_scores", "churn_probabilities"], + metadata={"owner": "pricing-team", "model_type": "heuristic"}, + ), + ] + ) + ledger.connect() + + # -- Lifecycle events -- + + fraud = ledger.get("fraud_scoring") + ledger.record( + fraud, + event="retrained", + payload={ + "accuracy": 0.94, + "features_added": ["device_fingerprint", "velocity_24h"], + "training_samples": 1_200_000, + }, + actor="ml-pipeline", + ) + ledger.record( + fraud, + event="deployed", + payload={"environment": "production", "version": "v3.2"}, + actor="ci-pipeline", + ) + ledger.record( + fraud, + event="metadata_updated", + payload={ + "model_card_url": "https://docs.example.com/fraud-scoring", + "training_data": "Customer transactions, 2024-2026", + }, + actor="data-scientist", + ) + + churn = ledger.get("churn_predictor") + ledger.record( + churn, + event="retrained", + payload={"accuracy": 0.87, "auc": 0.91}, + actor="ml-pipeline", + ) + + credit = ledger.get("credit_risk") + ledger.record( + credit, + event="metadata_updated", + payload={ + "regulatory_framework": "Basel III", + "last_validated": "2026-03-15", + }, + actor="compliance-team", + ) diff --git a/src/model_ledger/mcp/__init__.py b/src/model_ledger/mcp/__init__.py new file mode 100644 index 0000000..74a3788 --- /dev/null +++ b/src/model_ledger/mcp/__init__.py @@ -0,0 +1,21 @@ +"""MCP (Model Context Protocol) server for model-ledger. + +Exposes model-ledger tools and resources via the FastMCP framework, +allowing AI assistants to interact with the model inventory. + + >>> from model_ledger.mcp.server import create_server + >>> server = create_server() +""" + +from __future__ import annotations + +__all__ = ["create_server", "main"] + + +def __getattr__(name: str): # noqa: ANN001 + """Lazy import to avoid hard dep on mcp at package level.""" + if name in __all__: + from model_ledger.mcp.server import create_server, main + + return {"create_server": create_server, "main": main}[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/model_ledger/mcp/server.py b/src/model_ledger/mcp/server.py new file mode 100644 index 0000000..5e95168 --- /dev/null +++ b/src/model_ledger/mcp/server.py @@ -0,0 +1,321 @@ +"""FastMCP server wrapping model-ledger's 6 tools and 3 resources. + +Usage: + >>> from model_ledger.mcp.server import create_server + >>> server = create_server() + >>> server.run() # stdio transport + +CLI entry point: + $ model-ledger mcp --demo +""" + +from __future__ import annotations + +import json +import sys +from typing import Any + +from mcp.server.fastmcp import FastMCP + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.backends.ledger_protocol import LedgerBackend +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools import schemas +from model_ledger.tools.changelog import changelog as _changelog +from model_ledger.tools.discover import discover as _discover +from model_ledger.tools.investigate import investigate as _investigate +from model_ledger.tools.query import query as _query +from model_ledger.tools.record import record as _record +from model_ledger.tools.trace import trace as _trace + + +def create_server( + backend: LedgerBackend | None = None, + demo: bool = False, +) -> FastMCP: + """Create a FastMCP server with model-ledger tools and resources. + + Args: + backend: Optional storage backend. Defaults to InMemoryLedgerBackend. + demo: If True, pre-populate with sample data (requires datasets.demo module). + + Returns: + A configured FastMCP server ready to ``run()``. + """ + if backend is None: + backend = InMemoryLedgerBackend() + + ledger = Ledger(backend=backend) + + if demo: + try: + from model_ledger.datasets.demo import load_demo_inventory + + load_demo_inventory(ledger) + except ImportError: + pass # demo module not yet available (Task 11) + + mcp = FastMCP("model-ledger") + + # ------------------------------------------------------------------ + # Tools (6) — thin wrappers that convert primitives -> Pydantic -> tool fn + # ------------------------------------------------------------------ + + @mcp.tool() + def discover( + source_type: str, + models: list[dict] | None = None, + connector_name: str | None = None, + connector_config: dict | None = None, + file_path: str | None = None, + auto_connect: bool = True, + ) -> dict: + """Import models from external sources into the ledger. + + Supports inline model dicts, JSON files, or named connectors. + Returns counts of models added/skipped and links created. + """ + inp = schemas.DiscoverInput( + source_type=source_type, # type: ignore[arg-type] + models=models, + connector_name=connector_name, + connector_config=connector_config, + file_path=file_path, + auto_connect=auto_connect, + ) + return _discover(inp, ledger).model_dump(mode="json") + + @mcp.tool() + def record( + model_name: str, + event: str, + payload: dict | None = None, + actor: str = "user", + owner: str | None = None, + model_type: str | None = None, + purpose: str | None = None, + ) -> dict: + """Register a new model or record an event on an existing model. + + Use event='registered' to create a new model. Any other event + value appends to an existing model's history. + """ + inp = schemas.RecordInput( + model_name=model_name, + event=event, + payload=payload or {}, + actor=actor, + owner=owner, + model_type=model_type, + purpose=purpose, + ) + return _record(inp, ledger).model_dump(mode="json") + + @mcp.tool() + def investigate( + model_name: str, + detail: str = "summary", + as_of: str | None = None, + ) -> dict: + """Deep-dive into a single model — history, metadata, dependencies. + + Returns owner, type, status, recent events, upstream/downstream + dependencies, and group memberships. + """ + from datetime import datetime, timezone + + as_of_dt = None + if as_of is not None: + as_of_dt = datetime.fromisoformat(as_of) + if as_of_dt.tzinfo is None: + as_of_dt = as_of_dt.replace(tzinfo=timezone.utc) + + inp = schemas.InvestigateInput( + model_name=model_name, + detail=detail, # type: ignore[arg-type] + as_of=as_of_dt, + ) + return _investigate(inp, ledger).model_dump(mode="json") + + @mcp.tool() + def query( + text: str | None = None, + platform: str | None = None, + model_type: str | None = None, + owner: str | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> dict: + """Search and filter the model inventory. + + Supports text search (fuzzy name/purpose match) and structured + filters (platform, model_type, owner, status) with pagination. + """ + inp = schemas.QueryInput( + text=text, + platform=platform, + model_type=model_type, + owner=owner, + status=status, + limit=limit, + offset=offset, + ) + return _query(inp, ledger).model_dump(mode="json") + + @mcp.tool() + def trace( + name: str, + direction: str = "both", + depth: int | None = None, + ) -> dict: + """Traverse a model's dependency graph. + + Walks upstream (models this one depends on) and/or downstream + (models that depend on this one). Returns dependency nodes with + depth and relationship metadata. + """ + inp = schemas.TraceInput( + name=name, + direction=direction, # type: ignore[arg-type] + depth=depth, + ) + return _trace(inp, ledger).model_dump(mode="json") + + @mcp.tool() + def changelog( + since: str | None = None, + until: str | None = None, + model_name: str | None = None, + event_type: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> dict: + """View cross-model event history with time range filtering. + + Returns events sorted newest-first with pagination. Defaults + to the last 7 days if no time range is specified. + """ + from datetime import datetime, timezone + + since_dt = None + if since is not None: + since_dt = datetime.fromisoformat(since) + if since_dt.tzinfo is None: + since_dt = since_dt.replace(tzinfo=timezone.utc) + + until_dt = None + if until is not None: + until_dt = datetime.fromisoformat(until) + if until_dt.tzinfo is None: + until_dt = until_dt.replace(tzinfo=timezone.utc) + + inp = schemas.ChangelogInput( + since=since_dt, + until=until_dt, + model_name=model_name, + event_type=event_type, + limit=limit, + offset=offset, + ) + return _changelog(inp, ledger).model_dump(mode="json") + + # ------------------------------------------------------------------ + # Resources (3) + # ------------------------------------------------------------------ + + @mcp.resource("ledger://overview") + def overview() -> str: + """Inventory statistics — model count, event count, type breakdown.""" + models = ledger.list() + type_counts: dict[str, int] = {} + total_events = 0 + for m in models: + mt = m.model_type or "unknown" + type_counts[mt] = type_counts.get(mt, 0) + 1 + total_events += len(ledger.history(m)) + + data: dict[str, Any] = { + "total_models": len(models), + "total_events": total_events, + "model_types": type_counts, + } + return json.dumps(data, indent=2) + + @mcp.resource("ledger://schema") + def schema_resource() -> str: + """JSON Schema definitions for all tool I/O models.""" + all_schemas: dict[str, Any] = {} + for cls in [ + schemas.DiscoverInput, + schemas.DiscoverOutput, + schemas.RecordInput, + schemas.RecordOutput, + schemas.QueryInput, + schemas.QueryOutput, + schemas.InvestigateInput, + schemas.InvestigateOutput, + schemas.TraceInput, + schemas.TraceOutput, + schemas.ChangelogInput, + schemas.ChangelogOutput, + ]: + all_schemas[cls.__name__] = cls.model_json_schema() + return json.dumps(all_schemas, indent=2) + + @mcp.resource("ledger://backends") + def backends_resource() -> str: + """Active backend configuration.""" + backend_type = type(backend).__name__ + data: dict[str, Any] = { + "backend": backend_type, + "demo": demo, + } + return json.dumps(data, indent=2) + + return mcp + + +def main() -> None: + """Entry point for ``model-ledger mcp`` command. + + Parses --backend, --path, and --demo arguments, creates the server, + and runs it on stdio transport. + """ + import argparse + + parser = argparse.ArgumentParser(description="model-ledger MCP server") + parser.add_argument( + "--backend", + choices=["memory", "sqlite", "json"], + default="memory", + help="Storage backend (default: memory)", + ) + parser.add_argument( + "--path", + default=None, + help="Path for sqlite or json backend", + ) + parser.add_argument( + "--demo", + action="store_true", + help="Pre-populate with sample data", + ) + + args = parser.parse_args(sys.argv[1:]) + + backend: LedgerBackend | None = None + if args.backend == "sqlite": + from model_ledger.backends.sqlite_ledger import SQLiteLedgerBackend + + path = args.path or "ledger.db" + backend = SQLiteLedgerBackend(path) + elif args.backend == "json": + from model_ledger.backends.json_files import JsonFileLedgerBackend + + path = args.path or "./ledger-data" + backend = JsonFileLedgerBackend(path) + # else: memory — use None (create_server default) + + server = create_server(backend=backend, demo=args.demo) + server.run() diff --git a/src/model_ledger/rest/__init__.py b/src/model_ledger/rest/__init__.py new file mode 100644 index 0000000..c44f922 --- /dev/null +++ b/src/model_ledger/rest/__init__.py @@ -0,0 +1,12 @@ +"""FastAPI REST API for model-ledger. + +Wraps the 6 agent protocol tools as HTTP endpoints with +auto-generated OpenAPI docs at ``/docs``. + + >>> from model_ledger.rest.app import create_app + >>> app = create_app() +""" + +from model_ledger.rest.app import create_app + +__all__ = ["create_app"] diff --git a/src/model_ledger/rest/app.py b/src/model_ledger/rest/app.py new file mode 100644 index 0000000..ce610f0 --- /dev/null +++ b/src/model_ledger/rest/app.py @@ -0,0 +1,192 @@ +"""FastAPI application wrapping the 6 model-ledger tool functions. + +Usage:: + + from model_ledger.rest.app import create_app + + app = create_app() # in-memory backend + app = create_app(demo=True) # pre-loaded demo data (if available) + + # With a custom backend + from model_ledger.backends.sqlite_ledger import SQLiteLedgerBackend + app = create_app(backend=SQLiteLedgerBackend("inventory.db")) + +Run with uvicorn:: + + uvicorn model_ledger.rest.app:app --reload +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from fastapi import FastAPI, HTTPException + +from model_ledger.backends.ledger_protocol import LedgerBackend +from model_ledger.core.exceptions import ModelNotFoundError +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.changelog import changelog as changelog_fn +from model_ledger.tools.discover import discover as discover_fn +from model_ledger.tools.investigate import investigate as investigate_fn +from model_ledger.tools.query import query as query_fn +from model_ledger.tools.record import record as record_fn +from model_ledger.tools.schemas import ( + ChangelogInput, + ChangelogOutput, + DiscoverInput, + DiscoverOutput, + InvestigateOutput, + QueryInput, + QueryOutput, + RecordInput, + RecordOutput, + TraceInput, + TraceOutput, +) +from model_ledger.tools.trace import trace as trace_fn + + +def create_app( + backend: LedgerBackend | None = None, + demo: bool = False, +) -> FastAPI: + """Create a FastAPI app wrapping the model-ledger tool functions. + + Args: + backend: Optional ledger backend. Defaults to in-memory. + demo: If True, pre-loads demo inventory data (requires Task 11). + + Returns: + A configured FastAPI application. + """ + ledger = Ledger(backend=backend) + + if demo: + try: + from model_ledger.datasets.demo import load_demo_inventory # type: ignore[import-not-found] # noqa: I001 + + load_demo_inventory(ledger) + except ImportError: + pass # demo dataset not yet available (Task 11) + + app = FastAPI( + title="Model Ledger API", + description="REST API for model inventory and governance", + version="0.5.0", + ) + + # ------------------------------------------------------------------ + # POST /record + # ------------------------------------------------------------------ + @app.post("/record", response_model=RecordOutput) + def record_endpoint(body: RecordInput) -> RecordOutput: + try: + return record_fn(body, ledger) + except ModelNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + # ------------------------------------------------------------------ + # POST /discover + # ------------------------------------------------------------------ + @app.post("/discover", response_model=DiscoverOutput) + def discover_endpoint(body: DiscoverInput) -> DiscoverOutput: + return discover_fn(body, ledger) + + # ------------------------------------------------------------------ + # GET /investigate/{model_name} + # ------------------------------------------------------------------ + @app.get("/investigate/{model_name}", response_model=InvestigateOutput) + def investigate_endpoint(model_name: str) -> InvestigateOutput: + from model_ledger.tools.schemas import InvestigateInput + + try: + return investigate_fn(InvestigateInput(model_name=model_name), ledger) + except ModelNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + # ------------------------------------------------------------------ + # GET /query + # ------------------------------------------------------------------ + @app.get("/query", response_model=QueryOutput) + def query_endpoint( + text: str | None = None, + platform: str | None = None, + model_type: str | None = None, + owner: str | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> QueryOutput: + inp = QueryInput( + text=text, + platform=platform, + model_type=model_type, + owner=owner, + status=status, + limit=limit, + offset=offset, + ) + return query_fn(inp, ledger) + + # ------------------------------------------------------------------ + # GET /trace/{name} + # ------------------------------------------------------------------ + @app.get("/trace/{name}", response_model=TraceOutput) + def trace_endpoint( + name: str, + direction: str = "both", + depth: int | None = None, + ) -> TraceOutput: + inp = TraceInput(name=name, direction=direction, depth=depth) # type: ignore[arg-type] + try: + return trace_fn(inp, ledger) + except ModelNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + # ------------------------------------------------------------------ + # GET /changelog + # ------------------------------------------------------------------ + @app.get("/changelog", response_model=ChangelogOutput) + def changelog_endpoint( + since: str | None = None, + until: str | None = None, + model_name: str | None = None, + event_type: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> ChangelogOutput: + since_dt = datetime.fromisoformat(since) if since else None + until_dt = datetime.fromisoformat(until) if until else None + inp = ChangelogInput( + since=since_dt, + until=until_dt, + model_name=model_name, + event_type=event_type, + limit=limit, + offset=offset, + ) + try: + return changelog_fn(inp, ledger) + except ModelNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + # ------------------------------------------------------------------ + # GET /overview + # ------------------------------------------------------------------ + @app.get("/overview") + def overview_endpoint() -> dict[str, Any]: + models = ledger.list() + total_events = 0 + for model in models: + total_events += len(ledger.history(model)) + return { + "total_models": len(models), + "total_events": total_events, + } + + return app + + +# Module-level app for `uvicorn model_ledger.rest.app:app` +app = create_app() diff --git a/src/model_ledger/sdk/ledger.py b/src/model_ledger/sdk/ledger.py index e3eca48..0779b08 100644 --- a/src/model_ledger/sdk/ledger.py +++ b/src/model_ledger/sdk/ledger.py @@ -297,7 +297,7 @@ def add(self, nodes): ref = self.register( name=node.name, owner=node.metadata.get("owner") or "unknown", - model_type=node.metadata.get("node_type") or "unknown", + model_type=node.metadata.get("model_type") or node.metadata.get("node_type") or node.metadata.get("type") or "unknown", tier=node.metadata.get("tier") or "unclassified", purpose=node.metadata.get("purpose") or "", model_origin=node.metadata.get("model_origin") or "internal", diff --git a/src/model_ledger/tools/__init__.py b/src/model_ledger/tools/__init__.py new file mode 100644 index 0000000..d8b400c --- /dev/null +++ b/src/model_ledger/tools/__init__.py @@ -0,0 +1,62 @@ +"""Agent protocol tools — Pydantic I/O schemas and tool functions. + +Re-exports all schemas and all tool functions. +""" + +from model_ledger.tools.changelog import changelog +from model_ledger.tools.discover import discover +from model_ledger.tools.investigate import investigate +from model_ledger.tools.query import query +from model_ledger.tools.record import record +from model_ledger.tools.schemas import ( + ChangelogInput, + ChangelogOutput, + DependencyNode, + DiscoverInput, + DiscoverOutput, + EventDetail, + EventSummary, + InvestigateInput, + InvestigateOutput, + ModelSummary, + QueryInput, + QueryOutput, + RecordInput, + RecordOutput, + TraceInput, + TraceOutput, +) +from model_ledger.tools.trace import trace + +__all__ = [ + # Shared types + "ModelSummary", + "EventSummary", + "EventDetail", + "DependencyNode", + # Tool functions + "changelog", + "discover", + "investigate", + "query", + "record", + "trace", + # record + "RecordInput", + "RecordOutput", + # query + "QueryInput", + "QueryOutput", + # investigate + "InvestigateInput", + "InvestigateOutput", + # trace + "TraceInput", + "TraceOutput", + # changelog + "ChangelogInput", + "ChangelogOutput", + # discover + "DiscoverInput", + "DiscoverOutput", +] diff --git a/src/model_ledger/tools/changelog.py b/src/model_ledger/tools/changelog.py new file mode 100644 index 0000000..0203251 --- /dev/null +++ b/src/model_ledger/tools/changelog.py @@ -0,0 +1,108 @@ +"""Changelog tool — cross-model event timeline with time range filtering.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from model_ledger.core.ledger_models import Snapshot +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import ( + ChangelogInput, + ChangelogOutput, + EventDetail, +) + + +def _ensure_utc(dt: datetime) -> datetime: + """Normalize a datetime to UTC. Treats naive datetimes as UTC.""" + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + +def _build_period(since: datetime | None, until: datetime | None) -> str: + """Build a human-readable period string.""" + if since is not None and until is not None: + return f"{since.strftime('%Y-%m-%d')} to {until.strftime('%Y-%m-%d')}" + if since is not None: + now = datetime.now(timezone.utc) + days = max((now - _ensure_utc(since)).days, 0) + return f"last {days} days" + return "all time" + + +def _snapshot_to_event(snapshot: Snapshot, model_name: str) -> EventDetail: + """Convert a Snapshot into an EventDetail with model association.""" + return EventDetail( + model_name=model_name, + event_type=snapshot.event_type, + timestamp=snapshot.timestamp, + actor=snapshot.actor, + summary=snapshot.payload.get("summary"), + payload=snapshot.payload, + ) + + +def changelog(input: ChangelogInput, ledger: Ledger) -> ChangelogOutput: + """Cross-model event timeline with time range filtering. + + Iterates all models (or a single model if ``input.model_name`` is set), + collects snapshots within the time range, and returns them sorted + newest-first with pagination. + """ + since = input.since + until = input.until + + # Default: if both since and until are None, set since = 7 days ago + if since is None and until is None: + since = datetime.now(timezone.utc) - timedelta(days=7) + + # Normalize to UTC for comparison + if since is not None: + since = _ensure_utc(since) + if until is not None: + until = _ensure_utc(until) + + # Get models to iterate + models = [ledger.get(input.model_name)] if input.model_name is not None else ledger.list() + + # Collect all matching events + all_events: list[EventDetail] = [] + for model in models: + snapshots = ledger.history(model) + for snap in snapshots: + ts = _ensure_utc(snap.timestamp) + + # Filter by time range + if since is not None and ts < since: + continue + if until is not None and ts > until: + continue + + # Filter by event_type + if input.event_type is not None and snap.event_type != input.event_type: + continue + + all_events.append(_snapshot_to_event(snap, model.name)) + + # Sort by timestamp descending (newest first) + _epoch = datetime.min.replace(tzinfo=timezone.utc) + all_events.sort( + key=lambda e: _ensure_utc(e.timestamp) if e.timestamp else _epoch, + reverse=True, + ) + + # Paginate + total = len(all_events) + page = all_events[input.offset : input.offset + input.limit] + has_more = (input.offset + input.limit) < total + + # Build period string + period = _build_period(since, until) + + return ChangelogOutput( + total=total, + events=page, + has_more=has_more, + period=period, + ) diff --git a/src/model_ledger/tools/discover.py b/src/model_ledger/tools/discover.py new file mode 100644 index 0000000..a8fe7f4 --- /dev/null +++ b/src/model_ledger/tools/discover.py @@ -0,0 +1,87 @@ +"""Discover tool — bulk ingestion from connectors, files, or inline data.""" + +from __future__ import annotations + +import json +from typing import Any + +from model_ledger.graph.models import DataNode +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import DiscoverInput, DiscoverOutput, ModelSummary + + +def _dict_to_datanode(d: dict[str, Any]) -> DataNode: + """Convert a raw dict to a DataNode.""" + return DataNode( + name=d["name"], + platform=d.get("platform", ""), + inputs=d.get("inputs", []), + outputs=d.get("outputs", []), + metadata={k: v for k, v in d.items() if k not in ("name", "platform", "inputs", "outputs")}, + ) + + +def discover(input: DiscoverInput, ledger: Ledger) -> DiscoverOutput: + """Import models from external sources into the ledger. + + Supports three source types: + + - **inline**: models passed directly as a list of dicts. + - **file**: models loaded from a JSON file on disk. + - **connector**: not yet supported — raises ``NotImplementedError``. + + When ``auto_connect`` is True and models were added, runs + ``ledger.connect()`` to auto-link dependencies based on matching + input/output ports. + """ + if input.source_type == "connector": + raise NotImplementedError( + "Connector execution via tool not yet supported. Use the Python SDK directly." + ) + + if input.source_type == "file": + if input.file_path is None: + raise ValueError("file_path is required when source_type is 'file'") + with open(input.file_path) as f: + raw_models = json.load(f) + nodes = [_dict_to_datanode(d) for d in raw_models] + + else: # inline + if input.models is None: + raise ValueError("models is required when source_type is 'inline'") + nodes = [_dict_to_datanode(d) for d in input.models] + + # Add nodes to ledger (content-hash dedup) + add_result = ledger.add(nodes) + added = add_result["added"] + skipped = add_result["skipped"] + + # Auto-connect dependencies if requested and models were added + links_created = 0 + if input.auto_connect and added > 0: + connect_result = ledger.connect() + links_created = connect_result["links_created"] + + # Build summaries for added models + summaries: list[ModelSummary] = [] + for node in nodes: + try: + ref = ledger.get(node.name) + summaries.append( + ModelSummary( + name=ref.name, + owner=ref.owner, + model_type=ref.model_type, + platform=node.platform or None, + status=ref.status, + ) + ) + except Exception: + pass + + return DiscoverOutput( + models_added=added, + models_skipped=skipped, + links_created=links_created, + models=summaries, + ) diff --git a/src/model_ledger/tools/investigate.py b/src/model_ledger/tools/investigate.py new file mode 100644 index 0000000..316843e --- /dev/null +++ b/src/model_ledger/tools/investigate.py @@ -0,0 +1,123 @@ +"""Investigate tool — comprehensive deep-dive into a single model.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from model_ledger.core.ledger_models import Snapshot +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import ( + DependencyNode, + EventSummary, + InvestigateInput, + InvestigateOutput, +) + + +def _snapshot_to_event(snapshot: Snapshot) -> EventSummary: + """Convert a Snapshot to a compact EventSummary.""" + return EventSummary( + event_type=snapshot.event_type, + timestamp=snapshot.timestamp, + actor=snapshot.actor, + summary=snapshot.payload.get("summary"), + ) + + +def investigate(input: InvestigateInput, ledger: Ledger) -> InvestigateOutput: + """Deep-dive into a single model — history, metadata, dependencies. + + Retrieves the model identity, merges metadata from all snapshot + payloads (oldest-first so newest wins), builds an event timeline, + and resolves upstream/downstream dependencies from the graph. + + Raises: + ModelNotFoundError: If the model does not exist. + """ + # 1. Get the model — raises ModelNotFoundError if missing + model = ledger.get(input.model_name) + + # 2. Get all snapshots (newest first from ledger.history) + snapshots = ledger.history(model) + + # 3. Filter by as_of if set + if input.as_of is not None: + as_of = input.as_of + # Ensure as_of is timezone-aware for comparison + if as_of.tzinfo is None: + as_of = as_of.replace(tzinfo=timezone.utc) + snapshots = [s for s in snapshots if s.timestamp <= as_of] + + # 4. Merge metadata from user-facing snapshot payloads (oldest first, newest wins) + # Skip internal event types (graph wiring, registration identity) and internal keys + _INTERNAL_EVENTS = {"depends_on", "has_dependent", "registered"} + _INTERNAL_KEYS = {"_content_hash", "upstream", "downstream", "upstream_hash", + "downstream_hash", "relationship", "via", "via_schema", + "name", "owner", "tier", "purpose", "model_origin"} + metadata: dict = {} + for snap in reversed(snapshots): # reversed = oldest first + if snap.event_type in _INTERNAL_EVENTS: + continue + metadata.update({k: v for k, v in snap.payload.items() if k not in _INTERNAL_KEYS}) + + # 5. Build recent_events list + events = [_snapshot_to_event(s) for s in snapshots] + total_events = len(events) + recent_events = events[:10] if input.detail == "summary" else events + + # 6. Compute days_since_last_event + days_since_last_event: int | None = None + if snapshots: + latest_ts = snapshots[0].timestamp + if latest_ts.tzinfo is None: + latest_ts = latest_ts.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + days_since_last_event = (now - latest_ts).days + + # 7. Get upstream/downstream (catch exceptions for models without graph nodes) + upstream_names: list[str] = [] + try: + upstream_names = ledger.upstream(input.model_name) + except (KeyError, ValueError, Exception): + upstream_names = [] + + downstream_names: list[str] = [] + try: + downstream_names = ledger.downstream(input.model_name) + except (KeyError, ValueError, Exception): + downstream_names = [] + + upstream_nodes = [DependencyNode(name=n) for n in upstream_names] + downstream_nodes = [DependencyNode(name=n) for n in downstream_names] + + # 8. Get groups and members (catch exceptions) + group_names: list[str] = [] + try: + group_refs = ledger.groups(model) + group_names = [g.name for g in group_refs] + except (KeyError, ValueError, Exception): + group_names = [] + + member_names: list[str] = [] + try: + member_refs = ledger.members(model) + member_names = [m.name for m in member_refs] + except (KeyError, ValueError, Exception): + member_names = [] + + return InvestigateOutput( + name=model.name, + owner=model.owner, + model_type=model.model_type, + purpose=model.purpose, + status=model.status, + created_at=model.created_at, + metadata=metadata, + recent_events=recent_events, + days_since_last_event=days_since_last_event, + total_events=total_events, + upstream=upstream_nodes, + downstream=downstream_nodes, + groups=group_names, + members=member_names, + ) diff --git a/src/model_ledger/tools/query.py b/src/model_ledger/tools/query.py new file mode 100644 index 0000000..ec9a0cc --- /dev/null +++ b/src/model_ledger/tools/query.py @@ -0,0 +1,80 @@ +"""Query tool — search and filter the model inventory with pagination.""" + +from __future__ import annotations + +from model_ledger.core.ledger_models import ModelRef +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import ModelSummary, QueryInput, QueryOutput + + +def _model_to_summary(model: ModelRef, ledger: Ledger) -> ModelSummary: + """Convert a ModelRef to a ModelSummary. + + Enriches the static ModelRef identity with dynamic event data: + - ``last_event``: timestamp of the most recent snapshot + - ``event_count``: total number of snapshots + - ``platform``: source field from the first snapshot that has one + """ + snapshots = ledger.history(model) + event_count = len(snapshots) + last_event = snapshots[0].timestamp if snapshots else None + + platform: str | None = None + for snap in snapshots: + # Prefer platform from discovered payload, fall back to source + if snap.event_type == "discovered" and snap.payload.get("platform"): + platform = snap.payload["platform"] + break + if snap.source and not platform: + platform = snap.source + + return ModelSummary( + name=model.name, + owner=model.owner, + model_type=model.model_type, + status=model.status, + platform=platform, + last_event=last_event, + event_count=event_count, + ) + + +def query(input: QueryInput, ledger: Ledger) -> QueryOutput: + """Search and filter the model inventory with pagination. + + Applies structured filters (model_type, owner, status) via the + ledger backend, then optionally fuzzy-filters on name and purpose + using case-insensitive substring matching. Results are paginated + via offset/limit. + """ + # Build filter dict — only include non-None values + filters: dict[str, str] = {} + if input.model_type is not None: + filters["model_type"] = input.model_type + if input.owner is not None: + filters["owner"] = input.owner + if input.status is not None: + filters["status"] = input.status + + # Get all matching models from the backend + models = ledger.list(**filters) + + # Fuzzy-filter on name and purpose (case-insensitive contains) + if input.text: + text_lower = input.text.lower() + models = [ + m + for m in models + if text_lower in m.name.lower() or text_lower in (m.purpose or "").lower() + ] + + total = len(models) + + # Paginate + page = models[input.offset : input.offset + input.limit] + has_more = (input.offset + input.limit) < total + + # Convert each ModelRef to a ModelSummary + summaries = [_model_to_summary(m, ledger) for m in page] + + return QueryOutput(total=total, models=summaries, has_more=has_more) diff --git a/src/model_ledger/tools/record.py b/src/model_ledger/tools/record.py new file mode 100644 index 0000000..a8805ad --- /dev/null +++ b/src/model_ledger/tools/record.py @@ -0,0 +1,55 @@ +"""Record tool — register a new model or record an event on an existing model.""" + +from __future__ import annotations + +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import RecordInput, RecordOutput + + +def record(input: RecordInput, ledger: Ledger) -> RecordOutput: + """Register a new model or record an event on an existing model. + + When ``input.event == "registered"``, creates the model via + ``ledger.register()`` then logs the registration event. + Otherwise, looks up the existing model and appends the event. + + Raises: + ModelNotFoundError: If the model doesn't exist and the event + is not ``"registered"``. + """ + if input.event == "registered": + model = ledger.register( + name=input.model_name, + owner=input.owner or "unknown", + model_type=input.model_type or "unknown", + tier="unclassified", + purpose=input.purpose or "", + actor=input.actor, + ) + snapshot = ledger.record( + model, + event="registered", + payload=input.payload, + actor=input.actor, + ) + return RecordOutput( + model_name=input.model_name, + event_id=snapshot.snapshot_hash, + timestamp=snapshot.timestamp, + is_new_model=True, + ) + + # Non-registration event: model must already exist + model = ledger.get(input.model_name) + snapshot = ledger.record( + model, + event=input.event, + payload=input.payload, + actor=input.actor, + ) + return RecordOutput( + model_name=input.model_name, + event_id=snapshot.snapshot_hash, + timestamp=snapshot.timestamp, + is_new_model=False, + ) diff --git a/src/model_ledger/tools/schemas.py b/src/model_ledger/tools/schemas.py new file mode 100644 index 0000000..3b19445 --- /dev/null +++ b/src/model_ledger/tools/schemas.py @@ -0,0 +1,209 @@ +"""Pydantic I/O schemas for the 6 agent protocol tools. + +These schemas are the single source of truth for the protocol contract. +They serialize to JSON Schema for MCP tools, OpenAPI docs, and any language. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- + + +class ModelSummary(BaseModel): + """Compact model info returned by query and discover tools.""" + + name: str + owner: str | None = None + model_type: str | None = None + platform: str | None = None + status: str | None = None + last_event: datetime | None = None + event_count: int = 0 + + +class EventSummary(BaseModel): + """Compact event returned in investigation results.""" + + event_type: str + timestamp: datetime | None = None + actor: str | None = None + summary: str | None = None + + +class EventDetail(EventSummary): + """Full event with model association and payload.""" + + model_name: str | None = None + payload: dict[str, Any] = Field(default_factory=dict) + + +class DependencyNode(BaseModel): + """A node in a dependency graph.""" + + name: str + platform: str | None = None + depth: int = 0 + relationship: str | None = None + + +# --------------------------------------------------------------------------- +# record tool +# --------------------------------------------------------------------------- + + +class RecordInput(BaseModel): + """Input for the record tool — log an event against a model.""" + + model_name: str + event: str + payload: dict[str, Any] = Field(default_factory=dict) + actor: str = "user" + owner: str | None = None + model_type: str | None = None + purpose: str | None = None + + +class RecordOutput(BaseModel): + """Output from the record tool.""" + + model_name: str + event_id: str + timestamp: datetime + is_new_model: bool + + +# --------------------------------------------------------------------------- +# query tool +# --------------------------------------------------------------------------- + + +class QueryInput(BaseModel): + """Input for the query tool — search and filter models.""" + + text: str | None = None + platform: str | None = None + model_type: str | None = None + owner: str | None = None + status: str | None = None + limit: int = 50 + offset: int = 0 + + +class QueryOutput(BaseModel): + """Output from the query tool.""" + + total: int + models: list[ModelSummary] + has_more: bool + + +# --------------------------------------------------------------------------- +# investigate tool +# --------------------------------------------------------------------------- + + +class InvestigateInput(BaseModel): + """Input for the investigate tool — deep-dive into a single model.""" + + model_name: str + detail: Literal["summary", "full"] = "summary" + as_of: datetime | None = None + + +class InvestigateOutput(BaseModel): + """Output from the investigate tool.""" + + name: str + owner: str | None = None + model_type: str | None = None + purpose: str | None = None + status: str | None = None + created_at: datetime | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + recent_events: list[EventSummary] = Field(default_factory=list) + days_since_last_event: int | None = None + total_events: int = 0 + upstream: list[DependencyNode] = Field(default_factory=list) + downstream: list[DependencyNode] = Field(default_factory=list) + groups: list[str] = Field(default_factory=list) + members: list[str] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# trace tool +# --------------------------------------------------------------------------- + + +class TraceInput(BaseModel): + """Input for the trace tool — follow dependency chains.""" + + name: str + direction: Literal["upstream", "downstream", "both"] = "both" + depth: int | None = None + + +class TraceOutput(BaseModel): + """Output from the trace tool.""" + + root: str + upstream: list[DependencyNode] = Field(default_factory=list) + downstream: list[DependencyNode] = Field(default_factory=list) + total_nodes: int = 0 + + +# --------------------------------------------------------------------------- +# changelog tool +# --------------------------------------------------------------------------- + + +class ChangelogInput(BaseModel): + """Input for the changelog tool — view event history.""" + + since: datetime | None = None + until: datetime | None = None + model_name: str | None = None + event_type: str | None = None + limit: int = 100 + offset: int = 0 + + +class ChangelogOutput(BaseModel): + """Output from the changelog tool.""" + + total: int + events: list[EventDetail] + has_more: bool + period: str | None = None + + +# --------------------------------------------------------------------------- +# discover tool +# --------------------------------------------------------------------------- + + +class DiscoverInput(BaseModel): + """Input for the discover tool — import models from external sources.""" + + source_type: Literal["connector", "file", "inline"] + connector_name: str | None = None + connector_config: dict[str, Any] | None = None + file_path: str | None = None + models: list[dict[str, Any]] | None = None + auto_connect: bool = True + + +class DiscoverOutput(BaseModel): + """Output from the discover tool.""" + + models_added: int + models_skipped: int + links_created: int + models: list[ModelSummary] = Field(default_factory=list) + errors: list[str] = Field(default_factory=list) diff --git a/src/model_ledger/tools/trace.py b/src/model_ledger/tools/trace.py new file mode 100644 index 0000000..0d05e9e --- /dev/null +++ b/src/model_ledger/tools/trace.py @@ -0,0 +1,90 @@ +"""Trace tool — dependency graph traversal.""" + +from __future__ import annotations + +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import DependencyNode, TraceInput, TraceOutput + + +def _get_platform(name: str, ledger: Ledger) -> str | None: + """Try to extract platform from the model's snapshot history.""" + try: + model = ledger.get(name) + for snap in ledger.history(model): + if snap.source: + return snap.source + # Also check payload for platform from discovered events + if snap.payload.get("platform"): + return snap.payload["platform"] + except Exception: + pass + return None + + +def trace(input: TraceInput, ledger: Ledger) -> TraceOutput: + """Traverse a model's dependency graph. + + Walks upstream (models this one depends on) and/or downstream + (models that depend on this one), returning ``DependencyNode`` lists + with depth and relationship metadata. + + Raises: + ModelNotFoundError: If the target model does not exist. + """ + # 1. Verify model exists — raises ModelNotFoundError if missing + ledger.get(input.name) + + # 2. Build upstream list + upstream_nodes: list[DependencyNode] = [] + if input.direction in ("upstream", "both"): + try: + upstream_names = ledger.upstream(input.name) + except (KeyError, ValueError): + upstream_names = [] + + # upstream() returns topological order (sources first, nearest last), + # so reverse depth: nearest dependency = depth 1, furthest = len. + total_up = len(upstream_names) + for idx, name in enumerate(upstream_names): + platform = _get_platform(name, ledger) + upstream_nodes.append( + DependencyNode( + name=name, + platform=platform, + depth=total_up - idx, + relationship="depends_on", + ) + ) + + # 3. Build downstream list + downstream_nodes: list[DependencyNode] = [] + if input.direction in ("downstream", "both"): + try: + downstream_names = ledger.downstream(input.name) + except (KeyError, ValueError): + downstream_names = [] + + for idx, name in enumerate(downstream_names): + platform = _get_platform(name, ledger) + downstream_nodes.append( + DependencyNode( + name=name, + platform=platform, + depth=idx + 1, + relationship="feeds_into", + ) + ) + + # 4. Apply depth filter + if input.depth is not None: + upstream_nodes = [n for n in upstream_nodes if n.depth <= input.depth] + downstream_nodes = [n for n in downstream_nodes if n.depth <= input.depth] + + # 5. Return result + total = len(upstream_nodes) + len(downstream_nodes) + return TraceOutput( + root=input.name, + upstream=upstream_nodes, + downstream=downstream_nodes, + total_nodes=total, + ) diff --git a/tests/test_backends/test_json_files.py b/tests/test_backends/test_json_files.py new file mode 100644 index 0000000..63cdbaf --- /dev/null +++ b/tests/test_backends/test_json_files.py @@ -0,0 +1,334 @@ +"""Tests for JsonFileLedgerBackend.""" + +import json +import os +import tempfile +from datetime import datetime, timezone + +import pytest + +from model_ledger.backends.json_files import JsonFileLedgerBackend +from model_ledger.core.ledger_models import ModelRef, Snapshot, Tag + + +@pytest.fixture +def backend_dir(): + with tempfile.TemporaryDirectory() as d: + yield d + + +@pytest.fixture +def backend(backend_dir): + return JsonFileLedgerBackend(backend_dir) + + +def _make_model(name="test-model"): + return ModelRef( + name=name, + owner="alice", + model_type="ml_model", + tier="high", + purpose="testing", + status="active", + ) + + +def _make_snapshot(model_hash, event_type="discovered"): + return Snapshot( + model_hash=model_hash, + actor="test", + event_type=event_type, + payload={"key": "value"}, + ) + + +class TestModels: + def test_save_and_get(self, backend): + m = _make_model() + backend.save_model(m) + result = backend.get_model(m.model_hash) + assert result is not None + assert result.name == "test-model" + assert result.owner == "alice" + + def test_get_by_name(self, backend): + m = _make_model() + backend.save_model(m) + result = backend.get_model_by_name("test-model") + assert result is not None + assert result.model_hash == m.model_hash + + def test_get_missing(self, backend): + assert backend.get_model("nonexistent") is None + assert backend.get_model_by_name("nonexistent") is None + + def test_list_models(self, backend): + backend.save_model(_make_model("a")) + backend.save_model(_make_model("b")) + assert len(backend.list_models()) == 2 + + def test_list_models_with_filter(self, backend): + backend.save_model(_make_model("a")) + backend.save_model(_make_model("b")) + result = backend.list_models(name="a") + assert len(result) == 1 + assert result[0].name == "a" + + def test_update_model(self, backend): + m = _make_model() + backend.save_model(m) + m.status = "deprecated" + backend.update_model(m) + result = backend.get_model(m.model_hash) + assert result.status == "deprecated" + + +class TestSnapshots: + def test_append_and_get(self, backend): + m = _make_model() + backend.save_model(m) + s = _make_snapshot(m.model_hash) + backend.append_snapshot(s) + result = backend.get_snapshot(s.snapshot_hash) + assert result is not None + assert result.payload == {"key": "value"} + + def test_list_snapshots(self, backend): + m = _make_model() + backend.save_model(m) + backend.append_snapshot(_make_snapshot(m.model_hash, "registered")) + backend.append_snapshot(_make_snapshot(m.model_hash, "discovered")) + snaps = backend.list_snapshots(m.model_hash) + assert len(snaps) == 2 + + def test_list_snapshots_sorted_descending(self, backend): + """Snapshots are returned newest-first.""" + m = _make_model() + backend.save_model(m) + s1 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="registered", + payload={"order": 1}, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + s2 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="discovered", + payload={"order": 2}, + timestamp=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + backend.append_snapshot(s1) + backend.append_snapshot(s2) + snaps = backend.list_snapshots(m.model_hash) + assert snaps[0].timestamp > snaps[1].timestamp + + def test_latest_snapshot(self, backend): + m = _make_model() + backend.save_model(m) + s1 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="registered", + payload={"v": 1}, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + s2 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="discovered", + payload={"v": 2}, + timestamp=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + backend.append_snapshot(s1) + backend.append_snapshot(s2) + latest = backend.latest_snapshot(m.model_hash) + assert latest is not None + assert latest.payload["v"] == 2 + + def test_latest_snapshot_empty(self, backend): + assert backend.latest_snapshot("nonexistent") is None + + def test_list_snapshots_before(self, backend): + m = _make_model() + backend.save_model(m) + s1 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="registered", + payload={}, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + s2 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="discovered", + payload={}, + timestamp=datetime(2025, 6, 1, tzinfo=timezone.utc), + ) + backend.append_snapshot(s1) + backend.append_snapshot(s2) + before = datetime(2025, 3, 1, tzinfo=timezone.utc) + results = backend.list_snapshots_before(m.model_hash, before) + assert len(results) == 1 + assert results[0].snapshot_hash == s1.snapshot_hash + + def test_list_snapshots_before_with_event_type(self, backend): + m = _make_model() + backend.save_model(m) + s1 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="registered", + payload={}, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + s2 = Snapshot( + model_hash=m.model_hash, + actor="test", + event_type="discovered", + payload={}, + timestamp=datetime(2025, 2, 1, tzinfo=timezone.utc), + ) + backend.append_snapshot(s1) + backend.append_snapshot(s2) + before = datetime(2025, 6, 1, tzinfo=timezone.utc) + results = backend.list_snapshots_before( + m.model_hash, + before, + event_type="registered", + ) + assert len(results) == 1 + assert results[0].event_type == "registered" + + +class TestTags: + def test_set_and_get(self, backend): + m = _make_model() + backend.save_model(m) + s = _make_snapshot(m.model_hash) + backend.append_snapshot(s) + tag = Tag(name="latest", model_hash=m.model_hash, snapshot_hash=s.snapshot_hash) + backend.set_tag(tag) + result = backend.get_tag(m.model_hash, "latest") + assert result is not None + assert result.snapshot_hash == s.snapshot_hash + + def test_list_tags(self, backend): + m = _make_model() + backend.save_model(m) + s = _make_snapshot(m.model_hash) + backend.append_snapshot(s) + backend.set_tag(Tag(name="v1", model_hash=m.model_hash, snapshot_hash=s.snapshot_hash)) + backend.set_tag(Tag(name="v2", model_hash=m.model_hash, snapshot_hash=s.snapshot_hash)) + assert len(backend.list_tags(m.model_hash)) == 2 + + def test_latest_via_tag(self, backend): + """latest_snapshot with a tag resolves through the tag pointer.""" + m = _make_model() + backend.save_model(m) + s1 = _make_snapshot(m.model_hash, "registered") + s2 = _make_snapshot(m.model_hash, "discovered") + backend.append_snapshot(s1) + backend.append_snapshot(s2) + backend.set_tag(Tag(name="pinned", model_hash=m.model_hash, snapshot_hash=s1.snapshot_hash)) + result = backend.latest_snapshot(m.model_hash, tag="pinned") + assert result is not None + assert result.snapshot_hash == s1.snapshot_hash + + def test_get_tag_missing(self, backend): + assert backend.get_tag("nonexistent", "v1") is None + + +class TestPersistence: + def test_data_survives_reopen(self, backend_dir): + backend1 = JsonFileLedgerBackend(backend_dir) + m = _make_model() + backend1.save_model(m) + backend1.append_snapshot(_make_snapshot(m.model_hash)) + backend1.set_tag( + Tag( + name="v1", + model_hash=m.model_hash, + snapshot_hash=_make_snapshot(m.model_hash).snapshot_hash, + ) + ) + del backend1 + + backend2 = JsonFileLedgerBackend(backend_dir) + assert backend2.get_model_by_name("test-model") is not None + assert len(backend2.list_snapshots(m.model_hash)) == 1 + assert len(backend2.list_tags(m.model_hash)) == 1 + + +class TestReadableJson: + def test_model_file_is_valid_json(self, backend, backend_dir): + m = _make_model() + backend.save_model(m) + # Find the model file and read it with json.load + models_dir = os.path.join(backend_dir, "models") + files = os.listdir(models_dir) + assert len(files) == 1 + with open(os.path.join(models_dir, files[0])) as f: + data = json.load(f) + assert data["name"] == "test-model" + assert data["owner"] == "alice" + assert data["model_hash"] == m.model_hash + + def test_snapshot_file_is_valid_json(self, backend, backend_dir): + m = _make_model() + backend.save_model(m) + s = _make_snapshot(m.model_hash) + backend.append_snapshot(s) + snap_dir = os.path.join(backend_dir, "snapshots") + files = os.listdir(snap_dir) + assert len(files) == 1 + with open(os.path.join(snap_dir, files[0])) as f: + data = json.load(f) + assert data["snapshot_hash"] == s.snapshot_hash + assert data["payload"]["key"] == "value" + + def test_tag_file_is_valid_json(self, backend, backend_dir): + m = _make_model() + backend.save_model(m) + s = _make_snapshot(m.model_hash) + backend.append_snapshot(s) + tag = Tag(name="v1", model_hash=m.model_hash, snapshot_hash=s.snapshot_hash) + backend.set_tag(tag) + tag_file = os.path.join(backend_dir, "tags", m.model_hash, "v1.json") + assert os.path.exists(tag_file) + with open(tag_file) as f: + data = json.load(f) + assert data["name"] == "v1" + assert data["snapshot_hash"] == s.snapshot_hash + + +class TestFilenamesSanitized: + def test_model_with_special_chars(self, backend): + m = _make_model("my model/v1\\test") + backend.save_model(m) + result = backend.get_model_by_name("my model/v1\\test") + assert result is not None + assert result.name == "my model/v1\\test" + + +class TestLedgerIntegration: + def test_full_workflow(self, backend_dir): + from model_ledger import Ledger + from model_ledger.graph.models import DataNode + + backend = JsonFileLedgerBackend(backend_dir) + ledger = Ledger(backend) + + ledger.add( + [ + DataNode("writer", outputs=["shared_table"]), + DataNode("reader", inputs=["shared_table"]), + ] + ) + ledger.connect() + + assert len(ledger.list()) == 2 + trace = ledger.trace("reader") + assert "writer" in trace diff --git a/tests/test_mcp/__init__.py b/tests/test_mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_mcp/test_server.py b/tests/test_mcp/test_server.py new file mode 100644 index 0000000..475f2ab --- /dev/null +++ b/tests/test_mcp/test_server.py @@ -0,0 +1,153 @@ +"""Tests for the MCP server — tool and resource registration.""" + +from __future__ import annotations + +import asyncio + +import pytest + +try: + from mcp.server.fastmcp import FastMCP + + HAS_MCP = True +except ImportError: + HAS_MCP = False + +pytestmark = pytest.mark.skipif(not HAS_MCP, reason="mcp extra not installed") + + +@pytest.fixture +def server(): + from model_ledger.mcp.server import create_server + + return create_server(demo=False) + + +@pytest.fixture +def server_demo(): + from model_ledger.mcp.server import create_server + + return create_server(demo=True) + + +class TestCreateServer: + """create_server() returns a properly configured FastMCP instance.""" + + def test_returns_fastmcp_instance(self, server): + assert isinstance(server, FastMCP) + + def test_server_name(self, server): + assert server.name == "model-ledger" + + +class TestToolRegistration: + """All 6 tools are registered.""" + + EXPECTED_TOOLS = {"discover", "record", "investigate", "query", "trace", "changelog"} + + def test_all_tools_registered(self, server): + tools = asyncio.run(server.list_tools()) + tool_names = {t.name for t in tools} + assert tool_names >= self.EXPECTED_TOOLS + + def test_exactly_six_tools(self, server): + tools = asyncio.run(server.list_tools()) + assert len(tools) == 6 + + def test_each_tool_has_description(self, server): + tools = asyncio.run(server.list_tools()) + for tool in tools: + assert tool.description, f"Tool {tool.name} has no description" + + +class TestResourceRegistration: + """All 3 resources are registered.""" + + EXPECTED_URIS = { + "ledger://overview", + "ledger://schema", + "ledger://backends", + } + + def test_all_resources_registered(self, server): + resources = asyncio.run(server.list_resources()) + uris = {str(r.uri) for r in resources} + assert uris >= self.EXPECTED_URIS + + def test_exactly_three_resources(self, server): + resources = asyncio.run(server.list_resources()) + assert len(resources) == 3 + + +class TestToolExecution: + """Smoke-test that tools can be called (returns dict, not Pydantic).""" + + def test_record_tool_creates_model(self, server): + result = asyncio.run( + server.call_tool( + "record", + { + "model_name": "credit-scorecard", + "event": "registered", + "actor": "test", + "owner": "risk-team", + "model_type": "ml_model", + }, + ) + ) + # call_tool returns list of content items; check first one has text + assert len(result) > 0 + + def test_query_tool_returns_results(self, server): + # First register a model + asyncio.run( + server.call_tool( + "record", + { + "model_name": "fraud-detector", + "event": "registered", + "owner": "risk-team", + "model_type": "ml_model", + }, + ) + ) + result = asyncio.run(server.call_tool("query", {})) + assert len(result) > 0 + + +class TestResourceReading: + """Smoke-test that resources can be read.""" + + def test_overview_resource(self, server): + result = asyncio.run(server.read_resource("ledger://overview")) + # Returns bytes or str + assert result is not None + + def test_schema_resource(self, server): + result = asyncio.run(server.read_resource("ledger://schema")) + assert result is not None + + def test_backends_resource(self, server): + result = asyncio.run(server.read_resource("ledger://backends")) + assert result is not None + + +class TestCustomBackend: + """create_server accepts a custom backend.""" + + def test_with_explicit_backend(self): + from model_ledger.backends.ledger_memory import InMemoryLedgerBackend + from model_ledger.mcp.server import create_server + + backend = InMemoryLedgerBackend() + srv = create_server(backend=backend, demo=False) + assert isinstance(srv, FastMCP) + + +class TestMainEntryPoint: + """main() function exists and is importable.""" + + def test_main_is_callable(self): + from model_ledger.mcp.server import main + + assert callable(main) diff --git a/tests/test_rest/__init__.py b/tests/test_rest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_rest/test_app.py b/tests/test_rest/test_app.py new file mode 100644 index 0000000..a3bafa2 --- /dev/null +++ b/tests/test_rest/test_app.py @@ -0,0 +1,180 @@ +# tests/test_rest/test_app.py +"""Tests for the FastAPI REST API wrapping tool functions.""" + +from __future__ import annotations + +import pytest + +fastapi = pytest.importorskip("fastapi") + +from fastapi.testclient import TestClient + +from model_ledger.rest.app import create_app + + +@pytest.fixture +def client(): + """TestClient backed by an in-memory ledger.""" + app = create_app() + return TestClient(app) + + +class TestOverview: + """GET /overview on an empty ledger.""" + + def test_overview_empty(self, client): + resp = client.get("/overview") + assert resp.status_code == 200 + data = resp.json() + assert data["total_models"] == 0 + assert data["total_events"] == 0 + + +class TestRecordEndpoint: + """POST /record — register a model.""" + + def test_register_model(self, client): + resp = client.post( + "/record", + json={ + "model_name": "credit-scorecard", + "event": "registered", + "actor": "alice", + "owner": "risk-team", + "model_type": "ml_model", + "purpose": "Credit risk scoring", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["model_name"] == "credit-scorecard" + assert data["is_new_model"] is True + assert data["event_id"] # non-empty + + +class TestQueryAfterRegister: + """GET /query — search after registering a model.""" + + def test_query_finds_registered_model(self, client): + # Register first + client.post( + "/record", + json={ + "model_name": "fraud-detector", + "event": "registered", + "actor": "bob", + "owner": "security-team", + "model_type": "ml_model", + "purpose": "Detect fraudulent transactions", + }, + ) + + resp = client.get("/query", params={"text": "fraud"}) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] >= 1 + names = [m["name"] for m in data["models"]] + assert "fraud-detector" in names + + def test_query_empty_inventory(self, client): + resp = client.get("/query") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 0 + assert data["models"] == [] + + +class TestInvestigateEndpoint: + """GET /investigate/{model_name} — deep-dive into a model.""" + + def test_investigate_registered_model(self, client): + # Register + client.post( + "/record", + json={ + "model_name": "credit-scorecard", + "event": "registered", + "actor": "alice", + "owner": "risk-team", + "model_type": "ml_model", + "purpose": "Credit risk scoring", + }, + ) + + resp = client.get("/investigate/credit-scorecard") + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "credit-scorecard" + assert data["owner"] == "risk-team" + assert data["total_events"] >= 1 + + def test_investigate_404(self, client): + resp = client.get("/investigate/nonexistent-model") + assert resp.status_code == 404 + assert "nonexistent-model" in resp.json()["detail"] + + +class TestTraceEndpoint: + """GET /trace/{name} — dependency tracing.""" + + def test_trace_registered_model(self, client): + client.post( + "/record", + json={ + "model_name": "scoring-model", + "event": "registered", + "actor": "alice", + "owner": "data-team", + "model_type": "ml_model", + }, + ) + + resp = client.get("/trace/scoring-model") + assert resp.status_code == 200 + data = resp.json() + assert data["root"] == "scoring-model" + + def test_trace_404(self, client): + resp = client.get("/trace/nonexistent") + assert resp.status_code == 404 + + +class TestChangelogEndpoint: + """GET /changelog — event timeline.""" + + def test_changelog_with_events(self, client): + client.post( + "/record", + json={ + "model_name": "scoring-model", + "event": "registered", + "actor": "alice", + "owner": "data-team", + "model_type": "ml_model", + }, + ) + + resp = client.get("/changelog") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] >= 1 + assert len(data["events"]) >= 1 + + +class TestDiscoverEndpoint: + """POST /discover — bulk ingestion.""" + + def test_discover_inline(self, client): + resp = client.post( + "/discover", + json={ + "source_type": "inline", + "models": [ + {"name": "pipeline-a", "platform": "airflow"}, + {"name": "pipeline-b", "platform": "airflow"}, + ], + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["models_added"] == 2 diff --git a/tests/test_tools/__init__.py b/tests/test_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tools/test_changelog.py b/tests/test_tools/test_changelog.py new file mode 100644 index 0000000..dfc3fc5 --- /dev/null +++ b/tests/test_tools/test_changelog.py @@ -0,0 +1,217 @@ +# tests/test_tools/test_changelog.py +"""Tests for the changelog tool — cross-model event timeline.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.changelog import changelog +from model_ledger.tools.record import record +from model_ledger.tools.schemas import ( + ChangelogInput, + ChangelogOutput, + EventDetail, + RecordInput, +) + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +def _seed(ledger): + """Create 3 models with events matching the task spec.""" + for name in ["model_a", "model_b", "model_c"]: + record( + RecordInput( + model_name=name, + event="registered", + owner="team", + model_type="ml_model", + purpose="test", + ), + ledger, + ) + record(RecordInput(model_name="model_a", event="retrained", actor="pipeline"), ledger) + record(RecordInput(model_name="model_b", event="deployed", actor="ci"), ledger) + + +class TestAllEvents: + """No filters — all events returned.""" + + def test_returns_all_events(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(), ledger) + + assert isinstance(result, ChangelogOutput) + # 3 registrations + 1 retrained + 1 deployed = 5 tool-level events + # (register() internally also creates a snapshot, but the tool-level + # record() creates the one we see via ledger.history too — count all) + assert result.total >= 5 + assert len(result.events) >= 5 + for ev in result.events: + assert isinstance(ev, EventDetail) + assert ev.model_name is not None + assert ev.event_type + + def test_period_defaults_to_last_7_days(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(), ledger) + + assert result.period is not None + assert "7 days" in result.period + + +class TestFilterByModelName: + """Filter by model_name — only that model's events.""" + + def test_filter_single_model(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(model_name="model_a"), ledger) + + assert result.total >= 1 + for ev in result.events: + assert ev.model_name == "model_a" + + def test_filter_model_with_extra_events(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(model_name="model_a"), ledger) + + event_types = [ev.event_type for ev in result.events] + assert "registered" in event_types + assert "retrained" in event_types + + +class TestFilterByEventType: + """Filter by event_type — only matching events.""" + + def test_filter_retrained(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(event_type="retrained"), ledger) + + assert result.total >= 1 + for ev in result.events: + assert ev.event_type == "retrained" + model_names = [ev.model_name for ev in result.events] + assert "model_a" in model_names + + def test_filter_deployed(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(event_type="deployed"), ledger) + + assert result.total >= 1 + for ev in result.events: + assert ev.event_type == "deployed" + + def test_filter_event_type_no_match(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(event_type="nonexistent"), ledger) + + assert result.total == 0 + assert result.events == [] + + +class TestPagination: + """Pagination with limit/offset and has_more.""" + + def test_limit_returns_subset(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(limit=2), ledger) + + assert len(result.events) == 2 + assert result.has_more is True + assert result.total >= 5 + + def test_offset_skips(self, ledger): + _seed(ledger) + + full = changelog(ChangelogInput(), ledger) + page2 = changelog(ChangelogInput(limit=2, offset=2), ledger) + + assert page2.total == full.total + # Events on page 2 should be different from page 1 + page1 = changelog(ChangelogInput(limit=2, offset=0), ledger) + page1_ids = [(e.model_name, e.timestamp) for e in page1.events] + page2_ids = [(e.model_name, e.timestamp) for e in page2.events] + assert page1_ids != page2_ids + + def test_offset_beyond_total(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(limit=10, offset=1000), ledger) + + assert result.total >= 5 + assert len(result.events) == 0 + assert result.has_more is False + + +class TestNewestFirstOrdering: + """Events sorted by timestamp descending (newest first).""" + + def test_newest_first(self, ledger): + _seed(ledger) + + result = changelog(ChangelogInput(), ledger) + + timestamps = [ev.timestamp for ev in result.events] + for i in range(len(timestamps) - 1): + assert timestamps[i] >= timestamps[i + 1], ( + f"Event {i} ({timestamps[i]}) should be >= event {i + 1} ({timestamps[i + 1]})" + ) + + +class TestEmptyInventory: + """Empty inventory returns total=0.""" + + def test_empty_returns_zero(self, ledger): + result = changelog(ChangelogInput(), ledger) + + assert result.total == 0 + assert result.events == [] + assert result.has_more is False + + +class TestTimeRangeFiltering: + """Time range filtering with since/until.""" + + def test_since_filters_old_events(self, ledger): + _seed(ledger) + + # All events were just created, so a future 'since' should exclude them + future = datetime.now(timezone.utc) + timedelta(hours=1) + result = changelog(ChangelogInput(since=future), ledger) + + assert result.total == 0 + + def test_until_filters_future_events(self, ledger): + _seed(ledger) + + # 'until' set to the past should exclude all just-created events + past = datetime.now(timezone.utc) - timedelta(hours=1) + result = changelog(ChangelogInput(until=past), ledger) + + assert result.total == 0 + + def test_both_since_and_until_period_string(self, ledger): + _seed(ledger) + + since = datetime(2026, 1, 1, tzinfo=timezone.utc) + until = datetime(2026, 12, 31, tzinfo=timezone.utc) + result = changelog(ChangelogInput(since=since, until=until), ledger) + + assert result.period is not None + assert "2026-01-01" in result.period + assert "2026-12-31" in result.period diff --git a/tests/test_tools/test_discover.py b/tests/test_tools/test_discover.py new file mode 100644 index 0000000..e1372bd --- /dev/null +++ b/tests/test_tools/test_discover.py @@ -0,0 +1,170 @@ +# tests/test_tools/test_discover.py +"""Tests for the discover tool — bulk ingestion from inline data, files, or connectors.""" + +from __future__ import annotations + +import json +import os +import tempfile + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.discover import discover +from model_ledger.tools.schemas import DiscoverInput, DiscoverOutput + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +class TestDiscoverInline: + """Inline source_type — models passed directly as list of dicts.""" + + def test_inline_adds_models(self, ledger): + """Two inline models should yield models_added=2, models_skipped=0.""" + inp = DiscoverInput( + source_type="inline", + models=[ + {"name": "feature_pipeline", "platform": "airflow"}, + {"name": "scoring_model", "platform": "sagemaker"}, + ], + auto_connect=False, + ) + result = discover(inp, ledger) + + assert isinstance(result, DiscoverOutput) + assert result.models_added == 2 + assert result.models_skipped == 0 + assert len(result.models) == 2 + assert result.errors == [] + + def test_inline_auto_connect(self, ledger): + """Models with matching input/output ports should create links.""" + inp = DiscoverInput( + source_type="inline", + models=[ + { + "name": "feature_pipeline", + "platform": "airflow", + "outputs": ["feature_table"], + }, + { + "name": "scoring_model", + "platform": "sagemaker", + "inputs": ["feature_table"], + }, + ], + auto_connect=True, + ) + result = discover(inp, ledger) + + assert result.models_added == 2 + assert result.links_created >= 1 + + def test_inline_auto_connect_false(self, ledger): + """When auto_connect=False, links_created must be 0.""" + inp = DiscoverInput( + source_type="inline", + models=[ + { + "name": "feature_pipeline", + "platform": "airflow", + "outputs": ["feature_table"], + }, + { + "name": "scoring_model", + "platform": "sagemaker", + "inputs": ["feature_table"], + }, + ], + auto_connect=False, + ) + result = discover(inp, ledger) + + assert result.models_added == 2 + assert result.links_created == 0 + + def test_inline_dedup(self, ledger): + """Adding the same models twice should skip them on the second call.""" + models = [ + {"name": "feature_pipeline", "platform": "airflow"}, + {"name": "scoring_model", "platform": "sagemaker"}, + ] + inp1 = DiscoverInput(source_type="inline", models=models, auto_connect=False) + result1 = discover(inp1, ledger) + assert result1.models_added == 2 + assert result1.models_skipped == 0 + + inp2 = DiscoverInput(source_type="inline", models=models, auto_connect=False) + result2 = discover(inp2, ledger) + assert result2.models_added == 0 + assert result2.models_skipped == 2 + + def test_inline_empty_list(self, ledger): + """Empty model list yields models_added=0.""" + inp = DiscoverInput(source_type="inline", models=[], auto_connect=False) + result = discover(inp, ledger) + + assert result.models_added == 0 + assert result.models_skipped == 0 + assert result.links_created == 0 + assert result.models == [] + + def test_inline_none_models_raises(self, ledger): + """source_type='inline' with models=None should raise ValueError.""" + inp = DiscoverInput(source_type="inline", models=None) + with pytest.raises(ValueError, match="models"): + discover(inp, ledger) + + +class TestDiscoverFile: + """File source_type — models loaded from a JSON file.""" + + def test_file_loads_models(self, ledger): + """Loading models from a JSON file should add them.""" + models = [ + {"name": "etl_job", "platform": "spark"}, + {"name": "report_gen", "platform": "tableau"}, + ] + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".json", + delete=False, + ) as f: + json.dump(models, f) + tmp_path = f.name + + try: + inp = DiscoverInput( + source_type="file", + file_path=tmp_path, + auto_connect=False, + ) + result = discover(inp, ledger) + + assert result.models_added == 2 + assert result.models_skipped == 0 + finally: + os.unlink(tmp_path) + + def test_file_none_path_raises(self, ledger): + """source_type='file' with file_path=None should raise ValueError.""" + inp = DiscoverInput(source_type="file", file_path=None) + with pytest.raises(ValueError, match="file_path"): + discover(inp, ledger) + + +class TestDiscoverConnector: + """Connector source_type — should raise NotImplementedError.""" + + def test_connector_raises(self, ledger): + inp = DiscoverInput( + source_type="connector", + connector_name="databricks", + connector_config={"workspace": "test"}, + ) + with pytest.raises(NotImplementedError, match="not yet supported"): + discover(inp, ledger) diff --git a/tests/test_tools/test_investigate.py b/tests/test_tools/test_investigate.py new file mode 100644 index 0000000..ca3fd6c --- /dev/null +++ b/tests/test_tools/test_investigate.py @@ -0,0 +1,335 @@ +# tests/test_tools/test_investigate.py +"""Tests for the investigate tool — deep-dive into a single model.""" + +from __future__ import annotations + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.core.exceptions import ModelNotFoundError +from model_ledger.graph.models import DataNode +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.investigate import investigate +from model_ledger.tools.record import record +from model_ledger.tools.schemas import ( + EventSummary, + InvestigateInput, + InvestigateOutput, + RecordInput, +) + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +class TestBasicInvestigation: + """Basic investigation returns name, owner, type, status.""" + + def test_basic_fields_present(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + assert isinstance(result, InvestigateOutput) + assert result.name == "fraud_scoring" + assert result.owner == "risk-team" + assert result.model_type == "ml_model" + assert result.status == "active" + assert result.purpose == "Fraud detection" + + def test_created_at_present(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + assert result.created_at is not None + + def test_total_events_counted(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"accuracy": 0.92}, + actor="ds", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + # register creates 2 snapshots (internal register + record tool's record), + # plus the metadata_updated event + assert result.total_events >= 3 + + def test_days_since_last_event(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + # Just registered, should be 0 days + assert result.days_since_last_event is not None + assert result.days_since_last_event == 0 + + +class TestMetadataMerge: + """Metadata from snapshot payloads is merged oldest-first.""" + + def test_metadata_merged_from_payloads(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"accuracy": 0.92}, + actor="ds", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"environment": "production"}, + actor="ops", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + # Both payload keys should be present in merged metadata + assert result.metadata["accuracy"] == 0.92 + assert result.metadata["environment"] == "production" + + def test_newer_metadata_overwrites_older(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"accuracy": 0.85}, + actor="ds", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"accuracy": 0.92}, + actor="ds", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + # Newest value wins + assert result.metadata["accuracy"] == 0.92 + + +class TestRecentEvents: + """Recent events list with event_type.""" + + def test_recent_events_included(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + record( + RecordInput( + model_name="fraud_scoring", + event="metadata_updated", + payload={"accuracy": 0.92}, + actor="ds", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + assert len(result.recent_events) > 0 + for ev in result.recent_events: + assert isinstance(ev, EventSummary) + assert ev.event_type # non-empty + + # Should contain the event types we recorded + event_types = [e.event_type for e in result.recent_events] + assert "registered" in event_types + assert "metadata_updated" in event_types + + def test_summary_limits_to_10_events(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + # Create 15 extra events + for i in range(15): + record( + RecordInput( + model_name="fraud_scoring", + event="retrained", + payload={"iteration": i}, + actor="pipeline", + ), + ledger, + ) + + result = investigate( + InvestigateInput(model_name="fraud_scoring", detail="summary"), + ledger, + ) + + assert len(result.recent_events) == 10 + + def test_full_detail_returns_all_events(self, ledger): + record( + RecordInput( + model_name="fraud_scoring", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Fraud detection", + ), + ledger, + ) + for i in range(15): + record( + RecordInput( + model_name="fraud_scoring", + event="retrained", + payload={"iteration": i}, + actor="pipeline", + ), + ledger, + ) + + result = investigate( + InvestigateInput(model_name="fraud_scoring", detail="full"), + ledger, + ) + + # Should have all events, not capped at 10 + assert len(result.recent_events) > 10 + + +class TestNonexistentModel: + """Nonexistent model raises ModelNotFoundError.""" + + def test_raises_model_not_found(self, ledger): + with pytest.raises(ModelNotFoundError): + investigate(InvestigateInput(model_name="does_not_exist"), ledger) + + +class TestDependencies: + """Upstream and downstream dependencies from the graph.""" + + def test_shows_upstream_downstream(self, ledger): + ledger.add( + [ + DataNode("feature_pipeline", platform="etl", outputs=["scores"]), + DataNode( + "fraud_scoring", + platform="ml", + inputs=["scores"], + outputs=["alerts"], + ), + DataNode("alert_queue", platform="alerting", inputs=["alerts"]), + ] + ) + ledger.connect() + + result = investigate(InvestigateInput(model_name="fraud_scoring"), ledger) + + upstream_names = [d.name for d in result.upstream] + downstream_names = [d.name for d in result.downstream] + + assert "feature_pipeline" in upstream_names + assert "alert_queue" in downstream_names + + def test_no_graph_returns_empty_lists(self, ledger): + """Model with no graph connections returns empty dependency lists.""" + record( + RecordInput( + model_name="standalone_model", + event="registered", + owner="risk-team", + model_type="ml_model", + purpose="Standalone model", + ), + ledger, + ) + + result = investigate(InvestigateInput(model_name="standalone_model"), ledger) + + assert result.upstream == [] + assert result.downstream == [] diff --git a/tests/test_tools/test_query.py b/tests/test_tools/test_query.py new file mode 100644 index 0000000..330c904 --- /dev/null +++ b/tests/test_tools/test_query.py @@ -0,0 +1,255 @@ +# tests/test_tools/test_query.py +"""Tests for the query tool — search and filter model inventory.""" + +from __future__ import annotations + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.query import _model_to_summary, query +from model_ledger.tools.record import record +from model_ledger.tools.schemas import ( + ModelSummary, + QueryInput, + QueryOutput, + RecordInput, +) + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +def _register(ledger, name, owner="risk-team", model_type="ml_model", purpose=""): + """Helper to register a model with minimal boilerplate.""" + record( + RecordInput( + model_name=name, + event="registered", + owner=owner, + model_type=model_type, + purpose=purpose, + ), + ledger, + ) + + +@pytest.fixture +def populated_ledger(ledger): + """Ledger with 4 diverse models for filtering tests.""" + _register( + ledger, + "fraud-detector", + owner="risk-team", + model_type="ml_model", + purpose="Detect fraud", + ) + _register( + ledger, + "credit-scorecard", + owner="risk-team", + model_type="ml_model", + purpose="Credit scoring", + ) + _register( + ledger, + "pricing-rules", + owner="finance-team", + model_type="heuristic", + purpose="Pricing logic", + ) + _register( + ledger, + "churn-predictor", + owner="growth-team", + model_type="ml_model", + purpose="Predict churn", + ) + return ledger + + +class TestQueryListAll: + """No filters — list all models.""" + + def test_returns_all_models(self, populated_ledger): + result = query(QueryInput(), populated_ledger) + + assert isinstance(result, QueryOutput) + assert result.total == 4 + assert len(result.models) == 4 + assert result.has_more is False + + def test_each_model_is_model_summary(self, populated_ledger): + result = query(QueryInput(), populated_ledger) + + for m in result.models: + assert isinstance(m, ModelSummary) + assert m.name + assert m.event_count > 0 + + +class TestQueryPagination: + """Pagination via limit and offset.""" + + def test_limit_returns_subset(self, populated_ledger): + result = query(QueryInput(limit=2), populated_ledger) + + assert result.total == 4 + assert len(result.models) == 2 + assert result.has_more is True + + def test_second_page(self, populated_ledger): + result = query(QueryInput(limit=2, offset=2), populated_ledger) + + assert result.total == 4 + assert len(result.models) == 2 + assert result.has_more is False + + def test_offset_beyond_total(self, populated_ledger): + result = query(QueryInput(limit=10, offset=100), populated_ledger) + + assert result.total == 4 + assert len(result.models) == 0 + assert result.has_more is False + + +class TestQueryFilterByOwner: + """Filter by owner field.""" + + def test_filter_owner(self, populated_ledger): + result = query(QueryInput(owner="risk-team"), populated_ledger) + + assert result.total == 2 + assert all(m.owner == "risk-team" for m in result.models) + + def test_filter_owner_no_match(self, populated_ledger): + result = query(QueryInput(owner="nonexistent-team"), populated_ledger) + + assert result.total == 0 + assert result.models == [] + + +class TestQueryFilterByModelType: + """Filter by model_type field.""" + + def test_filter_model_type(self, populated_ledger): + result = query(QueryInput(model_type="heuristic"), populated_ledger) + + assert result.total == 1 + assert result.models[0].name == "pricing-rules" + assert result.models[0].model_type == "heuristic" + + def test_filter_model_type_ml(self, populated_ledger): + result = query(QueryInput(model_type="ml_model"), populated_ledger) + + assert result.total == 3 + + +class TestQueryTextSearch: + """Fuzzy text search on name and purpose.""" + + def test_text_matches_name(self, populated_ledger): + result = query(QueryInput(text="fraud"), populated_ledger) + + assert result.total == 1 + assert result.models[0].name == "fraud-detector" + + def test_text_matches_purpose(self, populated_ledger): + result = query(QueryInput(text="scoring"), populated_ledger) + + assert result.total == 1 + assert result.models[0].name == "credit-scorecard" + + def test_text_case_insensitive(self, populated_ledger): + result = query(QueryInput(text="FRAUD"), populated_ledger) + + assert result.total == 1 + assert result.models[0].name == "fraud-detector" + + def test_text_no_match(self, populated_ledger): + result = query(QueryInput(text="nonexistent"), populated_ledger) + + assert result.total == 0 + assert result.models == [] + + +class TestQueryEmptyInventory: + """Empty inventory returns empty results.""" + + def test_empty_inventory(self, ledger): + result = query(QueryInput(), ledger) + + assert result.total == 0 + assert result.models == [] + assert result.has_more is False + + +class TestModelToSummary: + """_model_to_summary helper builds ModelSummary from ModelRef.""" + + def test_summary_fields(self, ledger): + _register( + ledger, + "fraud-detector", + owner="risk-team", + model_type="ml_model", + purpose="Detect fraud", + ) + model = ledger.get("fraud-detector") + summary = _model_to_summary(model, ledger) + + assert isinstance(summary, ModelSummary) + assert summary.name == "fraud-detector" + assert summary.owner == "risk-team" + assert summary.model_type == "ml_model" + assert summary.status == "active" + assert summary.event_count >= 1 + assert summary.last_event is not None + + def test_summary_event_count_increases(self, ledger): + _register( + ledger, + "scoring-model", + owner="data-team", + model_type="ml_model", + purpose="Score", + ) + # Add more events + record( + RecordInput(model_name="scoring-model", event="retrained", actor="pipeline"), + ledger, + ) + record( + RecordInput(model_name="scoring-model", event="deployed", actor="deployer"), + ledger, + ) + + model = ledger.get("scoring-model") + summary = _model_to_summary(model, ledger) + + # register creates 2 snapshots (register call + record call in record tool), + # plus retrained + deployed + assert summary.event_count >= 3 + + def test_summary_platform_from_source(self, ledger): + _register( + ledger, + "platform-model", + owner="data-team", + model_type="ml_model", + purpose="Test", + ) + model = ledger.get("platform-model") + # Record an event with a source + ledger.record( + model, + event="discovered", + payload={"platform": "mlflow"}, + actor="connector", + source="mlflow", + ) + + summary = _model_to_summary(model, ledger) + assert summary.platform == "mlflow" diff --git a/tests/test_tools/test_record.py b/tests/test_tools/test_record.py new file mode 100644 index 0000000..e5376b7 --- /dev/null +++ b/tests/test_tools/test_record.py @@ -0,0 +1,139 @@ +# tests/test_tools/test_record.py +"""Tests for the record tool — register models and record events.""" +from __future__ import annotations + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.core.exceptions import ModelNotFoundError +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.record import record +from model_ledger.tools.schemas import RecordInput, RecordOutput + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +class TestRecordRegisterNewModel: + """event='registered' should create a new model and log the registration.""" + + def test_register_returns_new_model(self, ledger): + inp = RecordInput( + model_name="credit-scorecard", + event="registered", + actor="alice", + owner="risk-team", + model_type="ml_model", + purpose="Credit risk scoring", + ) + result = record(inp, ledger) + + assert isinstance(result, RecordOutput) + assert result.is_new_model is True + assert result.model_name == "credit-scorecard" + assert result.event_id # non-empty string + assert result.timestamp is not None + + def test_registered_model_is_retrievable(self, ledger): + inp = RecordInput( + model_name="fraud-detector", + event="registered", + actor="bob", + owner="security-team", + model_type="ml_model", + purpose="Detect fraudulent transactions", + ) + record(inp, ledger) + + model = ledger.get("fraud-detector") + assert model.name == "fraud-detector" + assert model.owner == "security-team" + assert model.model_type == "ml_model" + assert model.purpose == "Detect fraudulent transactions" + + +class TestRecordEventOnExistingModel: + """Non-registration events on existing models.""" + + def test_record_event_existing_model(self, ledger): + # First register the model + register_inp = RecordInput( + model_name="credit-scorecard", + event="registered", + actor="alice", + owner="risk-team", + model_type="ml_model", + purpose="Credit risk scoring", + ) + record(register_inp, ledger) + + # Now record a new event on it + event_inp = RecordInput( + model_name="credit-scorecard", + event="validated", + payload={"result": "passed", "validator": "mr-framework"}, + actor="validator-bot", + ) + result = record(event_inp, ledger) + + assert isinstance(result, RecordOutput) + assert result.is_new_model is False + assert result.model_name == "credit-scorecard" + assert result.event_id # non-empty + assert result.timestamp is not None + + def test_record_with_arbitrary_payload(self, ledger): + # Register the model + record( + RecordInput( + model_name="scoring-model", + event="registered", + actor="alice", + owner="data-team", + model_type="ml_model", + purpose="Score applicants", + ), + ledger, + ) + + # Record event with rich payload (docs, metrics, links) + payload = { + "metrics": {"auc": 0.92, "precision": 0.87, "recall": 0.91}, + "docs": {"validation_report": "https://docs.example.com/report-42"}, + "links": ["https://mlflow.example.com/runs/abc123"], + "tags": ["quarterly-review", "q1-2026"], + } + result = record( + RecordInput( + model_name="scoring-model", + event="performance-review", + payload=payload, + actor="monitoring-agent", + ), + ledger, + ) + + assert result.is_new_model is False + assert result.event_id + + # Verify the payload was persisted via ledger history + history = ledger.history("scoring-model") + perf_events = [s for s in history if s.event_type == "performance-review"] + assert len(perf_events) == 1 + assert perf_events[0].payload["metrics"]["auc"] == 0.92 + + +class TestRecordNonexistentModel: + """Non-registration events on models that don't exist should raise.""" + + def test_raises_model_not_found(self, ledger): + inp = RecordInput( + model_name="does-not-exist", + event="deployed", + payload={"version": "1.0"}, + actor="deployer", + ) + with pytest.raises(ModelNotFoundError): + record(inp, ledger) diff --git a/tests/test_tools/test_schemas.py b/tests/test_tools/test_schemas.py new file mode 100644 index 0000000..cb6f26a --- /dev/null +++ b/tests/test_tools/test_schemas.py @@ -0,0 +1,650 @@ +# tests/test_tools/test_schemas.py +"""Tests for agent protocol I/O schemas.""" +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from model_ledger.tools.schemas import ( + ChangelogInput, + ChangelogOutput, + DependencyNode, + DiscoverInput, + DiscoverOutput, + EventDetail, + EventSummary, + InvestigateInput, + InvestigateOutput, + ModelSummary, + QueryInput, + QueryOutput, + RecordInput, + RecordOutput, + TraceInput, + TraceOutput, +) + +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- + + +class TestModelSummary: + def test_all_fields(self): + ts = datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + m = ModelSummary( + name="credit-scorecard", + owner="risk-team", + model_type="ml_model", + platform="sagemaker", + status="active", + last_event=ts, + event_count=42, + ) + assert m.name == "credit-scorecard" + assert m.owner == "risk-team" + assert m.model_type == "ml_model" + assert m.platform == "sagemaker" + assert m.status == "active" + assert m.last_event == ts + assert m.event_count == 42 + + def test_optional_fields_none(self): + m = ModelSummary(name="basic-model") + assert m.owner is None + assert m.model_type is None + assert m.platform is None + assert m.status is None + assert m.last_event is None + assert m.event_count == 0 + + def test_json_roundtrip(self): + ts = datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + m = ModelSummary( + name="credit-scorecard", + owner="risk-team", + model_type="ml_model", + platform="sagemaker", + status="active", + last_event=ts, + event_count=42, + ) + data = m.model_dump(mode="json") + reconstructed = ModelSummary(**data) + assert reconstructed == m + + def test_json_schema_export(self): + schema = ModelSummary.model_json_schema() + assert schema["type"] == "object" + assert "name" in schema["properties"] + assert "event_count" in schema["properties"] + + +class TestEventSummary: + def test_all_fields(self): + ts = datetime(2026, 3, 1, 9, 0, 0, tzinfo=timezone.utc) + e = EventSummary( + event_type="validation", + timestamp=ts, + actor="alice", + summary="Passed all checks", + ) + assert e.event_type == "validation" + assert e.timestamp == ts + assert e.actor == "alice" + assert e.summary == "Passed all checks" + + def test_optional_fields_none(self): + e = EventSummary(event_type="registered") + assert e.timestamp is None + assert e.actor is None + assert e.summary is None + + def test_json_roundtrip(self): + ts = datetime(2026, 3, 1, 9, 0, 0, tzinfo=timezone.utc) + e = EventSummary(event_type="validation", timestamp=ts, actor="alice") + data = e.model_dump(mode="json") + assert EventSummary(**data) == e + + +class TestEventDetail: + def test_inherits_event_summary(self): + ts = datetime(2026, 3, 1, 9, 0, 0, tzinfo=timezone.utc) + d = EventDetail( + event_type="validation", + timestamp=ts, + actor="alice", + summary="OK", + model_name="credit-scorecard", + payload={"score": 0.95}, + ) + assert isinstance(d, EventSummary) + assert d.model_name == "credit-scorecard" + assert d.payload == {"score": 0.95} + + def test_defaults(self): + d = EventDetail(event_type="registered") + assert d.model_name is None + assert d.payload == {} + + def test_json_roundtrip(self): + d = EventDetail( + event_type="deployed", + model_name="fraud-detector", + payload={"version": "2.1"}, + ) + data = d.model_dump(mode="json") + assert EventDetail(**data) == d + + def test_json_schema_has_parent_fields(self): + schema = EventDetail.model_json_schema() + props = schema["properties"] + assert "event_type" in props + assert "model_name" in props + assert "payload" in props + + +class TestDependencyNode: + def test_all_fields(self): + n = DependencyNode( + name="feature-pipeline", + platform="airflow", + depth=2, + relationship="upstream", + ) + assert n.name == "feature-pipeline" + assert n.platform == "airflow" + assert n.depth == 2 + assert n.relationship == "upstream" + + def test_defaults(self): + n = DependencyNode(name="scoring-model") + assert n.platform is None + assert n.depth == 0 + assert n.relationship is None + + def test_json_roundtrip(self): + n = DependencyNode(name="etl-job", platform="spark", depth=1, relationship="downstream") + data = n.model_dump(mode="json") + assert DependencyNode(**data) == n + + +# --------------------------------------------------------------------------- +# RecordInput / RecordOutput +# --------------------------------------------------------------------------- + + +class TestRecordInput: + def test_all_fields(self): + r = RecordInput( + model_name="credit-scorecard", + event="deployed", + payload={"version": "3.0"}, + actor="deployer", + owner="risk-team", + model_type="ml_model", + purpose="Credit risk scoring", + ) + assert r.model_name == "credit-scorecard" + assert r.event == "deployed" + assert r.payload == {"version": "3.0"} + assert r.actor == "deployer" + assert r.owner == "risk-team" + assert r.model_type == "ml_model" + assert r.purpose == "Credit risk scoring" + + def test_defaults(self): + r = RecordInput(model_name="test-model", event="registered") + assert r.payload == {} + assert r.actor == "user" + assert r.owner is None + assert r.model_type is None + assert r.purpose is None + + def test_json_roundtrip(self): + r = RecordInput(model_name="test-model", event="registered") + data = r.model_dump(mode="json") + assert RecordInput(**data) == r + + def test_json_schema_export(self): + schema = RecordInput.model_json_schema() + assert "model_name" in schema["properties"] + assert "event" in schema["properties"] + required = schema.get("required", []) + assert "model_name" in required + assert "event" in required + + +class TestRecordOutput: + def test_all_fields(self): + ts = datetime(2026, 4, 1, 10, 0, 0, tzinfo=timezone.utc) + r = RecordOutput( + model_name="credit-scorecard", + event_id="abc123", + timestamp=ts, + is_new_model=True, + ) + assert r.model_name == "credit-scorecard" + assert r.event_id == "abc123" + assert r.timestamp == ts + assert r.is_new_model is True + + def test_json_roundtrip(self): + ts = datetime(2026, 4, 1, 10, 0, 0, tzinfo=timezone.utc) + r = RecordOutput( + model_name="test", event_id="x", timestamp=ts, is_new_model=False + ) + data = r.model_dump(mode="json") + assert RecordOutput(**data) == r + + +# --------------------------------------------------------------------------- +# QueryInput / QueryOutput +# --------------------------------------------------------------------------- + + +class TestQueryInput: + def test_defaults(self): + q = QueryInput() + assert q.text is None + assert q.platform is None + assert q.model_type is None + assert q.owner is None + assert q.status is None + assert q.limit == 50 + assert q.offset == 0 + + def test_all_fields(self): + q = QueryInput( + text="fraud", + platform="sagemaker", + model_type="ml_model", + owner="risk-team", + status="active", + limit=10, + offset=5, + ) + assert q.text == "fraud" + assert q.limit == 10 + assert q.offset == 5 + + def test_json_roundtrip(self): + q = QueryInput(text="scoring", limit=25) + data = q.model_dump(mode="json") + assert QueryInput(**data) == q + + +class TestQueryOutput: + def test_all_fields(self): + m = ModelSummary(name="test-model", owner="alice") + q = QueryOutput(total=1, models=[m], has_more=False) + assert q.total == 1 + assert len(q.models) == 1 + assert q.has_more is False + + def test_empty_results(self): + q = QueryOutput(total=0, models=[], has_more=False) + assert q.models == [] + + def test_json_roundtrip(self): + m = ModelSummary(name="test-model", event_count=5) + q = QueryOutput(total=100, models=[m], has_more=True) + data = q.model_dump(mode="json") + reconstructed = QueryOutput(**data) + assert reconstructed.total == 100 + assert reconstructed.models[0].name == "test-model" + assert reconstructed.has_more is True + + +# --------------------------------------------------------------------------- +# InvestigateInput / InvestigateOutput +# --------------------------------------------------------------------------- + + +class TestInvestigateInput: + def test_defaults(self): + i = InvestigateInput(model_name="credit-scorecard") + assert i.model_name == "credit-scorecard" + assert i.detail == "summary" + assert i.as_of is None + + def test_all_fields(self): + ts = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + i = InvestigateInput(model_name="test", detail="full", as_of=ts) + assert i.detail == "full" + assert i.as_of == ts + + def test_json_roundtrip(self): + i = InvestigateInput(model_name="test", detail="full") + data = i.model_dump(mode="json") + assert InvestigateInput(**data) == i + + +class TestInvestigateOutput: + def test_all_fields(self): + ts = datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + event = EventSummary(event_type="deployed", timestamp=ts) + dep = DependencyNode(name="feature-pipeline", depth=1) + o = InvestigateOutput( + name="credit-scorecard", + owner="risk-team", + model_type="ml_model", + purpose="Credit risk scoring", + status="active", + created_at=ts, + metadata={"version": "3.0"}, + recent_events=[event], + days_since_last_event=5, + total_events=42, + upstream=[dep], + downstream=[], + groups=["risk-models"], + members=[], + ) + assert o.name == "credit-scorecard" + assert o.owner == "risk-team" + assert o.total_events == 42 + assert len(o.upstream) == 1 + assert o.groups == ["risk-models"] + + def test_defaults(self): + o = InvestigateOutput(name="basic-model") + assert o.owner is None + assert o.model_type is None + assert o.purpose is None + assert o.status is None + assert o.created_at is None + assert o.metadata == {} + assert o.recent_events == [] + assert o.days_since_last_event is None + assert o.total_events == 0 + assert o.upstream == [] + assert o.downstream == [] + assert o.groups == [] + assert o.members == [] + + def test_json_roundtrip(self): + ts = datetime(2026, 2, 1, 0, 0, 0, tzinfo=timezone.utc) + o = InvestigateOutput( + name="test", + owner="alice", + created_at=ts, + recent_events=[EventSummary(event_type="registered")], + upstream=[DependencyNode(name="dep-a")], + ) + data = o.model_dump(mode="json") + reconstructed = InvestigateOutput(**data) + assert reconstructed.name == "test" + assert len(reconstructed.recent_events) == 1 + assert len(reconstructed.upstream) == 1 + + def test_json_schema_export(self): + schema = InvestigateOutput.model_json_schema() + props = schema["properties"] + assert "name" in props + assert "recent_events" in props + assert "upstream" in props + assert "downstream" in props + + +# --------------------------------------------------------------------------- +# TraceInput / TraceOutput +# --------------------------------------------------------------------------- + + +class TestTraceInput: + def test_defaults(self): + t = TraceInput(name="credit-scorecard") + assert t.name == "credit-scorecard" + assert t.direction == "both" + assert t.depth is None + + def test_all_fields(self): + t = TraceInput(name="test", direction="upstream", depth=3) + assert t.direction == "upstream" + assert t.depth == 3 + + def test_json_roundtrip(self): + t = TraceInput(name="test", direction="downstream", depth=2) + data = t.model_dump(mode="json") + assert TraceInput(**data) == t + + +class TestTraceOutput: + def test_all_fields(self): + up = DependencyNode(name="data-source", depth=1, relationship="upstream") + down = DependencyNode(name="dashboard", depth=1, relationship="downstream") + t = TraceOutput( + root="credit-scorecard", + upstream=[up], + downstream=[down], + total_nodes=2, + ) + assert t.root == "credit-scorecard" + assert len(t.upstream) == 1 + assert len(t.downstream) == 1 + assert t.total_nodes == 2 + + def test_defaults(self): + t = TraceOutput(root="test") + assert t.upstream == [] + assert t.downstream == [] + assert t.total_nodes == 0 + + def test_json_roundtrip(self): + t = TraceOutput( + root="test", + upstream=[DependencyNode(name="a", depth=1)], + downstream=[DependencyNode(name="b", depth=1)], + total_nodes=2, + ) + data = t.model_dump(mode="json") + reconstructed = TraceOutput(**data) + assert reconstructed.root == "test" + assert len(reconstructed.upstream) == 1 + assert len(reconstructed.downstream) == 1 + + +# --------------------------------------------------------------------------- +# ChangelogInput / ChangelogOutput +# --------------------------------------------------------------------------- + + +class TestChangelogInput: + def test_defaults(self): + c = ChangelogInput() + assert c.since is None + assert c.until is None + assert c.model_name is None + assert c.event_type is None + assert c.limit == 100 + assert c.offset == 0 + + def test_all_fields(self): + ts1 = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + ts2 = datetime(2026, 3, 1, 0, 0, 0, tzinfo=timezone.utc) + c = ChangelogInput( + since=ts1, + until=ts2, + model_name="credit-scorecard", + event_type="deployed", + limit=25, + offset=10, + ) + assert c.since == ts1 + assert c.until == ts2 + assert c.model_name == "credit-scorecard" + assert c.event_type == "deployed" + assert c.limit == 25 + assert c.offset == 10 + + def test_json_roundtrip(self): + c = ChangelogInput(model_name="test", limit=50) + data = c.model_dump(mode="json") + assert ChangelogInput(**data) == c + + +class TestChangelogOutput: + def test_all_fields(self): + event = EventDetail( + event_type="deployed", + model_name="credit-scorecard", + payload={"version": "2.0"}, + ) + c = ChangelogOutput( + total=1, + events=[event], + has_more=False, + period="2026-01-01 to 2026-03-01", + ) + assert c.total == 1 + assert len(c.events) == 1 + assert c.has_more is False + assert c.period == "2026-01-01 to 2026-03-01" + + def test_defaults(self): + c = ChangelogOutput(total=0, events=[], has_more=False) + assert c.period is None + + def test_json_roundtrip(self): + event = EventDetail(event_type="registered", model_name="test") + c = ChangelogOutput(total=1, events=[event], has_more=False, period="all time") + data = c.model_dump(mode="json") + reconstructed = ChangelogOutput(**data) + assert reconstructed.total == 1 + assert reconstructed.events[0].model_name == "test" + + +# --------------------------------------------------------------------------- +# DiscoverInput / DiscoverOutput +# --------------------------------------------------------------------------- + + +class TestDiscoverInput: + def test_connector_source(self): + d = DiscoverInput( + source_type="connector", + connector_name="sql-registry", + connector_config={"query": "SELECT * FROM models"}, + ) + assert d.source_type == "connector" + assert d.connector_name == "sql-registry" + assert d.connector_config == {"query": "SELECT * FROM models"} + assert d.auto_connect is True + + def test_file_source(self): + d = DiscoverInput(source_type="file", file_path="/data/models.csv") + assert d.source_type == "file" + assert d.file_path == "/data/models.csv" + + def test_inline_source(self): + models = [{"name": "test-model", "owner": "alice"}] + d = DiscoverInput(source_type="inline", models=models) + assert d.source_type == "inline" + assert d.models == models + + def test_defaults(self): + d = DiscoverInput(source_type="connector") + assert d.connector_name is None + assert d.connector_config is None + assert d.file_path is None + assert d.models is None + assert d.auto_connect is True + + def test_auto_connect_false(self): + d = DiscoverInput(source_type="inline", auto_connect=False) + assert d.auto_connect is False + + def test_json_roundtrip(self): + d = DiscoverInput( + source_type="connector", + connector_name="github", + connector_config={"org": "my-org"}, + auto_connect=False, + ) + data = d.model_dump(mode="json") + assert DiscoverInput(**data) == d + + def test_json_schema_source_type_literal(self): + schema = DiscoverInput.model_json_schema() + source_prop = schema["properties"]["source_type"] + # Should be constrained to the three valid values + assert "enum" in source_prop or "anyOf" in source_prop or "const" in source_prop + + +class TestDiscoverOutput: + def test_all_fields(self): + m = ModelSummary(name="discovered-model", platform="airflow") + d = DiscoverOutput( + models_added=3, + models_skipped=1, + links_created=5, + models=[m], + errors=["Failed to parse row 7"], + ) + assert d.models_added == 3 + assert d.models_skipped == 1 + assert d.links_created == 5 + assert len(d.models) == 1 + assert d.errors == ["Failed to parse row 7"] + + def test_defaults(self): + d = DiscoverOutput(models_added=0, models_skipped=0, links_created=0) + assert d.models == [] + assert d.errors == [] + + def test_json_roundtrip(self): + m = ModelSummary(name="test", event_count=0) + d = DiscoverOutput( + models_added=1, + models_skipped=0, + links_created=0, + models=[m], + errors=[], + ) + data = d.model_dump(mode="json") + reconstructed = DiscoverOutput(**data) + assert reconstructed.models_added == 1 + assert reconstructed.models[0].name == "test" + + def test_json_schema_export(self): + schema = DiscoverOutput.model_json_schema() + props = schema["properties"] + assert "models_added" in props + assert "models" in props + assert "errors" in props + + +# --------------------------------------------------------------------------- +# Cross-cutting: all schemas produce valid JSON Schema +# --------------------------------------------------------------------------- + + +ALL_SCHEMAS = [ + ModelSummary, + EventSummary, + EventDetail, + DependencyNode, + RecordInput, + RecordOutput, + QueryInput, + QueryOutput, + InvestigateInput, + InvestigateOutput, + TraceInput, + TraceOutput, + ChangelogInput, + ChangelogOutput, + DiscoverInput, + DiscoverOutput, +] + + +@pytest.mark.parametrize("schema_cls", ALL_SCHEMAS, ids=lambda c: c.__name__) +def test_all_schemas_export_json_schema(schema_cls): + """Every schema must produce a valid JSON Schema dict with type=object.""" + schema = schema_cls.model_json_schema() + assert isinstance(schema, dict) + assert schema.get("type") == "object" + assert "properties" in schema diff --git a/tests/test_tools/test_trace.py b/tests/test_tools/test_trace.py new file mode 100644 index 0000000..e021ac1 --- /dev/null +++ b/tests/test_tools/test_trace.py @@ -0,0 +1,214 @@ +"""Tests for the trace tool — dependency graph traversal.""" + +from __future__ import annotations + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.core.exceptions import ModelNotFoundError +from model_ledger.graph.models import DataNode +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.schemas import DependencyNode, TraceInput, TraceOutput +from model_ledger.tools.trace import trace + + +@pytest.fixture +def ledger(): + return Ledger(backend=InMemoryLedgerBackend()) + + +@pytest.fixture +def graph_ledger(ledger): + """Ledger with a 4-node linear pipeline for graph traversal tests. + + raw_data -> feature_pipeline -> scoring_model -> alert_engine + """ + ledger.add( + [ + DataNode("raw_data", platform="database", outputs=["customers"]), + DataNode( + "feature_pipeline", + platform="etl", + inputs=["customers"], + outputs=["features"], + ), + DataNode( + "scoring_model", + platform="ml", + inputs=["features"], + outputs=["scores"], + ), + DataNode("alert_engine", platform="alerting", inputs=["scores"]), + ] + ) + ledger.connect() + return ledger + + +class TestTraceBothDirections: + """Trace both upstream and downstream from a middle node.""" + + def test_returns_trace_output(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + assert isinstance(result, TraceOutput) + + def test_root_is_target_model(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + assert result.root == "scoring_model" + + def test_upstream_contains_dependencies(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + upstream_names = [n.name for n in result.upstream] + assert "feature_pipeline" in upstream_names + assert "raw_data" in upstream_names + + def test_downstream_contains_dependents(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + downstream_names = [n.name for n in result.downstream] + assert "alert_engine" in downstream_names + + def test_nodes_are_dependency_nodes(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + for node in result.upstream + result.downstream: + assert isinstance(node, DependencyNode) + assert node.depth >= 1 + assert node.relationship in ("depends_on", "feeds_into") + + def test_total_nodes_counts_all(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + assert result.total_nodes == len(result.upstream) + len(result.downstream) + + def test_upstream_relationship_is_depends_on(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + for node in result.upstream: + assert node.relationship == "depends_on" + + def test_downstream_relationship_is_feeds_into(self, graph_ledger): + result = trace(TraceInput(name="scoring_model"), graph_ledger) + + for node in result.downstream: + assert node.relationship == "feeds_into" + + +class TestTraceUpstreamOnly: + """Trace upstream only — downstream list should be empty.""" + + def test_upstream_populated(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="upstream"), + graph_ledger, + ) + + upstream_names = [n.name for n in result.upstream] + assert "feature_pipeline" in upstream_names + assert "raw_data" in upstream_names + + def test_downstream_empty(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="upstream"), + graph_ledger, + ) + + assert result.downstream == [] + + def test_total_nodes_excludes_downstream(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="upstream"), + graph_ledger, + ) + + assert result.total_nodes == len(result.upstream) + + +class TestTraceDownstreamOnly: + """Trace downstream only — upstream list should be empty.""" + + def test_downstream_populated(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="downstream"), + graph_ledger, + ) + + downstream_names = [n.name for n in result.downstream] + assert "alert_engine" in downstream_names + + def test_upstream_empty(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="downstream"), + graph_ledger, + ) + + assert result.upstream == [] + + def test_total_nodes_excludes_upstream(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", direction="downstream"), + graph_ledger, + ) + + assert result.total_nodes == len(result.downstream) + + +class TestTraceLeafNode: + """Trace a leaf node — no downstream dependents.""" + + def test_leaf_has_no_downstream(self, graph_ledger): + result = trace(TraceInput(name="alert_engine"), graph_ledger) + + assert result.downstream == [] + + def test_leaf_has_upstream(self, graph_ledger): + result = trace(TraceInput(name="alert_engine"), graph_ledger) + + upstream_names = [n.name for n in result.upstream] + assert len(upstream_names) >= 1 + + def test_root_node_has_no_upstream(self, graph_ledger): + result = trace(TraceInput(name="raw_data"), graph_ledger) + + assert result.upstream == [] + + def test_root_node_has_downstream(self, graph_ledger): + result = trace(TraceInput(name="raw_data"), graph_ledger) + + downstream_names = [n.name for n in result.downstream] + assert len(downstream_names) >= 1 + + +class TestTraceDepthFilter: + """Depth filter limits how far the trace goes.""" + + def test_depth_1_limits_results(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", depth=1), + graph_ledger, + ) + + for node in result.upstream + result.downstream: + assert node.depth <= 1 + + def test_depth_1_excludes_transitive(self, graph_ledger): + result = trace( + TraceInput(name="scoring_model", depth=1), + graph_ledger, + ) + + upstream_names = [n.name for n in result.upstream] + # feature_pipeline is depth 1, raw_data is depth 2 + assert "feature_pipeline" in upstream_names + assert "raw_data" not in upstream_names + + +class TestTraceNonexistentModel: + """Tracing a model that doesn't exist should raise.""" + + def test_raises_model_not_found(self, ledger): + with pytest.raises(ModelNotFoundError): + trace(TraceInput(name="nonexistent_model"), ledger)