Coverage for sdks / stigmem-py / src / stigmem / client.py: 59%

231 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-25 01:49 +0000

1"""Stigmem Python client SDK — spec v0.4/v0.5.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from collections.abc import AsyncGenerator, Callable 

8from typing import Any 

9 

10import httpx 

11 

12from .exceptions import ( 

13 StigmemAuthError, 

14 StigmemConflictError, 

15 StigmemHTTPError, 

16 StigmemNotFoundError, 

17) 

18from .models import ( 

19 AssertRequest, 

20 ConflictPage, 

21 ConflictResolution, 

22 Fact, 

23 FactPage, 

24 FactScope, 

25 FactValue, 

26 MemoryCard, 

27 NodeInfo, 

28 Peer, 

29 PeerPage, 

30 RecallRequest, 

31 RecallResponse, 

32 RecallWeights, 

33 ResolveRequest, 

34) 

35 

36logger = logging.getLogger("stigmem") 

37SESSION_HEADER = "Stigmem-Session" 

38VERIFY_HEADER = "Stigmem-Verify" 

39 

40 

41def _recall_headers(session_id: str | None, verify_full: bool = False) -> dict[str, str] | None: 

42 headers = _session_headers(session_id) or {} 

43 if verify_full: 

44 headers[VERIFY_HEADER] = "full" 

45 return headers or None 

46 

47 

48def _session_headers(session_id: str | None) -> dict[str, str] | None: 

49 if session_id is None: 

50 return None 

51 normalized = session_id.strip() 

52 if not normalized: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 return None 

54 return {SESSION_HEADER: normalized} 

55 

56 

57def _raise_for_status(resp: httpx.Response) -> None: 

58 if resp.is_success: 

59 return 

60 try: 

61 detail = resp.json().get("detail", resp.text) 

62 except ValueError: 

63 detail = resp.text 

64 if resp.status_code in (401, 403): 

65 raise StigmemAuthError(resp.status_code, detail) 

66 if resp.status_code == 404: 66 ↛ 68line 66 didn't jump to line 68 because the condition on line 66 was always true

67 raise StigmemNotFoundError(resp.status_code, detail) 

68 if resp.status_code == 409: 

69 raise StigmemConflictError(resp.status_code, detail) 

70 raise StigmemHTTPError(resp.status_code, detail) 

71 

72 

73class StigmemClient: 

74 """Synchronous Stigmem client. 

75 

76 Usage:: 

77 

78 client = StigmemClient(url="http://localhost:8765", api_key="sk-...") 

79 fact = client.assert_fact( 

80 entity="user:alice", 

81 relation="memory:role", 

82 value=string_value("CEO"), 

83 source="agent:cto", 

84 ) 

