Coverage for node / src / stigmem_node / plugin_migrations.py: 92%
92 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Plugin migration lifecycle ledger and application."""
3from __future__ import annotations
5import hashlib
6from collections.abc import Iterable
7from dataclasses import dataclass
8from datetime import UTC, datetime
9from typing import Any
11from packaging.version import InvalidVersion, Version
13from .plugins import Migration, PluginMigrationError
14from .storage import StorageBackend
17@dataclass(frozen=True, slots=True)
18class _MigrationRecord:
19 plugin_name: str
20 plugin_version: str
21 migration_id: int
22 backend: str
23 checksum: str
26def apply_registered_plugin_migrations(
27 backend: StorageBackend,
28 migrations: Iterable[Migration],
29 *,
30 plugin_order: Iterable[str] = (),
31 plugin_versions: dict[str, str] | None = None,
32) -> None:
33 """Apply plugin-declared migrations with checksum and downgrade checks."""
35 ordered = _ordered_migrations(
36 migrations,
37 backend_name=backend.backend_name,
38 plugin_order=plugin_order,
39 plugin_versions=plugin_versions or {},
40 )
41 with backend.connection() as conn:
42 _ensure_plugin_migrations_table(conn)
43 applied = _load_applied(conn, backend.backend_name)
44 _validate_no_duplicate_declarations(ordered)
45 _validate_no_downgrades(ordered, applied)
46 for migration in ordered:
47 key = (migration.plugin_name, migration.backend, migration.migration_id)
48 checksum = _checksum(migration.sql)
49 existing = applied.get(key)
50 if existing is not None:
51 if existing.checksum != checksum:
52 raise PluginMigrationError(
53 "plugin migration checksum mismatch for "
54 f"{migration.plugin_name}:{migration.migration_id}"
55 )
56 continue
57 _execute_migration_sql(conn, migration)
58 conn.execute(
59 "INSERT INTO plugin_migrations "
60 "(plugin_name, plugin_version, migration_id, backend, checksum, "
61 "description, applied_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
62 (
63 migration.plugin_name,
64 migration.plugin_version,
65 migration.migration_id,
66 migration.backend,
67 checksum,
68 migration.description,
69 datetime.now(UTC).isoformat(),
70 ),
71 )
74def _execute_migration_sql(conn: Any, migration: Migration) -> None:
75 if migration.backend == "sqlite" and hasattr(conn, "executescript"): 75 ↛ 78line 75 didn't jump to line 78 because the condition on line 75 was always true
76 conn.executescript(migration.sql)
77 return
78 conn.execute(migration.sql)
81def _ensure_plugin_migrations_table(conn: Any) -> None:
82 conn.execute(
83 """CREATE TABLE IF NOT EXISTS plugin_migrations (
84 id INTEGER PRIMARY KEY AUTOINCREMENT,
85 plugin_name TEXT NOT NULL,
86 plugin_version TEXT NOT NULL,
87 migration_id INTEGER NOT NULL,
88 backend TEXT NOT NULL,
89 checksum TEXT NOT NULL,
90 description TEXT NOT NULL DEFAULT '',
91 applied_at TEXT NOT NULL,
92 UNIQUE(plugin_name, backend, migration_id)
93 )"""
94 )
95 conn.execute(
96 "CREATE INDEX IF NOT EXISTS idx_plugin_migrations_plugin "
97 "ON plugin_migrations(plugin_name, backend)"
98 )
101def _load_applied(
102 conn: Any,
103 backend_name: str,
104) -> dict[tuple[str, str, int], _MigrationRecord]:
105 rows = conn.execute(
106 "SELECT plugin_name, plugin_version, migration_id, backend, checksum "
107 "FROM plugin_migrations WHERE backend = ?",
108 (backend_name,),
109 ).fetchall()
110 return {
111 (row["plugin_name"], row["backend"], int(row["migration_id"])): _MigrationRecord(
112 plugin_name=row["plugin_name"],
113 plugin_version=row["plugin_version"],
114 migration_id=int(row["migration_id"]),
115 backend=row["backend"],
116 checksum=row["checksum"],
117 )
118 for row in rows
119 }
122def _ordered_migrations(
123 migrations: Iterable[Migration],
124 *,
125 backend_name: str,
126 plugin_order: Iterable[str],
127 plugin_versions: dict[str, str],
128) -> list[Migration]:
129 order_index = {name: idx for idx, name in enumerate(plugin_order)}
130 filtered = [
131 _with_plugin_version(migration, plugin_versions)
132 for migration in migrations
133 if migration.backend == backend_name
134 ]
135 return sorted(
136 filtered,
137 key=lambda migration: (
138 order_index.get(migration.plugin_name, len(order_index)),
139 migration.plugin_name,
140 migration.migration_id,
141 ),
142 )
145def _with_plugin_version(
146 migration: Migration,
147 plugin_versions: dict[str, str],
148) -> Migration:
149 if migration.plugin_version != "0.0.0":
150 return migration
151 version = plugin_versions.get(migration.plugin_name)
152 if version is None: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 return migration
154 return Migration(
155 plugin_name=migration.plugin_name,
156 migration_id=migration.migration_id,
157 sql=migration.sql,
158 description=migration.description,
159 plugin_version=version,
160 backend=migration.backend,
161 )
164def _validate_no_duplicate_declarations(migrations: list[Migration]) -> None:
165 seen: dict[tuple[str, str, int], str] = {}
166 for migration in migrations:
167 key = (migration.plugin_name, migration.backend, migration.migration_id)
168 checksum = _checksum(migration.sql)
169 previous = seen.get(key)
170 if previous is not None and previous != checksum: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 raise PluginMigrationError(
172 "duplicate plugin migration declaration with different SQL for "
173 f"{migration.plugin_name}:{migration.migration_id}"
174 )
175 seen[key] = checksum
178def _validate_no_downgrades(
179 migrations: list[Migration],
180 applied: dict[tuple[str, str, int], _MigrationRecord],
181) -> None:
182 latest_applied_id: dict[tuple[str, str], int] = {}
183 latest_applied_version: dict[tuple[str, str], Version] = {}
184 for record in applied.values():
185 plugin_key = (record.plugin_name, record.backend)
186 latest_applied_id[plugin_key] = max(
187 latest_applied_id.get(plugin_key, record.migration_id),
188 record.migration_id,
189 )
190 latest_applied_version[plugin_key] = max(
191 latest_applied_version.get(plugin_key, _parse_version(record.plugin_version)),
192 _parse_version(record.plugin_version),
193 )
195 declared_ids: dict[tuple[str, str], set[int]] = {}
196 declared_versions: dict[tuple[str, str], Version] = {}
197 for migration in migrations:
198 plugin_key = (migration.plugin_name, migration.backend)
199 declared_ids.setdefault(plugin_key, set()).add(migration.migration_id)
200 declared_versions[plugin_key] = max(
201 declared_versions.get(plugin_key, _parse_version(migration.plugin_version)),
202 _parse_version(migration.plugin_version),
203 )
205 for plugin_key, applied_id in latest_applied_id.items():
206 declared = declared_ids.get(plugin_key)
207 if declared and max(declared) < applied_id:
208 raise PluginMigrationError(
209 f"plugin migration downgrade refused for {plugin_key[0]}: "
210 f"applied migration {applied_id} is newer than declared {max(declared)}"
211 )
213 for plugin_key, applied_version in latest_applied_version.items():
214 declared_version = declared_versions.get(plugin_key)
215 if declared_version is not None and declared_version < applied_version: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 raise PluginMigrationError(
217 f"plugin version downgrade refused for {plugin_key[0]}: "
218 f"applied {applied_version} is newer than declared {declared_version}"
219 )
222def _checksum(sql: str) -> str:
223 return hashlib.sha256(sql.encode("utf-8")).hexdigest()
226def _parse_version(raw: str) -> Version:
227 try:
228 return Version(raw)
229 except InvalidVersion as exc:
230 raise PluginMigrationError(f"invalid plugin migration version {raw!r}") from exc