Skip to content

Commit 26421e8

Browse files
committed
Additional queries file and added to tool definition
1 parent 77c8cc4 commit 26421e8

File tree

6 files changed

+309
-91
lines changed

6 files changed

+309
-91
lines changed

evaluator/algorithms/tool_rag_algorithm.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
216216
parts_to_include = self._settings["indexed_tool_def_parts"]
217217
if not parts_to_include:
218218
raise ValueError("indexed_tool_def_parts must be a non-empty list")
219-
220219
segments = []
221220
for p in parts_to_include:
222221
if p.lower() == "name":
@@ -236,18 +235,15 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
236235
if tags:
237236
segments.append(f"tags: {' '.join(tags)}")
238237
elif p.lower() == "additional_queries":
239-
# Append example queries supplied via settings["additional_queries"][tool.name]
240238
examples_map = self._settings.get("additional_queries") or {}
241239
examples_list = examples_map.get(tool.name) or []
242240
if examples_list:
243241
rendered = self._render_examples(examples_list)
244242
if rendered:
245243
segments.append(f"ex: {rendered}")
246-
247244
if not segments:
248245
raise ValueError(f"The following tool contains none of the fields listed in indexed_tool_def_parts:\n{tool}")
249246
text = " | ".join(segments)
250-
251247
# one-pass preprocess + truncation
252248
text = self._preprocess_text(text)
253249
text = self._truncate(text)
@@ -260,7 +256,31 @@ def _create_docs_from_tools(self, tools: List[BaseTool]) -> List[Document]:
260256
documents.append(Document(page_content=page_content, metadata={"name": tool.name}))
261257
return documents
262258

263-
def _index_tools(self, tools: List[BaseTool], queries: List[QuerySpecification]) -> None:
259+
def _collect_examples_from_tool_specs(self, tool_specs: Dict[str, Dict[str, Any]]) -> Dict[str, List[str]]:
260+
"""
261+
Build {tool_name: [example1, example2, ...]} from a tools dict where each
262+
value may contain an 'additional_queries' dict mapping query keys to strings.
263+
"""
264+
examples: Dict[str, List[str]] = {}
265+
for tool_name, spec in (tool_specs or {}).items():
266+
if not isinstance(spec, dict):
267+
continue
268+
aq = spec.get("additional_queries")
269+
if isinstance(aq, dict):
270+
for _, qtext in aq.items():
271+
if isinstance(qtext, str) and qtext.strip():
272+
examples.setdefault(tool_name, []).append(qtext.strip())
273+
# de-duplicate while preserving order
274+
for k, v in list(examples.items()):
275+
seen, out = set(), []
276+
for s in v:
277+
if s not in seen:
278+
seen.add(s)
279+
out.append(s)
280+
examples[k] = out
281+
return examples
282+
283+
def _index_tools(self, tools: List[BaseTool]) -> None:
264284
self.tool_name_to_base_tool = {tool.name: tool for tool in tools}
265285

266286
self.embeddings = HuggingFaceEmbeddings(model_name=self._settings["embedding_model_id"])
@@ -319,7 +339,7 @@ def _index_tools(self, tools: List[BaseTool], queries: List[QuerySpecification])
319339
search_params=search_params,
320340
)
321341

322-
def set_up(self, model: BaseChatModel, tools: List[BaseTool], queries: List[QuerySpecification]) -> None:
342+
def set_up(self, model: BaseChatModel, tools: List[BaseTool], tool_specs: Any) -> None:
323343
super().set_up(model, tools)
324344

