Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ MILVUS_URL=""

# miscellaneous
MCP_PROXY_LOCAL_PORT=""

# Additional examples
EXAMPLES_STORE_PATH=""
16 changes: 11 additions & 5 deletions evaluator/algorithms/tool_rag_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class ToolRagAlgorithm(Algorithm):
- max_document_size: the maximal size, in characters, of a single indexed document, or None to disable the size limit.
- indexed_tool_def_parts: the parts of the MCP tool definition to be used for index construction, such as 'name',
'description', 'args', etc.
You can also include 'examples' (or 'examples') to append example queries for each tool if provided
via the 'examples' setting (see defaults below).
- hybrid_mode: True to enable hybrid (sparse + dense) search and False to only enable dense search.
- analyzer_params: parameters for the Milvus BM25 analyzer.
- fusion_type: the algorithm for combining the dense and the sparse scores if hybrid mode is activated. Milvus only
Expand Down Expand Up @@ -130,6 +132,7 @@ def get_default_settings(self) -> Dict[str, Any]:
"index_type": "FLAT",
"indexed_tool_def_parts": ["name", "description"],


# preprocessing
"text_preprocessing_operations": None,
"max_document_size": None,
Expand Down Expand Up @@ -205,15 +208,14 @@ def _render_args_schema(schema: Dict[str, Any]) -> str:
return " ".join(parts)

@staticmethod
def _render_examples(examples: List[str], max_examples: int = 3) -> str:
def _render_examples(examples: List[str], max_examples: int = 5) -> str:
exs = (examples or [])[:max_examples]
return " || ".join(exs)

def _compose_tool_text(self, tool: BaseTool) -> str:
parts_to_include = self._settings["indexed_tool_def_parts"]
if not parts_to_include:
raise ValueError("indexed_tool_def_parts must be a non-empty list")

segments = []
for p in parts_to_include:
if p.lower() == "name":
Expand All @@ -232,11 +234,15 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
tags = tool.tags or []
if tags:
segments.append(f"tags: {' '.join(tags)}")

elif p.lower() == "examples":
examples_list =list(tool.metadata['examples'].values())
if examples_list:
rendered = self._render_examples(examples_list)
if rendered:
segments.append(f"ex: {rendered}")
if not segments:
raise ValueError(f"The following tool contains none of the fields listed in indexed_tool_def_parts:\n{tool}")
text = " | ".join(segments)

# one-pass preprocess + truncation
text = self._preprocess_text(text)
text = self._truncate(text)
Expand Down Expand Up @@ -581,4 +587,4 @@ def _dedup_keep_order(xs: List[str]) -> List[str]:

@staticmethod
def _strip_numbering(s: str) -> str:
return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".")
return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".")
304 changes: 301 additions & 3 deletions evaluator/components/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path

from numpy import str_
from pydantic import BaseModel, Field

from evaluator.components.llm_provider import query_llm, get_llm
from evaluator.config.schema import EnvironmentConfig, DatasetConfig
from evaluator.utils.file_downloader import fetch_remote_paths
from evaluator.utils.utils import log

from evaluator.config.config_io import load_config
from tqdm import tqdm
import re
ToolSet = Dict[str, Dict[str, Any]]