85 """ 

86 

87 def __init__( 

88 self, 

89 url: str, 

90 api_key: str | None = None, 

91 timeout: float = 10.0, 

92 ) -> None: 

93 self._url = url.rstrip("/") 

94 headers: dict[str, str] = {"Accept": "application/json"} 

95 if api_key: 95 ↛ 97line 95 didn't jump to line 97 because the condition on line 95 was always true

96 headers["Authorization"] = f"Bearer {api_key}" 

97 self._http = httpx.Client(base_url=self._url, headers=headers, timeout=timeout) 

98 

99 def close(self) -> None: 

100 self._http.close() 

101 

102 def __enter__(self) -> StigmemClient: 

103 return self 

104 

105 def __exit__(self, *args: Any) -> None: 

106 self.close() 

107 

108 # ------------------------------------------------------------------ 

109 # Node / metadata 

110 # ------------------------------------------------------------------ 

111 

112 def node_info(self) -> NodeInfo: 

113 resp = self._http.get("/.well-known/stigmem") 

114 _raise_for_status(resp) 

115 return NodeInfo.model_validate(resp.json()) 

116 

117 # ------------------------------------------------------------------ 

118 # Facts 

119 # ------------------------------------------------------------------ 

120 

121 def assert_fact( 

122 self, 

123 entity: str, 

124 relation: str, 

125 value: FactValue, 

126 source: str, 

127 *, 

128 confidence: float = 1.0, 

129 scope: FactScope = "company", 

130 valid_until: str | None = None, 

131 write_mode: str = "assert", 

132 derived_from: list[dict[str, Any]] | None = None, 

133 session_id: str | None = None, 

134 ) -> Fact: 

135 req = AssertRequest( 

136 entity=entity, 

137 relation=relation, 

138 value=value, 

139 source=source, 

140 confidence=confidence, 

141 scope=scope, 

142 valid_until=valid_until, 

143 write_mode=write_mode, 

144 derived_from=derived_from or [], 

145 ) 

146 body = req.model_dump(exclude_none=True) 

147 body["value"] = value.model_dump() 

148 resp = self._http.post("/v1/facts", json=body, headers=_session_headers(session_id)) 

149 _raise_for_status(resp) 

150 return Fact.model_validate(resp.json()) 

151 

152 def retract( 

153 self, 

154 entity: str, 

155 relation: str, 

156 scope: FactScope, 

157 source: str, 

158 *, 

159 value: FactValue | None = None, 

160 ) -> Fact: 

161 """Assert a retraction (confidence=0.0) for the given triple.""" 

162 from .models import string_value as _sv 

163 

164 retract_value = value if value is not None else _sv("retracted") 

165 return self.assert_fact( 

166 entity=entity, 

167 relation=relation, 

168 value=retract_value, 

169 source=source, 

170 confidence=0.0, 

171 scope=scope, 

172 ) 

173 

174 def get(self, fact_id: str, *, session_id: str | None = None) -> Fact: 

175 resp = self._http.get(f"/v1/facts/{fact_id}", headers=_session_headers(session_id)) 

176 _raise_for_status(resp) 

177 return Fact.model_validate(resp.json()) 

178 

179 def query( 

180 self, 

181 *, 

182 entity: str | None = None, 

183 relation: str | None = None, 

184 source: str | None = None, 

185 scope: FactScope | None = None, 

186 min_confidence: float | None = None, 

187 include_contradicted: bool = False, 

188 include_expired: bool = False, 

189 cursor: str | None = None, 

190 limit: int = 50, 

191 after: str | None = None, 

192 session_id: str | None = None, 

193 ) -> FactPage: 

194 params: dict[str, Any] = {"limit": limit} 

195 if entity: 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true

196 params["entity"] = entity 

197 if relation: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true

198 params["relation"] = relation 

199 if source: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 params["source"] = source 

201 if scope: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 params["scope"] = scope 

203 if min_confidence is not None: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 params["min_confidence"] = min_confidence 

205 if include_contradicted: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 params["include_contradicted"] = "true" 

207 if include_expired: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 params["include_expired"] = "true" 

209 if cursor: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 params["cursor"] = cursor 

211 if after: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 params["after"] = after 

213 resp = self._http.get("/v1/facts", params=params, headers=_session_headers(session_id)) 

214 _raise_for_status(resp) 

215 return FactPage.model_validate(resp.json()) 

216 

217 # ------------------------------------------------------------------ 

218 # Conflicts 

219 # ------------------------------------------------------------------ 

220 

221 def list_conflicts( 

222 self, 

223 *, 

224 status: str | None = "unresolved", 

225 cursor: str | None = None, 

226 limit: int = 50, 

227 ) -> ConflictPage: 

228 params: dict[str, Any] = {"limit": limit} 

229 if status: 229 ↛ 231line 229 didn't jump to line 231 because the condition on line 229 was always true

230 params["status"] = status 

231 if cursor: 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true

232 params["cursor"] = cursor 

233 resp = self._http.get("/v1/conflicts", params=params) 

234 _raise_for_status(resp) 

235 return ConflictPage.model_validate(resp.json()) 

236 

237 def resolve_conflict( 

238 self, 

239 conflict_id: str, 

240 *, 

241 winning_fact_id: str | None = None, 

242 resolution_note: str = "", 

243 new_value: FactValue | None = None, 

244 ) -> ConflictResolution: 

245 req = ResolveRequest( 

246 winning_fact_id=winning_fact_id, 

247 resolution_note=resolution_note, 

248 new_value=new_value, 

249 ) 

250 resp = self._http.post(f"/v1/conflicts/{conflict_id}/resolve", json=req.model_dump_api()) 

251 _raise_for_status(resp) 

252 return ConflictResolution.model_validate(resp.json()) 

253 

254 # ------------------------------------------------------------------ 

255 # Federation 

256 # ------------------------------------------------------------------ 

257 

258 def federation_status(self) -> list[Peer]: 

259 resp = self._http.get("/v1/federation/peers") 

260 _raise_for_status(resp) 

261 return PeerPage.model_validate(resp.json()).peers 

262 

263 # ------------------------------------------------------------------ 

264 # Subscribe (polling) 

265 # ------------------------------------------------------------------ 

266 

267 def subscribe_scope( 

268 self, 

269 scope: FactScope, 

270 callback: Callable[[list[Fact]], None], 

271 *, 

272 interval_s: float = 30.0, 

273 stop_event: asyncio.Event | None = None, 

274 ) -> None: 

275 """Poll for new facts in *scope* and call *callback* with each batch. 

