Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 141 additions & 8 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import threading
import time
from typing import Any, Dict, List
from collections import defaultdict, deque
from typing import Any, Callable, Dict, List

from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -58,6 +60,8 @@
logger = logging.getLogger("xmem.api.routes.memory")

_ingest_semaphore = asyncio.Semaphore(5)
_latency_samples: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=200))
_latency_lock = threading.Lock()

router = APIRouter(
prefix="/v1/memory",
Expand Down Expand Up @@ -233,6 +237,53 @@ def _schedule_job(job: Dict[str, Any], handler) -> None:
asyncio.create_task(run_job(get_default_job_store(), job["job_id"], handler))


def _record_latency(mode: str, elapsed_ms: float) -> None:
with _latency_lock:
_latency_samples[mode].append(elapsed_ms)


def _percentile(sorted_values: List[float], percentile: float) -> float:
if not sorted_values:
return 0.0
index = min(len(sorted_values) - 1, int(round((len(sorted_values) - 1) * percentile)))
return round(sorted_values[index], 2)


def _latency_stats() -> Dict[str, Dict[str, float]]:
with _latency_lock:
snapshot = {mode: list(samples) for mode, samples in _latency_samples.items()}

stats: Dict[str, Dict[str, float]] = {}
for mode, samples in snapshot.items():
values = sorted(samples)
stats[mode] = {
"count": len(values),
"p50_ms": _percentile(values, 0.50),
"p95_ms": _percentile(values, 0.95),
"p99_ms": _percentile(values, 0.99),
}
return stats


async def _timed(
mode: str,
func: Callable[..., Any],
*args: Any,
threaded: bool = False,
**kwargs: Any,
) -> tuple[Any, float]:
start = time.perf_counter()
if threaded:
result = await asyncio.to_thread(func, *args, **kwargs)
else:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
_record_latency(mode, elapsed_ms)
return result, elapsed_ms


def _detect_chat_provider(*urls: str) -> str:
for url in urls:
lowered = (url or "").lower()
Expand Down Expand Up @@ -915,15 +966,69 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen

try:
all_results: List[SourceRecord] = []
latency_ms: Dict[str, float] = {}
plan = pipeline.raw_retrieval_plan(req.domains, answer=req.answer)
raw_tasks = []

if "profile" in plan:
raw_tasks.append((
"profile",
_timed("profile", _search_profile, pipeline, user_id, threaded=True),
))
if "temporal" in plan:
raw_tasks.append((
"temporal",
_timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k, threaded=True),
))
if "summary" in plan:
raw_tasks.append((
"summary",
_timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k),
))
if "snippet" in plan:
raw_tasks.append((
"snippet",
_timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k),
))
if "code" in plan:
raw_tasks.append((
"code",
_timed("code", _search_code, pipeline, req.query, user_id, req.top_k),
))

if "profile" in req.domains:
all_results.extend(_search_profile(pipeline, user_id))
if "temporal" in req.domains:
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
if "summary" in req.domains:
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))
if raw_tasks:
raw_results = await asyncio.gather(*(task for _, task in raw_tasks))
for (domain, _), (results, elapsed) in zip(raw_tasks, raw_results):
latency_ms[domain] = elapsed
all_results.extend(results)

all_results.sort(key=lambda record: record.score, reverse=True)

answer = None
answer_sources: List[SourceRecord] = []
confidence = 0.0
if req.answer:
answer_result, elapsed = await _timed("answer", pipeline.run, req.query, user_id, req.top_k)
latency_ms["answer"] = elapsed
answer = answer_result.answer
confidence = answer_result.confidence
answer_sources = [
SourceRecord(
domain=s.domain, content=s.content,
score=round(s.score, 3), metadata=s.metadata,
)
for s in answer_result.sources
]

data = SearchResponse(results=all_results, total=len(all_results))
data = SearchResponse(
results=all_results,
total=len(all_results),
answer=answer,
answer_sources=answer_sources,
confidence=confidence,
latency_ms=latency_ms,
latency_stats=_latency_stats(),
)
elapsed = round((time.perf_counter() - start) * 1000, 2)
return _wrap(request, data, elapsed)

Expand Down Expand Up @@ -988,6 +1093,34 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str,
return []


async def _search_snippet(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]:
try:
raw = await pipeline._search_snippet(query=query, user_id=user_id, top_k=top_k)
return [
SourceRecord(domain=r.domain, content=r.content, score=round(r.score, 3), metadata=r.metadata)
for r in raw
]
except Exception as exc:
logger.warning("Snippet search error: %s", exc)
return []


async def _search_code(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]:
try:
raw = await pipeline.vector_store.search_by_text(
query_text=query,
top_k=top_k,
filters={"user_id": user_id, "domain": "code"},
)
return [
SourceRecord(domain="code", content=r.content, score=round(r.score, 3), metadata={"id": r.id, **r.metadata})
for r in raw
]
except Exception as exc:
logger.warning("Code search error: %s", exc)
return []


# POST /v1/memory/scrape
@scrape_router.post(
"/scrape",
Expand Down
23 changes: 19 additions & 4 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -159,24 +158,40 @@ class SearchRequest(BaseModel):
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
)
domains: List[str] = Field(
default=["profile", "temporal", "summary"],
default=["profile", "temporal", "summary", "snippet", "code"],
description="Which memory domains to search",
)
top_k: int = Field(default=10, ge=1, le=100)
answer: bool = Field(
default=False,
description="When true, also generate a synthesized answer after returning raw ranked hits.",
)

@field_validator("domains")
@classmethod
def validate_domains(cls, v: List[str]) -> List[str]:
allowed = {"profile", "temporal", "summary"}
allowed = {"profile", "temporal", "summary", "snippet", "code"}
for d in v:
if d not in allowed:
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
return v
return list(dict.fromkeys(v))


