Coverage for node / src / stigmem_node / federation / tls.py: 65%
59 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""mTLS support for federation transport (spec §22.1).
3Builds and manages SSL contexts for server (uvicorn) and client (httpx)
4federation connections. Hot-reload is achieved by calling load_cert_chain()
5on the existing SSLContext reference — in-flight connections are unaffected
6because TLS state is per-connection, not per-context.
7"""
9from __future__ import annotations
11import asyncio
12import logging
13import ssl
14from pathlib import Path
15from typing import Any
17from ..settings import settings
19logger = logging.getLogger("stigmem.tls")
21# TLS 1.3 cipher suites mandated by spec §22.1.2.3.
22# Passed to uvicorn's ssl_ciphers for the TLS 1.2 cipher list (kept empty to
23# allow only TLS 1.3); TLS 1.3 suites are enforced by setting minimum_version
24# on the SSLContext after creation.
25_TLS13_SUITE_NAMES = (
26 "TLS_AES_256_GCM_SHA384",
27 "TLS_AES_128_GCM_SHA256",
28 "TLS_CHACHA20_POLY1305_SHA256",
29)
31# Colon-separated string for OpenSSL cipher list APIs.
32TLS13_CIPHERS = ":".join(_TLS13_SUITE_NAMES)
35def build_server_ssl_context(
36 cert_path: str,
37 key_path: str,
38 ca_bundle: str,
39) -> ssl.SSLContext:
40 """Create a TLS 1.3 server SSLContext that requires a client certificate.
42 The returned context can be mutated in-place via load_cert_chain() for
43 zero-downtime certificate rotation.
44 """
45 ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
46 ctx.minimum_version = ssl.TLSVersion.TLSv1_3
47 ctx.verify_mode = ssl.CERT_REQUIRED
48 ctx.check_hostname = False # SAN verified by check_peer_san() at app layer
49 ctx.load_cert_chain(cert_path, key_path)
50 if ca_bundle: 50 ↛ 52line 50 didn't jump to line 52 because the condition on line 50 was always true
51 ctx.load_verify_locations(ca_bundle)
52 return ctx
55def build_client_ssl_context(
56 cert_path: str,
57 key_path: str,
58 ca_bundle: str,
59) -> ssl.SSLContext:
60 """Create a TLS 1.3 client SSLContext that presents the node's cert.
62 Pass the returned context as ``verify=ctx`` to httpx.AsyncClient; httpx
63 will present the loaded client cert during the TLS handshake.
64 """
65 ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
66 ctx.minimum_version = ssl.TLSVersion.TLSv1_3
67 ctx.check_hostname = False # SAN verified by check_peer_san() at app layer
68 ctx.verify_mode = ssl.CERT_REQUIRED
69 ctx.load_cert_chain(cert_path, key_path)
70 if ca_bundle: 70 ↛ 72line 70 didn't jump to line 72 because the condition on line 70 was always true
71 ctx.load_verify_locations(ca_bundle)
72 return ctx
75def reload_tls_cert(
76 ctx: ssl.SSLContext,
77 cert_path: str | None = None,
78 key_path: str | None = None,
79) -> None:
80 """Hot-reload the certificate on *ctx* without restarting the server.
82 Existing TLS connections are unaffected; new handshakes pick up the new
83 cert immediately. Raises ssl.SSLError if the new cert/key pair is invalid.
84 """
85 cert = cert_path or settings.tls_cert_path
86 key = key_path or settings.tls_key_path
87 ctx.load_cert_chain(cert, key)
88 logger.info("TLS certificate reloaded from %s", cert)
91def check_peer_san(peer_cert: dict[str, Any], expected_entity_uri: str) -> bool:
92 """Return True iff the peer's certificate contains entity_uri as a URI SAN.
94 peer_cert is the dict returned by ssl.SSLSocket.getpeercert().
95 Called after a successful TLS handshake to enforce §22.1.2.4.
96 """
97 for kind, value in peer_cert.get("subjectAltName", ()):
98 if kind == "URI" and value == expected_entity_uri:
99 return True
100 return False
103async def cert_watcher_task(ctx: ssl.SSLContext, poll_interval: float = 5.0) -> None:
104 """Async task: watch the cert file for changes and hot-reload on mtime delta.
106 Intended to run as an asyncio task alongside the uvicorn server. Cancels
107 cleanly when the parent lifespan shuts down.
108 """
109 path = Path(settings.tls_cert_path)
110 try:
111 last_mtime = path.stat().st_mtime
112 except OSError:
113 last_mtime = 0.0
115 while True:
116 await asyncio.sleep(poll_interval)
117 try:
118 mtime = path.stat().st_mtime
119 if mtime != last_mtime:
120 logger.info("TLS cert file changed — reloading")
121 try:
122 reload_tls_cert(ctx)
123 last_mtime = mtime
124 except ssl.SSLError:
125 logger.exception("TLS cert reload failed — keeping old cert")
126 except asyncio.CancelledError:
127 raise
128 except Exception:
129 logger.exception("Unexpected error in cert watcher")