Coverage for node / src / stigmem_node / federation / peer_token.py: 92%
101 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Ed25519 peer-token creation and verification (spec §3.5, §6.6).
3Peer tokens are short-lived Ed25519-signed JWTs used for machine-to-machine
4federation auth. They are distinct from long-lived API keys.
6exp/iat in the JWT payload are epoch_ms (not epoch_s) per spec §3.5.
7We skip PyJWT's built-in exp check and validate manually at ms resolution.
8"""
10from __future__ import annotations
12import base64
13import json
14import sqlite3
15import time
16import uuid
17from datetime import UTC, datetime
18from typing import Any
20import jwt
21from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
23from ..db import db, get_or_create_federation_keypair, get_or_create_node_id
24from ..settings import settings
26_cached_pub: str | None = None
27_cached_priv: str | None = None
28_MAX_PEER_TOKEN_IAT_AGE_MS = 24 * 60 * 60 * 1000
31def _pad(s: str) -> str:
32 return s + "=" * (-len(s) % 4)
35def init_federation_keys() -> tuple[str, str]:
36 """Load or generate the node's Ed25519 keypair. Must be called after migrations."""
37 global _cached_pub, _cached_priv
38 if settings.federation_pubkey and settings.federation_privkey:
39 _cached_pub = settings.federation_pubkey
40 _cached_priv = settings.federation_privkey
41 else:
42 _cached_pub, _cached_priv = get_or_create_federation_keypair()
43 return _cached_pub, _cached_priv
46def get_local_pubkey() -> str:
47 if _cached_pub:
48 return _cached_pub
49 pub, _ = init_federation_keys()
50 return pub
53def _get_privkey_obj() -> Ed25519PrivateKey:
54 _, priv_b64 = init_federation_keys()
55 raw = base64.urlsafe_b64decode(_pad(priv_b64))
56 return Ed25519PrivateKey.from_private_bytes(raw)
59def _pubkey_obj_from_b64(b64: str) -> Ed25519PublicKey:
60 raw = base64.urlsafe_b64decode(_pad(b64))
61 return Ed25519PublicKey.from_public_bytes(raw)
64def create_peer_token(
65 target_node_id: str,
66 scopes: list[str],
67 ttl_ms: int = 3_600_000,
68) -> str:
69 """Mint a signed peer token addressed to target_node_id."""
70 private_key = _get_privkey_obj()
71 our_node_id = get_or_create_node_id()
72 now_ms = int(time.time() * 1000)
73 payload: dict[str, Any] = {
74 "iss": our_node_id,
75 "sub": target_node_id,
76 "iat": now_ms,
77 "exp": now_ms + ttl_ms,
78 "nonce": str(uuid.uuid4()),
79 "scopes": scopes,
80 }
81 return jwt.encode(payload, private_key, algorithm="EdDSA")
84class TokenError(Exception):
85 def __init__(self, kind: str) -> None:
86 self.kind = kind
87 super().__init__(kind)
90def _get_peer_node_id_by_db_id(peer_db_id: str) -> str:
91 """Return the registered node_id for a peer database row."""
92 with db() as conn:
93 row = conn.execute("SELECT node_id FROM peers WHERE id = ?", (peer_db_id,)).fetchone()
94 if row is None: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 raise TokenError("peer_not_found")
96 return str(row["node_id"])
99def verify_peer_token(
100 raw_token: str,
101 peer_pubkey_b64: str,
102 peer_db_id: str,
103) -> dict[str, Any]:
104 """Verify Ed25519 peer JWT against peer's stored pubkey.
106 Checks (spec §6.6):
107 1. Signature
108 2. sub == our node_id
109 3. exp not passed (epoch_ms)
110 4. nonce not seen within window
112 Returns payload dict on success. Raises TokenError on any failure.
113 Writes nonce to cache on success (replay protection).
114 """
115 our_node_id = get_or_create_node_id()
116 try:
117 public_key = _pubkey_obj_from_b64(peer_pubkey_b64)
118 payload: dict[str, Any] = jwt.decode(
119 raw_token,
120 public_key,
121 algorithms=["EdDSA"],
122 options={
123 # exp/iat are epoch_ms per spec §3.5 — disable library checks, validate manually
124 "verify_exp": False,
125 "verify_nbf": False,
126 "verify_iat": False,
127 "verify_aud": False,
128 "verify_iss": False,
129 },
130 )
131 except jwt.exceptions.InvalidSignatureError as exc:
132 raise TokenError("invalid_signature") from exc
133 except jwt.exceptions.PyJWTError as exc:
134 raise TokenError("invalid_token") from exc
136 now_ms = int(time.time() * 1000)
137 leeway_ms = settings.peer_token_leeway_s * 1000
138 exp = int(payload.get("exp", 0))
139 if now_ms > exp + leeway_ms:
140 raise TokenError("token_expired")
142 if payload.get("sub") != our_node_id: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true
143 raise TokenError("invalid_sub")
145 expected_iss = _get_peer_node_id_by_db_id(peer_db_id)
146 if payload.get("iss") != expected_iss:
147 raise TokenError("invalid_iss")
149 iat = int(payload.get("iat", 0))
150 if iat > now_ms + leeway_ms:
151 raise TokenError("iat_in_future")
152 if iat < now_ms - _MAX_PEER_TOKEN_IAT_AGE_MS:
153 raise TokenError("iat_too_old")
155 nbf = payload.get("nbf")
156 if nbf is not None and int(nbf) > now_ms + leeway_ms:
157 raise TokenError("nbf_in_future")
159 nonce = payload.get("nonce")
160 if not nonce: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true
161 raise TokenError("missing_nonce")
163 # Atomically insert nonce (UNIQUE constraint rejects replays)
164 window_ms = settings.federation_nonce_window_s * 1000
165 nonce_expires_ms = max(exp, now_ms + window_ms)
166 expires_at = datetime.fromtimestamp(nonce_expires_ms / 1000, tz=UTC).isoformat()
167 try:
168 with db() as conn:
169 conn.execute(
170 "INSERT INTO nonce_cache (nonce, peer_id, expires_at) VALUES (?, ?, ?)",
171 (nonce, peer_db_id, expires_at),
172 )
173 # Opportunistic prune of expired nonces
174 conn.execute(
175 "DELETE FROM nonce_cache WHERE expires_at < ?",
176 (datetime.now(UTC).isoformat(),),
177 )
178 except sqlite3.IntegrityError as exc:
179 raise TokenError("nonce_already_seen") from exc
181 return payload
184def verify_declaration_sig(decl_fields: dict[str, Any], sig_b64: str, pubkey_b64: str) -> bool:
185 """Verify the PeerDeclaration signature (spec §6.1).
187 decl_fields must contain all signed fields (everything except declaration_sig itself),
188 in lexicographic key order for canonical JSON.
189 """
190 from cryptography.exceptions import InvalidSignature
192 canonical = json.dumps(decl_fields, sort_keys=True, separators=(",", ":")).encode("utf-8")
193 sig_bytes = base64.urlsafe_b64decode(_pad(sig_b64))
194 try:
195 pub = _pubkey_obj_from_b64(pubkey_b64)
196 pub.verify(sig_bytes, canonical)
197 return True
198 except (InvalidSignature, Exception):
199 return False