Coverage for node / src / stigmem_node / recall / vector_search.py: 66%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-25 01:49 +0000

1"""Vector search primitive — Phase 9 (spec §20 / design memo §2–3). 

2 

3Public API:: 

4 

5 vector_search(query_embedding, k=10, scope_filter=None, tenant_id="default") 

6 -> list[tuple[FactRecord, float]] 

7 

8 embed_and_store_fact(fact_id, entity, relation, value_type, value_v, conn, model) 

9 

10 backfill_missing_embeddings(conn, model, limit=500) 

11 

12 EmbeddingModelMismatch — raised when stored model_id / dimension differs 

13 from configured model. 

14""" 

15 

16from __future__ import annotations 

17 

18import logging 

19import struct 

20from datetime import UTC, datetime 

21from typing import TYPE_CHECKING, Any 

22 

23if TYPE_CHECKING: 

24 from ..embedding.base import EmbeddingModel, Vector 

25 from ..models.facts import FactRecord 

26 

27logger = logging.getLogger("stigmem.vector_search") 

28 

29 

30class EmbeddingModelMismatch(RuntimeError): 

31 """Raised when the configured model conflicts with a stored embedding_meta row.""" 

32 

33 

34# --------------------------------------------------------------------------- 

35# vec_facts virtual table management 

36# --------------------------------------------------------------------------- 

37 

38 

39def ensure_vec_table(conn: Any, dimension: int) -> None: 

40 """Create the ``vec_facts`` virtual table if it does not exist. 

41 

42 Called after the sqlite-vec extension is loaded. Safe to call on every 

43 connection open (the CREATE is guarded by IF NOT EXISTS). 

44 

45 Raises ``RuntimeError`` if sqlite-vec is not loaded. 

46 """ 

47 try: 

48 conn.execute( 

49 f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_facts " 

50 f"USING vec0(fact_id TEXT PRIMARY KEY, embedding FLOAT[{dimension}])" 

51 ) 

52 except Exception as exc: 

53 raise RuntimeError( 

54 f"Failed to create vec_facts virtual table (is sqlite-vec loaded?): {exc}" 

55 ) from exc 

56 

57 

58# --------------------------------------------------------------------------- 

59# Mixed-model safety 

60# --------------------------------------------------------------------------- 

61 

62 

63def check_or_register_model(conn: Any, model_id: str, dimension: int) -> None: 

64 """Assert that the stored embedding_meta matches *model_id* / *dimension*. 

65 

66 On first use (empty table) registers the model. Raises 

67 ``EmbeddingModelMismatch`` on conflicts. Callers must hold a transaction. 

68 """ 

69 row = conn.execute("SELECT model_id, dimension FROM embedding_meta WHERE id = 1").fetchone() 

70 if row is None: 

71 now = datetime.now(UTC).isoformat() 

72 conn.execute( 

73 "INSERT INTO embedding_meta (id, model_id, dimension, created_at) VALUES (1, ?, ?, ?)", 

74 (model_id, dimension, now), 

75 ) 

76 return 

77 

78 stored_model = row["model_id"] 

79 stored_dim = int(row["dimension"]) 

80 if stored_model != model_id or stored_dim != dimension: 

81 raise EmbeddingModelMismatch( 

82 f"Embedding model mismatch: stored=({stored_model!r}, dim={stored_dim}) " 

83 f"!= configured=({model_id!r}, dim={dimension}). " 

84 "Run `stigmem embed reindex` to re-embed all facts with the new model, " 

85 "or restore the original model configuration." 

86 ) 

87 

88 

89# --------------------------------------------------------------------------- 

90# Encode / decode vectors for sqlite-vec BLOB storage 

91# --------------------------------------------------------------------------- 

92 

93 

94def _encode_vector(vec: Vector) -> bytes: 

95 """Encode a float list as a little-endian IEEE 754 BLOB for sqlite-vec.""" 

96 return struct.pack(f"<{len(vec)}f", *vec) 

97 

98 

99# --------------------------------------------------------------------------- 

100# Write-path: store a single embedding 

101# --------------------------------------------------------------------------- 

102 

103 

104def store_embedding(conn: Any, fact_id: str, vec: Vector) -> None: 

105 """Upsert a vector into ``vec_facts`` and mark the fact as embedded.""" 

106 from ..lifecycle.immutability import set_embedding_status 

107 

108 blob = _encode_vector(vec) 

109 conn.execute( 

110 "INSERT OR REPLACE INTO vec_facts (fact_id, embedding) VALUES (?, ?)", 

111 (fact_id, blob), 

112 ) 

113 set_embedding_status(conn, fact_id=fact_id, embedding_missing=False) 

114 

115 

116def embed_and_store_fact( 

117 fact_id: str, 

118 entity: str, 

119 relation: str, 

120 value_type: str, 

121 value_v: str, 

122 conn: Any, 

123 model: EmbeddingModel, 

124) -> None: 

125 """Embed one fact and persist the vector. Caller owns the transaction.""" 

126 from ..embedding.base import compose_triple_text 

127 

