Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ wheels/

# Virtual environments
.venv
.venv-test/
64 changes: 60 additions & 4 deletions mcp_client_for_ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .models.manager import ModelManager
from .models.config_manager import ModelConfigManager
from .tools.manager import ToolManager
from .tools.rag import ToolRAG
from .utils.streaming import StreamingManager
from .utils.tool_display import ToolDisplayManager
from .utils.hil_manager import HumanInTheLoopManager
Expand All @@ -30,7 +31,8 @@
class MCPClient:
"""Main client class for interacting with Ollama and MCP servers"""

def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST, enable_tool_rag: bool = False,
tool_rag_threshold: float = 0.65, tool_rag_min_tools: int = 0, tool_rag_max_tools: int = 20):
# Initialize session and client objects
self.exit_stack = AsyncExitStack()
self.ollama = ollama.AsyncClient(host=host)
Expand All @@ -48,6 +50,12 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
self.streaming_manager = StreamingManager(console=self.console)
# Initialize the tool display manager
self.tool_display_manager = ToolDisplayManager(console=self.console)
# Initialize Tool RAG if enabled
self.enable_tool_rag = enable_tool_rag
self.tool_rag_threshold = tool_rag_threshold
self.tool_rag_min_tools = tool_rag_min_tools
self.tool_rag_max_tools = tool_rag_max_tools
self.tool_rag: Optional[ToolRAG] = ToolRAG() if enable_tool_rag else None
# Initialize the HIL manager
self.hil_manager = HumanInTheLoopManager(console=self.console)
# Store server and tool data
Expand Down Expand Up @@ -152,6 +160,12 @@ async def connect_to_servers(self, server_paths=None, server_urls=None, config_p
# Set up the tool manager with the available tools and their enabled status
self.tool_manager.set_available_tools(available_tools)
self.tool_manager.set_enabled_tools(enabled_tools)

# Embed tools for RAG if enabled
if self.enable_tool_rag and self.tool_rag and available_tools:
self.console.print("[dim]Embedding tools for intelligent filtering...[/dim]")
self.tool_rag.embed_tools(available_tools)
self.console.print(f"[green]✓ Tool RAG enabled with {len(available_tools)} tools[/green]")

def select_tools(self):
"""Let the user select which tools to enable using interactive prompts with server-based grouping"""
Expand Down Expand Up @@ -232,6 +246,22 @@ async def process_query(self, query: str) -> str:

# Get enabled tools from the tool manager
enabled_tool_objects = self.tool_manager.get_enabled_tool_objects()

# Apply Tool RAG filtering if enabled
if self.enable_tool_rag and self.tool_rag and enabled_tool_objects:
try:
enabled_tool_objects = self.tool_rag.retrieve_relevant_tools(
query,
threshold=self.tool_rag_threshold,
min_tools=self.tool_rag_min_tools,
max_tools=self.tool_rag_max_tools
)
# Filter to only enabled tools
enabled_tools_set = set(t.name for t in self.tool_manager.get_enabled_tool_objects())
enabled_tool_objects = [t for t in enabled_tool_objects if t.name in enabled_tools_set]
except Exception as e:
self.console.print(f"[yellow]Warning: Tool RAG filtering failed ({e}), using all enabled tools[/yellow]")
enabled_tool_objects = self.tool_manager.get_enabled_tool_objects()

if not enabled_tool_objects:
self.console.print("[yellow]Warning: No tools are enabled. Model will respond without tool access.[/yellow]")
Expand Down Expand Up @@ -1026,6 +1056,28 @@ def main(
help="Ollama host URL",
rich_help_panel="Ollama Configuration"
),

# Tool RAG Configuration
enable_tool_rag: bool = typer.Option(
False, "--enable-tool-rag",
help="Enable intelligent tool filtering using semantic search (recommended for 50+ tools)",
rich_help_panel="Tool RAG Configuration"
),
tool_rag_threshold: float = typer.Option(
0.65, "--tool-rag-threshold",
help="Minimum similarity score (0-1) for a tool to be considered relevant",
rich_help_panel="Tool RAG Configuration"
),
tool_rag_min_tools: int = typer.Option(
0, "--tool-rag-min-tools",
help="Minimum number of tools to send (fallback if none meet threshold)",
rich_help_panel="Tool RAG Configuration"
),
tool_rag_max_tools: int = typer.Option(
20, "--tool-rag-max-tools",
help="Maximum number of tools to send (performance cap)",
rich_help_panel="Tool RAG Configuration"
),

# General Options
version: Optional[bool] = typer.Option(
Expand All @@ -1044,15 +1096,19 @@ def main(
auto_discovery = True

# Run the async main function
asyncio.run(async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host))
asyncio.run(async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host,
enable_tool_rag, tool_rag_threshold, tool_rag_min_tools, tool_rag_max_tools))

async def async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host):
async def async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host,
enable_tool_rag, tool_rag_threshold, tool_rag_min_tools, tool_rag_max_tools):
"""Asynchronous main function to run the MCP Client for Ollama"""

console = Console()

