Files
2025-04-24 11:44:23 +02:00

207 lines
7.2 KiB
Python

import inspect
from typing import Callable, Iterable, Optional, Tuple
from starlette.applications import Starlette
from starlette.datastructures import MutableHeaders
from starlette.middleware.base import (
BaseHTTPMiddleware,
RequestResponseEndpoint,
)
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Match
from starlette.types import ASGIApp, Message, Scope, Receive, Send
from slowapi import Limiter, _rate_limit_exceeded_handler
def _find_route_handler(
routes: Iterable[BaseRoute], scope: Scope
) -> Optional[Callable]:
handler = None
for route in routes:
match, _ = route.matches(scope)
if match == Match.FULL and hasattr(route, "endpoint"):
handler = route.endpoint # type: ignore
return handler
def _get_route_name(handler: Callable):
return f"{handler.__module__}.{handler.__name__}"
def _check_limits(
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
) -> Tuple[Optional[Callable], bool, Optional[Exception]]:
"""
Utils to check (if needed) current requests limit.
It returns a tuple of size 3:
1. The exception handler to run, if needed
2. a bool, True if we need to inject some headers, False otherwise
3. the exception that happened, if any
"""
if limiter._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
):
try:
limiter._check_request_limit(request, handler, True)
except Exception as e:
# handle the exception since the global exception handler won't pick it up if we call_next
exception_handler = app.exception_handlers.get(
type(e), _rate_limit_exceeded_handler
)
return exception_handler, False, e
return None, True, None
return None, False, None
def sync_check_limits(
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
) -> Tuple[Optional[Response], bool]:
"""
Returns a `Response` object if an error occurred, as well as a boolean to know
whether we should inject headers or not.
Used in our WSGI middleware, it only supports synchronous exception_handler.
This will fallback on _rate_limit_exceeded_handler otherwise.
"""
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
if not exception_handler or not exc:
return None, _bool
# cannot execute asynchronous code in a synchronous middleware,
# -> fallback on default exception handler
if inspect.iscoroutinefunction(exception_handler):
exception_handler = _rate_limit_exceeded_handler
return exception_handler(request, exc), _bool # type: ignore
async def async_check_limits(
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
) -> Tuple[Optional[Response], bool]:
"""
Returns a `Response` object if an error occurred, as well as a boolean to know
whether we should inject headers or not.
Used in our ASGI middleware, this support both synchronous or asynchronous exception handlers.
"""
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
if not exception_handler:
return None, _bool
if inspect.iscoroutinefunction(exception_handler):
return await exception_handler(request, exc), _bool
else:
return exception_handler(request, exc), _bool
def _should_exempt(limiter: Limiter, handler: Optional[Callable]) -> bool:
# if we can't find the route handler
if handler is None:
return True
name = _get_route_name(handler)
# if exempt no need to check
if name in limiter._exempt_routes:
return True
# there is a decorator for this route we let the decorator handle it
if name in limiter._route_limits:
return True
return False
class SlowAPIMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
app: Starlette = request.app
limiter: Limiter = app.state.limiter
if not limiter.enabled:
return await call_next(request)
handler = _find_route_handler(app.routes, request.scope)
if _should_exempt(limiter, handler):
return await call_next(request)
error_response, should_inject_headers = sync_check_limits(
limiter, request, handler, app
)
if error_response is not None:
return error_response
response = await call_next(request)
if should_inject_headers:
response = limiter._inject_headers(response, request.state.view_rate_limit)
return response
class SlowAPIASGIMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
await _ASGIMiddlewareResponder(self.app)(scope, receive, send)
class _ASGIMiddlewareResponder:
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.error_response: Optional[Response] = None
self.initial_message: Message = {}
self.inject_headers = False
async def send_wrapper(self, message: Message) -> None:
if message["type"] == "http.response.start":
# do not send the http.response.start message now, so that we can edit the headers
# before sending it, based on what happens in the http.response.body message.
self.initial_message = message
elif message["type"] == "http.response.body":
if self.error_response:
self.initial_message["status"] = self.error_response.status_code
if self.inject_headers:
headers = MutableHeaders(raw=self.initial_message["headers"])
headers = self.limiter._inject_asgi_headers(
headers, self.request.state.view_rate_limit
)
# send the http.response.start message just before the http.response.body one,
# now that the headers are updated
await self.send(self.initial_message)
await self.send(message)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
_app: Starlette = scope["app"]
limiter: Limiter = _app.state.limiter
if not limiter.enabled:
return await self.app(scope, receive, self.send)
handler = _find_route_handler(_app.routes, scope)
request = Request(scope, receive=receive, send=self.send)
if _should_exempt(limiter, handler):
return await self.app(scope, receive, self.send)
error_response, should_inject_headers = await async_check_limits(
limiter, request, handler, _app
)
if error_response is not None:
return await error_response(scope, receive, self.send_wrapper)
if should_inject_headers:
self.inject_headers = True
self.limiter = limiter
self.request = request
return await self.app(scope, receive, self.send_wrapper)