diff --git a/.gitignore b/.gitignore index ac31eb41a..b9fe4ae08 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ tmp/ **settings.json** evaluation/*tmp/ evaluation/results -evaluation/.env +.env !evaluation/configs-example/*.json evaluation/configs/* **tree_textual_memory_locomo** diff --git a/evaluation/data/personamem/.gitkeep b/evaluation/data/personamem/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/hotpot/data_loader.py b/evaluation/scripts/hotpot/data_loader.py new file mode 100644 index 000000000..871981036 --- /dev/null +++ b/evaluation/scripts/hotpot/data_loader.py @@ -0,0 +1,78 @@ +import json + +from pathlib import Path + +from datasets import load_dataset + + +def load_hotpot_data(data_dir: Path | str) -> list[dict]: + """ + Load HotpotQA dataset. + If dev_distractor_gold.json exists in data_dir, load it. + Otherwise, download from Hugging Face, convert to standard format, save, and load. + """ + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + file_path = data_dir / "dev_distractor_gold.json" + + if file_path.exists(): + print(f"Loading local dataset from {file_path}") + try: + with open(file_path, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + print(f"Failed to load local file: {e}. Re-downloading...") + + print("Downloading HotpotQA dataset from Hugging Face...") + try: + dataset = load_dataset( + "hotpotqa/hotpot_qa", "distractor", split="validation", trust_remote_code=True + ) + except Exception as e: + print(f"Failed to download dataset: {e}") + raise + + print(f"Processing and saving dataset to {file_path}...") + items = [] + for item in dataset: + # Convert HF format to Standard format + # ID + qid = item.get("id") or item.get("_id") + + # Supporting Facts + sp = item.get("supporting_facts") + if isinstance(sp, dict): + sp_titles = sp.get("title", []) + sp_sent_ids = sp.get("sent_id", []) + sp_list = list(zip(sp_titles, sp_sent_ids, strict=False)) + else: + sp_list = sp or [] + + # Context + ctx = item.get("context") + if isinstance(ctx, dict): + ctx_titles = ctx.get("title", []) + ctx_sentences = ctx.get("sentences", []) + ctx_list = list(zip(ctx_titles, ctx_sentences, strict=False)) + else: + ctx_list = ctx or [] + + new_item = { + "_id": qid, + "question": item.get("question"), + "answer": item.get("answer"), + "supporting_facts": sp_list, + "context": ctx_list, + "type": item.get("type"), + "level": item.get("level"), + } + items.append(new_item) + + try: + with open(file_path, "w", encoding="utf-8") as f: + json.dump(items, f, ensure_ascii=False, indent=2) + print(f"Saved {len(items)} items to {file_path}") + except Exception as e: + print(f"Failed to save dataset: {e}") + + return items diff --git a/evaluation/scripts/hotpot/hotpot_check_files.py b/evaluation/scripts/hotpot/hotpot_check_files.py new file mode 100644 index 000000000..7d66957fb --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_check_files.py @@ -0,0 +1,245 @@ +import argparse +import json +import os + +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.client import MemosApiOnlineClient + + +load_dotenv() +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_HOTPOT") + + +# Load user_id -> file_id mapping from added_records.json +def _load_added_ids(records_path: Path) -> dict[str, str | None]: + if not records_path.exists(): + return {} + + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + added = obj.get("added") if isinstance(obj, dict) else None + if isinstance(added, dict): + return {str(k): (str(v) if v is not None else None) for k, v in added.items()} + except Exception: + return {} + + return {} + + +def _check_file_status( + client: MemosApiOnlineClient, file_ids: list[str], batch_size: int +) -> dict[str, dict[str, str | None]]: + """ + Phase 1: Query file processing status for given file_ids in batches. + Returns a dict: file_id -> {name, size, status}. + """ + file_status: dict[str, dict[str, str | None]] = {} + + for i in tqdm(range(0, len(file_ids), batch_size), desc="Checking files"): + batch = file_ids[i : i + batch_size] + try: + resp = client.check_file(batch) + except Exception as e: + print(f"[Check] error for batch starting at {i}: {e}") + continue + + if not isinstance(resp, dict): + continue + + data = resp.get("data") or {} + details = data.get("file_detail_list") or [] + + for item in details: + if not isinstance(item, dict): + continue + fid = item.get("id") + if not fid: + continue + file_status[str(fid)] = { + "name": item.get("name"), + "size": item.get("size"), + "status": item.get("status"), + } + + return file_status + + +def _reupload_failed_files( + client: MemosApiOnlineClient, + file_status: dict[str, dict[str, str | None]], + added_ids: dict[str, str | None], + url_prefix: str, +) -> list[dict[str, str | None]]: + """ + Phase 2: Re-upload files which status == PROCESSING_FAILED. + The file URL is built using user_id and url_prefix: /_context.txt. + Returns a list of per-file results for auditing. + """ + fid_to_user: dict[str, str] = {} + for uid, fid in added_ids.items(): + if fid: + fid_to_user[str(fid)] = str(uid) + + reupload_results: list[dict[str, str | None]] = [] + failed_ids = [ + fid for fid, info in file_status.items() if (info.get("status") == "PROCESSING_FAILED") + ] + + for fid in tqdm(failed_ids, desc="Reuploading failed files"): + uid = fid_to_user.get(fid) + if not uid: + reupload_results.append( + { + "old_file_id": fid, + "user_id": None, + "new_file_id": None, + "ok": "false", + "error": "user_id_not_found", + } + ) + continue + + file_url = f"{url_prefix.rstrip('/')}/{uid}_context.txt" + try: + resp = client.upload_file(memos_knowledgebase_id or "", file_url) + new_id = None + if isinstance(resp, dict): + data = resp.get("data") or {} + if isinstance(data, list) and data: + first = data[0] if isinstance(data[0], dict) else {} + new_id = str(first.get("id")) if first.get("id") else None + reupload_results.append( + { + "old_file_id": fid, + "user_id": uid, + "new_file_id": new_id, + "ok": "true", + "error": None, + } + ) + except Exception as e: + reupload_results.append( + { + "old_file_id": fid, + "user_id": uid, + "new_file_id": None, + "ok": "false", + "error": str(e), + } + ) + + return reupload_results + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="Check HotpotQA memos-online file status.") + parser.add_argument( + "--lib", + type=str, + default="memos-online", + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--batch-size", type=int, default=50) + parser.add_argument("--output", type=str, default=None, help="输出JSON文件路径") + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/hotpot_text_files/", + help="文件URL前缀", + ) + + args = parser.parse_args(argv) + + if args.lib != "memos-online": + print(f"Only memos-online is supported, got lib={args.lib}") + return + + output_dir = Path("evaluation/data/hotpot") + if args.version_dir: + output_dir = output_dir / args.version_dir + output_dir.mkdir(parents=True, exist_ok=True) + + records_path = output_dir / f"{args.lib}_added_records.json" + print(f"[Check] loading records from {records_path}") + + added_ids = _load_added_ids(records_path) + file_ids = sorted({fid for fid in added_ids.values() if fid}) + print(f"[Check] total file ids: {len(file_ids)}") + + if not file_ids: + return + + client = MemosApiOnlineClient() + batch_size = max(1, args.batch_size) + + # Phase 1: Query file processing status + file_status = _check_file_status(client, file_ids, batch_size) + + # Phase 2: Re-upload files which failed processing + reupload_results = _reupload_failed_files( + client=client, + file_status=file_status, + added_ids=added_ids, + url_prefix=args.url_prefix, + ) + + # Persist: update original added_records.json with new file ids + if reupload_results: + try: + # Load original JSON object to preserve additional fields (e.g., perf) + obj: dict = {} + if records_path.exists(): + txt = records_path.read_text(encoding="utf-8") + if txt: + parsed = json.loads(txt) + if isinstance(parsed, dict): + obj = parsed + # Build updated 'added' dict + added_obj: dict[str, str | None] = {} + if isinstance(obj.get("added"), dict): + added_obj = { + str(k): (str(v) if v is not None else None) for k, v in obj["added"].items() + } + else: + added_obj = dict(added_ids) + # Apply new file ids from successful reuploads + for item in reupload_results: + if item.get("ok") == "true" and item.get("user_id") and item.get("new_file_id"): + uid = str(item["user_id"]) + added_obj[uid] = str(item["new_file_id"]) + obj["added"] = dict(sorted(added_obj.items())) + # Atomic write back to records_path + tmp_r = records_path.with_suffix(records_path.suffix + ".tmp") + tmp_r.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp_r, records_path) + print(f"[Update] updated added_records with new file ids -> {records_path}") + except Exception as e: + print(f"[Update] failed to update added_records: {e}") + + if args.output: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + else: + output_path = output_dir / f"{args.lib}_file_status.json" + + result_obj = { + "lib": args.lib, + "version_dir": args.version_dir, + "total": len(file_ids), + "file_detail_list": [{"id": fid, **(file_status.get(fid) or {})} for fid in file_ids], + "reupload_results": reupload_results, + } + + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(result_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + + print(f"[Check] saved file status for {len(file_status)} files to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py new file mode 100644 index 000000000..a4f06e2d2 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_eval.py @@ -0,0 +1,329 @@ +import argparse +import importlib.util +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +import pandas as pd +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.extract_answer import ( + extract_answer, + parse_extracted_answer, +) +from evaluation.scripts.utils.metrics import Metrics +from evaluation.scripts.utils.prompts import HOTPOT_ANSWER_PROMPT + +load_dotenv() + +HOT_POT_DIR = Path("evaluation/data/hotpot") + + +def llm_response( + oai_client: OpenAI, + chat_model: str, + context: str, + question: str, +) -> str: + prompt = HOTPOT_ANSWER_PROMPT.format( + question=question, + context=context, + ) + resp = oai_client.chat.completions.create( + model=chat_model, + messages=[{"role": "system", "content": prompt}], + temperature=0, + ) + return resp.choices[0].message.content or "" + + +def _load_json_list(path: Path) -> list[dict[str, Any]]: + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + if isinstance(data, dict) and isinstance(data.get("results"), list): + return data["results"] + raise ValueError(f"Invalid json format: {path}") + + +def _save_pred( + pred_path: Path, + pred_answers: dict[str, str], + pred_sp: dict[str, list[Any]], + perf: dict[str, Any] | None = None, +) -> None: + pred_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = pred_path.with_suffix(pred_path.suffix + ".tmp") + + safe_pred_answers = { + k: v if isinstance(v, str) else "" if v is None else str(v) for k, v in pred_answers.items() + } + + obj: dict[str, Any] = { + "answer": safe_pred_answers, + "sp": pred_sp, + } + if perf is not None: + obj["perf"] = perf + + tmp_path.write_text( + json.dumps(obj, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp_path, pred_path) + + +def run_eval(pred_path: Path, gold_path: Path) -> None: + spec = importlib.util.spec_from_file_location( + "hotpot_eval_v1", + "evaluation/scripts/hotpot/hotpot_evaluate_v1.py", + ) + if spec is None or spec.loader is None: + raise ImportError("Failed to load hotpot_evaluate_v1") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + metrics: dict[str, Any] = module.eval( + str(pred_path), + str(gold_path), + ) + + # Save metrics back into pred json + try: + if pred_path.exists(): + current_data = json.loads( + pred_path.read_text(encoding="utf-8"), + ) + else: + current_data = {} + + if isinstance(current_data, list): + new_data: Any = [metrics, *current_data] + elif isinstance(current_data, dict): + new_data = metrics.copy() + for key, value in current_data.items(): + if key not in new_data: + new_data[key] = value + else: + new_data = metrics + + pred_path.write_text( + json.dumps(new_data, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + except Exception as exc: + print(f"[Eval] Failed to save metrics to {pred_path}: {exc}") + + # Save metrics to xlsx + try: + xlsx_path = pred_path.with_name( + f"{pred_path.stem}_metrics.xlsx", + ) + + rows: list[dict[str, Any]] = [] + row = { + "category": "overall", + "question_number": metrics.get("count"), + "em": metrics.get("em"), + "f1": metrics.get("f1"), + "sp_em": metrics.get("sp_em"), + "sp_f1": metrics.get("sp_f1"), + "joint_em": metrics.get("joint_em"), + "joint_f1": metrics.get("joint_f1"), + } + + for key, value in metrics.items(): + if key not in row and key != "count": + row[key] = value + + rows.append(row) + + df = pd.DataFrame(rows) + preferred_cols = [ + "category", + "question_number", + "em", + "f1", + "sp_em", + "sp_f1", + "joint_em", + "joint_f1", + ] + remaining_cols = [c for c in df.columns if c not in preferred_cols] + df = df[preferred_cols + remaining_cols] + + df.to_excel(xlsx_path, index=False) + print(f"[Eval] Metrics xlsx saved to: {xlsx_path}") + except Exception as exc: + print(f"[Eval] Failed to save metrics xlsx: {exc}") + + +def evaluate_one( + oai_client: OpenAI, + row: dict[str, Any], + chat_model: str, +) -> tuple[str, str, list[Any]]: + qid = str(row.get("_id")) + question = row.get("question") or "" + context = row.get("context") or "" + sp_list = row.get("sp") or [] + + raw_answer = llm_response( + oai_client, + chat_model, + context=context, + question=question, + ) + extracted = extract_answer(question, raw_answer) + answer = parse_extracted_answer(extracted, raw_answer) + + return qid, answer, sp_list + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="HotpotQA evaluation (OpenAI only).", + ) + parser.add_argument( + "--lib", + default="memos", + choices=["memos", "mem0", "supermemory"], + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--max_samples", type=int) + parser.add_argument("--version-dir", "-v") + parser.add_argument("--chat-model", required=True) + parser.add_argument("--search-mode", default="fine") + + args = parser.parse_args(argv) + + output_dir = HOT_POT_DIR / str(args.version_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if args.lib == "memos": + search_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json" + pred_path = output_dir / f"{args.lib}_{args.search_mode}_search_eval_results.json" + else: + search_path = output_dir / f"{args.lib}_search_results.json" + pred_path = output_dir / f"{args.lib}_eval_results.json" + + gold_path = HOT_POT_DIR / "dev_distractor_gold.json" + + if not search_path.exists(): + raise FileNotFoundError(f"Search results not found: {search_path}") + + if not gold_path.exists(): + load_hotpot_data(str(HOT_POT_DIR)) + + pred_answers: dict[str, str] = {} + pred_sp: dict[str, list[Any]] = {} + + if pred_path.exists(): + try: + prev = json.loads(pred_path.read_text(encoding="utf-8")) + if isinstance(prev, dict): + pred_answers.update(prev.get("answer", {})) + pred_sp.update(prev.get("sp", {})) + except Exception as exc: + print(f"[Eval] Failed to load existing pred: {exc}") + + rows = _load_json_list(search_path) + if args.max_samples is not None: + rows = rows[: args.max_samples] + + pending = [row for row in rows if str(row.get("_id")) not in pred_answers] + + print( + f"[Eval] lib={args.lib} total={len(rows)} pending={len(pending)} workers={args.workers}", + ) + + if not pending: + run_eval(pred_path, gold_path) + return + + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), + base_url=os.getenv("CHAT_MODEL_BASE_URL"), + ) + + metrics = Metrics() + start_time = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_eval(row: dict[str, Any]) -> tuple[str, str, list[Any]]: + start = time.perf_counter() + try: + result = evaluate_one( + oai_client, + row, + args.chat_model, + ) + metrics.record(time.perf_counter() - start, True) + return result + except Exception as exc: + metrics.record( + time.perf_counter() - start, + False, + str(exc), + ) + raise + + futures = [executor.submit(do_eval, row) for row in pending] + + for idx, future in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), + start=1, + ): + try: + qid, answer, sp_list = future.result() + pred_answers[qid] = answer + pred_sp[qid] = sp_list + if idx % 20 == 0: + _save_pred(pred_path, pred_answers, pred_sp) + except Exception as exc: + print(f"[Eval] Error: {exc}") + + total_duration = time.time() - start_time + summary = metrics.summary() + + perf_obj = { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "chat_model": args.chat_model, + "lib": args.lib, + }, + } + + _save_pred(pred_path, pred_answers, pred_sp, perf=perf_obj) + run_eval(pred_path, gold_path) + + print("\n" + "=" * 60) + print("Evaluation finished!") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print( + f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}", + ) + + if summary["errors"]: + print("\nTop errors:") + for error, count in sorted( + summary["errors"].items(), + key=lambda x: x[1], + reverse=True, + )[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py new file mode 100644 index 000000000..7718f3b19 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py @@ -0,0 +1,157 @@ +import re +import string +import sys + +from collections import Counter + +import ujson as json + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + zero_metric = (0, 0, 0) + + if ( + normalized_prediction in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return zero_metric + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def update_answer(metrics, prediction, gold): + em = exact_match_score(prediction, gold) + f1, prec, recall = f1_score(prediction, gold) + metrics["em"] += float(em) + metrics["f1"] += f1 + metrics["prec"] += prec + metrics["recall"] += recall + return em, prec, recall + + +def update_sp(metrics, prediction, gold): + cur_sp_pred = set(map(tuple, prediction)) + gold_sp_pred = set(map(tuple, gold)) + tp, fp, fn = 0, 0, 0 + for e in cur_sp_pred: + if e in gold_sp_pred: + tp += 1 + else: + fp += 1 + for e in gold_sp_pred: + if e not in cur_sp_pred: + fn += 1 + prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 + em = 1.0 if fp + fn == 0 else 0.0 + metrics["sp_em"] += em + metrics["sp_f1"] += f1 + metrics["sp_prec"] += prec + metrics["sp_recall"] += recall + return em, prec, recall + + +def eval(prediction_file, gold_file): + with open(prediction_file) as f: + prediction = json.load(f) + with open(gold_file) as f: + gold = json.load(f) + + evaluated_ids = set((prediction.get("answer") or {}).keys()) + gold = [dp for dp in gold if (dp.get("_id") or dp.get("id")) in evaluated_ids] + + metrics = { + "em": 0, + "f1": 0, + "prec": 0, + "recall": 0, + "sp_em": 0, + "sp_f1": 0, + "sp_prec": 0, + "sp_recall": 0, + "joint_em": 0, + "joint_f1": 0, + "joint_prec": 0, + "joint_recall": 0, + } + for dp in gold: + cur_id = dp["_id"] + can_eval_joint = True + if cur_id not in prediction["answer"]: + can_eval_joint = False + else: + em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"]) + if cur_id not in prediction["sp"]: + can_eval_joint = False + else: + sp_em, sp_prec, sp_recall = update_sp( + metrics, prediction["sp"][cur_id], dp["supporting_facts"] + ) + + if can_eval_joint: + joint_prec = prec * sp_prec + joint_recall = recall * sp_recall + if joint_prec + joint_recall > 0: + joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) + else: + joint_f1 = 0.0 + joint_em = em * sp_em + + metrics["joint_em"] += joint_em + metrics["joint_f1"] += joint_f1 + metrics["joint_prec"] += joint_prec + metrics["joint_recall"] += joint_recall + + print("=========Eval Results===========") + n = len(gold) + if n > 0: + for k in metrics: + metrics[k] /= n + print(metrics) + else: + print(metrics) + metrics["count"] = n + return metrics + + +if __name__ == "__main__": + eval(sys.argv[1], sys.argv[2]) diff --git a/evaluation/scripts/hotpot/hotpot_ingestion.py b/evaluation/scripts/hotpot/hotpot_ingestion.py new file mode 100644 index 000000000..425fd6b30 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_ingestion.py @@ -0,0 +1,299 @@ +import argparse +import json +import os +import time + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_HOTPOT") + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + if lib == "memos-online": + from evaluation.scripts.utils.client import MemosApiOnlineClient + + return MemosApiOnlineClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_added_ids(records_path: Path) -> dict[str, str | None]: + if not records_path.exists(): + return {} + + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + added = obj.get("added") if isinstance(obj, dict) else None + if isinstance(added, dict): + return {str(k): (str(v) if v is not None else None) for k, v in added.items()} + except Exception: + return {} + + return {} + + +def _save_added_ids( + records_path: Path, + added: dict[str, str | None], + perf: dict | None = None, +) -> None: + records_path.parent.mkdir(parents=True, exist_ok=True) + tmp = records_path.with_suffix(records_path.suffix + ".tmp") + + obj = {"added": dict(sorted(added.items()))} + if perf is not None: + obj["perf"] = perf + + tmp.write_text( + json.dumps(obj, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp, records_path) + + +def _build_tasks(ctx: dict | list | None) -> list[str]: + tasks: list[str] = [] + for item in ctx: + if not isinstance(item, list | tuple) or len(item) != 2: + continue + title, sentences = item + if not isinstance(sentences, list): + continue + for idx, sentence in enumerate(sentences): + tasks.append( + json.dumps({"idx": idx, "title": title, "sentence": sentence}, ensure_ascii=False) + ) + return tasks + + +def _build_memory_texts(ctx: dict | list | None) -> list[str]: + texts: list[str] = [] + + if not ctx: + return texts + + for item in ctx: + if not isinstance(item, list | tuple) or len(item) != 2: + continue + + title, sentences = item + if not isinstance(sentences, list): + continue + + for sentence in sentences: + texts.append(f"{title}: {sentence}") + + return texts + + +def add_context_memories( + client, + lib: str, + user_id: str, + ctx: dict | list | None, + url_prefix: str, + mode: str = "fine", + async_mode: str = "sync", +) -> str | None: + tasks = _build_tasks(ctx) + if not tasks: + return None + + file_id = None + + if lib == "memos-online": + file_url = f"{url_prefix.rstrip('/')}/{user_id}_context.txt" + result = retry_operation( + client.upload_file, + memos_knowledgebase_id, + file_url, + ) + file_id = result["data"][0]["id"] + + if lib == "memos": + messages = [{"type": "text", "text": content} for content in tasks] + writable_cube_ids = [user_id] + retry_operation( + client.add, + messages=messages, + user_id=user_id, + writable_cube_ids=writable_cube_ids, + source_type="batch_import", + mode=mode, + async_mode=async_mode, + ) + + if lib == "mem0": + ts = int(time.time()) + messages = [{"role": "user", "content": content} for content in tasks] + retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10) + + if lib == "supermemory": + for content in tasks: + retry_operation(client.add, content=content, user_id=user_id) + + return file_id + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="HotpotQA ingestion (add only).") + parser.add_argument( + "--lib", + type=str, + default="memos", + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--limit", type=int, default=None) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument( + "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)" + ) + parser.add_argument( + "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)" + ) + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/hotpot_text_files/", + help="URL prefix to be prepended to filenames", + ) + + args = parser.parse_args(argv) + + print("=" * 60) + print("hotpotQA Product Add Concurrent Tool") + print("=" * 60) + + output_dir = Path("evaluation/data/hotpot") + if args.version_dir: + output_dir = output_dir / args.version_dir + output_dir.mkdir(parents=True, exist_ok=True) + + items_list = load_hotpot_data("evaluation/data/hotpot") + if args.limit is not None: + items_list = items_list[: args.limit] + + records_path = output_dir / f"{args.lib}_added_records.json" + added_ids: dict[str, str | None] = _load_added_ids(records_path) + + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if str(qid) not in added_ids: + pending_items.append(it) + + print(f"[Add] lib={args.lib} total={len(items_list)} pending={len(pending_items)}") + if not pending_items: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + + def do_ingest(item): + start_time = time.perf_counter() + try: + qid = item.get("_id") or item.get("id") + ctx = item.get("context") + user_id = str(qid) + file_id = add_context_memories(client, args.lib, user_id, ctx, args.url_prefix) + + duration = time.perf_counter() - start_time + metrics.record(duration, True) + return str(qid), file_id + except Exception as e: + duration = time.perf_counter() - start_time + metrics.record(duration, False, str(e)) + raise e + + start_time = time.time() + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [executor.submit(do_ingest, it) for it in pending_items] + for _idx, f in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Adding"), 1): + try: + sid, fid = f.result() + if sid: + if args.lib == "memos-online": + added_ids[sid] = str(fid) if fid else None + else: + added_ids.setdefault(sid, None) + + if len(added_ids) % 20 == 0: + _save_added_ids(records_path, added_ids) + + except Exception as e: + print(f"[Add] Error: {e}") + + _save_added_ids(records_path, added_ids) + print(f"[Add] saved records to {records_path}") + + total_duration = time.time() - start_time + summary = metrics.summary() + + _save_added_ids( + records_path, + added_ids, + perf={ + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "mode": args.mode, + "async_mode": args.async_mode, + "lib": args.lib, + }, + }, + ) + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_search.py b/evaluation/scripts/hotpot/hotpot_search.py new file mode 100644 index 000000000..6542d4539 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_search.py @@ -0,0 +1,297 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.metrics import Metrics + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + result = func(*args, **kwargs) + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]: + if not output_path.exists(): + return [], set() + try: + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + ids = {str(r.get("_id")) for r in data if r.get("_id")} + return data, ids + if isinstance(data, dict) and isinstance(data.get("results"), list): + rows = data.get("results") or [] + ids = {str(r.get("_id")) for r in rows if r.get("_id")} + return rows, ids + except Exception: + return [], set() + return [], set() + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def get_sources_info(sources): + seen = set() + dedup_sp = [] + mem_texts = [] + + for source in sources: + if isinstance(source, str): + try: + obj = json.loads(source) + except json.JSONDecodeError: + continue + + title = obj.get("title") + idx = obj.get("idx") + sentence = obj.get("sentence") + + if title is None or idx is None: + continue + + key = (title, idx) + if key not in seen: + seen.add(key) + dedup_sp.append([title, idx]) + mem_texts.append(sentence) + + return mem_texts, dedup_sp + + +def memos_search( + client, user_id: str, query: str, top_k: int, search_mode: str +) -> tuple[str, list[list[str | int]]]: + readable_cube_ids = [user_id] + results = retry_operation( + client.search, + query=query, + user_id=user_id, + readable_cube_ids=readable_cube_ids, + top_k=top_k, + mode=search_mode, + ) + memories = results["text_mem"][0]["memories"] + mem_texts = [i["memory"] for i in memories] + + sources = [] + for m in memories: + source = (m.get("metadata", {}) or {}).get("sources") or [] + for s in source: + source_txt = json.loads(s["content"]) + sources.append(json.loads(source_txt)["content"]) + sources.extend(source) + + _, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def mem0_search(client, user_id: str, query: str, top_k: int) -> tuple[str, list[list[str | int]]]: + res = retry_operation(client.search, query, user_id, top_k) + sources = [m.get("memory", "") for m in res.get("results", []) if m.get("memory")] + mem_texts, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def supermemory_search( + client, user_id: str, query: str, top_k: int +) -> tuple[str, list[list[str | int]]]: + sources = retry_operation(client.search, query, user_id, top_k) + mem_texts, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def search_one( + client, lib: str, item: dict, top_k: int, version_dir: str, search_mode: str +) -> dict: + qid = item.get("_id") or item.get("id") + question = item.get("question") or "" + user_id = version_dir + "_" + str(qid) + + if lib == "memos": + memories, sp_list = memos_search(client, user_id, str(question), top_k, search_mode) + elif lib == "mem0": + memories, sp_list = mem0_search(client, user_id, str(question), top_k) + elif lib == "supermemory": + memories, sp_list = supermemory_search(client, user_id, str(question), top_k) + else: + memories, sp_list = [], [] + + return { + "_id": str(qid), + "question": question, + "answer": item.get("answer"), + "memories": memories, + "sp": sp_list, + } + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="HotpotQA search (search only).") + parser.add_argument( + "--lib", + type=str, + default="memos", + choices=["memos", "mem0", "supermemory"], + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--top-k", type=int, default=7) + parser.add_argument( + "--limit", type=int, default=None, help="Limit number of samples (was max_samples)" + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--search-mode", default="fine", help="Search mode") + + args = parser.parse_args(argv) + + # Handle limit/max_samples compatibility + limit = args.limit if args.limit is not None else args.max_samples + + items_list = load_hotpot_data("evaluation/data/hotpot") + if limit is not None: + items_list = items_list[:limit] + + output_dir = Path(f"evaluation/data/hotpot/{args.version_dir}") + output_dir.mkdir(parents=True, exist_ok=True) + + if args.lib == "memos": + output_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json" + else: + output_path = output_dir / f"{args.lib}_search_results.json" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + results, processed_ids = _load_existing_results(output_path) + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if str(qid) not in processed_ids: + pending_items.append(it) + + print( + f"[Search] lib={args.lib} total={len(items_list)} pending={len(pending_items)} top_k={args.top_k}" + ) + if not pending_items: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + start_time = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_search(item): + st = time.perf_counter() + try: + r = search_one( + client, args.lib, item, args.top_k, args.version_dir, args.search_mode + ) + dur = time.perf_counter() - st + metrics.record(dur, True) + return r + except Exception as e: + dur = time.perf_counter() - st + metrics.record(dur, False, str(e)) + raise e + + futures = [executor.submit(do_search, it) for it in pending_items] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Searching"), 1 + ): + try: + r = f.result() + results.append(r) + if idx % 20 == 0: + _save_json_list(output_path, results) + except Exception as e: + print(f"[Search] Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"[Search] saved {len(results)} rows to {output_path}") + + # Save performance metrics + total_duration = time.time() - start_time + summary = metrics.summary() + # Merge perf into results json file + combined_obj = { + "results": results, + "perf": { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "top_k": args.top_k, + "limit": limit, + "search_mode": args.search_mode, + "lib": args.lib, + }, + }, + } + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_update_records.py b/evaluation/scripts/hotpot/hotpot_update_records.py new file mode 100644 index 000000000..19d3b0491 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_update_records.py @@ -0,0 +1,80 @@ +import argparse +import json +import os + +from pathlib import Path + + +def _read_json(path: Path) -> dict: + if not path.exists(): + return {} + try: + txt = path.read_text(encoding="utf-8") + if not txt: + return {} + obj = json.loads(txt) + return obj if isinstance(obj, dict) else {} + except Exception: + return {} + + +def _write_json_atomic(path: Path, obj: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Update added_records.json from status JSON reupload results." + ) + parser.add_argument( + "--status-json", + type=str, + default="evaluation/data/hotpot/test_0113_memos/memos-online_file_status.json", + ) + parser.add_argument( + "--records-json", + type=str, + default="evaluation/data/hotpot/test_0113_memos/memos-online_added_records.json", + ) + args = parser.parse_args(argv) + + status_path = Path(args.status_json) + records_path = Path(args.records_json) + + status_obj = _read_json(status_path) + records_obj = _read_json(records_path) + + added = {} + if isinstance(records_obj.get("added"), dict): + added = { + str(k): (str(v) if v is not None else None) for k, v in records_obj["added"].items() + } + + reupload_results = status_obj.get("reupload_results") or [] + updated_count = 0 + for item in reupload_results: + if not isinstance(item, dict): + continue + if item.get("ok") != "true": + continue + uid = item.get("user_id") + new_id = item.get("new_file_id") + if not uid or not new_id: + continue + uid = str(uid) + new_id = str(new_id) + if added.get(uid) != new_id: + added[uid] = new_id + updated_count += 1 + + records_obj["added"] = dict(sorted(added.items())) + _write_json_atomic(records_path, records_obj) + + print(f"Updated {updated_count} entries in {records_path}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/long_bench-v2/__init__.py b/evaluation/scripts/long_bench-v2/__init__.py deleted file mode 100644 index 786c0ce03..000000000 --- a/evaluation/scripts/long_bench-v2/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# LongBench v2 evaluation scripts diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py deleted file mode 100644 index 5a5c11968..000000000 --- a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py +++ /dev/null @@ -1,199 +0,0 @@ -import argparse -import json -import os -import sys -import threading - -from concurrent.futures import ThreadPoolExecutor, as_completed - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -def ingest_sample( - client, sample, sample_idx, frame, version, success_records, record_file, file_lock -): - """Ingest a single LongBench v2 sample as memories.""" - # Skip if already processed - if str(sample_idx) in success_records: - return True - - user_id = f"longbench_v2_{sample_idx}_{version}" - conv_id = f"longbench_v2_{sample_idx}_{version}" - - # Get context and convert to messages - context = sample.get("context", "") - - # For memos, we ingest the context as a raw document content - messages = [ - { - "type": "file", - "file": { - "file_data": context, - "file_id": str(sample_idx), - }, - } - ] - - if "memos-api" in frame: - try: - client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) - print(f"✅ [{frame}] Ingested sample {sample_idx}") - # Record successful ingestion (thread-safe) - with file_lock, open(record_file, "a") as f: - f.write(f"{sample_idx}\n") - f.flush() - return True - except Exception as e: - print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}") - return False - - return False - - -def load_dataset_from_local(): - """Load LongBench v2 dataset from local JSON file.""" - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - filepath = os.path.join(data_dir, "data.json") - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSON file - with open(filepath, encoding="utf-8") as f: - samples = json.load(f) - - return samples - - -def main(frame, version="default", num_workers=10, max_samples=None): - """Main ingestion function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load dataset from local file - try: - dataset = load_dataset_from_local() - print(f"Loaded {len(dataset)} samples from LongBench v2") - except FileNotFoundError as e: - print(f"❌ Error loading dataset: {e}") - return - except Exception as e: - print(f"❌ Error loading dataset: {e}") - return - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize checkpoint file for resume functionality - checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" - ) - os.makedirs(checkpoint_dir, exist_ok=True) - record_file = os.path.join(checkpoint_dir, "success_records.txt") - - # Load existing success records for resume - success_records = set() - if os.path.exists(record_file): - with open(record_file) as f: - for line in f: - line = line.strip() - if line: - success_records.add(line) - print(f"📋 Found {len(success_records)} already processed samples (resume mode)") - else: - print("📋 Starting fresh ingestion (no checkpoint found)") - - # Initialize client - client = None - if frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - else: - print(f"❌ Unsupported frame: {frame}") - return - - # Ingest samples - success_count = len(success_records) # Start with already processed count - file_lock = threading.Lock() # Lock for thread-safe file writing - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - ingest_sample, - client, - sample, - idx, - frame, - version, - success_records, - record_file, - file_lock, - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Ingesting LongBench v2", - ): - try: - if future.result(): - success_count += 1 - except Exception as e: - print(f"Error processing sample: {e}") - - print(f"\n{'=' * 80}") - print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=["memos-api", "memos-api-online"], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=2, - help="Number of parallel workers", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples to process (default: all)", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py deleted file mode 100644 index af324c9c7..000000000 --- a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -import json -import os - - -def calculate_accuracy(responses): - """Calculate accuracy metrics for LongBench v2. - - Logic is aligned with longbench_stx.print_metrics, but returns a dict - and additionally computes by_domain statistics. - """ - total = len(responses) - if total == 0: - return {} - - # Counters (aligned with longbench_stx.print_metrics) - easy = hard = short = medium = long = 0 - easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0 - total_prompt_tokens = 0 - - for pred in responses: - acc = int(pred.get("judge", False)) - diff = pred.get("difficulty", "easy") - length = pred.get("length", "short") - - pt = pred.get("prompt_tokens") - if isinstance(pt, int | float): - total_prompt_tokens += int(pt) - - if diff == "easy": - easy += 1 - easy_acc += acc - else: - hard += 1 - hard_acc += acc - - if length == "short": - short += 1 - short_acc += acc - elif length == "medium": - medium += 1 - medium_acc += acc - else: - long += 1 - long_acc += acc - - o_acc = round(100 * (easy_acc + hard_acc) / total, 2) - e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0.0 - h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0.0 - s_acc = round(100 * short_acc / short, 2) if short > 0 else 0.0 - m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0.0 - l_acc = round(100 * long_acc / long, 2) if long > 0 else 0.0 - - # Additional by-domain stats (extra vs. stx) - domain_stats = {} - for r in responses: - domain = r.get("domain", "Unknown") - if domain not in domain_stats: - domain_stats[domain] = {"total": 0, "correct": 0} - domain_stats[domain]["total"] += 1 - if r.get("judge", False): - domain_stats[domain]["correct"] += 1 - - domain_acc = { - domain: round(100 * stats["correct"] / stats["total"], 2) - for domain, stats in domain_stats.items() - } - - return { - "overall": o_acc, - "easy": e_acc, - "hard": h_acc, - "short": s_acc, - "medium": m_acc, - "long": l_acc, - "by_domain": domain_acc, - "total_samples": total, - "correct_samples": easy_acc + hard_acc, - "total_prompt_tokens": total_prompt_tokens, - "avg_prompt_tokens": round(total_prompt_tokens / total, 2) if total > 0 else 0.0, - } - - -def main(frame, version="default"): - """Main metric calculation function.""" - print("\n" + "=" * 80) - print(f"📊 LONGBENCH V2 METRICS CALCULATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load responses - responses_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_responses.json" - if not os.path.exists(responses_path): - print(f"❌ Responses not found: {responses_path}") - print("Please run longbench_v2_responses.py first") - return - - with open(responses_path, encoding="utf-8") as f: - responses = json.load(f) - - # Only keep entries that actually have search results: - # - For new pipeline: non-empty memories_used list - # - For older runs: non-empty search_context string - def _has_search_results(r: dict) -> bool: - mems = r.get("memories_used") - if isinstance(mems, list) and any(str(m).strip() for m in mems): - return True - ctx = str(r.get("search_context", "")).strip() - return ctx != "" - - filtered = [r for r in responses if _has_search_results(r)] - - # Calculate metrics (handle case where no samples have search results) - if not filtered: - print("⚠️ No responses with valid search results were found. Metrics will be zeroed.") - metrics = { - "overall": 0.0, - "easy": 0.0, - "hard": 0.0, - "short": 0.0, - "medium": 0.0, - "long": 0.0, - "by_domain": {}, - "total_samples": 0, - "correct_samples": 0, - "total_prompt_tokens": 0, - "avg_prompt_tokens": 0.0, - } - else: - metrics = calculate_accuracy(filtered) - - # Save metrics - output_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_metrics.json" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(metrics, f, ensure_ascii=False, indent=4) - - print(f"\n{'=' * 80}") - print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - # Print summary table - print("\n📊 Summary of Results:") - print("-" * 80) - print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.2f}%") - print(f"{'Easy':<30s}: {metrics['easy']:.2f}%") - print(f"{'Hard':<30s}: {metrics['hard']:.2f}%") - print(f"{'Short':<30s}: {metrics['short']:.2f}%") - print(f"{'Medium':<30s}: {metrics['medium']:.2f}%") - print(f"{'Long':<30s}: {metrics['long']:.2f}%") - print(f"{'Avg Prompt Tokens':<30s}: {metrics.get('avg_prompt_tokens', 0.0):.2f}") - print("\nBy Domain:") - for domain, acc in metrics["by_domain"].items(): - print(f" {domain:<28s}: {acc:.1f}%") - print(f"\nTotal Samples: {metrics['total_samples']}") - print(f"Correct: {metrics['correct_samples']}") - print("-" * 80) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=["memos-api", "memos-api-online"], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - args = parser.parse_args() - - main(args.lib, args.version) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py deleted file mode 100644 index 686062c5f..000000000 --- a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py +++ /dev/null @@ -1,319 +0,0 @@ -import argparse -import json -import os -import re -import sys -import threading - -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from openai import OpenAI -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -# RAG-style prompt template aligned with longbench_stx.TEMPLATE_RAG -TEMPLATE_RAG = """Please read the following retrieved text chunks and answer the question below. - - -$DOC$ - - -What is the correct answer to this question: $Q$ -Choices: -(A) $C_A$ -(B) $C_B$ -(C) $C_C$ -(D) $C_D$ - -Format your response as follows: "The correct answer is (insert answer here)".""" - - -def extract_answer(response): - """Extract answer from response (A, B, C, or D). - - Logic is kept consistent with longbench_stx.extract_answer. - """ - response = response.replace("*", "") - # Try to find "The correct answer is (X)" pattern - match = re.search(r"The correct answer is \(([A-D])\)", response) - if match: - return match.group(1) - else: - match = re.search(r"The correct answer is ([A-D])", response) - if match: - return match.group(1) - return None - - -def llm_answer(llm_client, memories, question, choices): - """Generate response using RAG-style prompt, aligned with longbench_stx.llm_answer. - - Returns: - tuple[str, int | None]: (response_text, prompt_tokens) - """ - # Join memories to form the retrieved context document - doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) - - prompt = ( - TEMPLATE_RAG.replace("$DOC$", doc_content) - .replace("$Q$", question) - .replace("$C_A$", choices.get("A", "")) - .replace("$C_B$", choices.get("B", "")) - .replace("$C_C$", choices.get("C", "")) - .replace("$C_D$", choices.get("D", "")) - ) - - try: - response = llm_client.chat.completions.create( - model=os.getenv("CHAT_MODEL"), - messages=[{"role": "user", "content": prompt}], - temperature=0.1, - max_tokens=12800, - ) - text = response.choices[0].message.content or "" - prompt_tokens = None - usage = getattr(response, "usage", None) - if usage is not None: - # openai>=1.x style: usage.prompt_tokens - pt = getattr(usage, "prompt_tokens", None) - if isinstance(pt, int): - prompt_tokens = pt - else: - # fallback for dict-like usage - try: - prompt_tokens = int(usage.get("prompt_tokens")) # type: ignore[call-arg] - except Exception: - prompt_tokens = None - return text, prompt_tokens - except Exception as e: - print(f"Error generating response: {e}") - return "", None - - -def process_sample(search_result, llm_client, success_records, record_file, file_lock): - """Process a single sample: generate answer. - - This mirrors longbench_stx.evaluate_sample but consumes precomputed search results - produced by longbench_v2_search.py. - """ - # Use sample_idx when available, otherwise fall back to _id so that - # we can work with stx-style search results that only have _id. - sample_idx = search_result.get("sample_idx") - sample_key = str(sample_idx) if sample_idx is not None else str(search_result.get("_id", "")) - - # Skip if already processed - if sample_key and sample_key in success_records: - return None - - start = time() - - question = search_result.get("question", "") - choices = { - "A": search_result.get("choice_A", "") or "", - "B": search_result.get("choice_B", "") or "", - "C": search_result.get("choice_C", "") or "", - "D": search_result.get("choice_D", "") or "", - } - - # Prefer memories saved by longbench_v2_search; fall back to reconstructing - # from raw search_results if needed (for old search jsons). - memories = search_result.get("memories_used") - if memories is None: - raw = search_result.get("search_results") or {} - memories = [] - if isinstance(raw, dict) and raw.get("text_mem"): - text_mem = raw["text_mem"] - if text_mem and text_mem[0].get("memories"): - memories = [ - m.get("memory", "") for m in text_mem[0]["memories"] if isinstance(m, dict) - ] - - # Ensure we have a list, even if empty - memories = memories or [] - - # Skip if no retrieved memories and no question - if not question: - return None - if not memories: - return None - - # Generate answer - response, prompt_tokens = llm_answer(llm_client, memories, str(question), choices) - - # Extract answer (A, B, C, or D) - pred = extract_answer(response) - - response_duration_ms = (time() - start) * 1000 - - result = { - # Preserve sample_idx if present for backward compatibility - "sample_idx": search_result.get("sample_idx"), - "_id": search_result.get("_id"), - "domain": search_result.get("domain"), - "sub_domain": search_result.get("sub_domain"), - "difficulty": search_result.get("difficulty"), - "length": search_result.get("length"), - "question": question, - "choice_A": choices["A"], - "choice_B": choices["B"], - "choice_C": choices["C"], - "choice_D": choices["D"], - "answer": search_result.get("answer"), - "pred": pred, - "response": response, - "judge": pred == search_result.get("answer") if pred else False, - "prompt_tokens": prompt_tokens, - # Keep full retrieved memories list for inspection / debugging - "memories_used": memories, - # Preserve full search results payload (e.g., list of memories) - "search_results": search_result.get("search_results"), - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_result.get("search_duration_ms", 0), - } - - # Record successful processing (thread-safe) - if sample_key: - with file_lock, open(record_file, "a") as f: - f.write(f"{sample_key}\n") - f.flush() - - return result - - -def main(frame, version="default", num_workers=10): - """Main response generation function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH V2 RESPONSE GENERATION - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Initialize checkpoint file for resume functionality - checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" - ) - os.makedirs(checkpoint_dir, exist_ok=True) - record_file = os.path.join(checkpoint_dir, "response_success_records.txt") - search_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json") - output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_responses.json") - - # Load search results - if not os.path.exists(search_path): - print(f"❌ Search results not found: {search_path}") - print("Please run longbench_v2_search.py first") - return - - with open(search_path, encoding="utf-8") as f: - search_results = json.load(f) - - # Load existing results and success records for resume - existing_results: dict[str, dict] = {} - success_records: set[str] = set() - if os.path.exists(output_path): - with open(output_path, encoding="utf-8") as f: - existing_results_list = json.load(f) - for result in existing_results_list: - # Use sample_idx if present, otherwise _id as the unique key - sample_idx = result.get("sample_idx") - key = str(sample_idx) if sample_idx is not None else str(result.get("_id", "")) - if key: - existing_results[key] = result - success_records.add(key) - print(f"📋 Found {len(existing_results)} existing responses (resume mode)") - else: - print("📋 Starting fresh response generation (no checkpoint found)") - - # Load additional success records from checkpoint file - if os.path.exists(record_file): - with open(record_file) as f: - for line in f: - line = line.strip() - if line and line not in success_records: - success_records.add(line) - print(f"📋 Total {len(success_records)} samples already processed") - - # Initialize LLM client - llm_client = OpenAI( - api_key=os.getenv("CHAT_MODEL_API_KEY"), - base_url=os.getenv("CHAT_MODEL_BASE_URL"), - ) - print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") - - # Process all samples concurrently using ThreadPoolExecutor - new_results = [] - file_lock = threading.Lock() # Lock for thread-safe file writing - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [ - executor.submit( - process_sample, sample, llm_client, success_records, record_file, file_lock - ) - for sample in search_results - ] - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Generating responses", - ): - result = future.result() - if result: - new_results.append(result) - # Update existing results with new result (keyed by sample_idx or _id) - sample_idx = result.get("sample_idx") - key = str(sample_idx) if sample_idx is not None else str(result.get("_id", "")) - if key: - existing_results[key] = result - - # Merge and save all results - all_responses = list(existing_results.values()) - - # Sort by sample_idx when available, otherwise by _id for stability - def _sort_key(x: dict): - if x.get("sample_idx") is not None: - return ("0", int(x.get("sample_idx"))) - return ("1", str(x.get("_id", ""))) - - all_responses.sort(key=_sort_key) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_responses, f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=["memos-api", "memos-api-online"], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for loading results", - ) - parser.add_argument( - "--workers", - type=int, - default=10, - help="Number of parallel workers", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers) diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py deleted file mode 100644 index 2347e5d66..000000000 --- a/evaluation/scripts/long_bench-v2/longbench_v2_search.py +++ /dev/null @@ -1,273 +0,0 @@ -import argparse -import json -import os -import sys -import threading - -from concurrent.futures import ThreadPoolExecutor, as_completed -from time import time - -from dotenv import load_dotenv -from tqdm import tqdm - - -ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") - -sys.path.insert(0, ROOT_DIR) -sys.path.insert(0, EVAL_SCRIPTS_DIR) - - -def memos_api_search(client, query, user_id, top_k, frame): - """Search using memos API.""" - start = time() - search_results = client.search(query=query, user_id=user_id, top_k=top_k) - - # Extract raw memory texts in the same way as longbench_stx.memos_search - memories_texts: list[str] = [] - if ( - (frame == "memos-api" or frame == "memos-api-online") - and isinstance(search_results, dict) - and "text_mem" in search_results - ): - text_mem = search_results.get("text_mem") or [] - if text_mem and text_mem[0].get("memories"): - memories = text_mem[0]["memories"] - for m in memories: - if not isinstance(m, dict): - continue - # tags may be at top-level or inside metadata - tags = m.get("tags") or m.get("metadata", {}).get("tags") or [] - # Skip fast-mode memories - if any(isinstance(t, str) and "mode:fast" in t for t in tags): - continue - mem_text = m.get("memory", "") - if str(mem_text).strip(): - memories_texts.append(mem_text) - - duration_ms = (time() - start) * 1000 - return memories_texts, duration_ms, search_results - - -def process_sample( - client, sample, sample_idx, frame, version, top_k, success_records, record_file, file_lock -): - """Process a single sample: search for relevant memories.""" - # Skip if already processed - if str(sample_idx) in success_records: - return None - - user_id = f"longbench_v2_{sample_idx}_{version}" - query = sample.get("question", "") - - if not query: - return None - - memories_used, duration_ms, search_results = memos_api_search( - client, query, user_id, top_k, frame - ) - - if not (isinstance(memories_used, list) and any(str(m).strip() for m in memories_used)): - return None - - result = { - "sample_idx": sample_idx, - "_id": sample.get("_id"), - "domain": sample.get("domain"), - "sub_domain": sample.get("sub_domain"), - "difficulty": sample.get("difficulty"), - "length": sample.get("length"), - "question": query, - "choice_A": sample.get("choice_A"), - "choice_B": sample.get("choice_B"), - "choice_C": sample.get("choice_C"), - "choice_D": sample.get("choice_D"), - "answer": sample.get("answer"), - # Raw memories used for RAG answering (aligned with longbench_stx) - "memories_used": memories_used, - # Preserve full search results payload for debugging / analysis - "search_results": search_results, - "search_duration_ms": duration_ms, - } - - # Record successful processing (thread-safe) - with file_lock, open(record_file, "a") as f: - f.write(f"{sample_idx}\n") - f.flush() - - return result - - -def load_dataset_from_local(): - """Load LongBench v2 dataset from local JSON file.""" - data_dir = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "data", - "long_bench_v2", - ) - - filepath = os.path.join(data_dir, "data.json") - - if not os.path.exists(filepath): - raise FileNotFoundError(f"Dataset file not found: {filepath}") - - # Load JSON file - with open(filepath, encoding="utf-8") as f: - samples = json.load(f) - - return samples - - -def main(frame, version="default", num_workers=10, top_k=20, max_samples=None): - """Main search function.""" - load_dotenv() - - print("\n" + "=" * 80) - print(f"🚀 LONGBENCH V2 SEARCH - {frame.upper()} v{version}".center(80)) - print("=" * 80 + "\n") - - # Load dataset from local file - try: - dataset = load_dataset_from_local() - print(f"Loaded {len(dataset)} samples from LongBench v2") - except FileNotFoundError as e: - print(f"❌ Error loading dataset: {e}") - return - except Exception as e: - print(f"❌ Error loading dataset: {e}") - return - - # Limit samples if specified - if max_samples: - dataset = dataset[:max_samples] - print(f"Limited to {len(dataset)} samples") - - # Initialize checkpoint file for resume functionality - checkpoint_dir = os.path.join( - ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}" - ) - os.makedirs(checkpoint_dir, exist_ok=True) - record_file = os.path.join(checkpoint_dir, "search_success_records.txt") - output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json") - - # Load existing results and success records for resume - existing_results = {} - success_records = set() - if os.path.exists(output_path): - with open(output_path, encoding="utf-8") as f: - existing_results_list = json.load(f) - for result in existing_results_list: - sample_idx = result.get("sample_idx") - if sample_idx is not None: - existing_results[sample_idx] = result - success_records.add(str(sample_idx)) - print(f"📋 Found {len(existing_results)} existing search results (resume mode)") - else: - print("📋 Starting fresh search (no checkpoint found)") - - # Load additional success records from checkpoint file - if os.path.exists(record_file): - with open(record_file) as f: - for line in f: - line = line.strip() - if line and line not in success_records: - success_records.add(line) - print(f"📋 Total {len(success_records)} samples already processed") - - # Initialize client - client = None - if frame == "memos-api": - from utils.client import MemosApiClient - - client = MemosApiClient() - elif frame == "memos-api-online": - from utils.client import MemosApiOnlineClient - - client = MemosApiOnlineClient() - else: - print(f"❌ Unsupported frame: {frame}") - return - - # Process samples - new_results = [] - file_lock = threading.Lock() # Lock for thread-safe file writing - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for idx, sample in enumerate(dataset): - future = executor.submit( - process_sample, - client, - sample, - idx, - frame, - version, - top_k, - success_records, - record_file, - file_lock, - ) - futures.append(future) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Searching LongBench v2", - ): - result = future.result() - if result: - new_results.append(result) - # Update existing results with new result - sample_idx = result.get("sample_idx") - if sample_idx is not None: - existing_results[sample_idx] = result - - # Merge and save all results - search_results = list(existing_results.values()) - # Sort by sample_idx to maintain order - search_results.sort(key=lambda x: x.get("sample_idx", 0)) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(search_results, f, ensure_ascii=False, indent=2) - - print(f"\n{'=' * 80}") - print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80)) - print(f"{'=' * 80}\n") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - choices=["memos-api", "memos-api-online"], - default="memos-api", - ) - parser.add_argument( - "--version", - type=str, - default="default", - help="Version identifier for saving results", - ) - parser.add_argument( - "--workers", - type=int, - default=1, - help="Number of parallel workers", - ) - parser.add_argument( - "--top_k", - type=int, - default=20, - help="Number of results to retrieve in search queries", - ) - parser.add_argument( - "--max_samples", - type=int, - default=None, - help="Maximum number of samples to process (default: all)", - ) - args = parser.parse_args() - - main(args.lib, args.version, args.workers, args.top_k, args.max_samples) diff --git a/evaluation/scripts/long_bench-v2/wait_scheduler.py b/evaluation/scripts/long_bench-v2/wait_scheduler.py deleted file mode 100644 index 716869a11..000000000 --- a/evaluation/scripts/long_bench-v2/wait_scheduler.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import time - -import requests - -from dotenv import load_dotenv - - -def wait_until_completed(params: dict, interval: float = 2.0, timeout: float = 600.0): - """ - Keep polling /product/scheduler/status until status == 'completed' (or terminal). - - params: dict passed as query params, e.g. {"user_id": "xxx"} or {"user_id": "xxx", "task_id": "..."} - interval: seconds between polls - timeout: max seconds to wait before raising TimeoutError - """ - load_dotenv() - base_url = os.getenv("MEMOS_URL") - if not base_url: - raise RuntimeError("MEMOS_URL not set in environment") - - url = f"{base_url}/product/scheduler/status" - start = time.time() - active_states = {"waiting", "pending", "in_progress"} - - while True: - resp = requests.get(url, params=params, timeout=10) - resp.raise_for_status() - data = resp.json() - - items = data.get("data", []) if isinstance(data, dict) else [] - statuses = [item.get("status") for item in items if isinstance(item, dict)] - status_set = set(statuses) - - # Print current status snapshot - print(f"Current status: {status_set or 'empty'}") - - # Completed if no active states remain - if not status_set or status_set.isdisjoint(active_states): - print("Task completed!") - return data - - if (time.time() - start) > timeout: - raise TimeoutError(f"Timeout after {timeout}s; last statuses={status_set or 'empty'}") - - time.sleep(interval) - - -if __name__ == "__main__": - import argparse - import json - - parser = argparse.ArgumentParser() - parser.add_argument( - "--user_id", default="longbench_v2_0_long-bench-v2-1208-2119-async", help="User ID to query" - ) - parser.add_argument("--task_id", help="Optional task_id to query") - parser.add_argument("--interval", type=float, default=2.0, help="Poll interval seconds") - parser.add_argument("--timeout", type=float, default=600.0, help="Timeout seconds") - args = parser.parse_args() - - params = {"user_id": args.user_id} - if args.task_id: - params["task_id"] = args.task_id - - result = wait_until_completed(params, interval=args.interval, timeout=args.timeout) - print(json.dumps(result, indent=2, ensure_ascii=False)) diff --git a/evaluation/scripts/longbench_v2/longbench_v2_check_files.py b/evaluation/scripts/longbench_v2/longbench_v2_check_files.py new file mode 100644 index 000000000..7cf209cce --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_check_files.py @@ -0,0 +1,214 @@ +import argparse +import json +import os + +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.client import MemosApiOnlineClient + + +load_dotenv() +# Knowledgebase ID used when re-uploading LongBench-v2 files +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_LONGBENCH_V2") + + +def _load_added_ids(records_path: Path) -> dict[str, str | None]: + """ + Load mapping from sample_id (version-prefixed user id) to file_id from add_results.json. + """ + if not records_path.exists(): + return {} + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + added = obj.get("added") if isinstance(obj, dict) else None + if isinstance(added, dict): + return {str(k): (str(v) if v is not None else None) for k, v in added.items()} + except Exception: + return {} + return {} + + +def _check_file_status( + client: MemosApiOnlineClient, file_ids: list[str], batch_size: int +) -> dict[str, dict[str, str | None]]: + """ + Phase 1: Query file processing status for all given file_ids in batches. + Returns file_id -> {name, size, status}. + """ + file_status: dict[str, dict[str, str | None]] = {} + for i in tqdm(range(0, len(file_ids), batch_size), desc="Checking files"): + batch = file_ids[i : i + batch_size] + try: + resp = client.check_file(batch) + except Exception as e: + print(f"[Check] error for batch starting at {i}: {e}") + continue + if not isinstance(resp, dict): + continue + data = resp.get("data") or {} + details = data.get("file_detail_list") or [] + for item in details: + if not isinstance(item, dict): + continue + fid = item.get("id") + if not fid: + continue + file_status[str(fid)] = { + "name": item.get("name"), + "size": item.get("size"), + "status": item.get("status"), + } + return file_status + + +def _reupload_failed_files( + client: MemosApiOnlineClient, + file_status: dict[str, dict[str, str | None]], + added_ids: dict[str, str | None], + url_prefix: str, +) -> list[dict[str, str | None]]: + """ + Phase 2: Re-upload files whose status == PROCESSING_FAILED. + LongBench-v2 user_id is '_', so the file URL uses the last segment. + Returns a list of per-file reupload results for auditing. + """ + fid_to_user: dict[str, str] = {} + for uid, fid in added_ids.items(): + if fid: + fid_to_user[str(fid)] = str(uid) + reupload_results: list[dict[str, str | None]] = [] + failed_ids = [ + fid for fid, info in file_status.items() if (info.get("status") == "PROCESSING_FAILED") + ] + for fid in tqdm(failed_ids, desc="Reuploading failed files"): + uid = fid_to_user.get(fid) + if not uid: + reupload_results.append( + { + "old_file_id": fid, + "user_id": None, + "new_file_id": None, + "ok": "false", + "error": "user_id_not_found", + } + ) + continue + file_url = f"{url_prefix.rstrip('/')}/{uid.split('_')[-1]}.txt" + try: + resp = client.upload_file(memos_knowledgebase_id or "", file_url) + new_id = None + if isinstance(resp, dict): + data = resp.get("data") or {} + if isinstance(data, list) and data: + first = data[0] if isinstance(data[0], dict) else {} + new_id = str(first.get("id")) if first.get("id") else None + reupload_results.append( + { + "old_file_id": fid, + "user_id": uid, + "new_file_id": new_id, + "ok": "true", + "error": None, + } + ) + except Exception as e: + reupload_results.append( + { + "old_file_id": fid, + "user_id": uid, + "new_file_id": None, + "ok": "false", + "error": str(e), + } + ) + return reupload_results + + +def main(argv: list[str] | None = None) -> None: + """ + Orchestrate file status checking and failed-file reupload for LongBench-v2 memos-online runs. + """ + parser = argparse.ArgumentParser( + description="Check LongBench-v2 memos-online file status and reupload failed." + ) + parser.add_argument("--lib", type=str, default="memos-online") + parser.add_argument("--version-dir", "-v", default=None) + parser.add_argument("--batch-size", type=int, default=50) + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/longbench_v2_text_files/", + ) + args = parser.parse_args(argv) + + if args.lib != "memos-online": + print(f"Only memos-online is supported, got lib={args.lib}") + return + + output_dir = Path("evaluation/data/longbench_v2") + if args.version_dir: + output_dir = output_dir / args.version_dir + output_dir.mkdir(parents=True, exist_ok=True) + + records_path = output_dir / f"{args.lib}_add_results.json" + print(f"[Check] loading records from {records_path}") + + added_ids = _load_added_ids(records_path) + file_ids = sorted({fid for fid in added_ids.values() if fid}) + print(f"[Check] total file ids: {len(file_ids)}") + if not file_ids: + return + + client = MemosApiOnlineClient() + batch_size = max(1, args.batch_size) + file_status = _check_file_status(client, file_ids, batch_size) + reupload_results = _reupload_failed_files(client, file_status, added_ids, args.url_prefix) + + # Update added records with new file ids + if reupload_results: + try: + obj: dict = {} + if records_path.exists(): + txt = records_path.read_text(encoding="utf-8") + if txt: + parsed = json.loads(txt) + if isinstance(parsed, dict): + obj = parsed + added_obj: dict[str, str | None] = {} + if isinstance(obj.get("added"), dict): + added_obj = { + str(k): (str(v) if v is not None else None) for k, v in obj["added"].items() + } + else: + added_obj = dict(added_ids) + for item in reupload_results: + if item.get("ok") == "true" and item.get("user_id") and item.get("new_file_id"): + added_obj[str(item["user_id"])] = str(item["new_file_id"]) + obj["added"] = dict(sorted(added_obj.items())) + tmp_r = records_path.with_suffix(records_path.suffix + ".tmp") + tmp_r.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp_r, records_path) + print(f"[Update] updated add_results with new file ids -> {records_path}") + except Exception as e: + print(f"[Update] failed to update add_results: {e}") + + output_path = output_dir / f"{args.lib}_file_status.json" + + result_obj = { + "lib": args.lib, + "version_dir": args.version_dir, + "total": len(file_ids), + "file_detail_list": [{"id": fid, **(file_status.get(fid) or {})} for fid in file_ids], + "reupload_results": reupload_results, + } + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(result_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + print(f"[Check] saved file status for {len(file_status)} files to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_eval.py b/evaluation/scripts/longbench_v2/longbench_v2_eval.py new file mode 100644 index 000000000..808558e41 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_eval.py @@ -0,0 +1,341 @@ +import argparse +import json +import os +import re +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + +from evaluation.scripts.utils.prompts import LONGBENCH_V2_ANSWER_PROMPT + + +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def extract_answer(response: str) -> str | None: + response = response.replace("*", "") + match = re.search(r"The correct answer is \(([A-D])\)", response) + if match: + return match.group(1) + match = re.search(r"The correct answer is ([A-D])", response) + if match: + return match.group(1) + return None + + +def llm_answer( + oai_client, model_name, memories: list[str], question: str, choices: dict +) -> tuple[str, int]: + doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) + prompt = ( + LONGBENCH_V2_ANSWER_PROMPT.replace("$DOC$", doc_content) + .replace("$Q$", question) + .replace("$C_A$", choices.get("A", "")) + .replace("$C_B$", choices.get("B", "")) + .replace("$C_C$", choices.get("C", "")) + .replace("$C_D$", choices.get("D", "")) + ) + messages = [{"role": "user", "content": prompt}] + resp = retry_operation( + oai_client.chat.completions.create, + model=model_name, + messages=messages, + temperature=0.1, + max_tokens=12800, + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def print_metrics(results: list[dict], duration: float) -> dict: + easy, hard, short, medium, long = 0, 0, 0, 0, 0 + easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0 + total_tokens = 0 + + for pred in results: + acc = int(pred.get("judge", False)) + diff = pred.get("difficulty", "easy") + length = pred.get("length", "short") + tokens = pred.get("prompt_tokens", 0) + total_tokens += tokens + + if diff == "easy": + easy += 1 + easy_acc += acc + else: + hard += 1 + hard_acc += acc + + if length == "short": + short += 1 + short_acc += acc + elif length == "medium": + medium += 1 + medium_acc += acc + else: + long += 1 + long_acc += acc + + total = len(results) + if total == 0: + print("No results to calculate metrics.") + return { + "count": 0, + "overall_acc": 0, + "by_difficulty": {"easy": {"count": 0, "acc": 0}, "hard": {"count": 0, "acc": 0}}, + "by_length": { + "short": {"count": 0, "acc": 0}, + "medium": {"count": 0, "acc": 0}, + "long": {"count": 0, "acc": 0}, + }, + "avg_prompt_tokens": 0, + "total_duration_sec": round(duration, 2), + } + + o_acc = round(100 * (easy_acc + hard_acc) / total, 2) + e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0 + h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0 + s_acc = round(100 * short_acc / short, 2) if short > 0 else 0 + m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0 + l_acc = round(100 * long_acc / long, 2) if long > 0 else 0 + avg_tokens = round(total_tokens / total, 2) + + print("\n" + "=" * 60) + print(f"{'Metric':<15} | {'Count':<10} | {'Accuracy (%)':<10}") + print("-" * 60) + print(f"{'Overall':<15} | {total:<10} | {o_acc:<10}") + print(f"{'Easy':<15} | {easy:<10} | {e_acc:<10}") + print(f"{'Hard':<15} | {hard:<10} | {h_acc:<10}") + print(f"{'Short':<15} | {short:<10} | {s_acc:<10}") + print(f"{'Medium':<15} | {medium:<10} | {m_acc:<10}") + print(f"{'Long':<15} | {long:<10} | {l_acc:<10}") + print("-" * 60) + print(f"{'Avg Tokens':<15} | {total:<10} | {avg_tokens:<10}") + print(f"Total Duration: {duration:.2f} seconds") + print("=" * 60 + "\n") + return { + "count": total, + "overall_acc": o_acc, + "by_difficulty": { + "easy": {"count": easy, "acc": e_acc}, + "hard": {"count": hard, "acc": h_acc}, + }, + "by_length": { + "short": {"count": short, "acc": s_acc}, + "medium": {"count": medium, "acc": m_acc}, + "long": {"count": long, "acc": l_acc}, + }, + "avg_prompt_tokens": avg_tokens, + "total_duration_sec": round(duration, 2), + } + + +def _load_json_list(path: Path) -> list[dict]: + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + if isinstance(data, dict) and isinstance(data.get("results"), list): + return data["results"] + raise ValueError(f"Invalid json format: {path}") + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def _save_metrics(path: Path, metrics: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + obj = {"results": []} + if path.exists(): + try: + current = json.loads(path.read_text(encoding="utf-8")) + if isinstance(current, dict) and isinstance(current.get("results"), list): + obj["results"] = current["results"] + elif isinstance(current, list): + obj["results"] = current + except Exception: + pass + obj = {"metrics": metrics, **obj} + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + # Also save metrics to xlsx (rows=category, columns=metric) + xlsx_path = path.with_suffix(".xlsx") + rows: list[dict] = [] + + # Overall row + rows.append( + { + "category": "overall", + "question_number": metrics.get("count", 0), + "accuracy": metrics.get("overall_acc", 0), + "avg_prompt_tokens": metrics.get("avg_prompt_tokens", 0), + "total_duration_sec": metrics.get("total_duration_sec", 0), + } + ) + + # By difficulty + by_diff = metrics.get("by_difficulty") or {} + for name in ("easy", "hard"): + info = by_diff.get(name) or {} + rows.append( + { + "category": f"difficulty_{name}", + "question_number": info.get("count", 0), + "accuracy": info.get("acc", 0), + "avg_prompt_tokens": None, + "total_duration_sec": None, + } + ) + + # By length + by_len = metrics.get("by_length") or {} + for name in ("short", "medium", "long"): + info = by_len.get(name) or {} + rows.append( + { + "category": f"length_{name}", + "question_number": info.get("count", 0), + "accuracy": info.get("acc", 0), + "avg_prompt_tokens": None, + "total_duration_sec": None, + } + ) + + df = pd.DataFrame(rows) + # Reorder columns + cols = ["category", "question_number", "accuracy", "avg_prompt_tokens", "total_duration_sec"] + remaining = [c for c in df.columns if c not in cols] + df = df[cols + remaining] + + df.to_excel(xlsx_path, index=False) + + +def evaluate_one(oai_client, model_name, row: dict) -> dict: + question = row.get("question") or "" + choices = row.get("choices") or {} + memories = row.get("memories_used") or [] + response, prompt_tokens = llm_answer( + oai_client, model_name, list(memories), str(question), dict(choices) + ) + pred = extract_answer(response) + judge = pred == row.get("answer") + out = dict(row) + out["response"] = response + out["pred"] = pred + out["judge"] = judge + out["prompt_tokens"] = prompt_tokens + out.pop("memories_used") + return out + + +def main() -> None: + parser = argparse.ArgumentParser(description="LongBench-v2 eval Tool") + parser.add_argument( + "--lib", + "-b", + required=True, + help="Product name to evaluate", + ) + parser.add_argument("--workers", "-w", type=int, default=20, help="Number of parallel threads") + parser.add_argument( + "--top-k", "-k", type=int, default=20, help="Top k results to use (default: 20)" + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--search_results_path", type=str, default=None) + parser.add_argument("--output_path", type=str, default=None) + parser.add_argument("--chat-model", type=str, default=None, help="Chat model for evaluation") + args = parser.parse_args() + + print("=" * 60) + print("LongBench-v2 Product Eval Tool") + print("=" * 60) + + start_time = time.time() + + output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + search_filename = f"{args.lib}_search_results.json" + search_path = Path(os.path.join(output_dir, search_filename)) + + if not search_path.exists(): + raise FileNotFoundError(f"Search results not found: {search_path}") + + search_rows = _load_json_list(search_path) + output_filename = f"{args.lib}_eval_results.json" + output_path = Path(os.path.join(output_dir, output_filename)) + + results: list[dict] = [] + processed_ids: set[str] = set() + + # Resume from checkpoint + if output_path.exists(): + try: + existing = _load_json_list(output_path) + results = existing + processed_ids = {str(r.get("_id")) for r in results if r.get("_id")} + print(f"Loaded {len(results)} existing results from checkpoint.") + except Exception as e: + print(f"Error loading checkpoint: {e}") + + pending = [r for r in search_rows if str(r.get("_id")) not in processed_ids] + print(f"[Eval] total={len(search_rows)} pending={len(pending)} workers={args.workers}") + if not pending: + metrics = print_metrics(results, time.time() - start_time) + _save_metrics(output_path, metrics) + return + + print("[Response model]: ", args.chat_model) + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [ + executor.submit(evaluate_one, oai_client, args.chat_model, row) for row in pending + ] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), start=1 + ): + try: + res = f.result() + results.append(res) + if idx % 10 == 0: + _save_json_list(output_path, results) + except Exception as e: + print(f"Evaluation Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"Saved {len(results)} results to {output_path}") + metrics = print_metrics(results, time.time() - start_time) + _save_metrics(output_path, metrics) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py new file mode 100644 index 000000000..84d370077 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py @@ -0,0 +1,338 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from datasets import load_dataset +from dotenv import load_dotenv +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +from tqdm import tqdm + +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() +fastgpt_dataset_id = os.getenv("FASTGPT_DATASET_ID_LONGBENCH_V2") +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_LONGBENCH_V2") + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + traceback.print_exc() + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + if lib == "fastgpt": + from evaluation.scripts.utils.client import FastGPTClient + + return FastGPTClient() + if lib == "memos-online": + from evaluation.scripts.utils.client import MemosApiOnlineClient + + return MemosApiOnlineClient() + if lib == "memos": + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_dataset_jsonl(dataset_path: Path) -> list[dict]: + if not dataset_path.exists(): + dataset = load_dataset("zai-org/LongBench-v2", split="train") + dataset_path.parent.mkdir(parents=True, exist_ok=True) + with open(dataset_path, "w", encoding="utf-8") as f: + for i in range(len(dataset)): + s = dataset[i] + row = { + "_id": s.get("_id") or s.get("id") or str(i), + "domain": s.get("domain"), + "sub_domain": s.get("sub_domain"), + "difficulty": s.get("difficulty"), + "length": s.get("length"), + "question": s.get("question"), + "choice_A": s.get("choice_A"), + "choice_B": s.get("choice_B"), + "choice_C": s.get("choice_C"), + "choice_D": s.get("choice_D"), + "answer": s.get("answer"), + "context": s.get("context") or s.get("document") or s.get("documents"), + } + f.write(json.dumps(row, ensure_ascii=False) + "\n") + print(f"Successfully saved dataset to {dataset_path}") + + samples: list[dict] = [] + with open(dataset_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + samples.append(json.loads(line)) + return samples + + +def _load_added_ids(records_path: Path) -> dict[str, str | None]: + """ + Load added records as a mapping: sample_id -> file_id (or None). + """ + if not records_path.exists(): + return {} + + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + added = obj.get("added") if isinstance(obj, dict) else None + if isinstance(added, dict): + return {str(k): (str(v) if v is not None else None) for k, v in added.items()} + except Exception: + return {} + + return {} + + +def _save_added_ids( + records_path: Path, + added: dict[str, str | None], + perf: dict | None = None, +) -> None: + records_path.parent.mkdir(parents=True, exist_ok=True) + tmp = records_path.with_suffix(records_path.suffix + ".tmp") + + obj = { + "added": dict(sorted(added.items())), + } + if perf is not None: + obj["perf"] = perf + + tmp.write_text( + json.dumps(obj, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp, records_path) + + +def ingest_context( + client, + sample: dict, + lib: str, + url_prefix: str, + mode: str = "fine", + async_mode: str = "sync", + version_dir: str | None = None, +) -> tuple[str, str]: + sample_id = str(sample.get("_id")) + user_id = version_dir + "_" + sample_id + context = sample.get("context") or "" + ts = int(time.time()) + file_url = f"{url_prefix.rstrip('/')}/{sample_id}.txt" # URL前缀 + 文件名 + + file_id = "" + if lib == "memos" or lib == "memos-online": + result = retry_operation(client.upload_file, memos_knowledgebase_id, file_url) + file_id = result["data"][0]["id"] + if lib == "fastgpt": + result = retry_operation( + client.upload_file, datasetId=fastgpt_dataset_id, file_url=file_url + ) + file_id = result["data"]["collectionId"] + if lib == "mem0": + chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=2048, chunk_overlap=128 + ) + chunks = [p for p in chunker.split_text(context or "") if p.strip()] + + messages = [{"role": "user", "content": p} for p in chunks] + retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10) + + if lib == "supermemory": + retry_operation(client.add, content=context, user_id=user_id) + + return sample_id, file_id + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LongBench-v2 Product Add Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="MemOS API URL (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--workers", "-w", type=int, default=5, help="Concurrency (default: 10)") + + parser.add_argument( + "--timeout", type=float, default=1200, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)" + ) + + parser.add_argument( + "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)" + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Limit number of samples to process (for testing, default all)", + ) + + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/longbench_v2_text_files/", + help="URL prefix to be prepended to filenames", + ) + + parser.add_argument( + "--dataset_path", + "-p", + default="evaluation/data/longbench_v2/longbenchv2_train.json", + help="Dataset path", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + print("=" * 60) + print("LongBench-v2 Product Add Concurrent Tool") + print("=" * 60) + + dataset_path = Path(args.dataset_path) + dataset = _load_dataset_jsonl(dataset_path) + if args.limit is not None: + dataset = dataset[: args.limit] + + version_output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + os.makedirs(version_output_dir, exist_ok=True) + output_path = os.path.join(version_output_dir, f"{args.lib}_add_results.json") + output_path = Path(output_path) + + added_ids: dict[str, str | None] = _load_added_ids(output_path) + pending = [s for s in dataset if str(s.get("_id")) not in added_ids] + print( + f"[Add] lib={args.lib} total={len(dataset)} pending={len(pending)} workers={args.workers}" + ) + if not pending: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + + def do_ingest(sample): + start_time = time.perf_counter() + try: + sample_id, file_id = ingest_context( + client, + sample, + args.lib, + args.url_prefix, + args.mode, + args.async_mode, + args.version_dir, + ) + duration = time.perf_counter() - start_time + metrics.record(duration, True) + return sample_id, file_id + except Exception as e: + traceback.print_exc() + duration = time.perf_counter() - start_time + metrics.record(duration, False, str(e)) + raise e + + start_time = time.time() + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [executor.submit(do_ingest, sample) for sample in pending] + for f in tqdm(as_completed(futures), total=len(futures), desc="Adding"): + try: + sid, fid = f.result() + if sid: + sid = str(sid) + added_ids[sid] = str(fid) if fid else None + if len(added_ids) % 10 == 0: + _save_added_ids(output_path, added_ids) + + except Exception as e: + print(f"[Add] Error: {e}") + traceback.print_exc() + + _save_added_ids(output_path, added_ids) + print(f"[Add] saved records to {output_path}") + + total_duration = time.time() - start_time + + summary = metrics.summary() + + _save_added_ids( + output_path, + added_ids, + perf={ + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "mode": args.mode, + "async_mode": args.async_mode, + "dataset_path": args.dataset_path, + }, + }, + ) + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_search.py b/evaluation/scripts/longbench_v2/longbench_v2_search.py new file mode 100644 index 000000000..23a2a1fc9 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_search.py @@ -0,0 +1,331 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() +fastgpt_dataset_id = os.getenv("FASTGPT_DATASET_ID_LONGBENCH_V2") +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_LONGBENCH_V2") + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + result = func(*args, **kwargs) + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + if lib == "fastgpt": + from evaluation.scripts.utils.client import FastGPTClient + + return FastGPTClient() + if lib == "memos-online": + from evaluation.scripts.utils.client import MemosApiOnlineClient + + return MemosApiOnlineClient() + + +def _load_dataset_jsonl(dataset_path: Path) -> list[dict]: + samples: list[dict] = [] + with open(dataset_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + samples.append(json.loads(line)) + return samples + + +def memos_search(client, user_id: str, query: str, top_k: int, search_mode: str) -> list[str]: + readable_cube_ids = [user_id] + results = retry_operation( + client.search, + query=query, + user_id=user_id, + top_k=top_k, + readable_cube_ids=readable_cube_ids, + mode=search_mode, + ) + memories = results["text_mem"][0]["memories"] + return [m["memory"] for m in memories] + + +def memos_online_search(client, user_id: str, query: str, top_k: int, mode: str) -> list[str]: + results = client.search( + query=query, + user_id=user_id, + top_k=top_k, + mode=mode, + knowledgebase_ids=[memos_knowledgebase_id], + ) + if "memory_detail_list" in results["data"] and results["data"]["memory_detail_list"]: + memories = results["data"]["memory_detail_list"] + return [m.get("memory_value", "") for m in memories] + return [] + + +def mem0_search(client, user_id: str, query: str, top_k: int) -> list[str]: + res = retry_operation(client.search, query, user_id, top_k) + results = res.get("results", []) + return [m.get("memory", "") for m in results if m.get("memory")] + + +def supermemory_search(client, user_id: str, query: str, top_k: int) -> list[str]: + return retry_operation(client.search, query, user_id, top_k) + + +def fastgpt_search(client, query: str, top_k: int) -> list[str]: + return retry_operation(client.search, datasetId=fastgpt_dataset_id, query=query, top_k=top_k) + + +def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]: + if not output_path.exists(): + return [], set() + try: + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + ids = {str(r.get("_id")) for r in data if r.get("_id")} + return data, ids + if isinstance(data, dict) and isinstance(data.get("results"), list): + rows = data.get("results") or [] + ids = {str(r.get("_id")) for r in rows if r.get("_id")} + return rows, ids + except Exception: + return [], set() + return [], set() + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def search_one(sample: dict, lib: str, top_k: int, version_dir: str, search_mode: str) -> dict: + sample_id = str(sample.get("_id")) + user_id = version_dir + "_" + sample_id + question = sample.get("question") or "" + choices = { + "A": sample.get("choice_A") or "", + "B": sample.get("choice_B") or "", + "C": sample.get("choice_C") or "", + "D": sample.get("choice_D") or "", + } + + client = _get_lib_client(lib) + if lib == "memos": + memories = memos_search( + client, user_id, str(question), top_k=top_k, search_mode=search_mode + ) + elif lib == "memos-online": + memories = memos_online_search( + client=client, + query=str(question), + user_id=user_id, + top_k=top_k, + mode=search_mode, + ) + elif lib == "mem0": + memories = mem0_search(client, user_id, str(question), top_k=top_k) + elif lib == "supermemory": + memories = supermemory_search(client, user_id, str(question), top_k=top_k) + elif lib == "fastgpt": + memories = fastgpt_search(client, str(question), top_k=top_k) + else: + memories = [] + print(f"[{lib} Search] sample_id: {sample_id} search memories: {len(memories)}") + + return { + "_id": sample_id, + "domain": sample.get("domain"), + "sub_domain": sample.get("sub_domain"), + "difficulty": sample.get("difficulty"), + "length": sample.get("length"), + "question": question, + "choices": choices, + "answer": sample.get("answer"), + "memories_used": memories, + } + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Longbench-v2 Product Search Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--dataset-path", + "-s", + default="evaluation/data/longbench_v2/longbenchv2_train.json", + help="Path to JSON file containing samples", + ) + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="API service address (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)") + + parser.add_argument( + "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--top-k", + "-k", + type=int, + default=20, + help="Number of results to return per search (default: 20)", + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Limit number of samples to process (for testing, default all)", + ) + + parser.add_argument( + "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)" + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + print("=" * 60) + print("Longbench-v2 Product Search Concurrent Tool") + print("=" * 60) + + dataset_path = Path(args.dataset_path) + if not dataset_path.exists(): + raise FileNotFoundError(f"Dataset file not found: {dataset_path}") + dataset = _load_dataset_jsonl(dataset_path) + if args.limit is not None: + dataset = dataset[: args.limit] + + output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + os.makedirs(output_dir, exist_ok=True) + output_filename = f"{args.lib}_search_results.json" + output_path = Path(os.path.join(output_dir, output_filename)) + + results, processed_ids = _load_existing_results(output_path) + pending = [s for s in dataset if str(s.get("_id")) not in processed_ids] + if not pending: + return + metrics = Metrics() + start_time = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_search(sample: dict) -> dict: + st = time.perf_counter() + r = search_one(sample, args.lib, args.top_k, args.version_dir, args.mode) + dur = time.perf_counter() - st + r["duration_ms"] = dur * 1000 + metrics.record(dur, True) + return r + + futures = [executor.submit(do_search, sample) for sample in pending] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Searching"), start=1 + ): + try: + r = f.result() + results.append(r) + if idx % 20 == 0: + _save_json_list(output_path, results) + except Exception as e: + metrics.record(0.0, False, str(e)) + print(f"[Search] Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"[Search] saved {len(results)} rows to {output_path}") + + total_duration = time.time() - start_time + summary = metrics.summary() + combined_obj = { + "perf": { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "top_k": args.top_k, + "dataset_path": str(dataset_path), + "limit": args.limit, + "mode": args.mode, + }, + }, + "results": results, + } + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/eval_utils/__init__.py b/evaluation/scripts/mmlongbench/eval_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/mmlongbench/eval_utils/eval_score.py b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py new file mode 100644 index 000000000..02ef6eb53 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/mmlongbench/mmlongbench_check_files.py b/evaluation/scripts/mmlongbench/mmlongbench_check_files.py new file mode 100644 index 000000000..195976760 --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_check_files.py @@ -0,0 +1,255 @@ +import argparse +import json +import os +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.client import MemosApiOnlineClient + +load_dotenv() + +MEMOS_KNOWLEDGEBASE_ID = os.getenv("MEMOS_KNOWLEDGEBASE_ID_MM_LONGBENCH") + + +def _load_added_ids(records_path: Path) -> dict[str, str | None]: + if not records_path.exists(): + return {} + + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + except Exception: + return {} + + if not isinstance(obj, dict): + return {} + + added = obj.get("added") + if not isinstance(added, dict): + return {} + + return {str(key): str(value) if value is not None else None for key, value in added.items()} + + +def _check_file_status( + client: MemosApiOnlineClient, + file_ids: list[str], + batch_size: int, +) -> dict[str, dict[str, str | None]]: + file_status: dict[str, dict[str, str | None]] = {} + + for start in tqdm( + range(0, len(file_ids), batch_size), + desc="Checking files", + ): + batch = file_ids[start : start + batch_size] + try: + resp = client.check_file(batch) + except Exception as exc: + print(f"[Check] error for batch starting at {start}: {exc}") + continue + + if not isinstance(resp, dict): + continue + + data = resp.get("data") + if not isinstance(data, dict): + continue + + details = data.get("file_detail_list") + if not isinstance(details, list): + continue + + for item in details: + if not isinstance(item, dict): + continue + + fid = item.get("id") + if not fid: + continue + + file_status[str(fid)] = { + "name": item.get("name"), + "size": item.get("size"), + "status": item.get("status"), + } + + return file_status + + +def _reupload_failed_files( + client: MemosApiOnlineClient, + file_status: dict[str, dict[str, str | None]], + added_ids: dict[str, str | None], + url_prefix: str, +) -> list[dict[str, str | None]]: + fid_to_filename: dict[str, str] = { + str(fid): str(filename) for filename, fid in added_ids.items() if fid + } + + failed_ids = [ + fid for fid, info in file_status.items() if info.get("status") == "PROCESSING_FAILED" + ] + + reupload_results: list[dict[str, str | None]] = [] + + for fid in tqdm(failed_ids, desc="Reuploading failed files"): + filename = fid_to_filename.get(fid) + if not filename: + reupload_results.append( + { + "old_file_id": fid, + "filename": None, + "new_file_id": None, + "ok": "false", + "error": "filename_not_found", + } + ) + continue + + file_url = f"{url_prefix.rstrip('/')}/{filename}" + + try: + resp = client.upload_file( + MEMOS_KNOWLEDGEBASE_ID or "", + file_url, + ) + new_id: str | None = None + + if isinstance(resp, dict): + data = resp.get("data") + if isinstance(data, list) and data: + first = data[0] + if isinstance(first, dict) and first.get("id"): + new_id = str(first["id"]) + + reupload_results.append( + { + "old_file_id": fid, + "filename": filename, + "new_file_id": new_id, + "ok": "true", + "error": None, + } + ) + except Exception as exc: + reupload_results.append( + { + "old_file_id": fid, + "filename": filename, + "new_file_id": None, + "ok": "false", + "error": str(exc), + } + ) + + return reupload_results + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Check MMLongBench memos-online file status and reupload failed files.", + ) + parser.add_argument("--lib", default="memos-online") + parser.add_argument("--version-dir", "-v") + parser.add_argument("--batch-size", type=int, default=50) + parser.add_argument( + "--url-prefix", + "-u", + default=( + "https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/" + "mmlongbench_pdf_files/" + ), + ) + args = parser.parse_args(argv) + + if args.lib != "memos-online": + print(f"Only memos-online is supported, got lib={args.lib}") + return + + output_dir = Path("evaluation/data/mmlongbench") + if args.version_dir: + output_dir /= args.version_dir + output_dir.mkdir(parents=True, exist_ok=True) + + records_path = output_dir / f"{args.lib}_add_results.json" + print(f"[Check] loading records from {records_path}") + + added_ids = _load_added_ids(records_path) + file_ids = sorted({fid for fid in added_ids.values() if fid}) + + print(f"[Check] total file ids: {len(file_ids)}") + if not file_ids: + return + + client = MemosApiOnlineClient() + batch_size = max(1, args.batch_size) + + file_status = _check_file_status(client, file_ids, batch_size) + reupload_results = _reupload_failed_files( + client, + file_status, + added_ids, + args.url_prefix, + ) + + if reupload_results: + try: + obj: dict[str, Any] = {} + if records_path.exists(): + txt = records_path.read_text(encoding="utf-8") + if txt: + parsed = json.loads(txt) + if isinstance(parsed, dict): + obj = parsed + + added_obj: dict[str, str | None] + if isinstance(obj.get("added"), dict): + added_obj = { + str(k): str(v) if v is not None else None for k, v in obj["added"].items() + } + else: + added_obj = dict(added_ids) + + for item in reupload_results: + if item.get("ok") == "true" and item.get("filename") and item.get("new_file_id"): + added_obj[str(item["filename"])] = str(item["new_file_id"]) + + obj["added"] = dict(sorted(added_obj.items())) + + tmp_path = records_path.with_suffix(records_path.suffix + ".tmp") + tmp_path.write_text( + json.dumps(obj, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp_path, records_path) + + print(f"[Update] updated add_results -> {records_path}") + except Exception as exc: + print(f"[Update] failed to update add_results: {exc}") + + output_path = output_dir / f"{args.lib}_file_status.json" + result_obj = { + "lib": args.lib, + "version_dir": args.version_dir, + "total": len(file_ids), + "file_detail_list": [{"id": fid, **(file_status.get(fid) or {})} for fid in file_ids], + "reupload_results": reupload_results, + } + + tmp_output = output_path.with_suffix(output_path.suffix + ".tmp") + tmp_output.write_text( + json.dumps(result_obj, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp_output, output_path) + + print( + f"[Check] saved file status for {len(file_status)} files to {output_path}", + ) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/mmlongbench_eval.py b/evaluation/scripts/mmlongbench/mmlongbench_eval.py new file mode 100644 index 000000000..6071e812c --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_eval.py @@ -0,0 +1,285 @@ +import base64 +import json +import mimetypes +import os +import re +import sys +import time +import traceback +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path +from typing import Any + +import openai +import pandas as pd +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.eval_score import ( + eval_acc_and_f1, + eval_score, + show_results, +) +from evaluation.scripts.utils.extract_answer import extract_answer +from evaluation.scripts.utils.prompts import MMLONGBENCH_ANSWER_PROMPT + + +load_dotenv() + + +def create_openai_client() -> openai.Client: + return openai.Client( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + ) + + +def _encode_image_to_data_url(image_path: str) -> str | None: + try: + mime, _ = mimetypes.guess_type(image_path) + mime = mime or "image/jpeg" + with open(image_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{b64}" + except Exception as exc: + print("Failed to encode image '%s' to data URL: %s", image_path, exc) + return None + + +def build_images_index(base_dir: Path) -> dict[str, str]: + index: dict[str, str] = {} + if not base_dir.exists(): + return index + + for images_dir in base_dir.rglob("auto/images"): + if images_dir.is_dir(): + for img_file in images_dir.iterdir(): + if img_file.is_file(): + index[img_file.name] = str(img_file.resolve()) + return index + + +def get_images( + sources: list[Any], + image_index: dict[str, str], +) -> list[str]: + if not sources: + return [] + + found: list[str] = [] + + md_img_pattern = re.compile(r"\[Image: images/\s*(.+?)\s*-") + images_substr_pattern = re.compile( + r"images/[^\s)]+\.(?:png|jpg|jpeg|webp)", + re.IGNORECASE, + ) + + for src in sources: + if not isinstance(src, str): + continue + + for candidate in md_img_pattern.findall(src) + images_substr_pattern.findall(src): + basename = os.path.basename(candidate) + if basename in image_index: + found.append(image_index[basename]) + + # deduplicate + seen: set[str] = set() + return [p for p in found if not (p in seen or seen.add(p))] + + +def add_images_context( + messages: list[dict[str, Any]], + image_paths: list[str], +) -> list[dict[str, Any]]: + if not image_paths: + return messages + + user_idx = next( + i for i in range(len(messages) - 1, -1, -1) if messages[i].get("role") == "user" + ) + user_msg = messages[user_idx] + content = user_msg.get("content", "") + + parts: list[dict[str, Any]] = ( + content if isinstance(content, list) else [{"type": "text", "text": str(content)}] + ) + + for img_path in image_paths[:6]: + data_url = _encode_image_to_data_url(img_path) + if data_url: + parts.append( + {"type": "image_url", "image_url": {"url": data_url}}, + ) + + user_msg["content"] = parts + messages[user_idx] = user_msg + return messages + + +def multimodal_answer( + client: openai.Client, + chat_model: str, + memories: list[str], + question: str, + sources: list[Any], + image_index: dict[str, str], +) -> tuple[str, int | None]: + image_paths = get_images(sources, image_index) + + system_prompt = MMLONGBENCH_ANSWER_PROMPT.format( + memories="\n\n".join(memories), + question=question, + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + messages = add_images_context(messages, image_paths) + + resp = client.chat.completions.create( + model=chat_model, + messages=messages, + temperature=0, + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def process_single_item( + client: openai.Client, + chat_model: str, + image_index: dict[str, str], + item: dict[str, Any], + index: int, +) -> dict[str, Any]: + try: + response, prompt_tokens = multimodal_answer( + client, + chat_model, + item["memories"], + item["question"], + item.get("sources", []), + image_index, + ) + + extracted = extract_answer(item["question"], response) + pred = extracted or response.strip() + score = eval_score(item["answer"], pred, item.get("answer_format", "Str")) + + return { + "index": index, + "result": { + "response": response, + "pred": pred, + "score": score, + "prompt_tokens": prompt_tokens, + "eval_success": True, + "eval_error": None, + }, + } + + except Exception as exc: + traceback.print_exc() + return { + "index": index, + "result": { + "response": None, + "pred": None, + "score": 0, + "eval_success": False, + "eval_error": str(exc), + }, + } + + +def run_eval( + questions_file: Path, + output_file: Path, + version_dir: Path, + chat_model: str, + max_workers: int, + limit: int | None, +) -> None: + client = create_openai_client() + image_index = build_images_index( + Path("/Users/tianxingshi/Desktop/lcy/ppt_test_result"), + ) + + data = json.loads(questions_file.read_text(encoding="utf-8")) + items = data["results"][:limit] if limit else data["results"] + + results: dict[int, dict[str, Any]] = {} + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit( + process_single_item, + client, + chat_model, + image_index, + item, + i, + ): i + for i, item in enumerate(items) + } + + for future in tqdm(as_completed(futures), total=len(futures)): + res = future.result() + results[res["index"]] = res["result"] + + for i, item in enumerate(items): + if i in results: + item.update(results[i]) + + acc, f1 = eval_acc_and_f1(items) + data["eval_summary"] = { + "accuracy": acc, + "f1_score": f1, + "eval_timestamp": datetime.now().isoformat(), + } + + output_file.write_text( + json.dumps(data, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + report_path = version_dir / "eval_results.txt" + show_results(items, show_path=str(report_path)) + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser("MMLongBench Eval") + parser.add_argument("--lib", required=True) + parser.add_argument("--workers", type=int, default=20) + parser.add_argument("--version-dir", required=True) + parser.add_argument("--chat-model", required=True) + parser.add_argument("--limit", type=int) + + args = parser.parse_args() + + base_dir = Path("evaluation/data/mmlongbench") + version_dir = base_dir / args.version_dir + input_path = version_dir / f"{args.lib}_search_results.json" + + if not input_path.exists(): + print(f"Input not found: {input_path}") + sys.exit(1) + + run_eval( + questions_file=input_path, + output_file=input_path, + version_dir=version_dir, + chat_model=args.chat_model, + max_workers=args.workers, + limit=args.limit, + ) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py new file mode 100644 index 000000000..bee792d73 --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import json +import os +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter + +from evaluation.scripts.utils.metrics import Metrics + +load_dotenv() + +FASTGPT_DATASET_ID = os.getenv("FASTGPT_DATASET_ID_MM_LONGBENCH") +MEMOS_KNOWLEDGEBASE_ID = os.getenv("MEMOS_KNOWLEDGEBASE_ID_MM_LONGBENCH") + + +def retry_operation( + func, + *args, + retries: int = 5, + delay: int = 2, + **kwargs, +): + """Retry wrapper with exponential backoff.""" + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as exc: + if attempt >= retries - 1: + raise + traceback.print_exc() + func_name = getattr(func, "__name__", "operation") + print(f"[Retry] {func_name} failed: {exc}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + return None + + +def read_filenames(filepath: str | Path) -> list[str]: + """Read filenames from text file, one per line.""" + path = Path(filepath) + filenames: list[str] = [] + with path.open(encoding="utf-8") as f: + for line in f: + name = line.strip() + if name: + filenames.append(name) + return filenames + + +def _get_lib_client(lib: str): + if lib == "memos": + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + if lib == "fastgpt": + from evaluation.scripts.utils.client import FastGPTClient + + return FastGPTClient() + if lib == "memos-online": + from evaluation.scripts.utils.client import MemosApiOnlineClient + + return MemosApiOnlineClient() + + msg = f"Unknown lib type: {lib}" + raise ValueError(msg) + + +def run_concurrent_add( + *, + lib: str, + filenames: list[str], + url_prefix: str, + user_prefix: str, + workers: int, + mode: str = "fine", + async_mode: str = "sync", +) -> dict[str, Any]: + """Run concurrent ingestion.""" + + client = _get_lib_client(lib) + metrics = Metrics() + + total_files = len(filenames) + completed = 0 + completed_lock = threading.Lock() + + added_ids: dict[str, str] = {} + + base_dir = Path("ppt_test_result") + all_md_files = list(base_dir.rglob("*.md")) + + def add_single_file(filename: str, doc_id: str) -> tuple[bool, Any]: + nonlocal completed + + file_url = f"{url_prefix.rstrip('/')}/{filename}" + stem = Path(filename).stem.lower() + name = filename.lower() + + md_path: Path | None = None + for md_file in all_md_files: + pstr = str(md_file).lower() + if (stem and stem in pstr) or (name and name in pstr): + md_path = md_file + break + + if md_path is None: + raise FileNotFoundError(f"No markdown found for {filename}") + + text = md_path.read_text(encoding="utf-8", errors="ignore") + start_time = time.perf_counter() + user_id = f"{user_prefix}_{doc_id}" + + try: + result = None + + if lib == "memos-online": + result = retry_operation( + client.upload_file, + MEMOS_KNOWLEDGEBASE_ID, + file_url, + ) + file_id = None + if isinstance(result, dict): + data = result.get("data") or [] + if isinstance(data, list) and data: + file_id = data[0].get("id") + if file_id: + added_ids[filename] = str(file_id) + + elif lib == "fastgpt": + result = retry_operation( + client.upload_file, + datasetId=FASTGPT_DATASET_ID, + file_url=file_url, + ) + + elif lib == "supermemory": + result = client.add(content=text, user_id=user_id) + + elif lib == "mem0": + splitter = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, + chunk_size=5120, + chunk_overlap=128, + ) + chunks = [c for c in splitter.split_text(text) if c.strip()] + messages = [{"role": "user", "content": c} for c in chunks] + result = client.add( + messages=messages, + user_id=doc_id, + timestamp=int(time.time()), + batch_size=10, + ) + + duration = time.perf_counter() - start_time + metrics.record(duration, True) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_files}] ✓ Success: {filename} ({duration * 1000:.0f}ms)" + ) + + return True, result + + except Exception as exc: + duration = time.perf_counter() - start_time + metrics.record(duration, False, str(exc)) + + with completed_lock: + completed += 1 + print(f"[{completed}/{total_files}] ✗ Failed: {filename} - {exc!s:.100}") + + return False, str(exc) + + print(f"\nStarting concurrent add for {total_files} files...") + print(f"Concurrency: {workers}") + print(f"Version Dir: {user_prefix}") + print(f"URL prefix: {url_prefix}") + print("-" * 60) + + start_time = time.time() + + results: list[dict[str, Any]] = [] + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(add_single_file, fn, fn[:-3] + ".pdf") for fn in filenames] + + for filename, future in zip(filenames, futures, strict=True): + try: + success, result = future.result() + results.append({"filename": filename, "success": success, "result": result}) + except Exception as exc: + results.append({"filename": filename, "success": False, "result": str(exc)}) + + total_duration = time.time() - start_time + summary = metrics.summary() + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + return { + "summary": summary, + "total_duration": total_duration, + "results": results, + "added": added_ids, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="MMLongbench-doc Product Add Concurrent Script", + ) + parser.add_argument("--lib", "-b", required=True) + parser.add_argument( + "--filenames-file", + "-f", + default="evaluation/data/mmlongbench/pdf_file_list.txt", + ) + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/mmlongbench_pdf_files/", + ) + parser.add_argument("--workers", "-w", type=int, default=5) + parser.add_argument("--limit", "-l", type=int) + parser.add_argument("--version-dir", "-v", required=True) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + filenames = read_filenames(args.filenames_file) + if args.limit: + filenames = filenames[: args.limit] + + version_dir = Path("evaluation/data/mmlongbench") / args.version_dir + version_dir.mkdir(parents=True, exist_ok=True) + + output_path = version_dir / f"{args.lib}_add_results.json" + + existing_added: dict[str, str] = {} + if output_path.exists(): + with output_path.open(encoding="utf-8") as f: + obj = json.load(f) + existing_added = obj.get("added", {}) if isinstance(obj, dict) else {} + + filenames = [f for f in filenames if f not in existing_added] + + if not filenames: + print("[Add] no pending files.") + return + + result = run_concurrent_add( + lib=args.lib, + filenames=filenames, + url_prefix=args.url_prefix, + user_prefix=args.version_dir, + workers=args.workers, + ) + + output = { + "summary": result["summary"], + "total_duration": result["total_duration"], + "added": {**existing_added, **result.get("added", {})}, + } + + with output_path.open("w", encoding="utf-8") as f: + json.dump(output, f, ensure_ascii=False, indent=2) + + print(f"\nResults saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/mmlongbench_search.py b/evaluation/scripts/mmlongbench/mmlongbench_search.py new file mode 100644 index 000000000..2676ce288 --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_search.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import threading +import time + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv + +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() + +fastgpt_dataset_id = os.getenv("FASTGPT_DATASET_ID_MM_LONGBENCH") +memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_MM_LONGBENCH") + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def load_samples(filepath: str) -> list[dict]: + """ + Read sample list from JSON file + """ + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + return samples + + +def memos_search(client, user_id: str, query: str, top_k: int, mode: str) -> list[str]: + results = retry_operation( + client.search, + query=query, + user_id=user_id, + top_k=top_k, + mode=mode, + knowledgebase_ids=[memos_knowledgebase_id], + ) + if "memory_detail_list" in results["data"] and results["data"]["memory_detail_list"]: + memories = results["data"]["memory_detail_list"] + return [m.get("memory_value", "") for m in memories] + return [] + + +def mem0_search(client, user_id: str, query: str, top_k: int) -> tuple[list[str], list[str]]: + res = retry_operation(client.search, query, user_id, top_k) + results = res.get("results", []) + mem_texts = [m.get("memory", "") for m in results if m.get("memory")] + return mem_texts, mem_texts + + +def supermemory_search(client, user_id: str, query: str, top_k: int) -> tuple[list[str], list[str]]: + chunk_list = retry_operation(client.search, query, user_id, top_k) + return chunk_list, chunk_list + + +def fastgpt_search(client, query: str, top_k: int) -> list[str]: + result = retry_operation(client.search, datasetId=fastgpt_dataset_id, query=query, top_k=top_k) + return [item["q"] for item in result[:top_k]] + + +def _load_existing_results(path: str | os.PathLike[str]) -> tuple[list[dict], set[str]]: + p = Path(path) + if not p.exists(): + return [], set() + try: + data = json.loads(p.read_text(encoding="utf-8")) + rows: list[dict] = [] + if isinstance(data, dict) and isinstance(data.get("results"), list): + rows = data.get("results") or [] + elif isinstance(data, list): + rows = data + success_rows = [r for r in rows if r.get("success") is True] + ids = {str(r.get("doc_id")) for r in success_rows if r.get("doc_id")} + return success_rows, ids + except Exception: + return [], set() + + +def _get_lib_client(lib: str): + if lib == "memos": + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + if lib == "fastgpt": + from evaluation.scripts.utils.client import FastGPTClient + + return FastGPTClient() + if lib == "memos-online": + from evaluation.scripts.utils.client import MemosApiOnlineClient + + return MemosApiOnlineClient() + + +def run_concurrent_search( + lib: str, samples: list[dict], user_prefix: str, concurrency: int, top_k: int, mode: str +) -> dict: + """ + Execute concurrent search operations + + Args: + lib: Client name + samples: Sample list, each containing doc_id and question + user_prefix: User ID prefix + concurrency: Concurrency + top_k: Number of results to return + mode: Query mode ['fast', 'fine'] + + Returns: + Search results + """ + + client = _get_lib_client(lib) + metrics = Metrics() + total_samples = len(samples) + completed = 0 + completed_lock = threading.Lock() + + # 用于存储所有搜索结果 + all_results = [] + results_lock = threading.Lock() + + user_id = user_prefix + + def search_single(sample: dict, index: int): + nonlocal completed + + doc_id = sample.get("doc_id", "") + question = sample.get("question", "") + + user_id = doc_id[:20] + start_time = time.perf_counter() + try: + memories, sources = [], [] + if lib == "memos" or lib == "memos-online": + memories = memos_search( + client=client, + query=question, + user_id=user_id, + top_k=top_k, + mode=mode, + ) + elif lib == "mem0": + memories, sources = mem0_search(client, user_id, question, top_k=top_k) + elif lib == "supermemory": + memories, sources = supermemory_search(client, user_id, question, top_k=top_k) + elif lib == "fastgpt": + memories = fastgpt_search(client, question, top_k=top_k) + + duration = time.perf_counter() - start_time + metrics.record(duration, True) + + result = { + "index": index, + "doc_id": doc_id, + "question": question, + "answer": sample.get("answer", ""), + "evidence_pages": sample.get("evidence_pages", ""), + "evidence_sources": sample.get("evidence_sources", ""), + "answer_format": sample.get("answer_format", ""), + "doc_type": sample.get("doc_type", ""), + "memories": memories, + "memory_count": len(memories), + "success": True, + "duration_ms": duration * 1000, + "mode": mode, + } + + with results_lock: + all_results.append(result) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_samples}] ✓ Success: {doc_id[:30]}... ({duration * 1000:.0f}ms, {len(memories)} memories)" + ) + + return True, result + + except Exception as e: + duration = time.perf_counter() - start_time + error_msg = str(e) + metrics.record(duration, False, error_msg) + + result = { + "index": index, + "doc_id": doc_id, + "question": question, + "answer": sample.get("answer", ""), + "evidence_pages": sample.get("evidence_pages", ""), + "evidence_sources": sample.get("evidence_sources", ""), + "answer_format": sample.get("answer_format", ""), + "doc_type": sample.get("doc_type", ""), + "memories": [], + "memory_count": 0, + "success": False, + "error": error_msg, + "duration_ms": duration * 1000, + } + + with results_lock: + all_results.append(result) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_samples}] ✗ Failed: {doc_id[:30]}... - {error_msg[:80]}" + ) + + return False, result + + print(f"\nStarting concurrent search for {total_samples} questions...") + print(f"Concurrency: {concurrency}") + print(f"User ID: {user_id}") + print(f"Top-K: {top_k}") + print("-" * 60) + + start_time = time.time() + + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [] + for i, sample in enumerate(samples): + future = executor.submit(search_single, sample, i) + futures.append(future) + + # Wait for all tasks to complete + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Task execution exception: {e}") + + end_time = time.time() + total_duration = end_time - start_time + + # Sort results by original index + all_results.sort(key=lambda x: x["index"]) + + # Print statistics + summary = metrics.summary() + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError statistics:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + return {"summary": summary, "total_duration": total_duration, "results": all_results} + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMLongbench-doc Product Search Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--samples-file", + "-s", + default="evaluation/data/mmlongbench/samples.json", + help="Path to JSON file containing samples", + ) + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="API service address (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--api-key", default="", help="API key (optional)") + + parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)") + + parser.add_argument( + "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--top-k", + "-k", + type=int, + default=15, + help="Number of results to return per search (default: 20)", + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Limit number of samples to process (for testing, default all)", + ) + + parser.add_argument( + "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)" + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + print("=" * 60) + print("MMLongbench-doc Product Search Concurrent Tool") + print("=" * 60) + + # Read sample data + samples_path = "evaluation/data/mmlongbench/samples.json" + print(f"\nReading sample file: {samples_path}") + try: + samples = load_samples(samples_path) + print(f"Total {len(samples)} samples read") + + # Limit number of samples + if args.limit and args.limit > 0: + samples = samples[: args.limit] + print(f"Limiting to first {len(samples)} samples") + + if len(samples) == 0: + print("Error: Sample list is empty!") + return + + except FileNotFoundError: + print(f"Error: File not found {args.samples_file}") + return + except json.JSONDecodeError as e: + print(f"Error: JSON parse failed - {e}") + return + except Exception as e: + print(f"Error: Failed to read file - {e}") + return + + # Determine output file path + import os + + output_dir = os.path.join("evaluation/data/mmlongbench", args.version_dir) + os.makedirs(output_dir, exist_ok=True) + output_filename = f"{args.lib}_search_results.json" + output_path = os.path.join(output_dir, output_filename) + + existing_results, processed_ids = _load_existing_results(output_path) + if processed_ids: + before = len(samples) + samples = [s for s in samples if str(s.get("doc_id", "")) not in processed_ids] + print( + f"[Resume] found {len(processed_ids)} successful samples in checkpoint, " + f"skip {before - len(samples)} samples, pending={len(samples)}" + ) + + if not samples: + print("[Search] no pending samples, nothing to do.") + return + + # Execute concurrent search only on pending samples + result = run_concurrent_search( + lib=args.lib, + samples=samples, + user_prefix=args.version_dir, + concurrency=args.workers, + top_k=args.top_k, + mode=args.mode, + ) + + new_results = [r for r in result["results"] if r.get("success")] + all_results = existing_results + new_results + + # Save results + output_data = { + "summary": result["summary"], + "total_duration": result["total_duration"], + "config": { + "samples_file": args.samples_file, + "api_url": args.api_url, + "workers": args.workers, + "top_k": args.top_k, + }, + "results": all_results, + } + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(output_data, f, ensure_ascii=False, indent=2) + + print(f"\nResults saved to: {output_path}") + + # Calculate valid results + success_results = all_results + total_memories = sum(r["memory_count"] for r in success_results) + avg_memories = total_memories / len(success_results) if success_results else 0 + print(f"Average {avg_memories:.1f} memories returned per question") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/run_hotpot_eval.sh b/evaluation/scripts/run_hotpot_eval.sh new file mode 100755 index 000000000..e3e38c6bc --- /dev/null +++ b/evaluation/scripts/run_hotpot_eval.sh @@ -0,0 +1,54 @@ +#!/bin/bash +set -e + +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" + +# Common parameters +LIB="fastgpt" +WORKERS=20 +TOPK=7 +ADD_MODE="fine" +SEARCH_MODE="fine" +VERSION_DIR="hotpot_fastgpt_0114" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" +LIMIT=1000 + +# Add / Ingestion +echo "Running hotpot_ingestion.py..." +python -m evaluation.scripts.hotpot.hotpot_ingestion \ + --lib "$LIB" \ + --workers "$WORKERS" \ + --version-dir "$VERSION_DIR" \ + --mode "$ADD_MODE" \ + --async-mode "$ASYNC_MODE" \ + --limit "$LIMIT" + +# #check +# echo "Running hotpot_check_files.py..." +# python -m evaluation.scripts.hotpot.hotpot_check_files \ +# --lib "$LIB" \ +# --version-dir "$VERSION_DIR" \ + +# Search +#echo "Running hotpot_search.py..." +#python -m evaluation.scripts.hotpot.hotpot_search \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --top-k "$TOPK" \ +# --search-mode "$SEARCH_MODE" \ +# --limit "$LIMIT" + +# Eval +#echo "Running hotpot_eval.py..." +#python -m evaluation.scripts.hotpot.hotpot_eval \ +# --lib "$LIB" \ +# --version-dir "$VERSION_DIR" \ +# --workers "$WORKERS" \ +# --search-mode "$SEARCH_MODE" \ +# --chat-model "$CHAT_MODEL" + +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_longbench_v2_eval.sh b/evaluation/scripts/run_longbench_v2_eval.sh index 917c57bfb..b13f87585 100755 --- a/evaluation/scripts/run_longbench_v2_eval.sh +++ b/evaluation/scripts/run_longbench_v2_eval.sh @@ -1,110 +1,54 @@ #!/bin/bash - -# Common parameters for all scripts -LIB="memos-api" -VERSION="long-bench-v2-1208-1556-async" -WORKERS=10 -TOPK=20 -MAX_SAMPLES="" # Empty means all samples -WAIT_INTERVAL=2 # seconds between polls -WAIT_TIMEOUT=900 # seconds per user - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - --lib) - LIB="$2" - shift 2 - ;; - --version) - VERSION="$2" - shift 2 - ;; - --workers) - WORKERS="$2" - shift 2 - ;; - --top_k) - TOPK="$2" - shift 2 - ;; - --max_samples) - MAX_SAMPLES="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac -done - -# Build max_samples argument -MAX_SAMPLES_ARG="" -if [ -n "$MAX_SAMPLES" ]; then - MAX_SAMPLES_ARG="--max_samples $MAX_SAMPLES" -fi - -echo "Running LongBench v2 evaluation with:" -echo " LIB: $LIB" -echo " VERSION: $VERSION" -echo " WORKERS: $WORKERS" -echo " TOPK: $TOPK" -echo " MAX_SAMPLES: ${MAX_SAMPLES:-all}" -echo "" - -# Step 2: Search -echo "" -echo "==========================================" -echo "Step 2: Running longbench_v2_search.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_search.py \ - --lib $LIB \ - --version $VERSION \ - --top_k $TOPK \ - --workers $WORKERS \ - $MAX_SAMPLES_ARG - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_search.py" - exit 1 -fi - -# Step 3: Response Generation -echo "" -echo "==========================================" -echo "Step 3: Running longbench_v2_responses.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_responses.py \ - --lib $LIB \ - --version $VERSION \ - --workers $WORKERS - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_responses.py" - exit 1 -fi - -# Step 4: Metrics Calculation -echo "" -echo "==========================================" -echo "Step 4: Running longbench_v2_metric.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_metric.py \ - --lib $LIB \ - --version $VERSION - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_metric.py" - exit 1 -fi - -echo "" -echo "==========================================" -echo "All steps completed successfully!" -echo "==========================================" -echo "" -echo "Results are saved in: results/long_bench-v2/$LIB-$VERSION/" -echo " - Search results: ${LIB}_longbench_v2_search_results.json" -echo " - Responses: ${LIB}_longbench_v2_responses.json" -echo " - Metrics: ${LIB}_longbench_v2_metrics.json" +set -e + +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" + +# Common parameters +LIB="fastgpt" +WORKERS=5 +TOPK=30 +ADD_MODE="fine" +SEARCH_MODE="fast" +VERSION_DIR="longbench_v2_fastgpt_0114" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" +#CHAT_MODEL="o4-mini" +LIMIT=200 + +# Add / Ingestion +echo "Running longbench_v2_ingestion.py..." +python -m evaluation.scripts.longbench_v2.longbench_v2_ingestion \ + --lib "$LIB" \ + --workers "$WORKERS" \ + --version-dir "$VERSION_DIR" \ + --mode "$ADD_MODE" \ + --async-mode "$ASYNC_MODE" \ + --limit "$LIMIT" + +# #check +# echo "Running longbench_v2_check_files.py..." +# python -m evaluation.scripts.longbench_v2.longbench_v2_check_files \ +# --lib "$LIB" \ +# --version-dir "$VERSION_DIR" \ + +# # Search +# echo "Running longbench_v2_search.py..." +# python -m evaluation.scripts.longbench_v2.longbench_v2_search \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --top-k "$TOPK" \ +# --mode "$SEARCH_MODE" \ +# --limit "$LIMIT" + +# Eval +# echo "Running longbench_v2_eval.py..." +# python -m evaluation.scripts.longbench_v2.longbench_v2_eval \ +# --lib "$LIB" \ +# --version-dir "$VERSION_DIR" \ +# --workers "$WORKERS" \ +# --chat-model "$CHAT_MODEL" + +#echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_mmlongbench_eval.sh b/evaluation/scripts/run_mmlongbench_eval.sh new file mode 100755 index 000000000..0eda6565b --- /dev/null +++ b/evaluation/scripts/run_mmlongbench_eval.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -e + +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" + +# Common parameters +LIB="memos-online" +WORKERS=5 +TOPK=15 +ADD_MODE="fine" +SEARCH_MODE="fine" +VERSION_DIR="mmlongbench_fastgpt_0114_03" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" +LIMIT=10 + +# Add / Ingestion +# echo "Running mmlongbench_ingestion.py..." +# python -m evaluation.scripts.mmlongbench.mmlongbench_ingestion \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --mode "$ADD_MODE" \ +# --async-mode "$ASYNC_MODE" \ +# # --limit "$LIMIT" + +# #check +# echo "Running mmllongbench_check_files.py..." +# python -m evaluation.scripts.mmlongbench.mmllongbench_check_files \ +# --lib "$LIB" \ +# --version-dir "$VERSION_DIR" \ + +# # Search +# echo "Running mmlongbench_search.py..." +# python -m evaluation.scripts.mmlongbench.mmlongbench_search \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --top-k "$TOPK" \ +# --mode "$SEARCH_MODE" +# # --limit "$LIMIT" + +# Eval +echo "Running mmlongbench_eval.py..." +python -m evaluation.scripts.mmlongbench.mmlongbench_eval \ + --lib "$LIB" \ + --version-dir "$VERSION_DIR" \ + --workers "$WORKERS" \ + --chat-model "$CHAT_MODEL" \ + +# echo "All scripts completed successfully!" diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 157c3f8ea..673f407da 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -1,5 +1,6 @@ import json import os +import re import sys import time import uuid @@ -56,30 +57,22 @@ def __init__(self, enable_graph=False): self.enable_graph = enable_graph def add(self, messages, user_id, timestamp, batch_size=2): - max_retries = 5 for i in range(0, len(messages), batch_size): batch_messages = messages[i : i + batch_size] - for attempt in range(max_retries): - try: - if self.enable_graph: - self.client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=user_id, - enable_graph=True, - ) - else: - self.client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=user_id, - ) - break - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + if self.enable_graph: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + enable_graph=True, + ) + else: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + infer=False, + ) def search(self, query, user_id, top_k): res = self.client.search( @@ -143,135 +136,170 @@ def string_to_uuid(self, s: str, salt="memobase_client"): class MemosApiClient: - def __init__(self): - self.memos_url = os.getenv("MEMOS_URL") - self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} + """Product Add API 封装""" + + def __init__(self, timeout: float = 600.0): + self.base_url = os.getenv("MEMOS_URL") + self.headers = {"Content-Type": "application/json"} + self.timeout = timeout + + def add( + self, + messages, + user_id, + writable_cube_ids: list[str], + source_type: str, + mode: str, + async_mode: str, + ): + """ + 调用 /product/add 接口 + + Args: + messages: 添加记忆信息 + user_id: 用户ID + writable_cube_ids: 可写cube ID列表 + source_type: 来源类型 + mode: 模式 (fine/coarse) + async_mode: 异步模式 (sync/async) + """ + url = f"{self.base_url}/product/add" + + payload = { + "user_id": user_id, + "writable_cube_ids": writable_cube_ids, + "messages": messages, + "info": {"source_type": source_type}, + "mode": mode, + "async_mode": async_mode, + } + + response = requests.post( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers=self.headers, + timeout=self.timeout, + ) + + if response.status_code != 200: + raise RuntimeError(f"HTTP {response.status_code}: {response.text}") - def add(self, messages, user_id, conv_id, batch_size: int = 9999): + body = response.json() + if body.get("code") is not None and body.get("code") != 200: + raise RuntimeError(f"BUSINESS ERROR {body.get('code')}: {response.text}") + + return body + + def search(self, query, user_id, readable_cube_ids: list[str], top_k: str, mode: str): """ - messages = [{"role": "assistant", "content": data, "chat_time": date_str}] + 调用 /product/search 接口 + + Args: + query: 搜索查询 + user_id: 用户ID + readable_cube_ids: 可读cube ID列表, 默认为[user_id] + top_k: 返回结果数量 """ - url = f"{self.memos_url}/product/add" - added_memories = [] - for i in range(0, len(messages), batch_size): - batch_messages = messages[i : i + batch_size] - payload = json.dumps( - { - "messages": batch_messages, - "user_id": user_id, - "mem_cube_id": user_id, - "conversation_id": conv_id, - } - ) - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Memory added successfully", ( - response.text - ) - added_memories += json.loads(response.text)["data"] - return added_memories - def search(self, query, user_id, top_k): - """Search memories.""" - url = f"{self.memos_url}/product/search" - payload = json.dumps( - { - "query": query, - "user_id": user_id, - "mem_cube_id": user_id, - "conversation_id": "", - "top_k": top_k, - "mode": os.getenv("SEARCH_MODE", "fast"), - "include_preference": True, - "pref_top_k": 6, - }, - ensure_ascii=False, - ) - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Search completed successfully", ( - response.text + url = f"{self.base_url}/product/search" + + if readable_cube_ids is None: + readable_cube_ids = [user_id] + + payload = { + "query": query, + "user_id": user_id, + "readable_cube_ids": readable_cube_ids, + "top_k": top_k, + "mode": mode, + } + + response = requests.post( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers=self.headers, + timeout=self.timeout, ) - return json.loads(response.text)["data"] + + if response.status_code != 200: + raise RuntimeError(f"HTTP {response.status_code}: {response.text}") + + return response.json() class MemosApiOnlineClient: def __init__(self): self.memos_url = os.getenv("MEMOS_ONLINE_URL") - self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} - - def add(self, messages, user_id, conv_id=None, batch_size: int = 9999): + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Token {os.environ['MEMOS_API_KEY']}", + } + + def add( + self, + messages, + user_id, + writable_cube_ids: list[str], + source_type: str, + mode: str, + async_mode: str, + ): url = f"{self.memos_url}/add/message" - for i in range(0, len(messages), batch_size): - batch_messages = messages[i : i + batch_size] - payload = json.dumps( - { - "messages": batch_messages, - "user_id": user_id, - "conversation_id": conv_id, - } - ) - - max_retries = 5 - for attempt in range(max_retries): - try: - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "ok", response.text - break - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e - - def search(self, query, user_id, top_k): - """Search memories.""" - url = f"{self.memos_url}/search/memory" payload = json.dumps( { - "query": query, "user_id": user_id, - "memory_limit_number": top_k, - "mode": os.getenv("SEARCH_MODE", "fast"), - "include_preference": True, - "pref_top_k": 6, + "conversation_id": user_id, + "messages": messages, + "writable_cube_ids": writable_cube_ids, + "info": {"source_type": source_type}, + "mode": mode, + "async_mode": async_mode, } ) - max_retries = 5 - for attempt in range(max_retries): - try: - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "ok", response.text - text_mem_res = json.loads(response.text)["data"]["memory_detail_list"] - pref_mem_res = json.loads(response.text)["data"]["preference_detail_list"] - preference_note = json.loads(response.text)["data"]["preference_note"] - for i in text_mem_res: - i.update({"memory": i.pop("memory_value")}) - explicit_pref_string = "Explicit Preference:" - implicit_pref_string = "\n\nImplicit Preference:" - explicit_idx = 0 - implicit_idx = 0 - for pref in pref_mem_res: - if pref["preference_type"] == "explicit_preference": - explicit_pref_string += f"\n{explicit_idx + 1}. {pref['preference']}" - explicit_idx += 1 - if pref["preference_type"] == "implicit_preference": - implicit_pref_string += f"\n{implicit_idx + 1}. {pref['preference']}" - implicit_idx += 1 - - return { - "text_mem": [{"memories": text_mem_res}], - "pref_string": explicit_pref_string + implicit_pref_string + preference_note, + response = requests.request("POST", url, data=payload, headers=self.headers) + assert response.status_code == 200, response.text + assert json.loads(response.text)["message"] == "ok", response.text + return response.json() + + def search(self, query: str, user_id: str, top_k: int, mode: str, knowledgebase_ids: list[str]): + """Search memories.""" + url = f"{self.memos_url}/search/memory" + data = { + "query": query, + "user_id": user_id, + "memory_limit_number": top_k, + "knowledgebase_ids": knowledgebase_ids, + "mode": mode, + } + + resp = requests.post(url, headers=self.headers, json=data, timeout=60) + resp.raise_for_status() + return resp.json() + + def upload_file(self, knowledgebase_id: str, file_url: str): + """Upload file.""" + url = f"{self.memos_url}/add/knowledgebase-file" + data = { + "knowledgebase_id": knowledgebase_id, + "file": [ + { + "content": file_url, } + ], + } + + resp = requests.post(url, headers=self.headers, json=data, timeout=60) + resp.raise_for_status() + return resp.json() - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + def check_file(self, file_ids: list[str]): + """Check file state.""" + url = f"{self.memos_url}/get/knowledgebase-file" + data = {"file_ids": file_ids} + resp = requests.post(url, headers=self.headers, json=data, timeout=60) + resp.raise_for_status() + return resp.json() class SupermemoryClient: @@ -280,40 +308,85 @@ def __init__(self): self.client = Supermemory(api_key=os.getenv("SUPERMEMORY_API_KEY")) - def add(self, messages, user_id): - content = "\n".join( - [f"{msg['chat_time']} {msg['role']}: {msg['content']}" for msg in messages] - ) - max_retries = 5 - for attempt in range(max_retries): - try: - self.client.memories.add(content=content, container_tag=user_id) - break - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + self.api_key = os.getenv("SUPERMEMORY_API_KEY") + if not self.api_key: + raise ValueError( + "SUPERMEMORY_API_KEY environment variable is not set. Please set it in your .env file or environment." + ) + self.add_url = "https://api.supermemory.ai/v3/documents" + self.search_url = "https://api.supermemory.ai/v3/search" + + def _sanitize_tag(self, s: str) -> str: + t = str(s).strip() + t = os.path.splitext(t)[0] + t = t.replace(" ", "_") + t = re.sub(r"[^A-Za-z0-9_-]", "_", t) + t = re.sub(r"[_-]+", "_", t) + t = t.strip("_") + t = t.lower() + if not re.match(r"^[a-z0-9]", t or ""): + t = f"tag_{t}" if t else "tag_default" + return t + + def add( + self, content: str | None = None, user_id: str | None = None, messages: list | None = None + ): + if messages: + content = "\n".join( + f"{msg.get('chat_time', '')} {msg.get('role', '')}: {msg.get('content', '')}" + for msg in messages + ) - def search(self, query, user_id, top_k): - max_retries = 10 - for attempt in range(max_retries): - try: - results = self.client.search.memories( - q=query, - container_tag=user_id, - threshold=0, - rerank=True, - rewrite_query=True, - limit=top_k, - ) - context = "\n\n".join([r.memory for r in results.results]) - return context - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + max_retries = 5 + for attempt in range(max_retries): + try: + self.client.memories.add(content=content, container_tag=user_id) + break + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e + return + + payload = { + "content": content, + "raw": content, + "containerTag": self._sanitize_tag(user_id), + } + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(self.add_url, json=payload, headers=headers) + resp.raise_for_status() + return resp.json() + + def search(self, query: str, user_id: str, top_k: int): + payload = { + "q": query, + "limit": top_k, + "containerTags": [self._sanitize_tag(user_id)], + "rerank": True, + } + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + resp = requests.post(self.search_url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + + chunk_list = [] + res = [entry.get("chunks") for entry in data.get("results", [])] + for chunks in res: + for chunk in chunks: + chunk_list.append(chunk["content"]) + + return chunk_list class MemuClient: @@ -354,12 +427,110 @@ def wait_for_completion(self, task_id): time.sleep(2) +class FastGPTClient: + def __init__(self): + self.base_url = os.getenv("FASTGPT_BASE_URL") + self.api_key = os.getenv("FASTGPT_API_KEY") + + def create_dataset(self, dataset_name: str): + url = f"{self.base_url}/core/dataset/create" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = { + "name": dataset_name, + } + resp = requests.post(url, headers=headers, json=data, timeout=30) + resp.raise_for_status() + dataset_id = resp.json()["data"] + return dataset_id + + def delete_dataset(self, dataset_id: str): + url = f"{self.base_url}/core/dataset/delete?id={dataset_id}" + headers = {"Authorization": f"Bearer {self.api_key}"} + resp = requests.delete(url, headers=headers, timeout=30) + resp.raise_for_status() + return resp.json() + + def add_content(self, dataset_id: str, content: str, collection_name: str): + url = f"{self.base_url}/core/dataset/collection/create/text" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = { + "text": content, + "datasetId": dataset_id, + "name": collection_name, + "trainingType": "chunk", + "chunkSettingMode": "auto", + } + resp = requests.post(url, headers=headers, json=data, timeout=60) + resp.raise_for_status() + return resp.json() + + def upload_file(self, dataset_id: str, file_url: str): + url = f"{self.base_url}/proApi/core/dataset/collection/create/externalFileUrl" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = { + "externalFileUrl": file_url, + "externalFileId": file_url, + "datasetId": dataset_id, + "trainingType": "chunk", + "chunkSize": 512, + } + resp = requests.post(url, headers=headers, json=data, timeout=60) + resp.raise_for_status() + return resp.json() + + def batch_add_content(self, collection_id: str, data: list[str]): + url = f"{self.base_url}/core/dataset/data/pushData" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = {"collectionId": collection_id, "data": [{"q": d} for d in data]} + resp = requests.post(url, headers=headers, json=data, timeout=30) + resp.raise_for_status() + return resp.json() + + def search(self, dataset_id: str, query: str, top_k: int): + url = f"{self.base_url}/core/dataset/searchTest" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = {"datasetId": dataset_id, "text": query, "searchMode": "embedding"} + resp = requests.post(url, headers=headers, json=data, timeout=30) + resp.raise_for_status() + + result = resp.json() + data_list = result["data"]["list"] + return data_list + + def create_collection(self, dataset_id: str, collection_name: str): + url = f"{self.base_url}/core/dataset/collection/create" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = {"datasetId": dataset_id, "name": collection_name, "type": "virtual"} + resp = requests.post(url, headers=headers, json=data, timeout=30) + resp.raise_for_status() + collection_id = resp.json()["data"] + return collection_id + + if __name__ == "__main__": messages = [ {"role": "user", "content": "杭州西湖有什么好玩的"}, {"role": "assistant", "content": "杭州西湖有好多松鼠,还有断桥"}, ] - user_id = "test_user" + user_id = "lme_exper_user_default_499" iso_date = "2023-05-01T00:00:00.000Z" timestamp = 1682899200 query = "杭州西湖有什么" @@ -369,6 +540,6 @@ def wait_for_completion(self, task_id): client = MemosApiClient() for m in messages: m["created_at"] = iso_date - client.add(messages, user_id, user_id) - memories = client.search(query, user_id, top_k) + client.add(messages, user_id, [user_id], "extreme_multimodal", "fine", "async") + memories = client.search(query, user_id, [user_id], top_k, "fast") print(memories) diff --git a/evaluation/scripts/utils/eval_score.py b/evaluation/scripts/utils/eval_score.py new file mode 100644 index 000000000..02ef6eb53 --- /dev/null +++ b/evaluation/scripts/utils/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/utils/extract_answer.py b/evaluation/scripts/utils/extract_answer.py new file mode 100644 index 000000000..e527d4b97 --- /dev/null +++ b/evaluation/scripts/utils/extract_answer.py @@ -0,0 +1,58 @@ +import os + +from pathlib import Path + +import openai + +from dotenv import load_dotenv + + +load_dotenv() + +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + +PROMPT_PATH = Path("evaluation/scripts/utils/prompt_for_answer_extraction.md") +with open(PROMPT_PATH, encoding="utf-8") as f: + EXTRACTION_PROMPT = f.read() + + +def extract_answer(question: str, output: str, model_name: str = "gpt-4o-mini") -> str: + resp = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "user", "content": EXTRACTION_PROMPT}, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + content = resp.choices[0].message.content or "" + return content + + +def parse_extracted_answer(extracted_res: str, fallback_output: str) -> str: + try: + head = extracted_res.split("Answer format:")[0] + ans = head.split("Extracted answer:")[1].strip() + if ans: + return ans + except Exception: + pass + text = (fallback_output or "").strip() + low = text.lower() + if " yes" in low or low.startswith("yes"): + return "yes" + if " no" in low or low.startswith("no"): + return "no" + for sep in ["\n", ". ", ".", "?", "!"]: + if sep in text: + cand = text.split(sep)[0].strip() + if cand: + return cand + return text diff --git a/evaluation/scripts/utils/metrics.py b/evaluation/scripts/utils/metrics.py new file mode 100644 index 000000000..135a60cec --- /dev/null +++ b/evaluation/scripts/utils/metrics.py @@ -0,0 +1,56 @@ +import threading + + +class Metrics: + def __init__(self): + self.times_ms: list[float] = [] + self.success_count = 0 + self.fail_count = 0 + self.errors = {} + self.lock = threading.Lock() + + def record(self, duration_s: float, success: bool, error_msg: str | None = None): + ms = duration_s * 1000.0 + with self.lock: + if success: + self.times_ms.append(ms) + self.success_count += 1 + else: + self.fail_count += 1 + if error_msg: + short_err = error_msg[:200] if len(error_msg) > 200 else error_msg + self.errors[short_err] = self.errors.get(short_err, 0) + 1 + + def summary(self) -> dict: + with self.lock: + if not self.times_ms: + return { + "stats": {}, + "counts": {"success": self.success_count, "failed": self.fail_count}, + "errors": dict(self.errors), + } + sorted_times = sorted(self.times_ms) + n = len(sorted_times) + + def percentile(p: int): + if n == 1: + return sorted_times[0] + k = max(0, min(n - 1, round((p / 100) * (n - 1)))) + return sorted_times[k] + + mean = sum(sorted_times) / n + variance = sum((x - mean) ** 2 for x in sorted_times) / (n - 1) if n > 1 else 0.0 + return { + "stats": { + "count": n, + "mean": mean, + "median": percentile(50), + "min": sorted_times[0], + "max": sorted_times[-1], + "p95": percentile(95), + "p99": percentile(99), + "std": variance**0.5, + }, + "counts": {"success": self.success_count, "failed": self.fail_count}, + "errors": dict(self.errors), + } diff --git a/evaluation/scripts/utils/prompt_for_answer_extraction.md b/evaluation/scripts/utils/prompt_for_answer_extraction.md new file mode 100644 index 000000000..a309c0935 --- /dev/null +++ b/evaluation/scripts/utils/prompt_for_answer_extraction.md @@ -0,0 +1,35 @@ +Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. +- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". +- Please make your response as concise as possible. Also note that your response should be formatted as below: +``` +Extracted answer: [answer] +Answer format: [answer format] +``` + +Please read the following example, then extract the answer from the model response and type it at the end of the prompt. + +--- +Question: List the primary questions asked about the services in this report. +Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? +Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] +Answer format: List + +--- +Question: How many regulations of the HSCA 2008 are breached in all according to this report? +Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. +Extracted answer: 10 +Answer format: Integer + +--- +Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? +Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! +Extracted answer: Not answerable +Answer format: String + +--- +Question: How many quotations from male respondent over 50 years old are included in this report? +Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. +Extracted answer: Fail to answer +Answer format: String + +--- diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index 32e6d6729..ba5f5db8f 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -65,6 +65,55 @@ {context} """ +MMLONGBENCH_ANSWER_PROMPT = """ + You are a helpful assistant that can answer questions based on the provided memories and images. + + {memories} + + Read the above memories and answer this question + Please make your answer as concise as possible. +""" + +LONGBENCH_V2_ANSWER_PROMPT = """ +Please read the following retrieved text chunks and answer the question below. + + +$DOC$ + + +What is the correct answer to this question: $Q$ +Choices: +(A) $C_A$ +(B) $C_B$ +(C) $C_C$ +(D) $C_D$ + +Format your response as follows: "The correct answer is (insert answer here)". +""" + + +HOTPOT_ANSWER_PROMPT = """ +You are answering a question from the HotpotQA dataset. + +The question may require multi-hop reasoning across multiple supporting facts. +Carefully read the provided context and identify the relevant evidence. +Reason step by step to connect the facts and determine the correct answer. + +Important instructions: +- Use only the information provided in the context. +- Perform multi-step reasoning internally if needed. +- The final answer must be a short factual answer (e.g., a name, place, date, or entity). +- Do NOT include explanations, reasoning steps, or citations in the final output. + +Question: +{question} + +Context: +{context} + +Final Answer: + +""" ZEP_CONTEXT_TEMPLATE = """ FACTS and ENTITIES represent relevant context to the current conversation.