Coverage for node / src / stigmem_node / routes / recall / orchestration.py: 94%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-25 01:49 +0000

1"""Recall route orchestration.""" 

2 

3from __future__ import annotations 

4 

5import hashlib 

6import uuid 

7from typing import Annotated, Any 

8 

9from fastapi import Depends, Header, HTTPException, Query, Response, status 

10from fastapi.responses import JSONResponse 

11 

12from ...auth import Identity, resolve_identity 

13from ...card_materializer import CARD_MIN_CONFIDENCE, get_fresh_card 

14from ...db import db 

15from ...lifecycle.tombstone_cache import is_tombstoned as _is_tombstoned 

16from ...metrics import FACT_READ, RECALL_RANKER_DURATION, observe_duration 

17from ...models.constants import VALID_SCOPES 

18from ...models.facts import FactRecord, FactValue 

19from ...models.recall import ( 

20 FactChainProof, 

21 RecallRequest, 

22 RecallResponse, 

23 ScoreBreakdown, 

24 ScoredFact, 

25) 

26from ...plugins import get_registry 

27from ...recall.graph import MAX_DEPTH 

28from ...recall.recall_pipeline import apply_recall_pipeline 

29from ...session_graph import record_read_scopes 

30from ...tracing import start_span 

31from ..time_travel_gate import require_time_travel_enabled 

32from .as_of import _recall_as_of_impl 

33from .common import ( 

34 _estimate_tokens, 

35 _fetch_facts_by_ids, 

36 _now_iso, 

37 _write_recall_audit, 

38 logger, 

39 router, 

40) 

41from .graph import _MAX_SEED_ENTITIES, _graph_expand 

42from .lexical import _lexical_search 

43from .ranking import _greedy_pack, _score_candidates 

44from .vector import _semantic_search 

45 

46_MAX_CANDIDATES = 500 

47@router.post("", response_model=RecallResponse) 

48def recall( 

49 req: RecallRequest, 

50 identity: Annotated[Identity, Depends(resolve_identity)], 

51 response: Response, 

52 session_id: Annotated[str | None, Header(alias="Stigmem-Session")] = None, 

53 verify_mode: Annotated[str | None, Header(alias="Stigmem-Verify")] = None, 

54 legacy_format: Annotated[ 

55 bool, 

56 Query( 

57 description=( 

58 "Return the temporary legacy recall response shape without " 

59 "`content` / `instructions` channel fields." 

60 ) 

61 ), 

62 ] = False, 

63) -> RecallResponse | JSONResponse: 

64 """Hybrid recall — return the most salient facts for a query, within budget. 

65 

66 Combines lexical (FTS5/BM25), dense-vector, and graph-traversal signals. 

67 Honors Spec-02-Scopes-and-ACL and Spec-05-Federation-Trust at every step. 

68 """ 

69 with start_span( 

70 "stigmem.recall", 

71 **{ 

72 "stigmem.tenant": identity.tenant_id, 

73 "stigmem.principal": identity.entity_uri, 

74 "stigmem.scope": req.scope, 

75 }, 

76 ) as _span: 

77 result = _recall_impl( 

78 req, 

79 identity, 

80 _span, 

81 session_id=session_id, 

82 verify_full=verify_mode == "full", 

83 ) 

84 headers = {} 

85 if result.total_scored is not None: 

86 headers["X-Total-Count"] = str(result.total_scored) 

87 response.headers["X-Total-Count"] = headers["X-Total-Count"] 

88 if legacy_format: 

89 return JSONResponse(content=_legacy_recall_payload(result), headers=headers) 

90 return result 

91 

92 

93def _legacy_recall_payload(result: RecallResponse) -> dict[str, Any]: 

94 """Return the one-minor-version compatibility shape for pre-channel clients.""" 

95 return result.model_dump(mode="json", exclude={"content", "instructions"}) 

96 

97 

98def _split_interpretation_channels( 

99 packed: list[ScoredFact], 

100) -> tuple[list[ScoredFact], list[ScoredFact]]: 