Expand All @@ -27,6 +30,7 @@ class QuerySpecification(BaseModel):
"""
id: int
query: str
examples: Optional[Dict[str, Any]] = None
reference_answer: Optional[str] = None
golden_tools: ToolSet = Field(default_factory=dict)
additional_tools: Optional[ToolSet] = None
Expand Down Expand Up @@ -395,9 +399,303 @@ def get_queries(
def get_tools_from_queries(queries: List[QuerySpecification]) -> ToolSet:
tools = {}

for query_spec in queries:
cfg_path = "evaluator/config/yaml/tool_rag_experiments.yaml"
cfg = load_config(cfg_path, use_defaults=True)
examples = cfg.data.generate_examples
Comment on lines +402 to +404
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I asked before, please don't load the configuration here. Instead, pass generate_examples and the model ID as input to get_tools_from_queries


model_id = cfg.data.additional_examples_model_id
# Base tools from the dataset
for query_spec in tqdm(queries, desc="Getting tools from queries"):
tools.update(query_spec.golden_tools)
if query_spec.additional_tools:
tools.update(query_spec.additional_tools)

#Getting or generating additional examples for tools that don't have them
if examples:
golden_tools = query_spec.golden_tools
for tool in golden_tools:
examples_exists = is_tool_in_additional_store(tool, query_spec.id)
if not examples_exists:
# TODO: get the model id from the config file, This doesnt work
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right - this doesn't work. Please fix by passing model_config as another optional parameter to get_tools_from_queries.

llm = get_llm(model_id=model_id, model_config=cfg.models)
tools[tool]["examples"] = generate_and_save_examples(llm, tool, query_spec)
else:
aq = get_additional_query(query_spec.id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should be renamed following the "additional_query" => "example" refactor

tools[tool]["examples"] = aq[tool]
return tools


def load_examples_store(path: str | None = None) -> List[Dict[str, Any]]:
"""
Load the centralized additional queries store.
Expected format: a JSON list of objects {"query_id": int, "examples": {...}}.
Returns an empty list if the file doesn't exist or cannot be parsed.
"""
try:
store_path = Path(path) if path else Path(os.getenv("EXAMPLES_STORE_PATH"))
if not store_path.exists():
return []
with store_path.open("r", encoding="utf-8") as f:
loaded = json.load(f)
return loaded if isinstance(loaded, list) else []
except Exception:
return []


def get_additional_query(query_id: int) -> Dict[str, Any] | None:
"""
Return ALL examples for the given query_id by merging entries
from data/examples.json (supports multiple records per query_id).
"""
store = load_examples_store()
if not store:
return None
for item in store:
next_query_id = item.get("query_id")
if next_query_id == query_id:
return item.get('examples')
return None


def get_examples_by_tool_name(tool_name: str) -> Dict[str, Any] | None:
"""
Look through data/examples.json (list of entries with an
"examples" mapping) and return the queries map for the specified
tool_name if present. Tries exact match first, then tolerant variants that
add/remove a trailing period.
"""
store = load_examples_store()

for item in store:
if not isinstance(item, dict):
continue
aq_map = item.get("examples")
if not isinstance(aq_map, dict):
continue
block = aq_map.get(tool_name)
if isinstance(block, dict):
return block
return None

def generate_and_save_examples(llm, tool_name, query_spec, store_path: Path | None = None):
"""
For each query in queries, use the provided LLM to generate examples if not present,
and save to the appropriate JSON file for that query (matching by query_id).
"""

# system_prompt = '''You create 5 additional queries for each tool and only return the additional queries information, given the query implemented, return in the following format as a JSON string:
# {tool_name: {"query1": "", "query2": "", "query3": "", "query4": "", "query5": ""}} '''

system_prompt = '''You are a tool query generator. For each specified tool, create EXACTLY 5 concise, natural-language user queries suitable for invoking that tool.

Output requirements:
- Return ONLY a single JSON object (no explanations, no code fences).
- Shape:
{
"query1",
"query2",
"query3",
"query4",
"query5"
}
- Each query must be ≤ 20 words and phrased as a natural request (not SQL), ending with a question mark when appropriate.
- Preserve any placeholder tokens as given (e.g., {id}, tt1234567) without inventing new identifiers.
- Avoid near-duplicates; vary phrasing and sub-intents across the 5 queries.
- Use English.

Context you will receive:
- tool_name(s): a set of tool identifiers
- original user query: the initial task description
'''

out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH"))
out_path.parent.mkdir(parents=True, exist_ok=True)
if not out_path.exists():
try:
with out_path.open('w', encoding='utf-8') as f:
json.dump([], f, indent=2, ensure_ascii=False)
except Exception as e:
log(f"error creating central_out_path: {e}")
pass

example = _generate_additional_query_for_tool(
llm,
system_prompt,
query_spec.query,
tool_name,
)

examples = {}
examples[tool_name] = example
query_spec.examples = examples
# Save additional queries to centralized file

if examples is not None:
append_examples_entry(query_spec.id, query_spec.examples, out_path)
return examples

def _generate_additional_query_for_tool(llm, system_prompt: str, query_text: str, tool_name: str) -> Dict[str, Any] | None:
"""
Call the LLM to generate additional queries for a single tool, retrying until
a mapping with query1..query5 is produced or max attempts are reached.
Returns the parsed dict (queries map or tool->queries map) or None.
"""
correct_response = False
iteration = 0
additional_query = None
while correct_response is False:
user_prompt = f"tool_name = {tool_name}, Query= {query_text}"
result = query_llm(llm, system_prompt, user_prompt)
model_id = str(getattr(llm, "model", "") or getattr(llm, "model_name", "") or "")
if "llama3.1:8b" in model_id:
additional_query = lama_model_parsing(result)
else:
additional_query = qwen_model_parsing(result)
correct_response = has_required_query_keys(additional_query)
iteration += 1
if iteration > 10:
log(f"Failed to generate additional queries for tool {tool_name} after 5 iterations")
break
return additional_query

def lama_model_parsing(response: str):
"""
Parse the response from the Lama model and return the additional queries.
"""
if not response:
return None
text = response.strip()
quoted = re.findall(r'"([^"\\]*(?:\\.[^"\\]*)*)"', text)
if not quoted:
return None
return {f"query{i}": v for i, v in enumerate(quoted[:5], start=1)}

def qwen_model_parsing(response: str):
"""
Parse the response from the Qwen model and return the additional queries.
"""
# Remove markdown/code block wrappers if present
match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
response_text = match.group(1).strip() if match else response
# Try to extract the 'examples' dict block
additional = None
response_text = response_text.strip()
try:
additional = json.loads(response_text)
except Exception as e:
additional = None
return additional


def has_required_query_keys(response: Any) -> bool:
"""
Return True iff the response contains all of the keys Query1..Query5 (case-insensitive)
under at least one mapping block. Accepts either a parsed dict or a JSON/string response
(optionally wrapped in ```json fences).
"""
required = {"query1", "query2", "query3", "query4", "query5"}

def _check_dict(d: Dict[str, Any]) -> bool:
if not isinstance(d, dict):
return False
# Case 1: top-level is the queries map
keys_lc = {k.lower() for k in d.keys() if isinstance(k, str)}
if required.issubset(keys_lc):
return True
# Case 2: top-level is tool_name -> queries_map
for _, v in d.items():
if isinstance(v, dict):
inner_keys = {k.lower() for k in v.keys() if isinstance(k, str)}
if required.issubset(inner_keys):
return True
return False

# If it's already a dict, check directly
if isinstance(response, dict):
return _check_dict(response)

# If it's a string, strip fences and try JSON
if isinstance(response, str):
text = response.strip()
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if m:
text = m.group(1).strip()
try:
obj = json.loads(text)
return _check_dict(obj)
except Exception:
# Heuristic fallback: ensure all tokens appear (weak check)
found = {tok for tok in required if re.search(fr"\b{tok}\b", text, flags=re.IGNORECASE)}
return required.issubset(found)

return False


def is_tool_in_additional_store(tool_name: str, query_id: int, store_path: Path | None = None) -> bool:
"""
Return True if any entry in the centralized store has examples containing this tool_name
(tolerates trailing-dot variants).
"""

try:
out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH"))
if not out_path.exists():
log(f"examples.json not found, creating empty file")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("[]", encoding="utf-8")
return False

with out_path.open('r', encoding='utf-8') as f:
loaded = json.load(f)

for item in loaded:
next_query_id = item.get("query_id")
if next_query_id == query_id:
aq = item.get("examples")
if tool_name in aq:
return True
return False
except Exception:
return False


def append_examples_entry(query_id: int, examples: Dict[str, Any], store_path: Path | None = None) -> None:
"""
Append a new entry to data/examples.json in the list-of-dicts format:
{ "query_id": <int>, "examples": <dict> }
"""
out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH"))
store_list: List[Dict[str, Any]] = []
try:
if out_path.exists():
with out_path.open('r', encoding='utf-8') as f:
loaded = json.load(f)
if isinstance(loaded, list):
store_list = loaded
except Exception:
store_list = []

# Upsert: if entry for query_id exists, merge/overwrite per tool; otherwise append new
idx = None
for i, item in enumerate(store_list):
if isinstance(item, dict) and item.get("query_id") == query_id:
idx = i
break

if idx is None:
store_list.append({
"query_id": query_id,
"examples": examples or {},
})
else:
existing_block = store_list[idx].get("examples")
if not isinstance(existing_block, dict):
existing_block = {}
for tool_name, qmap in (examples or {}).items():
if isinstance(qmap, dict):
existing_block[tool_name] = qmap
store_list[idx]["examples"] = existing_block
with out_path.open('w', encoding='utf-8') as f:
json.dump(store_list, f, indent=2, ensure_ascii=False)


Loading