-
Notifications
You must be signed in to change notification settings - Fork 3
Additional queries Embedding for Tool Rag #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
77c8cc4
26421e8
6d98d72
d309bf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,3 +19,6 @@ MILVUS_URL="" | |
|
|
||
| # miscellaneous | ||
| MCP_PROXY_LOCAL_PORT="" | ||
|
|
||
| # Additional examples | ||
| EXAMPLES_STORE_PATH="" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,12 +7,15 @@ | |
| from typing import List, Dict, Any, Optional, Tuple | ||
| from pathlib import Path | ||
|
|
||
| from numpy import str_ | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from evaluator.components.llm_provider import query_llm, get_llm | ||
| from evaluator.config.schema import EnvironmentConfig, DatasetConfig | ||
| from evaluator.utils.file_downloader import fetch_remote_paths | ||
| from evaluator.utils.utils import log | ||
|
|
||
| from evaluator.config.config_io import load_config | ||
| from tqdm import tqdm | ||
| import re | ||
| ToolSet = Dict[str, Dict[str, Any]] | ||
|
|
||
|
|
||
|
|
@@ -27,6 +30,7 @@ class QuerySpecification(BaseModel): | |
| """ | ||
| id: int | ||
| query: str | ||
| examples: Optional[Dict[str, Any]] = None | ||
| reference_answer: Optional[str] = None | ||
| golden_tools: ToolSet = Field(default_factory=dict) | ||
| additional_tools: Optional[ToolSet] = None | ||
|
|
@@ -395,9 +399,303 @@ def get_queries( | |
| def get_tools_from_queries(queries: List[QuerySpecification]) -> ToolSet: | ||
| tools = {} | ||
|
|
||
| for query_spec in queries: | ||
| cfg_path = "evaluator/config/yaml/tool_rag_experiments.yaml" | ||
| cfg = load_config(cfg_path, use_defaults=True) | ||
| examples = cfg.data.generate_examples | ||
|
|
||
| model_id = cfg.data.additional_examples_model_id | ||
| # Base tools from the dataset | ||
| for query_spec in tqdm(queries, desc="Getting tools from queries"): | ||
| tools.update(query_spec.golden_tools) | ||
| if query_spec.additional_tools: | ||
| tools.update(query_spec.additional_tools) | ||
|
|
||
| #Getting or generating additional examples for tools that don't have them | ||
| if examples: | ||
| golden_tools = query_spec.golden_tools | ||
| for tool in golden_tools: | ||
| examples_exists = is_tool_in_additional_store(tool, query_spec.id) | ||
| if not examples_exists: | ||
| # TODO: get the model id from the config file, This doesnt work | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right - this doesn't work. Please fix by passing |
||
| llm = get_llm(model_id=model_id, model_config=cfg.models) | ||
| tools[tool]["examples"] = generate_and_save_examples(llm, tool, query_spec) | ||
| else: | ||
| aq = get_additional_query(query_spec.id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function should be renamed following the "additional_query" => "example" refactor |
||
| tools[tool]["examples"] = aq[tool] | ||
| return tools | ||
|
|
||
|
|
||
| def load_examples_store(path: str | None = None) -> List[Dict[str, Any]]: | ||
| """ | ||
| Load the centralized additional queries store. | ||
| Expected format: a JSON list of objects {"query_id": int, "examples": {...}}. | ||
| Returns an empty list if the file doesn't exist or cannot be parsed. | ||
| """ | ||
| try: | ||
| store_path = Path(path) if path else Path(os.getenv("EXAMPLES_STORE_PATH")) | ||
| if not store_path.exists(): | ||
| return [] | ||
| with store_path.open("r", encoding="utf-8") as f: | ||
| loaded = json.load(f) | ||
| return loaded if isinstance(loaded, list) else [] | ||
| except Exception: | ||
| return [] | ||
|
|
||
|
|
||
| def get_additional_query(query_id: int) -> Dict[str, Any] | None: | ||
ilya-kolchinsky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Return ALL examples for the given query_id by merging entries | ||
| from data/examples.json (supports multiple records per query_id). | ||
| """ | ||
| store = load_examples_store() | ||
| if not store: | ||
| return None | ||
| for item in store: | ||
| next_query_id = item.get("query_id") | ||
| if next_query_id == query_id: | ||
| return item.get('examples') | ||
| return None | ||
|
|
||
|
|
||
| def get_examples_by_tool_name(tool_name: str) -> Dict[str, Any] | None: | ||
| """ | ||
| Look through data/examples.json (list of entries with an | ||
| "examples" mapping) and return the queries map for the specified | ||
| tool_name if present. Tries exact match first, then tolerant variants that | ||
| add/remove a trailing period. | ||
| """ | ||
| store = load_examples_store() | ||
|
|
||
| for item in store: | ||
| if not isinstance(item, dict): | ||
| continue | ||
| aq_map = item.get("examples") | ||
| if not isinstance(aq_map, dict): | ||
| continue | ||
| block = aq_map.get(tool_name) | ||
| if isinstance(block, dict): | ||
| return block | ||
| return None | ||
|
|
||
| def generate_and_save_examples(llm, tool_name, query_spec, store_path: Path | None = None): | ||
| """ | ||
| For each query in queries, use the provided LLM to generate examples if not present, | ||
| and save to the appropriate JSON file for that query (matching by query_id). | ||
| """ | ||
|
|
||
| # system_prompt = '''You create 5 additional queries for each tool and only return the additional queries information, given the query implemented, return in the following format as a JSON string: | ||
| # {tool_name: {"query1": "", "query2": "", "query3": "", "query4": "", "query5": ""}} ''' | ||
|
|
||
| system_prompt = '''You are a tool query generator. For each specified tool, create EXACTLY 5 concise, natural-language user queries suitable for invoking that tool. | ||
|
|
||
| Output requirements: | ||
| - Return ONLY a single JSON object (no explanations, no code fences). | ||
| - Shape: | ||
| { | ||
| "query1", | ||
| "query2", | ||
| "query3", | ||
| "query4", | ||
| "query5" | ||
| } | ||
| - Each query must be ≤ 20 words and phrased as a natural request (not SQL), ending with a question mark when appropriate. | ||
| - Preserve any placeholder tokens as given (e.g., {id}, tt1234567) without inventing new identifiers. | ||
| - Avoid near-duplicates; vary phrasing and sub-intents across the 5 queries. | ||
| - Use English. | ||
|
|
||
| Context you will receive: | ||
| - tool_name(s): a set of tool identifiers | ||
| - original user query: the initial task description | ||
| ''' | ||
|
|
||
| out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH")) | ||
| out_path.parent.mkdir(parents=True, exist_ok=True) | ||
| if not out_path.exists(): | ||
| try: | ||
| with out_path.open('w', encoding='utf-8') as f: | ||
| json.dump([], f, indent=2, ensure_ascii=False) | ||
| except Exception as e: | ||
| log(f"error creating central_out_path: {e}") | ||
| pass | ||
|
|
||
| example = _generate_additional_query_for_tool( | ||
| llm, | ||
| system_prompt, | ||
| query_spec.query, | ||
| tool_name, | ||
| ) | ||
|
|
||
| examples = {} | ||
| examples[tool_name] = example | ||
| query_spec.examples = examples | ||
| # Save additional queries to centralized file | ||
|
|
||
| if examples is not None: | ||
| append_examples_entry(query_spec.id, query_spec.examples, out_path) | ||
| return examples | ||
|
|
||
| def _generate_additional_query_for_tool(llm, system_prompt: str, query_text: str, tool_name: str) -> Dict[str, Any] | None: | ||
| """ | ||
| Call the LLM to generate additional queries for a single tool, retrying until | ||
| a mapping with query1..query5 is produced or max attempts are reached. | ||
| Returns the parsed dict (queries map or tool->queries map) or None. | ||
| """ | ||
| correct_response = False | ||
| iteration = 0 | ||
| additional_query = None | ||
| while correct_response is False: | ||
| user_prompt = f"tool_name = {tool_name}, Query= {query_text}" | ||
| result = query_llm(llm, system_prompt, user_prompt) | ||
| model_id = str(getattr(llm, "model", "") or getattr(llm, "model_name", "") or "") | ||
| if "llama3.1:8b" in model_id: | ||
| additional_query = lama_model_parsing(result) | ||
| else: | ||
| additional_query = qwen_model_parsing(result) | ||
| correct_response = has_required_query_keys(additional_query) | ||
| iteration += 1 | ||
| if iteration > 10: | ||
| log(f"Failed to generate additional queries for tool {tool_name} after 5 iterations") | ||
| break | ||
| return additional_query | ||
|
|
||
| def lama_model_parsing(response: str): | ||
| """ | ||
| Parse the response from the Lama model and return the additional queries. | ||
| """ | ||
| if not response: | ||
| return None | ||
| text = response.strip() | ||
| quoted = re.findall(r'"([^"\\]*(?:\\.[^"\\]*)*)"', text) | ||
| if not quoted: | ||
| return None | ||
| return {f"query{i}": v for i, v in enumerate(quoted[:5], start=1)} | ||
|
|
||
| def qwen_model_parsing(response: str): | ||
| """ | ||
| Parse the response from the Qwen model and return the additional queries. | ||
| """ | ||
| # Remove markdown/code block wrappers if present | ||
| match = re.search(r"</think>\s*(.*)", response, re.DOTALL) | ||
| response_text = match.group(1).strip() if match else response | ||
| # Try to extract the 'examples' dict block | ||
| additional = None | ||
| response_text = response_text.strip() | ||
| try: | ||
| additional = json.loads(response_text) | ||
| except Exception as e: | ||
| additional = None | ||
| return additional | ||
|
|
||
|
|
||
| def has_required_query_keys(response: Any) -> bool: | ||
| """ | ||
| Return True iff the response contains all of the keys Query1..Query5 (case-insensitive) | ||
| under at least one mapping block. Accepts either a parsed dict or a JSON/string response | ||
| (optionally wrapped in ```json fences). | ||
| """ | ||
| required = {"query1", "query2", "query3", "query4", "query5"} | ||
|
|
||
| def _check_dict(d: Dict[str, Any]) -> bool: | ||
| if not isinstance(d, dict): | ||
| return False | ||
| # Case 1: top-level is the queries map | ||
| keys_lc = {k.lower() for k in d.keys() if isinstance(k, str)} | ||
| if required.issubset(keys_lc): | ||
| return True | ||
| # Case 2: top-level is tool_name -> queries_map | ||
| for _, v in d.items(): | ||
| if isinstance(v, dict): | ||
| inner_keys = {k.lower() for k in v.keys() if isinstance(k, str)} | ||
| if required.issubset(inner_keys): | ||
| return True | ||
| return False | ||
|
|
||
| # If it's already a dict, check directly | ||
| if isinstance(response, dict): | ||
| return _check_dict(response) | ||
|
|
||
| # If it's a string, strip fences and try JSON | ||
| if isinstance(response, str): | ||
| text = response.strip() | ||
| m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE) | ||
| if m: | ||
| text = m.group(1).strip() | ||
| try: | ||
| obj = json.loads(text) | ||
| return _check_dict(obj) | ||
| except Exception: | ||
| # Heuristic fallback: ensure all tokens appear (weak check) | ||
| found = {tok for tok in required if re.search(fr"\b{tok}\b", text, flags=re.IGNORECASE)} | ||
| return required.issubset(found) | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def is_tool_in_additional_store(tool_name: str, query_id: int, store_path: Path | None = None) -> bool: | ||
| """ | ||
| Return True if any entry in the centralized store has examples containing this tool_name | ||
| (tolerates trailing-dot variants). | ||
| """ | ||
|
|
||
| try: | ||
| out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH")) | ||
| if not out_path.exists(): | ||
| log(f"examples.json not found, creating empty file") | ||
| out_path.parent.mkdir(parents=True, exist_ok=True) | ||
| out_path.write_text("[]", encoding="utf-8") | ||
| return False | ||
|
|
||
| with out_path.open('r', encoding='utf-8') as f: | ||
| loaded = json.load(f) | ||
|
|
||
| for item in loaded: | ||
| next_query_id = item.get("query_id") | ||
| if next_query_id == query_id: | ||
| aq = item.get("examples") | ||
| if tool_name in aq: | ||
| return True | ||
| return False | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| def append_examples_entry(query_id: int, examples: Dict[str, Any], store_path: Path | None = None) -> None: | ||
| """ | ||
| Append a new entry to data/examples.json in the list-of-dicts format: | ||
| { "query_id": <int>, "examples": <dict> } | ||
| """ | ||
| out_path = store_path or Path(os.getenv("EXAMPLES_STORE_PATH")) | ||
| store_list: List[Dict[str, Any]] = [] | ||
| try: | ||
| if out_path.exists(): | ||
| with out_path.open('r', encoding='utf-8') as f: | ||
| loaded = json.load(f) | ||
| if isinstance(loaded, list): | ||
| store_list = loaded | ||
| except Exception: | ||
| store_list = [] | ||
|
|
||
| # Upsert: if entry for query_id exists, merge/overwrite per tool; otherwise append new | ||
| idx = None | ||
| for i, item in enumerate(store_list): | ||
| if isinstance(item, dict) and item.get("query_id") == query_id: | ||
| idx = i | ||
| break | ||
|
|
||
| if idx is None: | ||
| store_list.append({ | ||
| "query_id": query_id, | ||
| "examples": examples or {}, | ||
| }) | ||
| else: | ||
| existing_block = store_list[idx].get("examples") | ||
| if not isinstance(existing_block, dict): | ||
| existing_block = {} | ||
| for tool_name, qmap in (examples or {}).items(): | ||
| if isinstance(qmap, dict): | ||
| existing_block[tool_name] = qmap | ||
| store_list[idx]["examples"] = existing_block | ||
| with out_path.open('w', encoding='utf-8') as f: | ||
| json.dump(store_list, f, indent=2, ensure_ascii=False) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I asked before, please don't load the configuration here. Instead, pass
generate_examplesand the model ID as input toget_tools_from_queries