101 content = [scored for scored in packed if scored.fact.value.interpret_as != "instruction"] 

102 instructions = [scored for scored in packed if scored.fact.value.interpret_as == "instruction"] 

103 return content, instructions 

104 

105 

106def _validate_recall_request(req: RecallRequest, identity: Identity) -> None: 

107 """Auth + scope + depth validation. Raises HTTPException on failure.""" 

108 if not identity.can_read(): 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 raise HTTPException( 

110 status_code=status.HTTP_403_FORBIDDEN, detail="read permission required" 

111 ) 

112 FACT_READ.labels(principal=identity.entity_uri, tenant=identity.tenant_id).inc() 

113 if req.scope not in VALID_SCOPES: 

114 raise HTTPException( 

115 status_code=status.HTTP_400_BAD_REQUEST, 

116 detail=f"invalid_scope: must be one of {sorted(VALID_SCOPES)}", 

117 ) 

118 if req.depth > MAX_DEPTH: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 raise HTTPException( 

120 status_code=status.HTTP_400_BAD_REQUEST, 

121 detail={"code": "graph_depth_exceeded", "message": f"depth must be ≤ {MAX_DEPTH}"}, 

122 ) 

123 

124 

125def _chain_proof_or_409(conn: Any, tenant_id: str) -> FactChainProof: 

126 from ...fact_chain import FactChainIntegrityError, build_fact_chain_proof 

127 

128 try: 

129 return FactChainProof(**build_fact_chain_proof(conn, tenant_id=tenant_id)) 

130 except FactChainIntegrityError as exc: 

131 result = exc.result 

132 raise HTTPException( 

133 status_code=status.HTTP_409_CONFLICT, 

134 detail={ 

135 "code": "fact_chain_mismatch", 

136 "message": "local fact hash chain verification failed", 

137 "mismatch_reason": result.mismatch_reason, 

138 "fact_id": result.fact_id, 

139 "chain_seq": result.chain_seq, 

140 }, 

141 ) from exc 

142 

143 

144def _handle_as_of_recall( 

145 req: RecallRequest, 

146 identity: Identity, 

147 *, 

148 verify_full: bool = False, 

149) -> RecallResponse: 

150 """§24 time-travel path. Validates as_of, runs the as_of impl, returns the response.""" 

151 from ..facts import _validate_as_of 

152 

153 require_time_travel_enabled(get_registry(), surface="recall") 

154 _validate_as_of(req.as_of) # type: ignore[arg-type] # caller guarantees as_of is not None 

155 recall_id = str(uuid.uuid4()) 

156 query_hash = hashlib.sha256(req.query.encode()).hexdigest() 

157 with db() as conn: 

158 packed, notices, tombstone_filtered = _recall_as_of_impl( 

159 conn, 

160 query=req.query, 

161 scope=req.scope, 

162 as_of=req.as_of, # type: ignore[arg-type] 

163 is_admin_caller=identity.is_admin(), 

164 tenant_id=identity.tenant_id, 

165 max_chunks=req.limit, 

166 include_graph=req.include_neighbors, 

167 identity=identity, 

168 weights=req.weights, 

169 depth=req.depth, 

170 ) 

171 chain_proof = _chain_proof_or_409(conn, identity.tenant_id) if verify_full else None 

172 tokens_used = sum(sf.token_estimate for sf in packed) 

173 content, instructions = _split_interpretation_channels(packed) 

174 return RecallResponse( 

175 recall_id=recall_id, 

176 query_hash=query_hash, 

177 facts=packed, 

178 content=content, 

179 instructions=instructions, 

180 # §23.3.3 r.3: suppress total_scored when tombstone filtering was applied 

181 total_scored=None if tombstone_filtered else len(packed), 

182 token_budget=req.token_budget, 

183 tokens_used=tokens_used, 

184 truncated=False, 

185 tombstone_notices=notices, 

186 chain_proof=chain_proof, 

187 ) 

188 

189 

190def _gather_direct_matches( 

191 conn: Any, 

192 req: RecallRequest, 

193 identity: Identity, 

194 now: str, 

195) -> tuple[dict[str, float], dict[str, float]]: 