276 

277 Blocking — runs until *stop_event* is set or KeyboardInterrupt. 

278 For async use, see AsyncStigmemClient.subscribe_scope(). 

279 """ 

280 import time 

281 

282 cursor: str | None = None 

283 while True: 

284 if stop_event and stop_event.is_set(): 

285 break 

286 page = self.query(scope=scope, cursor=cursor, limit=100) 

287 if page.facts: 

288 callback(page.facts) 

289 cursor = page.cursor 

290 time.sleep(interval_s) 

291 

292 # ------------------------------------------------------------------ 

293 # Recall (Phase 9 — spec §20) 

294 # ------------------------------------------------------------------ 

295 

296 def recall( 

297 self, 

298 query: str, 

299 *, 

300 scope: FactScope = "local", 

301 token_budget: int = 4000, 

302 depth: int = 2, 

303 weights: RecallWeights | None = None, 

304 min_confidence: float = 0.1, 

305 include_neighbors: bool = True, 

306 limit: int = 100, 

307 legacy_format: bool = False, 

308 session_id: str | None = None, 

309 verify_full: bool = False, 

310 ) -> RecallResponse: 

311 """Hybrid recall — return the most salient facts for *query* within *token_budget*. 

312 

313 Combines lexical (BM25/FTS5), dense-vector, and graph-traversal signals. 

314 

315 Args: 

316 query: Natural-language or keyword query. 

317 scope: Fact scope to search within. 

318 token_budget: Maximum token budget for the response. 

319 depth: Graph traversal depth (1–3). 

320 weights: Signal weights; defaults applied server-side when None. 

321 min_confidence: Minimum fact confidence to include. 

322 include_neighbors: Whether to expand via graph traversal. 

323 limit: Maximum candidate facts before token-budget packing. 

324 legacy_format: Request the temporary pre-channel response shape. 

325 verify_full: Request full server-side integrity proof metadata. 

326 

327 Returns: 

328 RecallResponse with scored + packed facts and score breakdowns. 

329 """ 

330 req = RecallRequest( 

331 query=query, 

332 scope=scope, 

333 token_budget=token_budget, 

334 depth=depth, 

335 weights=weights or RecallWeights(), 

336 min_confidence=min_confidence, 

337 include_neighbors=include_neighbors, 

338 limit=limit, 

339 ) 

340 params = {"legacy_format": "true"} if legacy_format else None 

341 resp = self._http.post( 

342 "/v1/recall", 

343 json=req.model_dump(), 

344 params=params, 

345 headers=_recall_headers(session_id, verify_full), 

346 ) 

347 _raise_for_status(resp) 

348 return RecallResponse.model_validate(resp.json()) 

349 

350 # ------------------------------------------------------------------ 

351 # Memory cards (Phase 9 — spec §20) 

352 # ------------------------------------------------------------------ 

353 

354 def get_card( 

355 self, 

356 entity_uri: str, 

357 *, 

358 scope: FactScope = "local", 

359 refresh: bool = False, 

360 ) -> MemoryCard: 

361 """Fetch the synthesized memory card for *entity_uri*. 

362 

363 Args: 

364 entity_uri: The entity to fetch the card for. 

365 scope: Fact scope the card was materialised from. 

366 refresh: Force a server-side refresh even if the card is fresh. 

367 

368 Returns: 

369 MemoryCard with summary, contributing fact hashes, and confidence. 

370 

371 Raises: 

372 StigmemNotFoundError: When the entity has no live facts. 

