Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ tmp/
**settings.json**
evaluation/*tmp/
evaluation/results
evaluation/.env
.env
!evaluation/configs-example/*.json
evaluation/configs/*
**tree_textual_memory_locomo**
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions evaluation/scripts/hotpot/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json

from pathlib import Path

from datasets import load_dataset


def load_hotpot_data(data_dir: Path | str) -> list[dict]:
"""
Load HotpotQA dataset.
If dev_distractor_gold.json exists in data_dir, load it.
Otherwise, download from Hugging Face, convert to standard format, save, and load.
"""
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
file_path = data_dir / "dev_distractor_gold.json"

if file_path.exists():
print(f"Loading local dataset from {file_path}")
try:
with open(file_path, encoding="utf-8") as f:
return json.load(f)
except Exception as e:
print(f"Failed to load local file: {e}. Re-downloading...")

print("Downloading HotpotQA dataset from Hugging Face...")
try:
dataset = load_dataset(
"hotpotqa/hotpot_qa", "distractor", split="validation", trust_remote_code=True
)
except Exception as e:
print(f"Failed to download dataset: {e}")
raise

print(f"Processing and saving dataset to {file_path}...")
items = []
for item in dataset:
# Convert HF format to Standard format
# ID
qid = item.get("id") or item.get("_id")

# Supporting Facts
sp = item.get("supporting_facts")
if isinstance(sp, dict):
sp_titles = sp.get("title", [])
sp_sent_ids = sp.get("sent_id", [])
sp_list = list(zip(sp_titles, sp_sent_ids, strict=False))
else:
sp_list = sp or []

# Context
ctx = item.get("context")
if isinstance(ctx, dict):
ctx_titles = ctx.get("title", [])
ctx_sentences = ctx.get("sentences", [])
ctx_list = list(zip(ctx_titles, ctx_sentences, strict=False))
else:
ctx_list = ctx or []

new_item = {
"_id": qid,
"question": item.get("question"),
"answer": item.get("answer"),
"supporting_facts": sp_list,
"context": ctx_list,
"type": item.get("type"),
"level": item.get("level"),
}
items.append(new_item)

try:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(items, f, ensure_ascii=False, indent=2)
print(f"Saved {len(items)} items to {file_path}")
except Exception as e:
print(f"Failed to save dataset: {e}")

return items
245 changes: 245 additions & 0 deletions evaluation/scripts/hotpot/hotpot_check_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import argparse
import json
import os

from pathlib import Path

from dotenv import load_dotenv
from tqdm import tqdm

from evaluation.scripts.utils.client import MemosApiOnlineClient


load_dotenv()
memos_knowledgebase_id = os.getenv("MEMOS_KNOWLEDGEBASE_ID_HOTPOT")


# Load user_id -> file_id mapping from added_records.json
def _load_added_ids(records_path: Path) -> dict[str, str | None]:
if not records_path.exists():
return {}

try:
obj = json.loads(records_path.read_text(encoding="utf-8"))
added = obj.get("added") if isinstance(obj, dict) else None
if isinstance(added, dict):
return {str(k): (str(v) if v is not None else None) for k, v in added.items()}
except Exception:
return {}

return {}


def _check_file_status(
client: MemosApiOnlineClient, file_ids: list[str], batch_size: int
) -> dict[str, dict[str, str | None]]:
"""
Phase 1: Query file processing status for given file_ids in batches.
Returns a dict: file_id -> {name, size, status}.
"""
file_status: dict[str, dict[str, str | None]] = {}

for i in tqdm(range(0, len(file_ids), batch_size), desc="Checking files"):
batch = file_ids[i : i + batch_size]
try:
resp = client.check_file(batch)
except Exception as e:
print(f"[Check] error for batch starting at {i}: {e}")
continue

if not isinstance(resp, dict):
continue

data = resp.get("data") or {}
details = data.get("file_detail_list") or []

for item in details:
if not isinstance(item, dict):
continue
fid = item.get("id")
if not fid:
continue
file_status[str(fid)] = {
"name": item.get("name"),
"size": item.get("size"),
"status": item.get("status"),
}

return file_status


def _reupload_failed_files(
client: MemosApiOnlineClient,
file_status: dict[str, dict[str, str | None]],
added_ids: dict[str, str | None],
url_prefix: str,
) -> list[dict[str, str | None]]:
"""
Phase 2: Re-upload files which status == PROCESSING_FAILED.
The file URL is built using user_id and url_prefix: <prefix>/<user_id>_context.txt.
Returns a list of per-file results for auditing.
"""
fid_to_user: dict[str, str] = {}
for uid, fid in added_ids.items():
if fid:
fid_to_user[str(fid)] = str(uid)

reupload_results: list[dict[str, str | None]] = []
failed_ids = [
fid for fid, info in file_status.items() if (info.get("status") == "PROCESSING_FAILED")
]

for fid in tqdm(failed_ids, desc="Reuploading failed files"):
uid = fid_to_user.get(fid)
if not uid:
reupload_results.append(
{
"old_file_id": fid,
"user_id": None,
"new_file_id": None,
"ok": "false",
"error": "user_id_not_found",
}
)
continue

file_url = f"{url_prefix.rstrip('/')}/{uid}_context.txt"
try:
resp = client.upload_file(memos_knowledgebase_id or "", file_url)
new_id = None
if isinstance(resp, dict):
data = resp.get("data") or {}
if isinstance(data, list) and data:
first = data[0] if isinstance(data[0], dict) else {}
new_id = str(first.get("id")) if first.get("id") else None
reupload_results.append(
{
"old_file_id": fid,
"user_id": uid,
"new_file_id": new_id,
"ok": "true",
"error": None,
}
)
except Exception as e:
reupload_results.append(
{
"old_file_id": fid,
"user_id": uid,
"new_file_id": None,
"ok": "false",
"error": str(e),
}
)

return reupload_results


def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Check HotpotQA memos-online file status.")
parser.add_argument(
"--lib",
type=str,
default="memos-online",
)
parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
parser.add_argument("--batch-size", type=int, default=50)
parser.add_argument("--output", type=str, default=None, help="输出JSON文件路径")
parser.add_argument(
"--url-prefix",
"-u",
default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/hotpot_text_files/",
help="文件URL前缀",
)

args = parser.parse_args(argv)

if args.lib != "memos-online":
print(f"Only memos-online is supported, got lib={args.lib}")
return

output_dir = Path("evaluation/data/hotpot")
if args.version_dir:
output_dir = output_dir / args.version_dir
output_dir.mkdir(parents=True, exist_ok=True)

records_path = output_dir / f"{args.lib}_added_records.json"
print(f"[Check] loading records from {records_path}")

added_ids = _load_added_ids(records_path)
file_ids = sorted({fid for fid in added_ids.values() if fid})
print(f"[Check] total file ids: {len(file_ids)}")

if not file_ids:
return

client = MemosApiOnlineClient()
batch_size = max(1, args.batch_size)

# Phase 1: Query file processing status
file_status = _check_file_status(client, file_ids, batch_size)

# Phase 2: Re-upload files which failed processing
reupload_results = _reupload_failed_files(
client=client,
file_status=file_status,
added_ids=added_ids,
url_prefix=args.url_prefix,
)

# Persist: update original added_records.json with new file ids
if reupload_results:
try:
# Load original JSON object to preserve additional fields (e.g., perf)
obj: dict = {}
if records_path.exists():
txt = records_path.read_text(encoding="utf-8")
if txt:
parsed = json.loads(txt)
if isinstance(parsed, dict):
obj = parsed
# Build updated 'added' dict
added_obj: dict[str, str | None] = {}
if isinstance(obj.get("added"), dict):
added_obj = {
str(k): (str(v) if v is not None else None) for k, v in obj["added"].items()
}
else:
added_obj = dict(added_ids)
# Apply new file ids from successful reuploads
for item in reupload_results:
if item.get("ok") == "true" and item.get("user_id") and item.get("new_file_id"):
uid = str(item["user_id"])
added_obj[uid] = str(item["new_file_id"])
obj["added"] = dict(sorted(added_obj.items()))
# Atomic write back to records_path
tmp_r = records_path.with_suffix(records_path.suffix + ".tmp")
tmp_r.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
os.replace(tmp_r, records_path)
print(f"[Update] updated added_records with new file ids -> {records_path}")
except Exception as e:
print(f"[Update] failed to update added_records: {e}")

if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
else:
output_path = output_dir / f"{args.lib}_file_status.json"

result_obj = {
"lib": args.lib,
"version_dir": args.version_dir,
"total": len(file_ids),
"file_detail_list": [{"id": fid, **(file_status.get(fid) or {})} for fid in file_ids],
"reupload_results": reupload_results,
}

tmp = output_path.with_suffix(output_path.suffix + ".tmp")
tmp.write_text(json.dumps(result_obj, ensure_ascii=False, indent=2), encoding="utf-8")
os.replace(tmp, output_path)

print(f"[Check] saved file status for {len(file_status)} files to {output_path}")


if __name__ == "__main__":
main()
Loading