Coverage for node / src / stigmem_node / rate_limit.py: 92%
107 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-25 01:49 +0000
1"""Per-principal token-bucket quota middleware — spec §22.4.
3Replaces the legacy sliding-window implementation. Each principal
4(entity_uri resolved from the Bearer token) gets one token-bucket per quota
5dimension. Buckets are stored in the ``quota_buckets`` table and refilled
6lazily on each request.
8Dimension → endpoint mapping (spec §22.4.1):
9 fact_write POST /v1/facts, DELETE /v1/facts/*
10 fact_read GET /v1/facts/*, GET /v1/recall*
11 token_issue POST /v1/federation/capability-tokens
12 admin_action /v1/admin/*
13 audit_export GET /v1/admin/audit*
15Exemptions (same as legacy):
16 /v1/federation/ peer requests (use peer-token auth, not user API keys)
17 Requests without a Bearer token
19Backward-compat settings bridges:
20 STIGMEM_RATE_LIMIT_WRITE_PER_HOUR → fact_write burst capacity (default 100)
21 STIGMEM_RATE_LIMIT_READ_PER_HOUR → fact_read burst capacity (default 500)
22 0 on either setting disables rate limiting entirely.
23"""
25from __future__ import annotations
27import math
28import time
29from typing import Any
31from starlette.middleware.base import BaseHTTPMiddleware
32from starlette.requests import Request
33from starlette.responses import JSONResponse
35from .db import db
36from .settings import settings
38# ---------------------------------------------------------------------------
39# Default quota ceilings (spec §22.4.2)
40# ---------------------------------------------------------------------------
42_SPEC_DEFAULTS: dict[str, tuple[float, float]] = {
43 # dimension: (capacity, rate_per_second)
44 "fact_write": (100.0, 10.0),
45 "fact_read": (500.0, 50.0),
46 "token_issue": (20.0, 1 / 3),
47 "federation_pull": (30.0, 0.5),
48 "admin_action": (10.0, 1 / 6),
49 "subscription_event": (200.0, 20.0),
50 "audit_export": (10_000.0, 167.0),
51}
54def _capacity_for(dimension: str) -> float:
55 """Return burst capacity, honoring legacy per-hour settings for fact_write/read."""
56 if dimension == "fact_write":
57 cap = settings.rate_limit_write_per_hour
58 return float(cap) if cap > 0 else _SPEC_DEFAULTS["fact_write"][0]
59 if dimension == "fact_read": 59 ↛ 62line 59 didn't jump to line 62 because the condition on line 59 was always true
60 cap = settings.rate_limit_read_per_hour
61 return float(cap) if cap > 0 else _SPEC_DEFAULTS["fact_read"][0]
62 return _SPEC_DEFAULTS.get(dimension, (100.0, 1.0))[0]
65def _rate_for(dimension: str) -> float:
66 """Return refill rate (tokens/second)."""
67 if dimension == "fact_write":
68 cap = settings.rate_limit_write_per_hour
69 c = float(cap) if cap > 0 else _SPEC_DEFAULTS["fact_write"][0]
70 return c / 3600.0
71 if dimension == "fact_read": 71 ↛ 75line 71 didn't jump to line 75 because the condition on line 71 was always true
72 cap = settings.rate_limit_read_per_hour
73 c = float(cap) if cap > 0 else _SPEC_DEFAULTS["fact_read"][0]
74 return c / 3600.0
75 return _SPEC_DEFAULTS.get(dimension, (100.0, 1.0))[1]
78# ---------------------------------------------------------------------------
79# Endpoint → dimension routing
80# ---------------------------------------------------------------------------
83def _dimension(path: str, method: str) -> str | None:
84 """Return the quota dimension for this request, or None to skip quota."""
85 m = method.upper()
86 if path.startswith("/v1/admin/audit"): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 return "audit_export" if m == "GET" else "admin_action"
88 if path.startswith("/v1/admin/"): 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true
89 return "admin_action"
90 if path.startswith("/v1/federation/capability-tokens") and m == "POST": 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 return "token_issue"
92 if path.startswith("/v1/recall") or (path.startswith("/v1/facts") and m == "GET"):
93 return "fact_read"
94 if path.startswith("/v1/facts") and m in {"POST", "PUT", "PATCH", "DELETE"}:
95 return "fact_write"
96 return None
99# ---------------------------------------------------------------------------
100# Token-bucket check (SQLite upsert for atomic read-modify-write)
101# ---------------------------------------------------------------------------
104def _check_and_consume(
105 entity_uri: str,
106 tenant_id: str,
107 dimension: str,
108) -> tuple[bool, float]:
109 """Refill and consume one token. Returns (allowed, retry_after_seconds)."""
110 now = time.time()
111 capacity = _capacity_for(dimension)
112 rate = _rate_for(dimension)
114 with db() as conn:
115 conn.execute("BEGIN IMMEDIATE")
116 row = conn.execute(
117 "SELECT tokens, last_refill FROM quota_buckets"
118 " WHERE entity_uri=? AND tenant_id=? AND dimension=?",
119 (entity_uri, tenant_id, dimension),
120 ).fetchone()
122 if row is None:
123 tokens, last_refill = capacity, now
124 else:
125 tokens, last_refill = row["tokens"], row["last_refill"]
126 elapsed = max(0.0, now - last_refill)
127 tokens = min(capacity, tokens + elapsed * rate)
128 last_refill = now
130 if tokens >= 1.0:
131 new_tokens = tokens - 1.0
132 conn.execute(
133 """INSERT INTO quota_buckets (entity_uri, tenant_id, dimension, tokens, last_refill)
134 VALUES (?,?,?,?,?)
135 ON CONFLICT(entity_uri, tenant_id, dimension)
136 DO UPDATE SET tokens=excluded.tokens, last_refill=excluded.last_refill""",
137 (entity_uri, tenant_id, dimension, new_tokens, last_refill),
138 )
139 return True, 0.0
141 # Bucket empty — persist refilled (but not consumed) state
142 conn.execute(
143 """INSERT INTO quota_buckets (entity_uri, tenant_id, dimension, tokens, last_refill)
144 VALUES (?,?,?,?,?)
145 ON CONFLICT(entity_uri, tenant_id, dimension)
146 DO UPDATE SET tokens=excluded.tokens, last_refill=excluded.last_refill""",
147 (entity_uri, tenant_id, dimension, tokens, last_refill),
148 )
149 # retry_after: seconds until one token is earned
150 retry_after = (1.0 - tokens) / rate if rate > 0 else 1.0
151 return False, retry_after
154# ---------------------------------------------------------------------------
155# Identity lookup (lightweight — only entity_uri + tenant_id needed)
156# ---------------------------------------------------------------------------
158_HASH_CACHE: dict[
159 str, tuple[tuple[str, str, str | None], float]
160] = {} # raw-key fingerprint → (result, cached_at)
161_CACHE_TTL = 60.0
164def _lookup_principal(raw_key: str) -> tuple[str, str, str | None] | None:
165 """Return (entity_uri, tenant_id, oidc_sub) for the raw Bearer token, or None."""
166 import hashlib as _hl
168 fingerprint = _hl.sha256(raw_key.encode()).hexdigest()
169 if fingerprint in _HASH_CACHE:
170 result, cached_at = _HASH_CACHE[fingerprint]
171 if time.time() - cached_at < _CACHE_TTL:
172 return result
173 del _HASH_CACHE[fingerprint]
175 from fastapi import HTTPException
177 from .auth import lookup_principal
179 try:
180 principal = lookup_principal(raw_key)
181 except HTTPException:
182 return None
183 if principal is None:
184 return None
185 _HASH_CACHE[fingerprint] = (principal, time.time())
186 return principal
189# ---------------------------------------------------------------------------
190# Middleware
191# ---------------------------------------------------------------------------
194class RateLimitMiddleware(BaseHTTPMiddleware):
195 """Per-principal token-bucket rate limiting (spec §22.4).
197 Federation endpoints (/v1/federation/) are exempt — they use peer-token
198 auth, not user API keys. Requests without a Bearer token are also exempt.
199 Setting rate_limit_write_per_hour=0 AND rate_limit_read_per_hour=0
200 disables quota enforcement entirely (dev/test shortcut).
201 """
203 async def dispatch(self, request: Request, call_next: Any) -> Any:
204 if request.method == "OPTIONS": 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 return await call_next(request)
207 if request.url.path.startswith("/v1/federation/"):
208 return await call_next(request)
210 auth_header = request.headers.get("authorization", "")
211 if not auth_header.lower().startswith("bearer "):
212 return await call_next(request)
214 # Global kill-switch: both limits=0 disables enforcement.
215 if settings.rate_limit_write_per_hour == 0 and settings.rate_limit_read_per_hour == 0:
216 return await call_next(request)
218 raw_key = auth_header[7:]
219 principal = _lookup_principal(raw_key)
220 if principal is None:
221 # Unknown/expired key — let auth middleware reject it properly.
222 return await call_next(request)
224 entity_uri, tenant_id, oidc_sub = principal
225 dimension = _dimension(request.url.path, request.method)
226 if dimension is None:
227 return await call_next(request)
229 allowed, retry_after = _check_and_consume(entity_uri, tenant_id, dimension)
230 if not allowed:
231 # Write-ahead: emit quota_breach audit event before returning 429.
232 from .observability.audit_event import emit_nofail
234 emit_nofail(
235 "quota_breach",
236 entity_uri=entity_uri,
237 tenant_id=tenant_id,
238 oidc_sub=oidc_sub,
239 detail={
240 "dimension": dimension,
241 "path": request.url.path,
242 "method": request.method,
243 "retry_after": retry_after,
244 },
245 )
246 retry_ceil = math.ceil(retry_after)
247 return JSONResponse(
248 status_code=429,
249 content={
250 "error": "quota_exceeded",
251 "dimension": dimension,
252 "principal": entity_uri,
253 "retry_after": retry_after,
254 },
255 headers={"Retry-After": str(max(retry_ceil, 1))},
256 )
258 return await call_next(request)