196 """Run the lexical + semantic searches that produce direct-match scores.""" 

197 lex_scores = ( 

198 _lexical_search( 

199 conn, 

200 req.query, 

201 req.scope, 

202 identity.tenant_id, 

203 req.limit, 

204 req.min_confidence, 

205 now, 

206 ) 

207 if req.weights.lexical > 0 

208 else {} 

209 ) 

210 sem_scores = ( 

211 _semantic_search(conn, req.query, req.scope, identity.tenant_id, req.limit) 

212 if req.weights.semantic > 0 

213 else {} 

214 ) 

215 return lex_scores, sem_scores 

216 

217 

218def _expand_graph_neighbours( 

219 conn: Any, 

220 req: RecallRequest, 

221 identity: Identity, 

222 direct_ids: set[str], 

223 lex_scores: dict[str, float], 

224 sem_scores: dict[str, float], 

225 now: str, 

226) -> dict[str, int]: 

227 """Optionally BFS from the top direct-match seeds. Returns {fact_id: hop_distance}.""" 

228 if not (req.include_neighbors and req.weights.graph > 0 and direct_ids): 

229 return {} 

230 top_seeds = sorted( 

231 direct_ids, 

232 key=lambda fid: lex_scores.get(fid, 0) + sem_scores.get(fid, 0), 

233 reverse=True, 

234 )[:_MAX_SEED_ENTITIES] 

235 return _graph_expand( 

236 conn, 

237 top_seeds, 

238 req.depth, 

239 req.scope, 

240 identity.tenant_id, 

241 identity, 

242 req.limit, 

243 req.min_confidence, 

244 now, 

245 ) 

246 

247 

248def _build_card_for_entity( 

249 entity_uri: str, 

250 entity_fact_ids: list[str], 

251 req: RecallRequest, 

252 identity: Identity, 

253 conn: Any, 

254 now: str, 

255) -> tuple[ScoredFact, list[str]] | None: 

256 """Try to build a synthetic ScoredFact from a fresh, high-confidence card. 

257 

258 Returns (scored_fact, owned_fact_ids) on success, None when no card qualifies. 

259 """ 

260 if _is_tombstoned(entity_uri, identity.tenant_id): 260 ↛ 262line 260 didn't jump to line 262 because the condition on line 260 was never true

261 # §23.3.2 r.3: cards whose about_entity is tombstoned are fully excluded. 

262 return None 

263 card = get_fresh_card(entity_uri, req.scope, identity.tenant_id, conn) 

264 if ( 

265 card is None 

266 or card.is_stale 

267 or card.has_contradictions 

268 or card.avg_confidence < CARD_MIN_CONFIDENCE 

269 ): 

270 return None 

271 card_record = FactRecord( 

272 id=f"card:{entity_uri}", 

273 entity=entity_uri, 

274 relation="stigmem:card:summary", 

275 value=FactValue(type="text", v=card.summary), 

276 source="system:stigmem:materializer", 

277 timestamp=card.refreshed_at or now, 

278 confidence=card.avg_confidence, 

279 scope=req.scope, 

280 ) 

281 sf = ScoredFact( 

282 fact=card_record, 

283 score=round(card.avg_confidence, 6), 

284 score_breakdown=ScoreBreakdown( 

285 source_trust=round(card.avg_confidence, 4), 

286 weighted_total=round(card.avg_confidence, 6), 

287 ), 

288 hop_distance=0, 

289 token_estimate=_estimate_tokens(card_record), 

290 from_card=True, 

291 ) 

292 return sf, list(entity_fact_ids) 

293 

294 

295def _try_card_fast_path( 

296 all_facts_raw: dict[str, FactRecord], 

297 req: RecallRequest, 

298 identity: Identity, 

299 conn: Any, 

300 now: str, 

301) -> tuple[list[ScoredFact], set[str]]: 

302 """Build card-derived ScoredFacts for any candidate entity that has a fresh card. 

303 

304 Returns (card_facts, card_owned_fact_ids). Logs and returns ([], set()) on any error 

305 so the recall path can fall through to raw-fact scoring. 

306 """ 

