Coverage for node / src / stigmem_node / federation / peer_auth.py: 95%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-25 01:49 +0000

1"""Ed25519 federation keypair management and peer-token authentication. 

2 

3Spec §3.5 (peer tokens), §6.6 (security invariants). 

4 

5PeerToken JWT payload: 

6 iss: issuing node_id (URI) 

7 sub: target node_id (URI) 

8 iat: issued-at (milliseconds since epoch) 

9 exp: expiry (milliseconds since epoch; MUST be <= iat + 3_600_000) 

10 nonce: UUID for replay protection 

11 scopes: list[FactScope] 

12""" 

13 

14from __future__ import annotations 

15 

16import base64 

17import json 

18import time 

19import uuid 

20from datetime import UTC, datetime 

21from typing import Any 

22 

23import jwt 

24from cryptography.exceptions import InvalidSignature 

25from cryptography.hazmat.primitives.asymmetric.ed25519 import ( 

26 Ed25519PrivateKey, 

27 Ed25519PublicKey, 

28) 

29from cryptography.hazmat.primitives.serialization import ( 

30 Encoding, 

31 NoEncryption, 

32 PrivateFormat, 

33 PublicFormat, 

34 load_der_private_key, 

35) 

36from fastapi import HTTPException, status 

37 

38from ..db import db 

39from ..settings import settings 

40 

41# --------------------------------------------------------------------------- 

42# Base64url helpers 

43# --------------------------------------------------------------------------- 

44 

45 

46def b64url_encode(data: bytes) -> str: 

47 return base64.urlsafe_b64encode(data).rstrip(b"=").decode() 

48 

49 

50def b64url_decode(s: str) -> bytes: 

51 pad = (4 - len(s) % 4) % 4 

52 return base64.urlsafe_b64decode(s + "=" * pad) 

53 

54 

55# --------------------------------------------------------------------------- 

56# Node federation keypair (stored in node_meta) 

57# --------------------------------------------------------------------------- 

58 

59 

60def get_or_create_keypair() -> tuple[Ed25519PrivateKey, str]: 

61 """Return (private_key, base64url_pubkey). Creates and persists on first call.""" 

62 with db() as conn: 

63 row = conn.execute("SELECT value FROM node_meta WHERE key='federation_privkey'").fetchone() 

64 if row: 

65 der_bytes = b64url_decode(row["value"]) 

66 priv: Ed25519PrivateKey = load_der_private_key(der_bytes, password=None) # type: ignore[assignment] 

67 pub_raw = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) 

68 return priv, b64url_encode(pub_raw) 

69 

70 priv = Ed25519PrivateKey.generate() 

71 der_bytes = priv.private_bytes(Encoding.DER, PrivateFormat.PKCS8, NoEncryption()) 

72 pub_raw = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) 

73 conn.execute( 

74 "INSERT OR REPLACE INTO node_meta (key, value) VALUES ('federation_privkey', ?)", 

75 (b64url_encode(der_bytes),), 

76 ) 

77 return priv, b64url_encode(pub_raw) 

78 

79 

80def get_federation_pubkey() -> str: 

81 """Return the local node's base64url-encoded Ed25519 public key.""" 

82 _, pubkey = get_or_create_keypair() 

83 return pubkey 

84 

85 

86# --------------------------------------------------------------------------- 

87# Peer token minting (subscriber → publisher) 

88# --------------------------------------------------------------------------- 

89 

90 

91def mint_peer_token( 

92 local_node_id: str, 

93 target_node_id: str, 

94 scopes: list[str], 

95 ttl_ms: int = 3_600_000, 

96) -> str: 

97 """Mint a signed Ed25519 JWT peer token for pull replication.""" 

98 priv, _ = get_or_create_keypair() 

99 now_ms = int(time.time() * 1000) 

100 payload: dict[str, Any] = { 

101 "iss": local_node_id, 

102 "sub": target_node_id, 

103 "iat": now_ms, 

104 "exp": now_ms + min(ttl_ms, 3_600_000), 

105 "nonce": str(uuid.uuid4()), 

106 "scopes": scopes, 

107 } 

108 return jwt.encode(payload, priv, algorithm="EdDSA") 

109 

110 

111# --------------------------------------------------------------------------- 

112# Peer token verification (inbound — called by federation/facts endpoint) 

113# --------------------------------------------------------------------------- 

114 

115 

