diff --git a/responder/ext/ratelimit.py b/responder/ext/ratelimit.py index c846ea1..5cc868c 100644 --- a/responder/ext/ratelimit.py +++ b/responder/ext/ratelimit.py @@ -1,5 +1,6 @@ """Simple in-memory rate limiter for Responder.""" +import threading import time from collections import defaultdict @@ -28,6 +29,7 @@ class RateLimiter: self.max_requests = requests self.period = period self._buckets: dict[str, list[float]] = defaultdict(list) + self._lock = threading.Lock() def _client_key(self, req): client = req.client @@ -45,16 +47,19 @@ class RateLimiter: def check(self, req, resp): """Check rate limit. Sets 429 status if exceeded.""" key = self._client_key(req) - self._cleanup(key) - if len(self._buckets[key]) >= self.max_requests: - resp.status_code = 429 - resp.media = {"error": "rate limit exceeded"} - resp.headers["Retry-After"] = str(self.period) - return False + with self._lock: + self._cleanup(key) + + if len(self._buckets[key]) >= self.max_requests: + resp.status_code = 429 + resp.media = {"error": "rate limit exceeded"} + resp.headers["Retry-After"] = str(self.period) + return False + + self._buckets[key].append(time.time()) + remaining = self.max_requests - len(self._buckets[key]) - self._buckets[key].append(time.time()) - remaining = self.max_requests - len(self._buckets[key]) resp.headers["X-RateLimit-Limit"] = str(self.max_requests) resp.headers["X-RateLimit-Remaining"] = str(remaining) return True