Coverage for node / src / stigmem_node / storage / postgres_backend.py: 24%
239 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""PostgreSQL implementation of StorageBackend — feature-flagged (Phase 11).
3Uses psycopg2 with a SQLite-API-compatible wrapper so all existing SQL
4(written for ``?`` placeholders and ``row["col"]`` access) works without
5modification. SQL is translated on the fly:
7 * ``?`` → ``%s`` (psycopg2 placeholder style)
8 * Literal ``%`` in SQL literals → ``%%`` (psycopg2 escaping)
9 * ``INSERT OR IGNORE`` → ``INSERT … ON CONFLICT DO NOTHING``
10 * ``INSERT OR REPLACE`` → ``INSERT … ON CONFLICT (pk) DO UPDATE SET …``
11 * ``AUTOINCREMENT`` → ``SERIAL`` in migration DDL
13The backend applies migrations from a ``migrations_pg/`` sibling directory
14when a per-version override exists there, falling back to the standard file.
15Overrides handle SQLite-specific DDL: PRAGMA, FTS5 virtual tables, GLOB
16patterns, and table-rebuild workarounds for CHECK constraint changes.
18Install before use::
20 pip install 'stigmem-node[postgres]'
22Environment variables::
24 STIGMEM_BACKEND=postgres
25 DATABASE_URL=postgresql://user:pass@host:5432/dbname
26 # or equivalently STIGMEM_DATABASE_URL=...
28Per-test schema isolation::
30 PostgresBackend(dsn=..., schema="test_mytest_abc123")
32The backend creates the schema on first ``apply_migrations()`` call and the
33test harness can DROP it when the test finishes.
35Note on asyncpg
36---------------
37The issue spec originally mentioned asyncpg for connection pooling, but asyncpg
38is async-only and incompatible with the synchronous ``StorageBackend`` interface.
39psycopg2 with ``ThreadedConnectionPool`` provides equivalent pooling for the
40sync path. A future async backend can wrap asyncpg independently.
41"""
43from __future__ import annotations
45import logging
46import re
47from collections.abc import Generator
48from contextlib import contextmanager
49from datetime import UTC, datetime
50from pathlib import Path
51from typing import Any
53from .base import StorageBackend
55logger = logging.getLogger("stigmem.storage.postgres")
57_SCHEMA_NAME_RE = re.compile(r"^[a-z][a-z0-9_]{0,62}$")
60def _validate_schema_name(schema: str) -> str:
61 """Validate a Postgres schema identifier accepted by this backend."""
62 if not _SCHEMA_NAME_RE.fullmatch(schema):
63 raise ValueError(
64 "PostgresBackend schema name must match [a-z][a-z0-9_]{0,62}; "
65 f"got {schema!r}"
66 )
67 return schema
70# ---------------------------------------------------------------------------
71# Primary-key map for INSERT OR REPLACE rewriting
72# ---------------------------------------------------------------------------
74# Maps (lowercase) table name → list of primary key columns.
75# Add an entry here when a new table uses INSERT OR REPLACE.
76_TABLE_PK: dict[str, list[str]] = {
77 "node_meta": ["key"],
78 "entity_aliases": ["raw_uri"],
79 "vec_facts": ["fact_id"],
80 "boot_stubs": ["agent_id", "adapter_profile"],
81 "schema_migrations": ["version"],
82}
84# ---------------------------------------------------------------------------
85# SQL transpilation helpers
86# ---------------------------------------------------------------------------
88_OR_IGNORE_RE = re.compile(r"\bINSERT\s+OR\s+IGNORE\b", re.IGNORECASE)
89_OR_REPLACE_RE = re.compile(
90 r"\bINSERT\s+OR\s+REPLACE\s+INTO\s+(\w+)\s*\(([^)]+)\)",
91 re.IGNORECASE,
92)
93# SQLite strftime('%s', col) → EXTRACT(EPOCH FROM col::timestamptz)
94# Must be translated before the % → %% escaping step.
95# Bounded quantifiers — defends against the ``py/polynomial-redos`` heuristic
96# (CodeQL #21). Inputs are developer-authored migration SQL in practice, but
97# bounding ``\s{0,16}`` and ``[^)]{1,256}?`` removes any theoretical
98# superlinear-backtracking case and quiets the analyzer permanently.
99_STRFTIME_EPOCH_RE = re.compile(r"strftime\('%s',\s{0,16}([^)]{1,256}?)\)", re.IGNORECASE)
102def _rewrite_or_ignore(sql: str) -> str:
103 """INSERT OR IGNORE … → INSERT … ON CONFLICT DO NOTHING."""
104 if not _OR_IGNORE_RE.search(sql):
105 return sql
106 sql = _OR_IGNORE_RE.sub("INSERT", sql)
107 return sql.rstrip().rstrip(";") + " ON CONFLICT DO NOTHING"
110def _rewrite_or_replace(sql: str) -> str:
111 """INSERT OR REPLACE INTO table (cols) VALUES (…) → Postgres upsert."""
112 m = _OR_REPLACE_RE.search(sql)
113 if not m:
114 return sql
116 table = m.group(1).lower()
117 cols = [c.strip() for c in m.group(2).split(",")]
118 pk_cols = _TABLE_PK.get(table, [])
119 pk_set = set(pk_cols)
120 update_cols = [c for c in cols if c not in pk_set]
122 # Strip 'OR REPLACE' and rebuild
123 sql = _OR_REPLACE_RE.sub(
124 lambda mx: f"INSERT INTO {mx.group(1)} ({mx.group(2)})",
125 sql,
126 )
127 sql = sql.rstrip().rstrip(";")
129 if pk_cols:
130 conflict_target = ", ".join(pk_cols)
131 if update_cols:
132 set_clauses = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
133 sql += f"\nON CONFLICT ({conflict_target}) DO UPDATE SET {set_clauses}" # nosec B608 — conflict_target and update_cols are column names from _TABLE_PK (hardcoded schema dict), not user input
134 else:
135 sql += f"\nON CONFLICT ({conflict_target}) DO NOTHING"
136 else:
137 logger.warning(
138 "INSERT OR REPLACE for unknown table %r — add it to _TABLE_PK; "
139 "falling back to ON CONFLICT DO NOTHING",
140 table,
141 )
142 sql += " ON CONFLICT DO NOTHING"
144 return sql
147def _pg_translate(sql: str) -> str:
148 """Translate a SQLite DML/DDL string to psycopg2/Postgres format.
150 Applied in order:
151 1. Rewrite ``INSERT OR IGNORE`` and ``INSERT OR REPLACE``.
152 2. Translate ``strftime('%s', col)`` → ``EXTRACT(EPOCH FROM col::timestamptz)``.
153 Must happen before step 3 so the ``%s`` inside strftime is not mangled.
154 3. Escape literal ``%`` → ``%%`` (psycopg2 treats bare ``%`` as special).
155 4. Replace ``?`` parameter placeholders with ``%s``.
156 5. Translate ``INTEGER PRIMARY KEY AUTOINCREMENT`` → ``SERIAL PRIMARY KEY``.
157 """
158 if _OR_IGNORE_RE.search(sql):
159 sql = _rewrite_or_ignore(sql)
160 elif _OR_REPLACE_RE.search(sql):
161 sql = _rewrite_or_replace(sql)
163 sql = _STRFTIME_EPOCH_RE.sub(r"EXTRACT(EPOCH FROM \1::timestamptz)", sql)
164 sql = sql.replace("%", "%%")
165 sql = sql.replace("?", "%s")
166 sql = re.sub(
167 r"INTEGER\s+PRIMARY\s+KEY\s+AUTOINCREMENT",
168 "SERIAL PRIMARY KEY",
169 sql,
170 flags=re.IGNORECASE,
171 )
172 return sql
175def _pg_split_migration(sql: str) -> list[str]:
176 """Split a migration script into Postgres-executable statements.
178 Strips comments, then splits on ``;``. Filters out SQLite-specific
179 statements (PRAGMA, CREATE VIRTUAL TABLE, fts5 triggers, bare
180 transaction keywords left by trigger-body splits). Remaining
181 statements are passed through ``_pg_translate()``.
183 When a ``migrations_pg/`` override file is used, this function is still
184 called on the override SQL — the override files contain clean Postgres DDL
185 so filtering is effectively a no-op for them.
186 """
187 sql = re.sub(r"--[^\n]*", "", sql)
188 sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
190 statements: list[str] = []
191 for raw in sql.split(";"):
192 stmt = raw.strip()
193 if not stmt:
194 continue
195 upper = stmt.upper()
196 # Bare transaction / PL body delimiters left from trigger splits
197 if upper in ("BEGIN", "END", "COMMIT", "ROLLBACK"):
198 continue
199 # SQLite PRAGMA — no Postgres equivalent
200 if re.match(r"\s*PRAGMA\b", stmt, re.IGNORECASE):
201 continue
202 # FTS5 virtual table
203 if re.search(r"USING\s+fts5", stmt, re.IGNORECASE):
204 continue
205 # Any statement touching facts_fts (SQLite FTS5 triggers / inserts)
206 if re.search(r"\bfacts_fts\b", stmt, re.IGNORECASE):
207 continue
208 # Generic CREATE VIRTUAL TABLE guard
209 if re.search(r"CREATE\s+VIRTUAL\s+TABLE", stmt, re.IGNORECASE):
210 continue
211 statements.append(_pg_translate(stmt))
213 return statements
216# ---------------------------------------------------------------------------
217# SQLite-API-compatible row wrapper
218# ---------------------------------------------------------------------------
221class _PGRow:
222 """Dict-like wrapper around a psycopg2 ``RealDictRow``.
224 Supports ``row["col"]``, ``row[i]`` (by position), ``row.keys()``, and
225 ``row.get(key, default)`` — the same contract as ``sqlite3.Row``.
226 """
228 __slots__ = ("_d", "_vals")
230 def __init__(self, d: dict[str, Any], vals: tuple[Any, ...]) -> None:
231 self._d = d
232 self._vals = vals
234 def __getitem__(self, key: str | int) -> Any:
235 if isinstance(key, int):
236 return self._vals[key]
237 return self._d[key]
239 def __iter__(self) -> Any:
240 return iter(self._vals)
242 def keys(self) -> list[str]:
243 return list(self._d.keys())
245 def get(self, key: str, default: Any = None) -> Any:
246 return self._d.get(key, default)
249# ---------------------------------------------------------------------------
250# Cursor wrapper
251# ---------------------------------------------------------------------------
254class _PGCursor:
255 """Wraps a psycopg2 ``RealDictCursor`` to match the sqlite3 cursor API."""
257 def __init__(self, cur: Any) -> None:
258 self._cur = cur
260 def fetchall(self) -> list[_PGRow]:
261 rows = self._cur.fetchall()
262 return [_PGRow(dict(r), tuple(r.values())) for r in rows]
264 def fetchone(self) -> _PGRow | None:
265 r = self._cur.fetchone()
266 if r is None:
267 return None
268 return _PGRow(dict(r), tuple(r.values()))
270 def __iter__(self) -> Any:
271 for r in self._cur:
272 yield _PGRow(dict(r), tuple(r.values()))
274 @property
275 def rowcount(self) -> int:
276 return self._cur.rowcount # type: ignore[no-any-return]
279# ---------------------------------------------------------------------------
280# Connection wrapper
281# ---------------------------------------------------------------------------
284class _PGConn:
285 """SQLite-API-compatible wrapper around a psycopg2 connection.
287 Creates a fresh ``RealDictCursor`` per ``execute()`` call so multiple
288 cursors can be open concurrently, matching sqlite3 semantics.
289 """
291 def __init__(self, pg_conn: Any) -> None:
292 self._conn = pg_conn
294 def execute(self, sql: str, params: Any = ()) -> _PGCursor:
295 import psycopg2.extras
297 translated = _pg_translate(sql)
298 cur = self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
299 cur.execute(translated, params or ())
300 return _PGCursor(cur)
302 def executemany(self, sql: str, seq: Any) -> None:
303 translated = _pg_translate(sql)
304 cur = self._conn.cursor()
305 cur.executemany(translated, seq)
307 def executescript(self, sql: str) -> None:
308 """Execute a SQL script (multiple statements separated by ';')."""
309 for stmt in _pg_split_migration(sql):
310 cur = self._conn.cursor()
311 cur.execute(stmt)
313 def commit(self) -> None:
314 self._conn.commit()
316 def rollback(self) -> None:
317 self._conn.rollback()
319 def close(self) -> None:
320 # Pool connection — returned to pool by the context manager; do nothing.
321 pass
324# ---------------------------------------------------------------------------
325# PostgresBackend
326# ---------------------------------------------------------------------------
329class PostgresBackend(StorageBackend):
330 """PostgreSQL backend using psycopg2 with a ``ThreadedConnectionPool``.
332 Args:
333 dsn: libpq connection string, e.g.
334 ``"postgresql://user:pw@localhost/stigmem"``.
335 schema: Postgres schema for all tables (default ``"public"``).
336 Use a unique value per test run for schema-level isolation.
337 pool_min: Minimum pool size (default 2).
338 pool_max: Maximum pool size (default 10).
339 embed_enabled: When True, creates a pgvector ``vec_facts`` table.
340 embed_dimension: Vector dimension (must match embedding model).
341 """
343 def __init__(
344 self,
345 dsn: str,
346 schema: str = "public",
347 pool_min: int = 2,
348 pool_max: int = 10,
349 embed_enabled: bool = False,
350 embed_dimension: int = 768,
351 ) -> None:
352 self._dsn = dsn
353 self._schema = _validate_schema_name(schema)
354 self._pool_min = pool_min
355 self._pool_max = pool_max
356 self._embed_enabled = embed_enabled
357 self._embed_dimension = embed_dimension
358 self._pool: Any = None
360 @property
361 def backend_name(self) -> str:
362 return "postgres"
364 # ------------------------------------------------------------------
365 # Internal helpers
366 # ------------------------------------------------------------------
368 def _get_pool(self) -> Any:
369 """Lazily create and return the thread-safe psycopg2 connection pool."""
370 if self._pool is not None:
371 return self._pool
372 try:
373 import psycopg2.pool
374 except ImportError as exc:
375 raise RuntimeError(
376 "psycopg2 is required for the PostgreSQL backend. "
377 "Install it with: pip install 'stigmem-node[postgres]'"
378 ) from exc
379 self._pool = psycopg2.pool.ThreadedConnectionPool(
380 self._pool_min,
381 self._pool_max,
382 self._dsn,
383 )
384 return self._pool
386 def _open_raw_conn(self) -> Any:
387 """Open a direct psycopg2 connection (used by apply_migrations)."""
388 try:
389 import psycopg2
390 except ImportError as exc:
391 raise RuntimeError(
392 "psycopg2 is required for the PostgreSQL backend. "
393 "Install it with: pip install 'stigmem-node[postgres]'"
394 ) from exc
395 return psycopg2.connect(self._dsn)
397 def _set_search_path(self, conn: Any) -> None:
398 """Set the active schema using identifier quoting."""
399 from psycopg2 import sql
401 with conn.cursor() as cur:
402 cur.execute(
403 sql.SQL("SET search_path TO {}").format(sql.Identifier(self._schema))
404 )
406 def _pg_migrations(self, migrations_dir: Path) -> list[Path]:
407 """Ordered migration files, preferring Postgres-specific overrides.
409 Looks for a ``migrations_pg/`` sibling to *migrations_dir*. For each
410 version present there, the pg-specific file takes precedence.
411 """
412 pg_dir = migrations_dir.parent / "migrations_pg"
413 overrides: dict[str, Path] = {}
414 if pg_dir.is_dir():
415 for f in pg_dir.glob("*.sql"):
416 overrides[f.stem] = f
418 files: list[Path] = []
419 for f in sorted(migrations_dir.glob("*.sql")):
420 files.append(overrides.get(f.stem, f))
421 return files
423 # ------------------------------------------------------------------
424 # StorageBackend interface
425 # ------------------------------------------------------------------
427 @contextmanager
428 def connection(self) -> Generator[_PGConn, None, None]:
429 pool = self._get_pool()
430 pg_conn = pool.getconn()
431 wrapped = _PGConn(pg_conn)
432 try:
433 self._set_search_path(pg_conn)
434 yield wrapped
435 pg_conn.commit()
436 except Exception:
437 pg_conn.rollback()
438 raise
439 finally:
440 pool.putconn(pg_conn)
442 def apply_migrations(self, migrations_dir: Path) -> None:
443 conn = self._open_raw_conn()
444 try:
445 # Ensure the target schema exists (for per-test schema isolation).
446 from psycopg2 import sql
448 with conn.cursor() as cur:
449 cur.execute(
450 sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
451 sql.Identifier(self._schema)
452 )
453 )
454 self._set_search_path(conn)
455 conn.commit()
457 # Bootstrap schema_migrations table.
458 with conn.cursor() as cur:
459 cur.execute(
460 """
461 CREATE TABLE IF NOT EXISTS schema_migrations (
462 id SERIAL PRIMARY KEY,
463 version TEXT NOT NULL UNIQUE,
464 applied_at TEXT NOT NULL
465 )
466 """
467 )
468 conn.commit()
470 import psycopg2.extras
472 with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
473 cur.execute("SELECT version FROM schema_migrations")
474 applied: set[str] = {r["version"] for r in cur.fetchall()}
476 for f in self._pg_migrations(migrations_dir):
477 version = f.stem
478 if version in applied:
479 continue
481 logger.info("Applying migration %s (%s)", version, f.name)
482 stmts = _pg_split_migration(f.read_text())
483 try:
484 with conn.cursor() as cur:
485 for stmt in stmts:
486 cur.execute(stmt)
487 with conn.cursor() as cur:
488 cur.execute(
489 "INSERT INTO schema_migrations (version, applied_at) VALUES (%s, %s)",
490 (version, datetime.now(UTC).isoformat()),
491 )
492 conn.commit()
493 logger.info("Migration %s applied", version)
494 except Exception:
495 conn.rollback()
496 raise
498 if self._embed_enabled:
499 self._ensure_vec_table(conn)
500 finally:
501 conn.close()
503 def _ensure_vec_table(self, conn: Any) -> None:
504 """Create the pgvector ``vec_facts`` table and index (idempotent)."""
505 try:
506 from pgvector.psycopg2 import register_vector
507 except ImportError as exc:
508 raise RuntimeError(
509 "pgvector is required for Postgres vector search. "
510 "Install it with: pip install 'stigmem-node[postgres]'"
511 ) from exc
513 register_vector(conn)
514 dim = self._embed_dimension
515 with conn.cursor() as cur:
516 cur.execute(
517 f"""
518 CREATE TABLE IF NOT EXISTS vec_facts (
519 fact_id TEXT PRIMARY KEY,
520 embedding vector({dim})
521 )
522 """
523 )
524 with conn.cursor() as cur:
525 try:
526 cur.execute(
527 """
528 CREATE INDEX IF NOT EXISTS vec_facts_embedding_idx
529 ON vec_facts USING ivfflat (embedding vector_cosine_ops)
530 WITH (lists = 100)
531 """
532 )
533 except Exception as exc: # noqa: BLE001
534 logger.warning("Could not create ivfflat index on vec_facts: %s", exc)
535 conn.rollback()
536 conn.commit()
538 # ------------------------------------------------------------------
539 # Snapshot hooks — Postgres backup is an operator concern
540 # ------------------------------------------------------------------
542 def export_snapshot(self, dest: Path) -> None:
543 raise NotImplementedError(
544 "PostgresBackend does not support snapshot export. "
545 "Use pg_dump or your cloud provider's managed backup tooling."
546 )
548 def import_snapshot(self, src: Path) -> None:
549 raise NotImplementedError(
550 "PostgresBackend does not support snapshot import. "
551 "Use pg_restore or your cloud provider's managed restore tooling."
552 )