Coverage for node / src / stigmem_node / federation / peer_auth.py: 95%
125 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Ed25519 federation keypair management and peer-token authentication.
3Spec §3.5 (peer tokens), §6.6 (security invariants).
5PeerToken JWT payload:
6 iss: issuing node_id (URI)
7 sub: target node_id (URI)
8 iat: issued-at (milliseconds since epoch)
9 exp: expiry (milliseconds since epoch; MUST be <= iat + 3_600_000)
10 nonce: UUID for replay protection
11 scopes: list[FactScope]
12"""
14from __future__ import annotations
16import base64
17import json
18import time
19import uuid
20from datetime import UTC, datetime
21from typing import Any
23import jwt
24from cryptography.exceptions import InvalidSignature
25from cryptography.hazmat.primitives.asymmetric.ed25519 import (
26 Ed25519PrivateKey,
27 Ed25519PublicKey,
28)
29from cryptography.hazmat.primitives.serialization import (
30 Encoding,
31 NoEncryption,
32 PrivateFormat,
33 PublicFormat,
34 load_der_private_key,
35)
36from fastapi import HTTPException, status
38from ..db import db
39from ..settings import settings
41# ---------------------------------------------------------------------------
42# Base64url helpers
43# ---------------------------------------------------------------------------
46def b64url_encode(data: bytes) -> str:
47 return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
50def b64url_decode(s: str) -> bytes:
51 pad = (4 - len(s) % 4) % 4
52 return base64.urlsafe_b64decode(s + "=" * pad)
55# ---------------------------------------------------------------------------
56# Node federation keypair (stored in node_meta)
57# ---------------------------------------------------------------------------
60def get_or_create_keypair() -> tuple[Ed25519PrivateKey, str]:
61 """Return (private_key, base64url_pubkey). Creates and persists on first call."""
62 with db() as conn:
63 row = conn.execute("SELECT value FROM node_meta WHERE key='federation_privkey'").fetchone()
64 if row:
65 der_bytes = b64url_decode(row["value"])
66 priv: Ed25519PrivateKey = load_der_private_key(der_bytes, password=None) # type: ignore[assignment]
67 pub_raw = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
68 return priv, b64url_encode(pub_raw)
70 priv = Ed25519PrivateKey.generate()
71 der_bytes = priv.private_bytes(Encoding.DER, PrivateFormat.PKCS8, NoEncryption())
72 pub_raw = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
73 conn.execute(
74 "INSERT OR REPLACE INTO node_meta (key, value) VALUES ('federation_privkey', ?)",
75 (b64url_encode(der_bytes),),
76 )
77 return priv, b64url_encode(pub_raw)
80def get_federation_pubkey() -> str:
81 """Return the local node's base64url-encoded Ed25519 public key."""
82 _, pubkey = get_or_create_keypair()
83 return pubkey
86# ---------------------------------------------------------------------------
87# Peer token minting (subscriber → publisher)
88# ---------------------------------------------------------------------------
91def mint_peer_token(
92 local_node_id: str,
93 target_node_id: str,
94 scopes: list[str],
95 ttl_ms: int = 3_600_000,
96) -> str:
97 """Mint a signed Ed25519 JWT peer token for pull replication."""
98 priv, _ = get_or_create_keypair()
99 now_ms = int(time.time() * 1000)
100 payload: dict[str, Any] = {
101 "iss": local_node_id,
102 "sub": target_node_id,
103 "iat": now_ms,
104 "exp": now_ms + min(ttl_ms, 3_600_000),
105 "nonce": str(uuid.uuid4()),
106 "scopes": scopes,
107 }
108 return jwt.encode(payload, priv, algorithm="EdDSA")
111# ---------------------------------------------------------------------------
112# Peer token verification (inbound — called by federation/facts endpoint)
113# ---------------------------------------------------------------------------
116class PeerTokenClaims:
117 def __init__(self, iss: str, sub: str, scopes: list[str], nonce: str) -> None:
118 self.iss = iss
119 self.sub = sub
120 self.scopes = scopes
121 self.nonce = nonce
124def _claim_epoch_ms(payload: dict[str, Any], claim: str) -> int:
125 value = payload.get(claim)
126 if value is None:
127 raise HTTPException(
128 status_code=status.HTTP_401_UNAUTHORIZED,
129 detail=f"invalid_{claim}",
130 )
131 try:
132 return int(value)
133 except (TypeError, ValueError, OverflowError) as exc:
134 raise HTTPException(
135 status_code=status.HTTP_401_UNAUTHORIZED,
136 detail=f"invalid_{claim}",
137 ) from exc
140def verify_peer_token(
141 raw_token: str,
142 local_node_id: str,
143 audit_writer: Any | None = None,
144) -> PeerTokenClaims:
145 """Verify a peer JWT and return its claims.
147 Raises HTTPException 401/403 on any verification failure.
148 Nonce is NOT consumed here; caller must call consume_nonce() after successful auth.
149 """
150 # Step 1 — decode header/payload without verification to extract iss
151 try:
152 unverified = jwt.decode(
153 raw_token,
154 options={"verify_signature": False, "verify_exp": False},
155 algorithms=["EdDSA"],
156 )
157 except jwt.exceptions.DecodeError as e:
158 raise HTTPException(
159 status_code=status.HTTP_401_UNAUTHORIZED, detail=f"malformed_token: {e}"
160 ) from e
162 iss = unverified.get("iss", "")
163 if not isinstance(iss, str) or not iss:
164 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_iss")
165 exp = _claim_epoch_ms(unverified, "exp")
167 # Step 2 — expiry check (before touching DB)
168 now_ms = int(time.time() * 1000)
169 if exp <= now_ms:
170 _write_audit(audit_writer, iss, "rejected_token", {"reason": "token_expired", "exp": exp})
171 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_expired")
173 # Step 3 — look up peer by iss
174 with db() as conn:
175 peer_row = conn.execute(
176 "SELECT id, federation_pubkey, status, allowed_scopes FROM peers WHERE node_id = ?",
177 (iss,),
178 ).fetchone()
180 if peer_row is None or peer_row["status"] != "active":
181 raise HTTPException(
182 status_code=status.HTTP_401_UNAUTHORIZED, detail="unknown_or_inactive_peer"
183 )
185 # Step 4 — verify signature
186 try:
187 pub_bytes = b64url_decode(peer_row["federation_pubkey"])
188 public_key: Ed25519PublicKey = Ed25519PublicKey.from_public_bytes(pub_bytes)
189 verified = jwt.decode(
190 raw_token,
191 public_key,
192 algorithms=["EdDSA"],
193 options={
194 "verify_exp": False,
195 "verify_iat": False,
196 "verify_nbf": False,
197 },
198 )
199 except (
200 jwt.exceptions.InvalidSignatureError,
201 InvalidSignature,
202 jwt.exceptions.PyJWTError,
203 ) as e:
204 _write_audit(audit_writer, iss, "rejected_token", {"reason": "invalid_signature"})
205 raise HTTPException(
206 status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_signature"
207 ) from e
209 # Step 5 — sub must match local node
210 sub = verified.get("sub")
211 if not isinstance(sub, str) or sub != local_node_id:
212 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="sub_mismatch")
214 iat = _claim_epoch_ms(verified, "iat")
215 nbf_raw = verified.get("nbf")
216 leeway_ms = settings.peer_token_leeway_s * 1000
217 if iat > now_ms + leeway_ms or ( 217 ↛ 220line 217 didn't jump to line 220 because the condition on line 217 was never true
218 nbf_raw is not None and _claim_epoch_ms(verified, "nbf") > now_ms + leeway_ms
219 ):
220 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_not_yet_valid")
222 nonce = verified.get("nonce")
223 if not isinstance(nonce, str) or not nonce: 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true
224 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_nonce")
226 raw_scopes = verified.get("scopes", [])
227 scopes = (
228 [scope for scope in raw_scopes if isinstance(scope, str)]
229 if isinstance(raw_scopes, list)
230 else []
231 )
233 # Step 6 — replay check (nonce must not be in nonce_cache)
234 now_iso = datetime.now(UTC).isoformat()
235 with db() as conn:
236 # Prune expired nonces opportunistically
237 conn.execute("DELETE FROM nonce_cache WHERE expires_at < ?", (now_iso,))
238 existing = conn.execute(
239 "SELECT nonce FROM nonce_cache WHERE nonce = ?", (nonce,)
240 ).fetchone()
242 if existing is not None:
243 _write_audit(
244 audit_writer,
245 peer_row["id"],
246 "replay_attempt",
247 {"nonce": nonce, "reason": "nonce_already_seen"},
248 )
249 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="nonce_already_seen")
251 return PeerTokenClaims(
252 iss=iss,
253 sub=verified["sub"],
254 scopes=scopes,
255 nonce=nonce,
256 )
259def consume_nonce(peer_id: str, nonce: str, exp: int) -> None:
260 """Persist nonce to prevent replay. Call after successful auth."""
261 expires_iso = datetime.fromtimestamp(exp / 1000, UTC).isoformat()
262 with db() as conn:
263 conn.execute(
264 "INSERT OR IGNORE INTO nonce_cache (nonce, peer_id, expires_at) VALUES (?,?,?)",
265 (nonce, peer_id, expires_iso),
266 )
269# ---------------------------------------------------------------------------
270# Declaration signature helpers (§5.6, §6.1)
271# ---------------------------------------------------------------------------
274def canonical_declaration_json(
275 node_url: str,
276 node_id: str,
277 federation_pubkey: str,
278 allowed_scopes: list[str],
279 signed_at: str,
280) -> bytes:
281 """Canonical JSON for declaration signing — lexicographic key order, no whitespace."""
282 obj = {
283 "allowed_scopes": allowed_scopes,
284 "federation_pubkey": federation_pubkey,
285 "node_id": node_id,
286 "node_url": node_url,
287 "signed_at": signed_at,
288 }
289 return json.dumps(obj, separators=(",", ":"), sort_keys=True).encode()
292def sign_declaration(
293 node_url: str,
294 node_id: str,
295 allowed_scopes: list[str],
296) -> tuple[str, str, str]:
297 """Sign a peer declaration. Returns (federation_pubkey, declaration_sig, signed_at)."""
298 priv, pubkey = get_or_create_keypair()
299 signed_at = datetime.now(UTC).isoformat()
300 message = canonical_declaration_json(node_url, node_id, pubkey, allowed_scopes, signed_at)
301 sig_bytes = priv.sign(message)
302 return pubkey, b64url_encode(sig_bytes), signed_at
305def verify_declaration_sig(
306 node_url: str,
307 node_id: str,
308 federation_pubkey: str,
309 allowed_scopes: list[str],
310 signed_at: str,
311 declaration_sig: str,
312) -> bool:
313 """Return True if the declaration signature is valid."""
314 try:
315 pub_bytes = b64url_decode(federation_pubkey)
316 public_key: Ed25519PublicKey = Ed25519PublicKey.from_public_bytes(pub_bytes)
317 message = canonical_declaration_json(
318 node_url, node_id, federation_pubkey, allowed_scopes, signed_at
319 )
320 sig_bytes = b64url_decode(declaration_sig)
321 public_key.verify(sig_bytes, message)
322 return True
323 except (InvalidSignature, Exception):
324 return False
327# ---------------------------------------------------------------------------
328# Audit helper
329# ---------------------------------------------------------------------------
332def _write_audit(
333 audit_writer: Any | None,
334 peer_id: str,
335 event_type: str,
336 detail: dict[str, Any],
337) -> None:
338 if audit_writer is not None:
339 audit_writer(peer_id, event_type, detail)