Coverage for node / src / stigmem_node / recall / vector_search.py: 66%
70 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Vector search primitive — Phase 9 (spec §20 / design memo §2–3).
3Public API::
5 vector_search(query_embedding, k=10, scope_filter=None, tenant_id="default")
6 -> list[tuple[FactRecord, float]]
8 embed_and_store_fact(fact_id, entity, relation, value_type, value_v, conn, model)
10 backfill_missing_embeddings(conn, model, limit=500)
12 EmbeddingModelMismatch — raised when stored model_id / dimension differs
13 from configured model.
14"""
16from __future__ import annotations
18import logging
19import struct
20from datetime import UTC, datetime
21from typing import TYPE_CHECKING, Any
23if TYPE_CHECKING:
24 from ..embedding.base import EmbeddingModel, Vector
25 from ..models.facts import FactRecord
27logger = logging.getLogger("stigmem.vector_search")
30class EmbeddingModelMismatch(RuntimeError):
31 """Raised when the configured model conflicts with a stored embedding_meta row."""
34# ---------------------------------------------------------------------------
35# vec_facts virtual table management
36# ---------------------------------------------------------------------------
39def ensure_vec_table(conn: Any, dimension: int) -> None:
40 """Create the ``vec_facts`` virtual table if it does not exist.
42 Called after the sqlite-vec extension is loaded. Safe to call on every
43 connection open (the CREATE is guarded by IF NOT EXISTS).
45 Raises ``RuntimeError`` if sqlite-vec is not loaded.
46 """
47 try:
48 conn.execute(
49 f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_facts "
50 f"USING vec0(fact_id TEXT PRIMARY KEY, embedding FLOAT[{dimension}])"
51 )
52 except Exception as exc:
53 raise RuntimeError(
54 f"Failed to create vec_facts virtual table (is sqlite-vec loaded?): {exc}"
55 ) from exc
58# ---------------------------------------------------------------------------
59# Mixed-model safety
60# ---------------------------------------------------------------------------
63def check_or_register_model(conn: Any, model_id: str, dimension: int) -> None:
64 """Assert that the stored embedding_meta matches *model_id* / *dimension*.
66 On first use (empty table) registers the model. Raises
67 ``EmbeddingModelMismatch`` on conflicts. Callers must hold a transaction.
68 """
69 row = conn.execute("SELECT model_id, dimension FROM embedding_meta WHERE id = 1").fetchone()
70 if row is None:
71 now = datetime.now(UTC).isoformat()
72 conn.execute(
73 "INSERT INTO embedding_meta (id, model_id, dimension, created_at) VALUES (1, ?, ?, ?)",
74 (model_id, dimension, now),
75 )
76 return
78 stored_model = row["model_id"]
79 stored_dim = int(row["dimension"])
80 if stored_model != model_id or stored_dim != dimension:
81 raise EmbeddingModelMismatch(
82 f"Embedding model mismatch: stored=({stored_model!r}, dim={stored_dim}) "
83 f"!= configured=({model_id!r}, dim={dimension}). "
84 "Run `stigmem embed reindex` to re-embed all facts with the new model, "
85 "or restore the original model configuration."
86 )
89# ---------------------------------------------------------------------------
90# Encode / decode vectors for sqlite-vec BLOB storage
91# ---------------------------------------------------------------------------
94def _encode_vector(vec: Vector) -> bytes:
95 """Encode a float list as a little-endian IEEE 754 BLOB for sqlite-vec."""
96 return struct.pack(f"<{len(vec)}f", *vec)
99# ---------------------------------------------------------------------------
100# Write-path: store a single embedding
101# ---------------------------------------------------------------------------
104def store_embedding(conn: Any, fact_id: str, vec: Vector) -> None:
105 """Upsert a vector into ``vec_facts`` and mark the fact as embedded."""
106 from ..lifecycle.immutability import set_embedding_status
108 blob = _encode_vector(vec)
109 conn.execute(
110 "INSERT OR REPLACE INTO vec_facts (fact_id, embedding) VALUES (?, ?)",
111 (fact_id, blob),
112 )
113 set_embedding_status(conn, fact_id=fact_id, embedding_missing=False)
116def embed_and_store_fact(
117 fact_id: str,
118 entity: str,
119 relation: str,
120 value_type: str,
121 value_v: str,
122 conn: Any,
123 model: EmbeddingModel,
124) -> None:
125 """Embed one fact and persist the vector. Caller owns the transaction."""
126 from ..embedding.base import compose_triple_text
128 text = compose_triple_text(entity, relation, value_type, value_v)
129 vecs = model.embed([text])
130 store_embedding(conn, fact_id, vecs[0])
133# ---------------------------------------------------------------------------
134# Backfill job
135# ---------------------------------------------------------------------------
138def backfill_missing_embeddings(
139 conn: Any,
140 model: EmbeddingModel,
141 limit: int = 500,
142) -> int:
143 """Embed facts with ``embedding_missing = 1``.
145 Returns the number of facts successfully embedded. Skips facts on per-fact
146 failures and continues (graceful degradation).
147 """
148 from ..embedding.base import compose_triple_text
150 rows = conn.execute(
151 """SELECT f.id, f.entity, f.relation, f.value_type, f.value_v
152 FROM facts f
153 LEFT JOIN fact_embedding_status fes ON fes.fact_id = f.id
154 WHERE COALESCE(fes.embedding_missing, f.embedding_missing) = 1
155 AND f.confidence > 0.1
156 LIMIT ?""",
157 (limit,),
158 ).fetchall()
160 count = 0
161 for row in rows:
162 try:
163 text = compose_triple_text(
164 row["entity"], row["relation"], row["value_type"], row["value_v"] or ""
165 )
166 vecs = model.embed([text])
167 store_embedding(conn, row["id"], vecs[0])
168 count += 1
169 except Exception as exc:
170 logger.warning("Backfill embed failed for fact %s: %s", row["id"], exc)
172 return count
175# ---------------------------------------------------------------------------
176# Query-path: vector_search
177# ---------------------------------------------------------------------------
180def vector_search(
181 query_embedding: Vector,
182 k: int = 10,
183 scope_filter: str | None = None,
184 tenant_id: str = "default",
185 conn: Any = None,
186) -> list[tuple[FactRecord, float]]:
187 """Return the top-k facts closest to *query_embedding* by cosine similarity.
189 *query_embedding* MUST be L2-normalised (stored vectors are also normalised,
190 so dot product == cosine similarity — design memo §2).
192 Args:
193 query_embedding: Unit-length query vector.
194 k: Maximum results to return.
195 scope_filter: When set, restrict to facts with this scope.
196 tenant_id: Tenant to search within.
197 conn: Optional existing DB connection; if None, opens one via ``db()``.
199 Returns:
200 List of (FactRecord, similarity) tuples, sorted descending by similarity.
201 """
202 from ..db import db
203 from ..models.facts import row_to_record
205 blob = _encode_vector(query_embedding)
207 def _run(c: Any) -> list[tuple[FactRecord, float]]:
208 base_sql = """
209 SELECT f.*, v.distance,
210 COALESCE(fvo.valid_until, f.valid_until) AS projected_valid_until,
211 COALESCE(fvo.confidence, f.confidence) AS projected_confidence,
212 COALESCE(fgm.garden_id, f.garden_id) AS projected_garden_id,
213 COALESCE(fqs.quarantine_status, f.quarantine_status)
214 AS projected_quarantine_status,
215 COALESCE(fqs.quarantine_garden_id, f.quarantine_garden_id)
216 AS projected_quarantine_garden_id,
217 COALESCE(f.cid, (
218 SELECT fca.cid FROM fact_cid_aliases fca
219 WHERE fca.fact_id = f.id ORDER BY fca.cid LIMIT 1
220 )) AS projected_cid
221 FROM vec_facts v
222 JOIN facts f ON f.id = v.fact_id
223 LEFT JOIN fact_validity_overrides fvo ON fvo.fact_id = f.id
224 LEFT JOIN fact_garden_membership fgm ON fgm.fact_id = f.id
225 LEFT JOIN fact_quarantine_status fqs ON fqs.fact_id = f.id
226 WHERE v.embedding MATCH ?
227 AND k = ?
228 AND COALESCE(fvo.confidence, f.confidence) > 0.1
229 AND f.tenant_id = ?
230 """
231 if scope_filter:
232 sql = base_sql + " AND f.scope = ? ORDER BY v.distance"
233 params: tuple[Any, ...] = (blob, k, tenant_id, scope_filter)
234 else:
235 sql = base_sql + " ORDER BY v.distance"
236 params = (blob, k, tenant_id)
238 rows = c.execute(sql, params).fetchall()
239 results: list[tuple[FactRecord, float]] = []
240 for row in rows:
241 distance = float(row["distance"])
242 similarity = max(0.0, 1.0 - distance)
243 record = row_to_record(row)
244 results.append((record, similarity))
245 return results
247 if conn is not None:
248 return _run(conn)
250 with db() as c:
251 return _run(c)