diff --git a/.gitignore b/.gitignore index 505a3b1..7a39998 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ wheels/ # Virtual environments .venv +.venv-test/ diff --git a/mcp_client_for_ollama/client.py b/mcp_client_for_ollama/client.py index 15d9b89..64e0203 100644 --- a/mcp_client_for_ollama/client.py +++ b/mcp_client_for_ollama/client.py @@ -21,6 +21,7 @@ from .models.manager import ModelManager from .models.config_manager import ModelConfigManager from .tools.manager import ToolManager +from .tools.rag import ToolRAG from .utils.streaming import StreamingManager from .utils.tool_display import ToolDisplayManager from .utils.hil_manager import HumanInTheLoopManager @@ -30,7 +31,8 @@ class MCPClient: """Main client class for interacting with Ollama and MCP servers""" - def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST): + def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST, enable_tool_rag: bool = False, + tool_rag_threshold: float = 0.65, tool_rag_min_tools: int = 0, tool_rag_max_tools: int = 20): # Initialize session and client objects self.exit_stack = AsyncExitStack() self.ollama = ollama.AsyncClient(host=host) @@ -48,6 +50,12 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST): self.streaming_manager = StreamingManager(console=self.console) # Initialize the tool display manager self.tool_display_manager = ToolDisplayManager(console=self.console) + # Initialize Tool RAG if enabled + self.enable_tool_rag = enable_tool_rag + self.tool_rag_threshold = tool_rag_threshold + self.tool_rag_min_tools = tool_rag_min_tools + self.tool_rag_max_tools = tool_rag_max_tools + self.tool_rag: Optional[ToolRAG] = ToolRAG() if enable_tool_rag else None # Initialize the HIL manager self.hil_manager = HumanInTheLoopManager(console=self.console) # Store server and tool data @@ -152,6 +160,12 @@ async def connect_to_servers(self, server_paths=None, server_urls=None, config_p # Set up the tool manager with the available tools and their enabled status self.tool_manager.set_available_tools(available_tools) self.tool_manager.set_enabled_tools(enabled_tools) + + # Embed tools for RAG if enabled + if self.enable_tool_rag and self.tool_rag and available_tools: + self.console.print("[dim]Embedding tools for intelligent filtering...[/dim]") + self.tool_rag.embed_tools(available_tools) + self.console.print(f"[green]✓ Tool RAG enabled with {len(available_tools)} tools[/green]") def select_tools(self): """Let the user select which tools to enable using interactive prompts with server-based grouping""" @@ -232,6 +246,22 @@ async def process_query(self, query: str) -> str: # Get enabled tools from the tool manager enabled_tool_objects = self.tool_manager.get_enabled_tool_objects() + + # Apply Tool RAG filtering if enabled + if self.enable_tool_rag and self.tool_rag and enabled_tool_objects: + try: + enabled_tool_objects = self.tool_rag.retrieve_relevant_tools( + query, + threshold=self.tool_rag_threshold, + min_tools=self.tool_rag_min_tools, + max_tools=self.tool_rag_max_tools + ) + # Filter to only enabled tools + enabled_tools_set = set(t.name for t in self.tool_manager.get_enabled_tool_objects()) + enabled_tool_objects = [t for t in enabled_tool_objects if t.name in enabled_tools_set] + except Exception as e: + self.console.print(f"[yellow]Warning: Tool RAG filtering failed ({e}), using all enabled tools[/yellow]") + enabled_tool_objects = self.tool_manager.get_enabled_tool_objects() if not enabled_tool_objects: self.console.print("[yellow]Warning: No tools are enabled. Model will respond without tool access.[/yellow]") @@ -1026,6 +1056,28 @@ def main( help="Ollama host URL", rich_help_panel="Ollama Configuration" ), + + # Tool RAG Configuration + enable_tool_rag: bool = typer.Option( + False, "--enable-tool-rag", + help="Enable intelligent tool filtering using semantic search (recommended for 50+ tools)", + rich_help_panel="Tool RAG Configuration" + ), + tool_rag_threshold: float = typer.Option( + 0.65, "--tool-rag-threshold", + help="Minimum similarity score (0-1) for a tool to be considered relevant", + rich_help_panel="Tool RAG Configuration" + ), + tool_rag_min_tools: int = typer.Option( + 0, "--tool-rag-min-tools", + help="Minimum number of tools to send (fallback if none meet threshold)", + rich_help_panel="Tool RAG Configuration" + ), + tool_rag_max_tools: int = typer.Option( + 20, "--tool-rag-max-tools", + help="Maximum number of tools to send (performance cap)", + rich_help_panel="Tool RAG Configuration" + ), # General Options version: Optional[bool] = typer.Option( @@ -1044,15 +1096,19 @@ def main( auto_discovery = True # Run the async main function - asyncio.run(async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host)) + asyncio.run(async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host, + enable_tool_rag, tool_rag_threshold, tool_rag_min_tools, tool_rag_max_tools)) -async def async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host): +async def async_main(mcp_server, mcp_server_url, servers_json, auto_discovery, model, host, + enable_tool_rag, tool_rag_threshold, tool_rag_min_tools, tool_rag_max_tools): """Asynchronous main function to run the MCP Client for Ollama""" console = Console() # Create a temporary client to check if Ollama is running - client = MCPClient(model=model, host=host) + client = MCPClient(model=model, host=host, enable_tool_rag=enable_tool_rag, + tool_rag_threshold=tool_rag_threshold, tool_rag_min_tools=tool_rag_min_tools, + tool_rag_max_tools=tool_rag_max_tools) if not await client.model_manager.check_ollama_running(): console.print(Panel( "[bold red]Error: Ollama is not running![/bold red]\n\n" diff --git a/mcp_client_for_ollama/tools/rag.py b/mcp_client_for_ollama/tools/rag.py new file mode 100644 index 0000000..2333dbb --- /dev/null +++ b/mcp_client_for_ollama/tools/rag.py @@ -0,0 +1,205 @@ +"""Tool RAG (Retrieval-Augmented Generation) for intelligent tool filtering. + +This module implements semantic search over tool schemas to efficiently select +relevant tools for a given query, enabling scalable tool management with large +tool sets. +""" + +import json +import pickle +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +from mcp import Tool +from sentence_transformers import SentenceTransformer, util + + +class ToolRAG: + """Manages semantic search over tool schemas for intelligent tool filtering. + + Uses sentence transformers to embed tool descriptions and perform vector + search to find the most relevant tools for a given query. + """ + + def __init__( + self, + model_name: str = "all-MiniLM-L6-v2", + cache_dir: Optional[Path] = None, + ): + """Initialize the ToolRAG system. + + Args: + model_name: Name of the sentence-transformers model to use. + Default is 'all-MiniLM-L6-v2' (80MB, fast, good quality). + cache_dir: Directory to cache embeddings. If None, uses + ~/.cache/ollmcp/tool_embeddings/ + """ + self.model_name = model_name + self.model: Optional[SentenceTransformer] = None + self.cache_dir = cache_dir or Path.home() / ".cache" / "ollmcp" / "tool_embeddings" + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Storage for tools and their embeddings + self.tools: List[Tool] = [] + self.tool_texts: List[str] = [] + self.embeddings: Optional[torch.Tensor] = None + self._cache_hash: Optional[str] = None + + def _load_model(self) -> None: + """Lazy load the sentence transformer model.""" + if self.model is None: + self.model = SentenceTransformer(self.model_name) + + def _create_tool_text(self, tool: Tool) -> str: + """Create searchable text representation of a tool. + + Args: + tool: Tool object to convert to text + + Returns: + String representation combining name, description, and key parameters + """ + parts = [ + f"Tool: {tool.name}", + f"Description: {tool.description or 'No description'}", + ] + + # Add parameter information if available + if hasattr(tool, 'inputSchema') and tool.inputSchema: + schema = tool.inputSchema + if isinstance(schema, dict) and 'properties' in schema: + param_names = list(schema['properties'].keys()) + if param_names: + parts.append(f"Parameters: {', '.join(param_names)}") + + return " | ".join(parts) + + def _compute_cache_hash(self, tools: List[Tool]) -> str: + """Compute a hash of the tool set for cache validation. + + Args: + tools: List of tools to hash + + Returns: + Hash string representing the tool set + """ + # Create a stable representation of tools + tool_repr = json.dumps( + [(t.name, t.description) for t in tools], + sort_keys=True + ) + return str(hash(tool_repr)) + + def _get_cache_path(self, cache_hash: str) -> Path: + """Get the cache file path for a given hash. + + Args: + cache_hash: Hash of the tool set + + Returns: + Path to the cache file + """ + return self.cache_dir / f"{cache_hash}_{self.model_name.replace('/', '_')}.pkl" + + def embed_tools(self, tools: List[Tool], use_cache: bool = True) -> None: + """Embed all tools for semantic search. + + Args: + tools: List of Tool objects to embed + use_cache: Whether to use cached embeddings if available + """ + self.tools = tools + self.tool_texts = [self._create_tool_text(tool) for tool in tools] + self._cache_hash = self._compute_cache_hash(tools) + + cache_path = self._get_cache_path(self._cache_hash) + + # Try to load from cache + if use_cache and cache_path.exists(): + try: + with open(cache_path, 'rb') as f: + cached_data = pickle.load(f) + self.embeddings = cached_data['embeddings'] + return + except Exception: + # Cache load failed, will recompute + pass + + # Compute embeddings + self._load_model() + self.embeddings = self.model.encode( + self.tool_texts, + convert_to_tensor=True, + show_progress_bar=False + ) + + # Save to cache + if use_cache: + try: + with open(cache_path, 'wb') as f: + pickle.dump({ + 'embeddings': self.embeddings, + 'model_name': self.model_name, + }, f) + except Exception: + # Cache save failed, not critical + pass + + def retrieve_relevant_tools( + self, + query: str, + threshold: float = 0.65, + min_tools: int = 0, + max_tools: int = 20 + ) -> List[Tool]: + """Retrieve relevant tools for a given query using similarity threshold. + + Args: + query: User query to find relevant tools for + threshold: Minimum similarity score (0-1) for a tool to be considered relevant + min_tools: Minimum number of tools to return (fallback if none meet threshold) + max_tools: Maximum number of tools to return (cap for performance) + + Returns: + List of relevant Tool objects, filtered by threshold and bounded by min/max + + Raises: + ValueError: If tools haven't been embedded yet + """ + if self.embeddings is None or not self.tools: + raise ValueError("Tools must be embedded before retrieval. Call embed_tools() first.") + + self._load_model() + + # Encode the query + query_embedding = self.model.encode( + query, + convert_to_tensor=True, + show_progress_bar=False + ) + + # Compute similarity scores + similarity_scores = util.cos_sim(query_embedding, self.embeddings)[0] + + # Get all tools with their scores, sorted by score + scored_tools = [(self.tools[i], float(similarity_scores[i])) for i in range(len(self.tools))] + scored_tools.sort(key=lambda x: x[1], reverse=True) + + # Filter by threshold + relevant_tools = [tool for tool, score in scored_tools if score >= threshold] + + # Apply minimum bound (fallback to top-k if none meet threshold) + if len(relevant_tools) < min_tools: + relevant_tools = [tool for tool, _ in scored_tools[:min_tools]] + + # Apply maximum bound (cap for performance) + if len(relevant_tools) > max_tools: + relevant_tools = relevant_tools[:max_tools] + + return relevant_tools + + def clear_cache(self) -> None: + """Clear all cached embeddings.""" + for cache_file in self.cache_dir.glob("*.pkl"): + cache_file.unlink() diff --git a/pyproject.toml b/pyproject.toml index 7004348..bdbf84c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "ollama~=0.6.0", "prompt-toolkit~=3.0.52", "rich~=14.2.0", + "sentence-transformers~=3.3.1", "typer~=0.20.0", ] diff --git a/tests/test_tool_rag.py b/tests/test_tool_rag.py new file mode 100644 index 0000000..5742e00 --- /dev/null +++ b/tests/test_tool_rag.py @@ -0,0 +1,185 @@ +"""Unit tests for ToolRAG functionality.""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +import torch + +from mcp_client_for_ollama.tools.rag import ToolRAG + + +@pytest.fixture +def mock_tools(): + """Create mock Tool objects for testing.""" + tools = [] + + # GitHub-related tools + tool1 = Mock() + tool1.name = "github.list_issues" + tool1.description = "List issues from a GitHub repository" + tool1.inputSchema = {"properties": {"repo": {}, "state": {}}} + tools.append(tool1) + + tool2 = Mock() + tool2.name = "github.create_pr" + tool2.description = "Create a pull request on GitHub" + tool2.inputSchema = {"properties": {"title": {}, "body": {}, "base": {}}} + tools.append(tool2) + + # Filesystem tools + tool3 = Mock() + tool3.name = "filesystem.read_file" + tool3.description = "Read contents of a file from the filesystem" + tool3.inputSchema = {"properties": {"path": {}}} + tools.append(tool3) + + tool4 = Mock() + tool4.name = "filesystem.write_file" + tool4.description = "Write content to a file on the filesystem" + tool4.inputSchema = {"properties": {"path": {}, "content": {}}} + tools.append(tool4) + + # AWS tools + tool5 = Mock() + tool5.name = "aws.list_buckets" + tool5.description = "List all S3 buckets in AWS account" + tool5.inputSchema = {"properties": {}} + tools.append(tool5) + + return tools + + +@pytest.fixture +def tool_rag(): + """Create a ToolRAG instance with temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + rag = ToolRAG(cache_dir=Path(tmpdir)) + yield rag + + +def test_tool_rag_initialization(tool_rag): + """Test ToolRAG initializes correctly.""" + assert tool_rag.model is None # Lazy loading + assert tool_rag.embeddings is None + assert len(tool_rag.tools) == 0 + assert tool_rag.cache_dir.exists() + + +def test_create_tool_text(tool_rag, mock_tools): + """Test tool text representation creation.""" + tool = mock_tools[0] + text = tool_rag._create_tool_text(tool) + + assert "github.list_issues" in text + assert "List issues from a GitHub repository" in text + assert "repo" in text + assert "state" in text + + +def test_embed_tools(tool_rag, mock_tools): + """Test embedding tools creates embeddings.""" + tool_rag.embed_tools(mock_tools, use_cache=False) + + assert tool_rag.embeddings is not None + assert isinstance(tool_rag.embeddings, torch.Tensor) + assert tool_rag.embeddings.shape[0] == len(mock_tools) + assert len(tool_rag.tools) == len(mock_tools) + assert len(tool_rag.tool_texts) == len(mock_tools) + + +def test_embed_tools_with_cache(tool_rag, mock_tools): + """Test embedding caching works correctly.""" + # First embedding - creates cache + tool_rag.embed_tools(mock_tools, use_cache=True) + first_embeddings = tool_rag.embeddings.clone() + + # Create new instance with same cache dir + rag2 = ToolRAG(cache_dir=tool_rag.cache_dir) + rag2.embed_tools(mock_tools, use_cache=True) + + # Should load from cache + assert torch.allclose(first_embeddings, rag2.embeddings) + + +def test_retrieve_relevant_tools_github_query(tool_rag, mock_tools): + """Test retrieving tools for GitHub-related query.""" + tool_rag.embed_tools(mock_tools, use_cache=False) + + results = tool_rag.retrieve_relevant_tools("show me GitHub issues", threshold=0.3, max_tools=2) + + assert len(results) <= 2 + # GitHub tools should be top results + assert any("github" in tool.name for tool in results) + + +def test_retrieve_relevant_tools_filesystem_query(tool_rag, mock_tools): + """Test retrieving tools for filesystem-related query.""" + tool_rag.embed_tools(mock_tools, use_cache=False) + + results = tool_rag.retrieve_relevant_tools("read a file from disk", threshold=0.3, max_tools=2) + + assert len(results) <= 2 + # Filesystem tools should be top results + assert any("filesystem" in tool.name for tool in results) + + +def test_retrieve_relevant_tools_aws_query(tool_rag, mock_tools): + """Test retrieving tools for AWS-related query.""" + tool_rag.embed_tools(mock_tools, use_cache=False) + + results = tool_rag.retrieve_relevant_tools("list my S3 buckets", threshold=0.3, max_tools=2) + + assert len(results) <= 2 + # AWS tool should be in top results + assert any("aws" in tool.name for tool in results) + + +def test_retrieve_before_embed_raises_error(tool_rag): + """Test that retrieving before embedding raises ValueError.""" + with pytest.raises(ValueError, match="Tools must be embedded"): + tool_rag.retrieve_relevant_tools("test query") + + +def test_retrieve_respects_max_tools(tool_rag, mock_tools): + """Test that max_tools parameter is respected.""" + tool_rag.embed_tools(mock_tools, use_cache=False) + + results = tool_rag.retrieve_relevant_tools("test query", threshold=0.0, max_tools=3) + assert len(results) <= 3 + + results = tool_rag.retrieve_relevant_tools("test query", threshold=0.0, max_tools=10) + assert len(results) <= len(mock_tools) # Can't exceed available tools + + +def test_clear_cache(tool_rag, mock_tools): + """Test cache clearing functionality.""" + tool_rag.embed_tools(mock_tools, use_cache=True) + + # Verify cache file exists + cache_files = list(tool_rag.cache_dir.glob("*.pkl")) + assert len(cache_files) > 0 + + # Clear cache + tool_rag.clear_cache() + + # Verify cache is empty + cache_files = list(tool_rag.cache_dir.glob("*.pkl")) + assert len(cache_files) == 0 + + +def test_compute_cache_hash_consistency(tool_rag, mock_tools): + """Test that cache hash is consistent for same tools.""" + hash1 = tool_rag._compute_cache_hash(mock_tools) + hash2 = tool_rag._compute_cache_hash(mock_tools) + + assert hash1 == hash2 + + +def test_compute_cache_hash_changes_with_tools(tool_rag, mock_tools): + """Test that cache hash changes when tools change.""" + hash1 = tool_rag._compute_cache_hash(mock_tools) + hash2 = tool_rag._compute_cache_hash(mock_tools[:3]) + + assert hash1 != hash2