325345
if self._settings["cross_encoder_model_name"]:
@@ -331,34 +351,15 @@ def set_up(self, model: BaseChatModel, tools: List[BaseTool], queries: List[Quer
331351
if self._settings["enable_query_decomposition"] or self._settings["enable_query_rewriting"]:
332352
self.query_rewriting_model = self._get_llm(self._settings["query_rewriting_model_id"])
333353

334-
# Build additional_queries mapping from provided QuerySpecifications so YAML is not required.
354+
# Build additional_queries mapping from provided specs (accept dict of tool specs or list of QuerySpecifications)
335355
try:
336-
tool_examples: Dict[str, List[str]] = {}
337-
for spec in (queries or []):
338-
add_q = getattr(spec, "additional_queries", None) or {}
339-
# Flatten wrapper {"additional_queries": {...}} if present
340-
if isinstance(add_q, dict) and "additional_queries" in add_q and len(add_q) == 1:
341-
add_q = add_q["additional_queries"]
342-
for tool_name, qmap in add_q.items():
343-
if isinstance(qmap, dict):
344-
for _, qtext in qmap.items():
345-
if isinstance(qtext, str) and qtext.strip():
346-
tool_examples.setdefault(tool_name, []).append(qtext.strip())
347-
# Dedupe while preserving order
348-
for k, v in list(tool_examples.items()):
349-
seen = set()
350-
deduped = []
351-
for s in v:
352-
if s not in seen:
353-
seen.add(s)
354-
deduped.append(s)
355-
tool_examples[k] = deduped
356-
if tool_examples:
357-
self._settings["additional_queries"] = tool_examples
356+
examples_map: Dict[str, List[str]] = {}
357+
if isinstance(tool_specs, dict):
358+
examples_map = self._collect_examples_from_tool_specs(tool_specs)
359+
self._settings["additional_queries"] = examples_map
358360
except Exception:
359361
pass
360-
361-
self._index_tools(tools, queries)
362+
self._index_tools(tools)
362363

363364
def _threshold_results(self, docs_and_scores: List[Tuple[Document, float]]) -> List[Document]:
364365
"""
@@ -619,4 +620,4 @@ def _dedup_keep_order(xs: List[str]) -> List[str]:
619620

620621
@staticmethod
621622
def _strip_numbering(s: str) -> str:
622-
return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".")
623+
return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".")

evaluator/components/data_provider.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _load_queries_from_single_file(
315315
root_dataset_path: str or Path,
316316
experiment_environment: EnvironmentConfig,
317317
dataset_config: DatasetConfig,
318-
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
318+
) -> List[QuerySpecification]:
319319
with open(query_file_path, 'r') as f:
320320
data = json.load(f)
321321

@@ -334,13 +334,6 @@ def _load_queries_from_single_file(
334334
log(f"Invalid query spec, skipping this query.")
335335
else:
336336
query = raw_query_spec.get("query")
337-
if raw_query_spec.get("additional_queries"):
338-
additional_queries = raw_query_spec.get("additional_queries")
339-
print(f"Additional queries provided: {additional_queries}")
340-
341-
else:
342-
print(f"No additional queries provided")
343-
additional_queries = None
344337
query_id = int(raw_query_spec.get("query_id"))
345338
golden_tools, additional_tools = (
346339
_parse_raw_query_tool_definitions(raw_query_spec, experiment_environment, dataset_config))
@@ -354,8 +347,6 @@ def _load_queries_from_single_file(
354347
QuerySpecification(
355348
id=query_id,
356349
query=query,
357-
path=str(query_file_path),
358-
additional_queries=additional_queries,
359350
reference_answer=reference_answer,
360351
golden_tools=golden_tools,
361352
additional_tools=additional_tools or None
@@ -373,7 +364,7 @@ def get_queries(
373364
experiment_environment: EnvironmentConfig,
374365
dataset_config: DatasetConfig,
375366
fine_tuning_mode=False
376-
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
367+
) -> List[QuerySpecification]:
377368
"""Load queries from the dataset."""
378369
root_dataset_path = Path(os.getenv("ROOT_DATASET_PATH"))
379370
if not root_dataset_path:
@@ -390,14 +381,14 @@ def get_queries(
390381
queries_num = None if fine_tuning_mode else dataset_config.queries_num
391382
queries = []
392383
for path in local_paths:
393-
print(f"\n\n")
394-
print(f"--------------------------------")
395-
print(f"Loading queries from file: {path}")
396-
print(f"\n\n")
397384
remaining_queries_num = None if queries_num is None else queries_num - len(queries)
398385
if remaining_queries_num == 0:
399386
break
400-
new_queries= _load_queries_from_single_file(path, remaining_queries_num, root_dataset_path, experiment_environment, dataset_config)
387+
new_queries = _load_queries_from_single_file(path,
388+
remaining_queries_num,
389+
root_dataset_path,
390+
experiment_environment,
391+
dataset_config)
401392
queries.extend(new_queries)
402393

403394
return queries
@@ -406,9 +397,55 @@ def get_queries(
406397
def get_tools_from_queries(queries: List[QuerySpecification]) -> ToolSet:
407398
tools = {}
408399

400+
# Base tools from the dataset
409401
for query_spec in queries:
410402
tools.update(query_spec.golden_tools)
411403
if query_spec.additional_tools:
412404
tools.update(query_spec.additional_tools)
413405

406+
# Merge per-query additional queries from centralized store under the correct tool entry
407+
aq = get_additional_query(query_spec.id)
408+
if isinstance(aq, dict):
409+
golden_tools = query_spec.golden_tools
410+
for tool in golden_tools:
411+
additional_queries = aq.get(tool)
412+
tools[tool]["additional_queries"] = additional_queries
413+
414414
return tools
415+
416+
417+
def load_additional_queries_store(path: str | None = None) -> List[Dict[str, Any]]:
418+
"""
419+
Load the centralized additional queries store.
420+
Expected format: a JSON list of objects {"query_id": int, "additional_queries": {...}}.
421+
Returns an empty list if the file doesn't exist or cannot be parsed.
422+
"""
423+
try:
424+
store_path = Path(path) if path else (Path("data") / "additional_queries.json")
425+
if not store_path.exists():
426+
return []
427+
with store_path.open("r", encoding="utf-8") as f:
428+
loaded = json.load(f)
429+
return loaded if isinstance(loaded, list) else []
430+
except Exception:
431+
return []
432+
433+
434+
def get_additional_query(query_id: int) -> Dict[str, Any] | None:
435+
"""
436+
Return the additional_queries dict for the given query_id from data/additional_queries.json,
437+
or None if not found or invalid.
438+
"""
439+
store = load_additional_queries_store()
440+
for item in store:
441+
if not isinstance(item, dict):
442+
continue
443+
if "query_id" not in item or "additional_queries" not in item:
444+
continue
445+
try:
446+
qid = int(item["query_id"])
447+
except Exception:
448+
continue
449+
if qid == query_id and isinstance(item["additional_queries"], dict):
450+
return item["additional_queries"]
451+
return None

evaluator/config/yaml/tool_rag_experiments.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ data:
1212
reference_answers_path: "https://huggingface.co/datasets/stabletoolbench/baselines/resolve/main/data_baselines.zip"
1313
reference_model_id: "chatgpt_cot"
1414
queries_num: null
15+
additional_queries_model_id: "Qwen/Qwen3-8B"
1516

1617
models:
1718
- id: "Qwen/Qwen3-8B"

evaluator/evaluator.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
from re import S
34
import time
45
import traceback
56
from typing import List, Tuple
@@ -68,6 +69,18 @@ async def run(self) -> None:
6869
# Actually run the experiments
6970
metadata_columns = ["Experiment ID", "Algorithm ID", "Algorithm Details", "Environment", "Number of Queries"]
7071
with CSVLogger(metric_collectors, os.getenv("OUTPUT_DIR_PATH"), metadata_columns=metadata_columns) as logger:
72+
# generate additional queries here (optional)
73+
try:
74+
log(f"Generating additional queries...")
75+
environment = experiment_specs[0][1]
76+
gen_model_id = self.config.data.additional_queries_model_id
77+
llm = get_llm(model_id=gen_model_id, model_config=self.config.models)
78+
queries = get_queries(environment, self.config.data)
79+
generate_and_save_additional_queries(llm, queries)
80+
except Exception as _:
81+
log("Skipping additional query generation due to error.")
82+
83+
# generate queries here
7184
for i, spec in enumerate(experiment_specs):
7285
algorithm, environment = spec
7386
log(f"{'-' * 60}\nRunning Experiment {i+1} of {len(experiment_specs)}: {self._spec_to_str(spec)}...\n{'-' * 60}")
@@ -114,13 +127,15 @@ async def _run_experiment(self,
114127
Runs the specified experiment and returns the number of evaluated queries.
115128
"""
116129
processed_queries_num = 0
130+
117131
try:
118132
queries = await self._set_up_experiment(spec, metric_collectors, mcp_proxy_manager)
119133
algorithm, environment = spec
120134

121135
try:
122136
for i, query_spec in enumerate(queries):
123137
log(f"Processing query #{query_spec.id} (Experiment {exp_index} of {total_exp_num}, query {i+1} of {len(queries)})...")
138+
124139
for mc in metric_collectors:
125140
mc.prepare_for_measurement(query_spec)
126141

@@ -195,29 +210,28 @@ async def _set_up_experiment(self,
195210
mcp_proxy_manager: MCPProxyManager,
196211
) -> List[QuerySpecification]:
197212
algorithm, environment = spec
198-
199213
log(f"Initializing LLM connection: {environment.model_id}")
200-
llm = get_llm(model_id=environment.model_id, model_config=self.config.models)
201214
log("Connection established successfully.\n")
202215
log("Fetching queries for the current experiment...")
203216
queries = get_queries(environment, self.config.data)
204217
log(f"Successfully loaded {len(queries)} queries.\n")
205218
print_iterable_verbose("The following queries will be executed:\n", queries)
206-
log(f"Generating additional queries.\n")
207-
generate_and_save_additional_queries(llm, queries)
219+
llm = get_llm(model_id=environment.model_id, model_config=self.config.models)
208220
queries = get_queries(environment, self.config.data)
209221
log("Retrieving tool definitions for the current experiment...")
210222
tool_specs = get_tools_from_queries(queries)
211223
tools = await mcp_proxy_manager.run_mcp_proxy(tool_specs, init_client=True).get_tools()
212224
print_iterable_verbose("The following tools will be available during evaluation:\n", tools)
213225
log(f"The experiment will proceed with {len(tools)} tool(s).\n")
214-
215226
log("Setting up the algorithm and the metric collectors...")
216-
217-
algorithm.set_up(llm, tools, queries)
227+
# Pass queries to algorithms that accept them; fall back for others
228+
if algorithm.__module__ == "evaluator.algorithms.tool_rag_algorithm":
229+
algorithm.set_up(llm, tools, tool_specs)
230+
else:
231+
algorithm.set_up(llm, tools)
218232
for mc in metric_collectors:
219233
mc.set_up()
220-
log("All set!\n")
234+
log("Setup complete!")
221235

222236
return queries
223237

evaluator/metric_collectors/tool_selection_metric_collector.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class ToolSelectionMetricCollector(MetricCollector):
1616
def __init__(self, settings: Dict, model_config: List[ModelConfig]):
1717
super().__init__(settings, model_config)
1818

19-
self.total_queries = 0
20-
self.exact_matches = 0
21-
self.precision_sum = 0.0
22-
self.recall_sum = 0.0
19+
self.total_queries = None
20+
self.exact_matches = None
21+
self.precision_sum = None
22+
self.recall_sum = None
2323

2424
def get_collected_metrics_names(self) -> List[str]:
2525
return ["Exact Tool Selection Match Rate",
@@ -96,10 +96,7 @@ def report_results(self) -> Dict[str, Any] or None:
9696
raise RuntimeError("No measurements registered, cannot produce results.")
9797

9898
results = {
99-
"Exact Tool Selection Match Rate": (
100-
(self.exact_matches or 0) / (self.total_queries or 1)
101-
if self.total_queries else 0.0
102-
),
99+
"Exact Tool Selection Match Rate": self.exact_matches / self.total_queries,
103100
"Tool Selection Precision": self.precision_sum / self.total_queries,
104101
"Tool Selection Recall": self.recall_sum / self.total_queries,
105102
"Spurious Tool Calling Rate": 1.0 - (self.precision_sum / self.total_queries),

0 commit comments

Comments
 (0)