307 card_facts: list[ScoredFact] = [] 

308 card_entity_ids: set[str] = set() 

309 try: 

310 candidate_entities: dict[str, list[str]] = {} 

311 for fid, record in all_facts_raw.items(): 

312 candidate_entities.setdefault(record.entity, []).append(fid) 

313 for entity_uri, entity_fact_ids in candidate_entities.items(): 

314 built = _build_card_for_entity(entity_uri, entity_fact_ids, req, identity, conn, now) 

315 if built is not None: 

316 sf, owned = built 

317 card_facts.append(sf) 

318 card_entity_ids.update(owned) 

319 except Exception as _card_exc: 

320 logger.warning("card fast-path error (falling through to raw facts): %s", _card_exc) 

321 return [], set() 

322 return card_facts, card_entity_ids 

323 

324 

325def _exclude_card_owned( 

326 card_entity_ids: set[str], 

327 all_facts_raw: dict[str, FactRecord], 

328 lex_scores: dict[str, float], 

329 sem_scores: dict[str, float], 

330 graph_hops: dict[str, int], 

331) -> tuple[dict[str, FactRecord], dict[str, float], dict[str, float], dict[str, int]]: 

332 """Drop any fact_id owned by a card-served entity from all four scoring inputs. 

333 

334 No-op when ``card_entity_ids`` is empty. 

335 """ 

336 if not card_entity_ids: 

337 return all_facts_raw, lex_scores, sem_scores, graph_hops 

338 return ( 

339 {k: v for k, v in all_facts_raw.items() if k not in card_entity_ids}, 

340 {k: v for k, v in lex_scores.items() if k not in card_entity_ids}, 

341 {k: v for k, v in sem_scores.items() if k not in card_entity_ids}, 

342 {k: v for k, v in graph_hops.items() if k not in card_entity_ids}, 

343 ) 

344 

345 

346def _set_recall_span_attrs( 

347 span: object, 

348 recall_id: str, 

349 total_scored: int, 

350 tokens_used: int, 

351 truncated: bool, 

352) -> None: 

353 """Best-effort: attach recall outcome attributes to the OTel span.""" 

354 try: 

355 span.set_attribute("stigmem.recall_id", recall_id) # type: ignore[attr-defined] 

356 span.set_attribute("stigmem.total_scored", total_scored) # type: ignore[attr-defined] 

357 span.set_attribute("stigmem.tokens_used", tokens_used) # type: ignore[attr-defined] 

358 span.set_attribute("stigmem.truncated", truncated) # type: ignore[attr-defined] 

359 except Exception as exc: # noqa: BLE001 # nosec B110 — span attrs best-effort 

360 logger.debug("recall span attribute set failed: %s", exc) 

361 

362 

363def _recall_impl( 

364 req: RecallRequest, 

365 identity: Identity, 

366 _span: object, 

367 *, 

368 session_id: str | None = None, 

369 verify_full: bool = False, 

370) -> RecallResponse: 

371 _validate_recall_request(req, identity) 

372 

373 if req.as_of is not None: 

374 result = _handle_as_of_recall(req, identity, verify_full=verify_full) 

375 with db() as conn: 

376 record_read_scopes( 

377 conn, 

378 identity=identity, 

379 session_id=session_id, 

380 scopes={scored.fact.scope for scored in result.facts}, 

381 ) 

382 return result 

383 

384 recall_id = str(uuid.uuid4()) 

385 query_hash = hashlib.sha256(req.query.encode()).hexdigest() 

386 now = _now_iso() 

387 

388 logger.info( 

389 "recall id=%s entity=%s query_hash=%s scope=%s budget=%d", 

390 recall_id, 

391 identity.entity_uri, 

392 query_hash[:12], 

393 req.scope, 

394 req.token_budget, 

395 ) 

396 

397 with db() as conn: 

398 lex_scores, sem_scores = _gather_direct_matches(conn, req, identity, now) 

399 direct_ids = set(lex_scores) | set(sem_scores) 