class SearchLatencySummary(BaseModel):
count: int = 0
p50_ms: float = 0.0
p95_ms: float = 0.0
p99_ms: float = 0.0


class SearchResponse(BaseModel):
results: List[SourceRecord] = Field(default_factory=list)
total: int = 0
answer: Optional[str] = None
answer_sources: List[SourceRecord] = Field(default_factory=list)
confidence: float = 0.0
latency_ms: Dict[str, float] = Field(default_factory=dict)
latency_stats: Dict[str, SearchLatencySummary] = Field(default_factory=dict)


# ── Scrape (extract from shared chat links) ────────────────────────────────
Expand Down
46 changes: 46 additions & 0 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import asyncio
import logging
import threading
import time
from typing import Any, Callable, Dict, List, Optional

from dotenv import load_dotenv
Expand Down Expand Up @@ -133,6 +135,12 @@ def __init__(

self.embed_fn = embed_fn
self._snippet_stores: Dict[str, BaseVectorStore] = {}
self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], list]] = {}
self._raw_retrieval_plan_cache: Dict[tuple[tuple[str, ...], bool], tuple[str, ...]] = {}
self._cache_ttl_seconds = 60.0
self._profile_catalog_cache_max_users = 256
self._profile_catalog_cache_lock = threading.Lock()
self._raw_retrieval_plan_cache_lock = threading.Lock()

logger.info("RetrievalPipeline initialized")

Expand Down Expand Up @@ -494,6 +502,17 @@ def _fetch_profile_catalog(self, user_id: str):
catalog — list of {topic, sub_topic} for the prompt
raw_results — the full SearchResult list, cached for _search_profile
"""
now = time.monotonic()
with self._profile_catalog_cache_lock:
self._prune_profile_catalog_cache(now)

cached = self._profile_catalog_cache.get(user_id)
if cached and now - cached[0] < self._cache_ttl_seconds:
catalog, results = cached[1], cached[2]
self._profile_catalog_cache.pop(user_id)
self._profile_catalog_cache[user_id] = (now, catalog, results)
return catalog, results

try:
results = self.vector_store.search_by_metadata(
filters={"user_id": user_id, "domain": "profile"},
Expand Down Expand Up @@ -524,8 +543,35 @@ def _fetch_profile_catalog(self, user_id: str):
"sub_topic": "",
})

with self._profile_catalog_cache_lock:
self._prune_profile_catalog_cache(now)
self._profile_catalog_cache[user_id] = (now, catalog, results)
Comment on lines +546 to +548
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing double-check before cache write

When two concurrent calls for the same user_id both miss the cache in the first with block, they both release the lock and both execute the expensive search_by_metadata query. When the second with block is reached, there's no re-check of whether another thread already populated the entry. The second writer overwrites the first with a stale now timestamp (captured at the very start of the function). A simple check like if user_id not in self._profile_catalog_cache: before assigning would avoid the redundant overwrite and protect against cache stampede on concurrent first-hit requests for the same user.

Fix in Cursor Fix in Codex Fix in Claude Code

return catalog, results

def _prune_profile_catalog_cache(self, now: float) -> None:
"""Bound profile catalog cache by TTL and number of cached users."""
expired_user_ids = [
cached_user_id
for cached_user_id, (cached_at, _, _) in self._profile_catalog_cache.items()
if now - cached_at >= self._cache_ttl_seconds
]
for cached_user_id in expired_user_ids:
self._profile_catalog_cache.pop(cached_user_id, None)

while len(self._profile_catalog_cache) >= self._profile_catalog_cache_max_users:
oldest_user_id = next(iter(self._profile_catalog_cache))
self._profile_catalog_cache.pop(oldest_user_id, None)

def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]:
"""Return a cached deterministic raw-search plan for the requested domains."""
ordered_allowed = ("profile", "temporal", "summary", "snippet", "code")
normalized = tuple(d for d in ordered_allowed if d in set(domains))
key = (normalized, answer)
with self._raw_retrieval_plan_cache_lock:
if key not in self._raw_retrieval_plan_cache:
self._raw_retrieval_plan_cache[key] = normalized
return self._raw_retrieval_plan_cache[key]
Comment on lines +565 to +573
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 answer is part of the cache key but is never used to compute the returned tuple — both (normalized, True) and (normalized, False) store and return the same normalized value. This silently doubles the number of entries in the plan cache for any domain set that is queried with both values, and the parameter looks meaningful when it isn't. Either use answer to actually vary the plan, or drop it from the key entirely.

Suggested change
def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]:
"""Return a cached deterministic raw-search plan for the requested domains."""
ordered_allowed = ("profile", "temporal", "summary", "snippet", "code")
normalized = tuple(d for d in ordered_allowed if d in set(domains))
key = (normalized, answer)
with self._raw_retrieval_plan_cache_lock:
if key not in self._raw_retrieval_plan_cache:
self._raw_retrieval_plan_cache[key] = normalized
return self._raw_retrieval_plan_cache[key]
def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]:
"""Return a cached deterministic raw-search plan for the requested domains."""
ordered_allowed = ("profile", "temporal", "summary", "snippet", "code")
normalized = tuple(d for d in ordered_allowed if d in set(domains))
with self._raw_retrieval_plan_cache_lock:
if normalized not in self._raw_retrieval_plan_cache:
self._raw_retrieval_plan_cache[normalized] = normalized
return self._raw_retrieval_plan_cache[normalized]

Fix in Cursor Fix in Codex Fix in Claude Code


def _format_catalog(self, catalog: List[Dict[str, str]]) -> str:
"""Format profile catalog for the system prompt."""
if not catalog:
Expand Down
Loading
Loading