116class PeerTokenClaims: 

117 def __init__(self, iss: str, sub: str, scopes: list[str], nonce: str) -> None: 

118 self.iss = iss 

119 self.sub = sub 

120 self.scopes = scopes 

121 self.nonce = nonce 

122 

123 

124def _claim_epoch_ms(payload: dict[str, Any], claim: str) -> int: 

125 value = payload.get(claim) 

126 if value is None: 

127 raise HTTPException( 

128 status_code=status.HTTP_401_UNAUTHORIZED, 

129 detail=f"invalid_{claim}", 

130 ) 

131 try: 

132 return int(value) 

133 except (TypeError, ValueError, OverflowError) as exc: 

134 raise HTTPException( 

135 status_code=status.HTTP_401_UNAUTHORIZED, 

136 detail=f"invalid_{claim}", 

137 ) from exc 

138 

139 

140def verify_peer_token( 

141 raw_token: str, 

142 local_node_id: str, 

143 audit_writer: Any | None = None, 

144) -> PeerTokenClaims: 

145 """Verify a peer JWT and return its claims. 

146 

147 Raises HTTPException 401/403 on any verification failure. 

148 Nonce is NOT consumed here; caller must call consume_nonce() after successful auth. 

149 """ 

150 # Step 1 — decode header/payload without verification to extract iss 

151 try: 

152 unverified = jwt.decode( 

153 raw_token, 

154 options={"verify_signature": False, "verify_exp": False}, 

155 algorithms=["EdDSA"], 

156 ) 

157 except jwt.exceptions.DecodeError as e: 

158 raise HTTPException( 

159 status_code=status.HTTP_401_UNAUTHORIZED, detail=f"malformed_token: {e}" 

160 ) from e 

161 

162 iss = unverified.get("iss", "") 

163 if not isinstance(iss, str) or not iss: 

164 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_iss") 

165 exp = _claim_epoch_ms(unverified, "exp") 

166 

167 # Step 2 — expiry check (before touching DB) 

168 now_ms = int(time.time() * 1000) 

169 if exp <= now_ms: 

170 _write_audit(audit_writer, iss, "rejected_token", {"reason": "token_expired", "exp": exp}) 

171 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_expired") 

172 

173 # Step 3 — look up peer by iss 

174 with db() as conn: 

175 peer_row = conn.execute( 

176 "SELECT id, federation_pubkey, status, allowed_scopes FROM peers WHERE node_id = ?", 

177 (iss,), 

178 ).fetchone() 

179 

180 if peer_row is None or peer_row["status"] != "active": 

181 raise HTTPException( 

182 status_code=status.HTTP_401_UNAUTHORIZED, detail="unknown_or_inactive_peer" 

183 ) 

184 

185 # Step 4 — verify signature 

186 try: 

187 pub_bytes = b64url_decode(peer_row["federation_pubkey"]) 

188 public_key: Ed25519PublicKey = Ed25519PublicKey.from_public_bytes(pub_bytes) 

189 verified = jwt.decode( 

190 raw_token, 

191 public_key, 

192 algorithms=["EdDSA"], 

193 options={ 

194 "verify_exp": False, 

195 "verify_iat": False, 

196 "verify_nbf": False, 

197 }, 

198 ) 

199 except ( 

200 jwt.exceptions.InvalidSignatureError, 

201 InvalidSignature, 

202 jwt.exceptions.PyJWTError, 

203 ) as e: 

204 _write_audit(audit_writer, iss, "rejected_token", {"reason": "invalid_signature"}) 

205 raise HTTPException( 

206 status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_signature" 

207 ) from e 

208 

209 # Step 5 — sub must match local node 

210 sub = verified.get("sub") 

211 if not isinstance(sub, str) or sub != local_node_id: 

212 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="sub_mismatch") 

213 

214 iat = _claim_epoch_ms(verified, "iat") 

215 nbf_raw = verified.get("nbf") 

216 leeway_ms = settings.peer_token_leeway_s * 1000 

217 if iat > now_ms + leeway_ms or ( 217 ↛ 220line 217 didn't jump to line 220 because the condition on line 217 was never true

218 nbf_raw is not None and _claim_epoch_ms(verified, "nbf") > now_ms + leeway_ms 

219 ): 

220 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_not_yet_valid") 

221 

222 nonce = verified.get("nonce") 

223 if not isinstance(nonce, str) or not nonce: 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true

