Skip to content

fix: allow for self-referencing pydantic schema #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions fastapi_mcp/openapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional, Set


def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
Expand All @@ -16,7 +16,11 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
return param_schema.get("type", "string")


def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]:
def resolve_schema_references(
schema_part: Dict[str, Any],
reference_schema: Dict[str, Any],
seen: Optional[Set[str]] = None,
) -> Dict[str, Any]:
"""
Resolve schema references in OpenAPI schemas.

Expand All @@ -27,6 +31,8 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
Returns:
The schema with references resolved
"""
seen = seen or set()

# Make a copy to avoid modifying the input schema
schema_part = schema_part.copy()

Expand All @@ -35,6 +41,9 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
ref_path = schema_part["$ref"]
# Standard OpenAPI references are in the format "#/components/schemas/ModelName"
if ref_path.startswith("#/components/schemas/"):
if ref_path in seen:
return {"$ref": ref_path}
seen.add(ref_path)
model_name = ref_path.split("/")[-1]
if "components" in reference_schema and "schemas" in reference_schema["components"]:
if model_name in reference_schema["components"]["schemas"]:
Expand All @@ -47,11 +56,12 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
# Recursively resolve references in all dictionary values
for key, value in schema_part.items():
if isinstance(value, dict):
schema_part[key] = resolve_schema_references(value, reference_schema)
schema_part[key] = resolve_schema_references(value, reference_schema, seen)
elif isinstance(value, list):
# Only process list items that are dictionaries since only they can contain refs
schema_part[key] = [
resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value
resolve_schema_references(item, reference_schema, seen) if isinstance(item, dict) else item
for item in value
]

return schema_part
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Optional, List, Dict, Any
from datetime import datetime, date
from enum import Enum
Expand Down Expand Up @@ -95,6 +96,7 @@ class Product(BaseModel):
updated_at: Optional[datetime] = None
is_available: bool = True
metadata: Dict[str, Any] = {}
related_products: Optional[List[Product]] = None


class OrderItem(BaseModel):
Expand Down