# Create a temporary client to check if Ollama is running
client = MCPClient(model=model, host=host)
client = MCPClient(model=model, host=host, enable_tool_rag=enable_tool_rag,
tool_rag_threshold=tool_rag_threshold, tool_rag_min_tools=tool_rag_min_tools,
tool_rag_max_tools=tool_rag_max_tools)
if not await client.model_manager.check_ollama_running():
console.print(Panel(
"[bold red]Error: Ollama is not running![/bold red]\n\n"
Expand Down
205 changes: 205 additions & 0 deletions mcp_client_for_ollama/tools/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Tool RAG (Retrieval-Augmented Generation) for intelligent tool filtering.

This module implements semantic search over tool schemas to efficiently select
relevant tools for a given query, enabling scalable tool management with large
tool sets.
"""

import json
import pickle
from pathlib import Path
from typing import List, Optional, Tuple

import torch
from mcp import Tool
from sentence_transformers import SentenceTransformer, util


class ToolRAG:
"""Manages semantic search over tool schemas for intelligent tool filtering.

Uses sentence transformers to embed tool descriptions and perform vector
search to find the most relevant tools for a given query.
"""

def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
cache_dir: Optional[Path] = None,
):
"""Initialize the ToolRAG system.

Args:
model_name: Name of the sentence-transformers model to use.
Default is 'all-MiniLM-L6-v2' (80MB, fast, good quality).
cache_dir: Directory to cache embeddings. If None, uses
~/.cache/ollmcp/tool_embeddings/
"""
self.model_name = model_name
self.model: Optional[SentenceTransformer] = None
self.cache_dir = cache_dir or Path.home() / ".cache" / "ollmcp" / "tool_embeddings"
self.cache_dir.mkdir(parents=True, exist_ok=True)

# Storage for tools and their embeddings
self.tools: List[Tool] = []
self.tool_texts: List[str] = []
self.embeddings: Optional[torch.Tensor] = None
self._cache_hash: Optional[str] = None

def _load_model(self) -> None:
"""Lazy load the sentence transformer model."""
if self.model is None:
self.model = SentenceTransformer(self.model_name)

def _create_tool_text(self, tool: Tool) -> str:
"""Create searchable text representation of a tool.

Args:
tool: Tool object to convert to text

Returns:
String representation combining name, description, and key parameters
"""
parts = [
f"Tool: {tool.name}",
f"Description: {tool.description or 'No description'}",
]

# Add parameter information if available
if hasattr(tool, 'inputSchema') and tool.inputSchema:
schema = tool.inputSchema
if isinstance(schema, dict) and 'properties' in schema:
param_names = list(schema['properties'].keys())
if param_names:
parts.append(f"Parameters: {', '.join(param_names)}")

return " | ".join(parts)

def _compute_cache_hash(self, tools: List[Tool]) -> str:
"""Compute a hash of the tool set for cache validation.

Args:
tools: List of tools to hash

Returns:
Hash string representing the tool set
"""
# Create a stable representation of tools
tool_repr = json.dumps(
[(t.name, t.description) for t in tools],
sort_keys=True
)
return str(hash(tool_repr))

def _get_cache_path(self, cache_hash: str) -> Path:
"""Get the cache file path for a given hash.

Args:
cache_hash: Hash of the tool set

Returns:
Path to the cache file
"""
return self.cache_dir / f"{cache_hash}_{self.model_name.replace('/', '_')}.pkl"

def embed_tools(self, tools: List[Tool], use_cache: bool = True) -> None:
"""Embed all tools for semantic search.

Args:
tools: List of Tool objects to embed
use_cache: Whether to use cached embeddings if available
"""
self.tools = tools
self.tool_texts = [self._create_tool_text(tool) for tool in tools]
self._cache_hash = self._compute_cache_hash(tools)

cache_path = self._get_cache_path(self._cache_hash)

# Try to load from cache
if use_cache and cache_path.exists():
try:
with open(cache_path, 'rb') as f:
cached_data = pickle.load(f)
self.embeddings = cached_data['embeddings']
return
except Exception:
# Cache load failed, will recompute
pass

# Compute embeddings
self._load_model()
self.embeddings = self.model.encode(
self.tool_texts,
convert_to_tensor=True,
show_progress_bar=False
)

# Save to cache
if use_cache:
try:
with open(cache_path, 'wb') as f:
pickle.dump({
'embeddings': self.embeddings,
'model_name': self.model_name,
}, f)
except Exception:
# Cache save failed, not critical
pass

def retrieve_relevant_tools(
self,
query: str,
threshold: float = 0.65,
min_tools: int = 0,
max_tools: int = 20
) -> List[Tool]:
"""Retrieve relevant tools for a given query using similarity threshold.

Args:
query: User query to find relevant tools for
threshold: Minimum similarity score (0-1) for a tool to be considered relevant
min_tools: Minimum number of tools to return (fallback if none meet threshold)
max_tools: Maximum number of tools to return (cap for performance)

Returns:
List of relevant Tool objects, filtered by threshold and bounded by min/max

Raises:
ValueError: If tools haven't been embedded yet
"""
if self.embeddings is None or not self.tools:
raise ValueError("Tools must be embedded before retrieval. Call embed_tools() first.")

self._load_model()

# Encode the query
query_embedding = self.model.encode(
query,
convert_to_tensor=True,
show_progress_bar=False
)

# Compute similarity scores
similarity_scores = util.cos_sim(query_embedding, self.embeddings)[0]

# Get all tools with their scores, sorted by score
scored_tools = [(self.tools[i], float(similarity_scores[i])) for i in range(len(self.tools))]
scored_tools.sort(key=lambda x: x[1], reverse=True)

# Filter by threshold
relevant_tools = [tool for tool, score in scored_tools if score >= threshold]

# Apply minimum bound (fallback to top-k if none meet threshold)
if len(relevant_tools) < min_tools:
relevant_tools = [tool for tool, _ in scored_tools[:min_tools]]

# Apply maximum bound (cap for performance)
if len(relevant_tools) > max_tools:
relevant_tools = relevant_tools[:max_tools]

return relevant_tools

def clear_cache(self) -> None:
"""Clear all cached embeddings."""
for cache_file in self.cache_dir.glob("*.pkl"):
cache_file.unlink()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"ollama~=0.6.0",
"prompt-toolkit~=3.0.52",
"rich~=14.2.0",
"sentence-transformers~=3.3.1",
"typer~=0.20.0",
]

Expand Down
Loading