Coverage for node / src / stigmem_node / federation / tls.py: 65%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-25 01:49 +0000

1"""mTLS support for federation transport (spec §22.1). 

2 

3Builds and manages SSL contexts for server (uvicorn) and client (httpx) 

4federation connections. Hot-reload is achieved by calling load_cert_chain() 

5on the existing SSLContext reference — in-flight connections are unaffected 

6because TLS state is per-connection, not per-context. 

7""" 

8 

9from __future__ import annotations 

10 

11import asyncio 

12import logging 

13import ssl 

14from pathlib import Path 

15from typing import Any 

16 

17from ..settings import settings 

18 

19logger = logging.getLogger("stigmem.tls") 

20 

21# TLS 1.3 cipher suites mandated by spec §22.1.2.3. 

22# Passed to uvicorn's ssl_ciphers for the TLS 1.2 cipher list (kept empty to 

23# allow only TLS 1.3); TLS 1.3 suites are enforced by setting minimum_version 

24# on the SSLContext after creation. 

25_TLS13_SUITE_NAMES = ( 

26 "TLS_AES_256_GCM_SHA384", 

27 "TLS_AES_128_GCM_SHA256", 

28 "TLS_CHACHA20_POLY1305_SHA256", 

29) 

30 

31# Colon-separated string for OpenSSL cipher list APIs. 

32TLS13_CIPHERS = ":".join(_TLS13_SUITE_NAMES) 

33 

34 

35def build_server_ssl_context( 

36 cert_path: str, 

37 key_path: str, 

38 ca_bundle: str, 

39) -> ssl.SSLContext: 

40 """Create a TLS 1.3 server SSLContext that requires a client certificate. 

41 

42 The returned context can be mutated in-place via load_cert_chain() for 

43 zero-downtime certificate rotation. 

44 """ 

45 ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 

46 ctx.minimum_version = ssl.TLSVersion.TLSv1_3 

47 ctx.verify_mode = ssl.CERT_REQUIRED 

48 ctx.check_hostname = False # SAN verified by check_peer_san() at app layer 

49 ctx.load_cert_chain(cert_path, key_path) 

50 if ca_bundle: 50 ↛ 52line 50 didn't jump to line 52 because the condition on line 50 was always true

51 ctx.load_verify_locations(ca_bundle) 

52 return ctx 

53 

54 

55def build_client_ssl_context( 

56 cert_path: str, 

57 key_path: str, 

58 ca_bundle: str, 

59) -> ssl.SSLContext: 

60 """Create a TLS 1.3 client SSLContext that presents the node's cert. 

61 

62 Pass the returned context as ``verify=ctx`` to httpx.AsyncClient; httpx 

63 will present the loaded client cert during the TLS handshake. 

64 """ 

65 ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 

66 ctx.minimum_version = ssl.TLSVersion.TLSv1_3 

67 ctx.check_hostname = False # SAN verified by check_peer_san() at app layer 

68 ctx.verify_mode = ssl.CERT_REQUIRED 

69 ctx.load_cert_chain(cert_path, key_path) 

70 if ca_bundle: 70 ↛ 72line 70 didn't jump to line 72 because the condition on line 70 was always true

71 ctx.load_verify_locations(ca_bundle) 

72 return ctx 

73 

74 

75def reload_tls_cert( 

76 ctx: ssl.SSLContext, 

77 cert_path: str | None = None, 

78 key_path: str | None = None, 

79) -> None: 

80 """Hot-reload the certificate on *ctx* without restarting the server. 

81 

82 Existing TLS connections are unaffected; new handshakes pick up the new 

83 cert immediately. Raises ssl.SSLError if the new cert/key pair is invalid. 

84 """ 

85 cert = cert_path or settings.tls_cert_path 

86 key = key_path or settings.tls_key_path 

87 ctx.load_cert_chain(cert, key) 

88 logger.info("TLS certificate reloaded from %s", cert) 

89 

90 

91def check_peer_san(peer_cert: dict[str, Any], expected_entity_uri: str) -> bool: 

92 """Return True iff the peer's certificate contains entity_uri as a URI SAN. 

93 

94 peer_cert is the dict returned by ssl.SSLSocket.getpeercert(). 

95 Called after a successful TLS handshake to enforce §22.1.2.4. 

96 """ 

97 for kind, value in peer_cert.get("subjectAltName", ()): 

98 if kind == "URI" and value == expected_entity_uri: 

99 return True 

100 return False 

101 

102 

103async def cert_watcher_task(ctx: ssl.SSLContext, poll_interval: float = 5.0) -> None: 

104 """Async task: watch the cert file for changes and hot-reload on mtime delta. 

105 

106 Intended to run as an asyncio task alongside the uvicorn server. Cancels 

107 cleanly when the parent lifespan shuts down. 

108 """ 

109 path = Path(settings.tls_cert_path) 

110 try: 

111 last_mtime = path.stat().st_mtime 

112 except OSError: 

113 last_mtime = 0.0 

114 

115 while True: 

116 await asyncio.sleep(poll_interval) 

117 try: 

118 mtime = path.stat().st_mtime 

119 if mtime != last_mtime: 

120 logger.info("TLS cert file changed — reloading") 

121 try: 

122 reload_tls_cert(ctx) 

123 last_mtime = mtime 

124 except ssl.SSLError: 

125 logger.exception("TLS cert reload failed — keeping old cert") 

126 except asyncio.CancelledError: 

127 raise 

128 except Exception: 

129 logger.exception("Unexpected error in cert watcher")