128 text = compose_triple_text(entity, relation, value_type, value_v) 

129 vecs = model.embed([text]) 

130 store_embedding(conn, fact_id, vecs[0]) 

131 

132 

133# --------------------------------------------------------------------------- 

134# Backfill job 

135# --------------------------------------------------------------------------- 

136 

137 

138def backfill_missing_embeddings( 

139 conn: Any, 

140 model: EmbeddingModel, 

141 limit: int = 500, 

142) -> int: 

143 """Embed facts with ``embedding_missing = 1``. 

144 

145 Returns the number of facts successfully embedded. Skips facts on per-fact 

146 failures and continues (graceful degradation). 

147 """ 

148 from ..embedding.base import compose_triple_text 

149 

150 rows = conn.execute( 

151 """SELECT f.id, f.entity, f.relation, f.value_type, f.value_v 

152 FROM facts f 

153 LEFT JOIN fact_embedding_status fes ON fes.fact_id = f.id 

154 WHERE COALESCE(fes.embedding_missing, f.embedding_missing) = 1 

155 AND f.confidence > 0.1 

156 LIMIT ?""", 

157 (limit,), 

158 ).fetchall() 

159 

160 count = 0 

161 for row in rows: 

162 try: 

163 text = compose_triple_text( 

164 row["entity"], row["relation"], row["value_type"], row["value_v"] or "" 

165 ) 

166 vecs = model.embed([text]) 

167 store_embedding(conn, row["id"], vecs[0]) 

168 count += 1 

169 except Exception as exc: 

170 logger.warning("Backfill embed failed for fact %s: %s", row["id"], exc) 

171 

172 return count 

173 

174 

175# --------------------------------------------------------------------------- 

176# Query-path: vector_search 

177# --------------------------------------------------------------------------- 

178 

179 

180def vector_search( 

181 query_embedding: Vector, 

182 k: int = 10, 

183 scope_filter: str | None = None, 

184 tenant_id: str = "default", 

185 conn: Any = None, 

186) -> list[tuple[FactRecord, float]]: 

187 """Return the top-k facts closest to *query_embedding* by cosine similarity. 

188 

189 *query_embedding* MUST be L2-normalised (stored vectors are also normalised, 

190 so dot product == cosine similarity — design memo §2). 

191 

192 Args: 

193 query_embedding: Unit-length query vector. 

194 k: Maximum results to return. 

195 scope_filter: When set, restrict to facts with this scope. 

196 tenant_id: Tenant to search within. 

197 conn: Optional existing DB connection; if None, opens one via ``db()``. 

198 

199 Returns: 

200 List of (FactRecord, similarity) tuples, sorted descending by similarity. 

201 """ 

202 from ..db import db 

203 from ..models.facts import row_to_record 

204 

205 blob = _encode_vector(query_embedding) 

206 

207 def _run(c: Any) -> list[tuple[FactRecord, float]]: 

208 base_sql = """ 

209 SELECT f.*, v.distance, 

210 COALESCE(fvo.valid_until, f.valid_until) AS projected_valid_until, 

211 COALESCE(fvo.confidence, f.confidence) AS projected_confidence, 

212 COALESCE(fgm.garden_id, f.garden_id) AS projected_garden_id, 

213 COALESCE(fqs.quarantine_status, f.quarantine_status) 

214 AS projected_quarantine_status, 

215 COALESCE(fqs.quarantine_garden_id, f.quarantine_garden_id) 

216 AS projected_quarantine_garden_id, 

217 COALESCE(f.cid, ( 

218 SELECT fca.cid FROM fact_cid_aliases fca 

219 WHERE fca.fact_id = f.id ORDER BY fca.cid LIMIT 1 

220 )) AS projected_cid 

221 FROM vec_facts v 

222 JOIN facts f ON f.id = v.fact_id 

223 LEFT JOIN fact_validity_overrides fvo ON fvo.fact_id = f.id 

224 LEFT JOIN fact_garden_membership fgm ON fgm.fact_id = f.id 

225 LEFT JOIN fact_quarantine_status fqs ON fqs.fact_id = f.id 

226 WHERE v.embedding MATCH ? 

227 AND k = ? 

228 AND COALESCE(fvo.confidence, f.confidence) > 0.1 

229 AND f.tenant_id = ? 

230 """ 

231 if scope_filter: 

232 sql = base_sql + " AND f.scope = ? ORDER BY v.distance" 

233 params: tuple[Any, ...] = (blob, k, tenant_id, scope_filter) 

234 else: 

235 sql = base_sql + " ORDER BY v.distance" 

236 params = (blob, k, tenant_id) 

237 

238 rows = c.execute(sql, params).fetchall() 

239 results: list[tuple[FactRecord, float]] = [] 

240 for row in rows: 

241 distance = float(row["distance"]) 

242 similarity = max(0.0, 1.0 - distance) 

243 record = row_to_record(row) 

244 results.append((record, similarity)) 

245 return results 

246 

247 if conn is not None: 

248 return _run(conn) 

249 

250 with db() as c: 

251 return _run(c)