373 """ 

374 params: dict[str, Any] = {"scope": scope} 

375 if refresh: 

376 params["refresh"] = "true" 

377 resp = self._http.get(f"/v1/cards/{entity_uri}", params=params) 

378 _raise_for_status(resp) 

379 return MemoryCard.model_validate(resp.json()) 

380 

381 

382class AsyncStigmemClient: 

383 """Async Stigmem client (httpx.AsyncClient).""" 

384 

385 def __init__( 

386 self, 

387 url: str, 

388 api_key: str | None = None, 

389 timeout: float = 10.0, 

390 ) -> None: 

391 self._url = url.rstrip("/") 

392 headers: dict[str, str] = {"Accept": "application/json"} 

393 if api_key: 393 ↛ 395line 393 didn't jump to line 395 because the condition on line 393 was always true

394 headers["Authorization"] = f"Bearer {api_key}" 

395 self._http = httpx.AsyncClient(base_url=self._url, headers=headers, timeout=timeout) 

396 

397 async def aclose(self) -> None: 

398 await self._http.aclose() 

399 

400 async def __aenter__(self) -> AsyncStigmemClient: 

401 return self 

402 

403 async def __aexit__(self, *args: Any) -> None: 

404 await self.aclose() 

405 

406 async def node_info(self) -> NodeInfo: 

407 resp = await self._http.get("/.well-known/stigmem") 

408 _raise_for_status(resp) 

409 return NodeInfo.model_validate(resp.json()) 

410 

411 async def assert_fact( 

412 self, 

413 entity: str, 

414 relation: str, 

415 value: FactValue, 

416 source: str, 

417 *, 

418 confidence: float = 1.0, 

419 scope: FactScope = "company", 

420 valid_until: str | None = None, 

421 write_mode: str = "assert", 

422 derived_from: list[dict[str, Any]] | None = None, 

423 session_id: str | None = None, 

424 ) -> Fact: 

425 req = AssertRequest( 

426 entity=entity, 

427 relation=relation, 

428 value=value, 

429 source=source, 

430 confidence=confidence, 

431 scope=scope, 

432 valid_until=valid_until, 

433 write_mode=write_mode, 

434 derived_from=derived_from or [], 

435 ) 

436 body = req.model_dump(exclude_none=True) 

437 body["value"] = value.model_dump() 

438 resp = await self._http.post( 

439 "/v1/facts", json=body, headers=_session_headers(session_id) 

440 ) 

441 _raise_for_status(resp) 

442 return Fact.model_validate(resp.json()) 

443 

444 async def retract( 

445 self, 

446 entity: str, 

447 relation: str, 

448 scope: FactScope, 

449 source: str, 

450 *, 

451 value: FactValue | None = None, 

452 ) -> Fact: 

453 from .models import string_value as _sv 

454 

455 retract_value = value if value is not None else _sv("retracted") 

456 return await self.assert_fact( 

457 entity=entity, 

458 relation=relation, 

459 value=retract_value, 

460 source=source, 

461 confidence=0.0, 

462 scope=scope, 

463 ) 

464 

465 async def get(self, fact_id: str, *, session_id: str | None = None) -> Fact: 

466 resp = await self._http.get( 

467 f"/v1/facts/{fact_id}", headers=_session_headers(session_id) 

468 ) 

469 _raise_for_status(resp) 

470 return Fact.model_validate(resp.json()) 

471 

472 async def query( 

473 self, 

474 *, 

475 entity: str | None = None, 

476 relation: str | None = None, 

477 source: str | None = None, 

478 scope: FactScope | None = None, 

479 min_confidence: float | None = None, 

480 include_contradicted: bool = False, 

481 include_expired: bool = False, 

482 cursor: str | None = None, 

483 limit: int = 50, 

484 after: str | None = None, 

485 session_id: str | None = None, 

486 ) -> FactPage: 

487 params: dict[str, Any] = {"limit": limit} 

488 if entity: 488 ↛ 490line 488 didn't jump to line 490 because the condition on line 488 was always true

489 params["entity"] = entity 

490 if relation: 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true

491 params["relation"] = relation 

492 if source: 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true

493 params["source"] = source 

494 if scope: 494 ↛ 495line 494 didn't jump to line 495 because the condition on line 494 was never true

495 params["scope"] = scope 

496 if min_confidence is not None: 496 ↛ 497line 496 didn't jump to line 497 because the condition on line 496 was never true

497 params["min_confidence"] = min_confidence 

498 if include_contradicted: 498 ↛ 499line 498 didn't jump to line 499 because the condition on line 498 was never true

499 params["include_contradicted"] = "true" 

500 if include_expired: 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true

501 params["include_expired"] = "true" 

502 if cursor: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true

503 params["cursor"] = cursor 

504 if after: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true

505 params["after"] = after 

506 resp = await self._http.get( 

507 "/v1/facts", params=params, headers=_session_headers(session_id) 

508 ) 

509 _raise_for_status(resp) 

510 return FactPage.model_validate(resp.json()) 

511 

512 async def list_conflicts( 

513 self, 

514 *, 

515 status: str | None = "unresolved", 

516 cursor: str | None = None, 

517 limit: int = 50, 

518 ) -> ConflictPage: 

519 params: dict[str, Any] = {"limit": limit} 

520 if status: 

521 params["status"] = status 

522 if cursor: 

523 params["cursor"] = cursor 

524 resp = await self._http.get("/v1/conflicts", params=params) 

525 _raise_for_status(resp) 

526 return ConflictPage.model_validate(resp.json()) 

527 

528 async def resolve_conflict( 

529 self, 

530 conflict_id: str, 

531 *, 

532 winning_fact_id: str | None = None, 

533 resolution_note: str = "", 

534 new_value: FactValue | None = None, 

535 ) -> ConflictResolution: 

536 req = ResolveRequest( 

537 winning_fact_id=winning_fact_id, 

538 resolution_note=resolution_note, 

539 new_value=new_value, 

540 ) 

541 resp = await self._http.post( 

542 f"/v1/conflicts/{conflict_id}/resolve", 

543 json=req.model_dump_api(), 

544 ) 

545 _raise_for_status(resp) 

546 return ConflictResolution.model_validate(resp.json()) 

547 

548 async def federation_status(self) -> list[Peer]: 

549 resp = await self._http.get("/v1/federation/peers") 

550 _raise_for_status(resp) 

551 return PeerPage.model_validate(resp.json()).peers 

552 

553 async def subscribe_scope( 

554 self, 

555 scope: FactScope, 

556 callback: Callable[[list[Fact]], None], 

557 *, 

558 interval_s: float = 30.0, 

559 stop_event: asyncio.Event | None = None, 

560 ) -> AsyncGenerator[list[Fact], None]: 

561 """Async generator that yields batches of new facts in *scope*.""" 

562 cursor: str | None = None 

563 while True: 

564 if stop_event and stop_event.is_set(): 

565 return 

566 page = await self.query(scope=scope, cursor=cursor, limit=100) 

567 if page.facts: 

568 callback(page.facts) 

569 yield page.facts 

570 cursor = page.cursor 

571 await asyncio.sleep(interval_s) 

572 

573 # ------------------------------------------------------------------ 

574 # Recall (Phase 9 — spec §20) 

575 # ------------------------------------------------------------------ 

576 

577 async def recall( 

578 self, 

579 query: str, 

580 *, 

581 scope: FactScope = "local", 

582 token_budget: int = 4000, 

583 depth: int = 2, 

584 weights: RecallWeights | None = None, 

585 min_confidence: float = 0.1, 

586 include_neighbors: bool = True, 

587 limit: int = 100, 

588 legacy_format: bool = False, 

589 session_id: str | None = None, 

590 verify_full: bool = False, 

591 ) -> RecallResponse: 

592 """Async hybrid recall — return the most salient facts for *query* within *token_budget*.""" 

593 req = RecallRequest( 

594 query=query, 

595 scope=scope, 

596 token_budget=token_budget, 

597 depth=depth, 

598 weights=weights or RecallWeights(), 

599 min_confidence=min_confidence, 

600 include_neighbors=include_neighbors, 

601 limit=limit, 

602 ) 

603 params = {"legacy_format": "true"} if legacy_format else None 

604 resp = await self._http.post( 

605 "/v1/recall", 

606 json=req.model_dump(), 

607 params=params, 

608 headers=_recall_headers(session_id, verify_full), 

609 ) 

610 _raise_for_status(resp) 

611 return RecallResponse.model_validate(resp.json()) 

612 

613 # ------------------------------------------------------------------ 

614 # Memory cards (Phase 9 — spec §20) 

615 # ------------------------------------------------------------------ 

616 

617 async def get_card( 

618 self, 

619 entity_uri: str, 

620 *, 

621 scope: FactScope = "local", 

622 refresh: bool = False, 

623 ) -> MemoryCard: 

624 """Async fetch of the synthesized memory card for *entity_uri*. 

625 

626 Raises StigmemNotFoundError when the entity has no live facts. 

627 """ 

628 params: dict[str, Any] = {"scope": scope} 

629 if refresh: 

630 params["refresh"] = "true" 

631 resp = await self._http.get(f"/v1/cards/{entity_uri}", params=params) 

632 _raise_for_status(resp) 

633 return MemoryCard.model_validate(resp.json())