224 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_nonce") 

225 

226 raw_scopes = verified.get("scopes", []) 

227 scopes = ( 

228 [scope for scope in raw_scopes if isinstance(scope, str)] 

229 if isinstance(raw_scopes, list) 

230 else [] 

231 ) 

232 

233 # Step 6 — replay check (nonce must not be in nonce_cache) 

234 now_iso = datetime.now(UTC).isoformat() 

235 with db() as conn: 

236 # Prune expired nonces opportunistically 

237 conn.execute("DELETE FROM nonce_cache WHERE expires_at < ?", (now_iso,)) 

238 existing = conn.execute( 

239 "SELECT nonce FROM nonce_cache WHERE nonce = ?", (nonce,) 

240 ).fetchone() 

241 

242 if existing is not None: 

243 _write_audit( 

244 audit_writer, 

245 peer_row["id"], 

246 "replay_attempt", 

247 {"nonce": nonce, "reason": "nonce_already_seen"}, 

248 ) 

249 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="nonce_already_seen") 

250 

251 return PeerTokenClaims( 

252 iss=iss, 

253 sub=verified["sub"], 

254 scopes=scopes, 

255 nonce=nonce, 

256 ) 

257 

258 

259def consume_nonce(peer_id: str, nonce: str, exp: int) -> None: 

260 """Persist nonce to prevent replay. Call after successful auth.""" 

261 expires_iso = datetime.fromtimestamp(exp / 1000, UTC).isoformat() 

262 with db() as conn: 

263 conn.execute( 

264 "INSERT OR IGNORE INTO nonce_cache (nonce, peer_id, expires_at) VALUES (?,?,?)", 

265 (nonce, peer_id, expires_iso), 

266 ) 

267 

268 

269# --------------------------------------------------------------------------- 

270# Declaration signature helpers (§5.6, §6.1) 

271# --------------------------------------------------------------------------- 

272 

273 

274def canonical_declaration_json( 

275 node_url: str, 

276 node_id: str, 

277 federation_pubkey: str, 

278 allowed_scopes: list[str], 

279 signed_at: str, 

280) -> bytes: 

281 """Canonical JSON for declaration signing — lexicographic key order, no whitespace.""" 

282 obj = { 

283 "allowed_scopes": allowed_scopes, 

284 "federation_pubkey": federation_pubkey, 

285 "node_id": node_id, 

286 "node_url": node_url, 

287 "signed_at": signed_at, 

288 } 

289 return json.dumps(obj, separators=(",", ":"), sort_keys=True).encode() 

290 

291 

292def sign_declaration( 

293 node_url: str, 

294 node_id: str, 

295 allowed_scopes: list[str], 

296) -> tuple[str, str, str]: 

297 """Sign a peer declaration. Returns (federation_pubkey, declaration_sig, signed_at).""" 

298 priv, pubkey = get_or_create_keypair() 

299 signed_at = datetime.now(UTC).isoformat() 

300 message = canonical_declaration_json(node_url, node_id, pubkey, allowed_scopes, signed_at) 

301 sig_bytes = priv.sign(message) 

302 return pubkey, b64url_encode(sig_bytes), signed_at 

303 

304 

305def verify_declaration_sig( 

306 node_url: str, 

307 node_id: str, 

308 federation_pubkey: str, 

309 allowed_scopes: list[str], 

310 signed_at: str, 

311 declaration_sig: str, 

312) -> bool: 

313 """Return True if the declaration signature is valid.""" 

314 try: 

315 pub_bytes = b64url_decode(federation_pubkey) 

316 public_key: Ed25519PublicKey = Ed25519PublicKey.from_public_bytes(pub_bytes) 

317 message = canonical_declaration_json( 

318 node_url, node_id, federation_pubkey, allowed_scopes, signed_at 

319 ) 

320 sig_bytes = b64url_decode(declaration_sig) 

321 public_key.verify(sig_bytes, message) 

322 return True 

323 except (InvalidSignature, Exception): 

324 return False 

325 

326 

327# --------------------------------------------------------------------------- 

328# Audit helper 

329# --------------------------------------------------------------------------- 

330 

331 

332def _write_audit( 

333 audit_writer: Any | None, 

334 peer_id: str, 

335 event_type: str, 

336 detail: dict[str, Any], 

337) -> None: 

338 if audit_writer is not None: 

339 audit_writer(peer_id, event_type, detail)