Coverage for node / src / stigmem_node / routes / recall / orchestration.py: 94%
156 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Recall route orchestration."""
3from __future__ import annotations
5import hashlib
6import uuid
7from typing import Annotated, Any
9from fastapi import Depends, Header, HTTPException, Query, Response, status
10from fastapi.responses import JSONResponse
12from ...auth import Identity, resolve_identity
13from ...card_materializer import CARD_MIN_CONFIDENCE, get_fresh_card
14from ...db import db
15from ...lifecycle.tombstone_cache import is_tombstoned as _is_tombstoned
16from ...metrics import FACT_READ, RECALL_RANKER_DURATION, observe_duration
17from ...models.constants import VALID_SCOPES
18from ...models.facts import FactRecord, FactValue
19from ...models.recall import (
20 FactChainProof,
21 RecallRequest,
22 RecallResponse,
23 ScoreBreakdown,
24 ScoredFact,
25)
26from ...plugins import get_registry
27from ...recall.graph import MAX_DEPTH
28from ...recall.recall_pipeline import apply_recall_pipeline
29from ...session_graph import record_read_scopes
30from ...tracing import start_span
31from ..time_travel_gate import require_time_travel_enabled
32from .as_of import _recall_as_of_impl
33from .common import (
34 _estimate_tokens,
35 _fetch_facts_by_ids,
36 _now_iso,
37 _write_recall_audit,
38 logger,
39 router,
40)
41from .graph import _MAX_SEED_ENTITIES, _graph_expand
42from .lexical import _lexical_search
43from .ranking import _greedy_pack, _score_candidates
44from .vector import _semantic_search
46_MAX_CANDIDATES = 500
47@router.post("", response_model=RecallResponse)
48def recall(
49 req: RecallRequest,
50 identity: Annotated[Identity, Depends(resolve_identity)],
51 response: Response,
52 session_id: Annotated[str | None, Header(alias="Stigmem-Session")] = None,
53 verify_mode: Annotated[str | None, Header(alias="Stigmem-Verify")] = None,
54 legacy_format: Annotated[
55 bool,
56 Query(
57 description=(
58 "Return the temporary legacy recall response shape without "
59 "`content` / `instructions` channel fields."
60 )
61 ),
62 ] = False,
63) -> RecallResponse | JSONResponse:
64 """Hybrid recall — return the most salient facts for a query, within budget.
66 Combines lexical (FTS5/BM25), dense-vector, and graph-traversal signals.
67 Honors Spec-02-Scopes-and-ACL and Spec-05-Federation-Trust at every step.
68 """
69 with start_span(
70 "stigmem.recall",
71 **{
72 "stigmem.tenant": identity.tenant_id,
73 "stigmem.principal": identity.entity_uri,
74 "stigmem.scope": req.scope,
75 },
76 ) as _span:
77 result = _recall_impl(
78 req,
79 identity,
80 _span,
81 session_id=session_id,
82 verify_full=verify_mode == "full",
83 )
84 headers = {}
85 if result.total_scored is not None:
86 headers["X-Total-Count"] = str(result.total_scored)
87 response.headers["X-Total-Count"] = headers["X-Total-Count"]
88 if legacy_format:
89 return JSONResponse(content=_legacy_recall_payload(result), headers=headers)
90 return result
93def _legacy_recall_payload(result: RecallResponse) -> dict[str, Any]:
94 """Return the one-minor-version compatibility shape for pre-channel clients."""
95 return result.model_dump(mode="json", exclude={"content", "instructions"})
98def _split_interpretation_channels(
99 packed: list[ScoredFact],
100) -> tuple[list[ScoredFact], list[ScoredFact]]:
101 content = [scored for scored in packed if scored.fact.value.interpret_as != "instruction"]
102 instructions = [scored for scored in packed if scored.fact.value.interpret_as == "instruction"]
103 return content, instructions
106def _validate_recall_request(req: RecallRequest, identity: Identity) -> None:
107 """Auth + scope + depth validation. Raises HTTPException on failure."""
108 if not identity.can_read(): 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 raise HTTPException(
110 status_code=status.HTTP_403_FORBIDDEN, detail="read permission required"
111 )
112 FACT_READ.labels(principal=identity.entity_uri, tenant=identity.tenant_id).inc()
113 if req.scope not in VALID_SCOPES:
114 raise HTTPException(
115 status_code=status.HTTP_400_BAD_REQUEST,
116 detail=f"invalid_scope: must be one of {sorted(VALID_SCOPES)}",
117 )
118 if req.depth > MAX_DEPTH: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 raise HTTPException(
120 status_code=status.HTTP_400_BAD_REQUEST,
121 detail={"code": "graph_depth_exceeded", "message": f"depth must be ≤ {MAX_DEPTH}"},
122 )
125def _chain_proof_or_409(conn: Any, tenant_id: str) -> FactChainProof:
126 from ...fact_chain import FactChainIntegrityError, build_fact_chain_proof
128 try:
129 return FactChainProof(**build_fact_chain_proof(conn, tenant_id=tenant_id))
130 except FactChainIntegrityError as exc:
131 result = exc.result
132 raise HTTPException(
133 status_code=status.HTTP_409_CONFLICT,
134 detail={
135 "code": "fact_chain_mismatch",
136 "message": "local fact hash chain verification failed",
137 "mismatch_reason": result.mismatch_reason,
138 "fact_id": result.fact_id,
139 "chain_seq": result.chain_seq,
140 },
141 ) from exc
144def _handle_as_of_recall(
145 req: RecallRequest,
146 identity: Identity,
147 *,
148 verify_full: bool = False,
149) -> RecallResponse:
150 """§24 time-travel path. Validates as_of, runs the as_of impl, returns the response."""
151 from ..facts import _validate_as_of
153 require_time_travel_enabled(get_registry(), surface="recall")
154 _validate_as_of(req.as_of) # type: ignore[arg-type] # caller guarantees as_of is not None
155 recall_id = str(uuid.uuid4())
156 query_hash = hashlib.sha256(req.query.encode()).hexdigest()
157 with db() as conn:
158 packed, notices, tombstone_filtered = _recall_as_of_impl(
159 conn,
160 query=req.query,
161 scope=req.scope,
162 as_of=req.as_of, # type: ignore[arg-type]
163 is_admin_caller=identity.is_admin(),
164 tenant_id=identity.tenant_id,
165 max_chunks=req.limit,
166 include_graph=req.include_neighbors,
167 identity=identity,
168 weights=req.weights,
169 depth=req.depth,
170 )
171 chain_proof = _chain_proof_or_409(conn, identity.tenant_id) if verify_full else None
172 tokens_used = sum(sf.token_estimate for sf in packed)
173 content, instructions = _split_interpretation_channels(packed)
174 return RecallResponse(
175 recall_id=recall_id,
176 query_hash=query_hash,
177 facts=packed,
178 content=content,
179 instructions=instructions,
180 # §23.3.3 r.3: suppress total_scored when tombstone filtering was applied
181 total_scored=None if tombstone_filtered else len(packed),
182 token_budget=req.token_budget,
183 tokens_used=tokens_used,
184 truncated=False,
185 tombstone_notices=notices,
186 chain_proof=chain_proof,
187 )
190def _gather_direct_matches(
191 conn: Any,
192 req: RecallRequest,
193 identity: Identity,
194 now: str,
195) -> tuple[dict[str, float], dict[str, float]]:
196 """Run the lexical + semantic searches that produce direct-match scores."""
197 lex_scores = (
198 _lexical_search(
199 conn,
200 req.query,
201 req.scope,
202 identity.tenant_id,
203 req.limit,
204 req.min_confidence,
205 now,
206 )
207 if req.weights.lexical > 0
208 else {}
209 )
210 sem_scores = (
211 _semantic_search(conn, req.query, req.scope, identity.tenant_id, req.limit)
212 if req.weights.semantic > 0
213 else {}
214 )
215 return lex_scores, sem_scores
218def _expand_graph_neighbours(
219 conn: Any,
220 req: RecallRequest,
221 identity: Identity,
222 direct_ids: set[str],
223 lex_scores: dict[str, float],
224 sem_scores: dict[str, float],
225 now: str,
226) -> dict[str, int]:
227 """Optionally BFS from the top direct-match seeds. Returns {fact_id: hop_distance}."""
228 if not (req.include_neighbors and req.weights.graph > 0 and direct_ids):
229 return {}
230 top_seeds = sorted(
231 direct_ids,
232 key=lambda fid: lex_scores.get(fid, 0) + sem_scores.get(fid, 0),
233 reverse=True,
234 )[:_MAX_SEED_ENTITIES]
235 return _graph_expand(
236 conn,
237 top_seeds,
238 req.depth,
239 req.scope,
240 identity.tenant_id,
241 identity,
242 req.limit,
243 req.min_confidence,
244 now,
245 )
248def _build_card_for_entity(
249 entity_uri: str,
250 entity_fact_ids: list[str],
251 req: RecallRequest,
252 identity: Identity,
253 conn: Any,
254 now: str,
255) -> tuple[ScoredFact, list[str]] | None:
256 """Try to build a synthetic ScoredFact from a fresh, high-confidence card.
258 Returns (scored_fact, owned_fact_ids) on success, None when no card qualifies.
259 """
260 if _is_tombstoned(entity_uri, identity.tenant_id): 260 ↛ 262line 260 didn't jump to line 262 because the condition on line 260 was never true
261 # §23.3.2 r.3: cards whose about_entity is tombstoned are fully excluded.
262 return None
263 card = get_fresh_card(entity_uri, req.scope, identity.tenant_id, conn)
264 if (
265 card is None
266 or card.is_stale
267 or card.has_contradictions
268 or card.avg_confidence < CARD_MIN_CONFIDENCE
269 ):
270 return None
271 card_record = FactRecord(
272 id=f"card:{entity_uri}",
273 entity=entity_uri,
274 relation="stigmem:card:summary",
275 value=FactValue(type="text", v=card.summary),
276 source="system:stigmem:materializer",
277 timestamp=card.refreshed_at or now,
278 confidence=card.avg_confidence,
279 scope=req.scope,
280 )
281 sf = ScoredFact(
282 fact=card_record,
283 score=round(card.avg_confidence, 6),
284 score_breakdown=ScoreBreakdown(
285 source_trust=round(card.avg_confidence, 4),
286 weighted_total=round(card.avg_confidence, 6),
287 ),
288 hop_distance=0,
289 token_estimate=_estimate_tokens(card_record),
290 from_card=True,
291 )
292 return sf, list(entity_fact_ids)
295def _try_card_fast_path(
296 all_facts_raw: dict[str, FactRecord],
297 req: RecallRequest,
298 identity: Identity,
299 conn: Any,
300 now: str,
301) -> tuple[list[ScoredFact], set[str]]:
302 """Build card-derived ScoredFacts for any candidate entity that has a fresh card.
304 Returns (card_facts, card_owned_fact_ids). Logs and returns ([], set()) on any error
305 so the recall path can fall through to raw-fact scoring.
306 """
307 card_facts: list[ScoredFact] = []
308 card_entity_ids: set[str] = set()
309 try:
310 candidate_entities: dict[str, list[str]] = {}
311 for fid, record in all_facts_raw.items():
312 candidate_entities.setdefault(record.entity, []).append(fid)
313 for entity_uri, entity_fact_ids in candidate_entities.items():
314 built = _build_card_for_entity(entity_uri, entity_fact_ids, req, identity, conn, now)
315 if built is not None:
316 sf, owned = built
317 card_facts.append(sf)
318 card_entity_ids.update(owned)
319 except Exception as _card_exc:
320 logger.warning("card fast-path error (falling through to raw facts): %s", _card_exc)
321 return [], set()
322 return card_facts, card_entity_ids
325def _exclude_card_owned(
326 card_entity_ids: set[str],
327 all_facts_raw: dict[str, FactRecord],
328 lex_scores: dict[str, float],
329 sem_scores: dict[str, float],
330 graph_hops: dict[str, int],
331) -> tuple[dict[str, FactRecord], dict[str, float], dict[str, float], dict[str, int]]:
332 """Drop any fact_id owned by a card-served entity from all four scoring inputs.
334 No-op when ``card_entity_ids`` is empty.
335 """
336 if not card_entity_ids:
337 return all_facts_raw, lex_scores, sem_scores, graph_hops
338 return (
339 {k: v for k, v in all_facts_raw.items() if k not in card_entity_ids},
340 {k: v for k, v in lex_scores.items() if k not in card_entity_ids},
341 {k: v for k, v in sem_scores.items() if k not in card_entity_ids},
342 {k: v for k, v in graph_hops.items() if k not in card_entity_ids},
343 )
346def _set_recall_span_attrs(
347 span: object,
348 recall_id: str,
349 total_scored: int,
350 tokens_used: int,
351 truncated: bool,
352) -> None:
353 """Best-effort: attach recall outcome attributes to the OTel span."""
354 try:
355 span.set_attribute("stigmem.recall_id", recall_id) # type: ignore[attr-defined]
356 span.set_attribute("stigmem.total_scored", total_scored) # type: ignore[attr-defined]
357 span.set_attribute("stigmem.tokens_used", tokens_used) # type: ignore[attr-defined]
358 span.set_attribute("stigmem.truncated", truncated) # type: ignore[attr-defined]
359 except Exception as exc: # noqa: BLE001 # nosec B110 — span attrs best-effort
360 logger.debug("recall span attribute set failed: %s", exc)
363def _recall_impl(
364 req: RecallRequest,
365 identity: Identity,
366 _span: object,
367 *,
368 session_id: str | None = None,
369 verify_full: bool = False,
370) -> RecallResponse:
371 _validate_recall_request(req, identity)
373 if req.as_of is not None:
374 result = _handle_as_of_recall(req, identity, verify_full=verify_full)
375 with db() as conn:
376 record_read_scopes(
377 conn,
378 identity=identity,
379 session_id=session_id,
380 scopes={scored.fact.scope for scored in result.facts},
381 )
382 return result
384 recall_id = str(uuid.uuid4())
385 query_hash = hashlib.sha256(req.query.encode()).hexdigest()
386 now = _now_iso()
388 logger.info(
389 "recall id=%s entity=%s query_hash=%s scope=%s budget=%d",
390 recall_id,
391 identity.entity_uri,
392 query_hash[:12],
393 req.scope,
394 req.token_budget,
395 )
397 with db() as conn:
398 lex_scores, sem_scores = _gather_direct_matches(conn, req, identity, now)
399 direct_ids = set(lex_scores) | set(sem_scores)
401 graph_hops = _expand_graph_neighbours(
402 conn,
403 req,
404 identity,
405 direct_ids,
406 lex_scores,
407 sem_scores,
408 now,
409 )
410 # Direct matches have hop_distance=0 (mark in graph_hops)
411 for fid in direct_ids:
412 if fid not in graph_hops: 412 ↛ 411line 412 didn't jump to line 411 because the condition on line 412 was always true
413 graph_hops[fid] = 0
415 # --- Fetch all candidate facts ---
416 all_candidate_ids = list(direct_ids | set(graph_hops.keys()))[:_MAX_CANDIDATES]
417 all_facts_raw = _fetch_facts_by_ids(conn, all_candidate_ids)
419 # §23.3.2 r.3: exclude facts whose entity has an active tombstone (about_entity).
420 # Uses in-process cache (§23.3.3 r.4) — no per-fact DB read required.
421 pre_tombstone_count = len(all_facts_raw)
422 all_facts_raw = {
423 k: v
424 for k, v in all_facts_raw.items()
425 if not _is_tombstoned(v.entity, identity.tenant_id)
426 }
427 tombstone_filtered = len(all_facts_raw) < pre_tombstone_count
429 # --- Card fast-path (§20) ---
430 card_facts, card_entity_ids = _try_card_fast_path(
431 all_facts_raw,
432 req,
433 identity,
434 conn,
435 now,
436 )
437 all_facts_raw, lex_scores, sem_scores, graph_hops = _exclude_card_owned(
438 card_entity_ids,
439 all_facts_raw,
440 lex_scores,
441 sem_scores,
442 graph_hops,
443 )
445 # Apply §19 recall pipeline (source-trust multiplier + content sanitiser)
446 all_facts = {r.id: r for r in apply_recall_pipeline(list(all_facts_raw.values()), identity)}
448 # --- Score (timed for ranker histogram) ---
449 with observe_duration(RECALL_RANKER_DURATION, {"tenant": identity.tenant_id}):
450 candidates = _score_candidates(
451 all_facts,
452 lex_scores,
453 sem_scores,
454 graph_hops,
455 req.weights,
456 identity,
457 req.depth,
458 )
459 candidates.extend(card_facts)
460 total_scored = len(candidates)
462 # --- Token-budget packing ---
463 packed, tokens_used, truncated = _greedy_pack(candidates, req.token_budget)
465 # --- Audit ---
466 _write_recall_audit(
467 conn,
468 recall_id,
469 identity,
470 query_hash,
471 req.scope,
472 req.token_budget,
473 len(packed),
474 tokens_used,
475 truncated,
476 )
477 record_read_scopes(
478 conn,
479 identity=identity,
480 session_id=session_id,
481 scopes={scored.fact.scope for scored in packed},
482 )
483 chain_proof = _chain_proof_or_409(conn, identity.tenant_id) if verify_full else None
485 logger.info(
486 "recall id=%s scored=%d packed=%d tokens=%d truncated=%s",
487 recall_id,
488 total_scored,
489 len(packed),
490 tokens_used,
491 truncated,
492 )
494 _set_recall_span_attrs(_span, recall_id, total_scored, tokens_used, truncated)
496 # §23.3.3 r.3: suppress total_scored when tombstone filtering was applied
497 content, instructions = _split_interpretation_channels(packed)
498 return RecallResponse(
499 recall_id=recall_id,
500 query_hash=query_hash,
501 facts=packed,
502 content=content,
503 instructions=instructions,
504 total_scored=None if tombstone_filtered else total_scored,
505 token_budget=req.token_budget,
506 tokens_used=tokens_used,
507 truncated=truncated,
508 chain_proof=chain_proof,
509 )