400 

401 graph_hops = _expand_graph_neighbours( 

402 conn, 

403 req, 

404 identity, 

405 direct_ids, 

406 lex_scores, 

407 sem_scores, 

408 now, 

409 ) 

410 # Direct matches have hop_distance=0 (mark in graph_hops) 

411 for fid in direct_ids: 

412 if fid not in graph_hops: 412 ↛ 411line 412 didn't jump to line 411 because the condition on line 412 was always true

413 graph_hops[fid] = 0 

414 

415 # --- Fetch all candidate facts --- 

416 all_candidate_ids = list(direct_ids | set(graph_hops.keys()))[:_MAX_CANDIDATES] 

417 all_facts_raw = _fetch_facts_by_ids(conn, all_candidate_ids) 

418 

419 # §23.3.2 r.3: exclude facts whose entity has an active tombstone (about_entity). 

420 # Uses in-process cache (§23.3.3 r.4) — no per-fact DB read required. 

421 pre_tombstone_count = len(all_facts_raw) 

422 all_facts_raw = { 

423 k: v 

424 for k, v in all_facts_raw.items() 

425 if not _is_tombstoned(v.entity, identity.tenant_id) 

426 } 

427 tombstone_filtered = len(all_facts_raw) < pre_tombstone_count 

428 

429 # --- Card fast-path (§20) --- 

430 card_facts, card_entity_ids = _try_card_fast_path( 

431 all_facts_raw, 

432 req, 

433 identity, 

434 conn, 

435 now, 

436 ) 

437 all_facts_raw, lex_scores, sem_scores, graph_hops = _exclude_card_owned( 

438 card_entity_ids, 

439 all_facts_raw, 

440 lex_scores, 

441 sem_scores, 

442 graph_hops, 

443 ) 

444 

445 # Apply §19 recall pipeline (source-trust multiplier + content sanitiser) 

446 all_facts = {r.id: r for r in apply_recall_pipeline(list(all_facts_raw.values()), identity)} 

447 

448 # --- Score (timed for ranker histogram) --- 

449 with observe_duration(RECALL_RANKER_DURATION, {"tenant": identity.tenant_id}): 

450 candidates = _score_candidates( 

451 all_facts, 

452 lex_scores, 

453 sem_scores, 

454 graph_hops, 

455 req.weights, 

456 identity, 

457 req.depth, 

458 ) 

459 candidates.extend(card_facts) 

460 total_scored = len(candidates) 

461 

462 # --- Token-budget packing --- 

463 packed, tokens_used, truncated = _greedy_pack(candidates, req.token_budget) 

464 

465 # --- Audit --- 

466 _write_recall_audit( 

467 conn, 

468 recall_id, 

469 identity, 

470 query_hash, 

471 req.scope, 

472 req.token_budget, 

473 len(packed), 

474 tokens_used, 

475 truncated, 

476 ) 

477 record_read_scopes( 

478 conn, 

479 identity=identity, 

480 session_id=session_id, 

481 scopes={scored.fact.scope for scored in packed}, 

482 ) 

483 chain_proof = _chain_proof_or_409(conn, identity.tenant_id) if verify_full else None 

484 

485 logger.info( 

486 "recall id=%s scored=%d packed=%d tokens=%d truncated=%s", 

487 recall_id, 

488 total_scored, 

489 len(packed), 

490 tokens_used, 

491 truncated, 

492 ) 

493 

494 _set_recall_span_attrs(_span, recall_id, total_scored, tokens_used, truncated) 

495 

496 # §23.3.3 r.3: suppress total_scored when tombstone filtering was applied 

497 content, instructions = _split_interpretation_channels(packed) 

498 return RecallResponse( 

499 recall_id=recall_id, 

500 query_hash=query_hash, 

501 facts=packed, 

502 content=content, 

503 instructions=instructions, 

504 total_scored=None if tombstone_filtered else total_scored, 

505 token_budget=req.token_budget, 

506 tokens_used=tokens_used, 

507 truncated=truncated, 

508 chain_proof=chain_proof, 

509 )