207 lines
7.2 KiB
Python
207 lines
7.2 KiB
Python
import inspect
|
|
from typing import Callable, Iterable, Optional, Tuple
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.datastructures import MutableHeaders
|
|
from starlette.middleware.base import (
|
|
BaseHTTPMiddleware,
|
|
RequestResponseEndpoint,
|
|
)
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import BaseRoute, Match
|
|
from starlette.types import ASGIApp, Message, Scope, Receive, Send
|
|
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
|
|
|
|
def _find_route_handler(
|
|
routes: Iterable[BaseRoute], scope: Scope
|
|
) -> Optional[Callable]:
|
|
handler = None
|
|
for route in routes:
|
|
match, _ = route.matches(scope)
|
|
if match == Match.FULL and hasattr(route, "endpoint"):
|
|
handler = route.endpoint # type: ignore
|
|
return handler
|
|
|
|
|
|
def _get_route_name(handler: Callable):
|
|
return f"{handler.__module__}.{handler.__name__}"
|
|
|
|
|
|
def _check_limits(
|
|
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
|
) -> Tuple[Optional[Callable], bool, Optional[Exception]]:
|
|
"""
|
|
Utils to check (if needed) current requests limit.
|
|
It returns a tuple of size 3:
|
|
1. The exception handler to run, if needed
|
|
2. a bool, True if we need to inject some headers, False otherwise
|
|
3. the exception that happened, if any
|
|
"""
|
|
if limiter._auto_check and not getattr(
|
|
request.state, "_rate_limiting_complete", False
|
|
):
|
|
try:
|
|
limiter._check_request_limit(request, handler, True)
|
|
except Exception as e:
|
|
# handle the exception since the global exception handler won't pick it up if we call_next
|
|
exception_handler = app.exception_handlers.get(
|
|
type(e), _rate_limit_exceeded_handler
|
|
)
|
|
return exception_handler, False, e
|
|
|
|
return None, True, None
|
|
return None, False, None
|
|
|
|
|
|
def sync_check_limits(
|
|
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
|
) -> Tuple[Optional[Response], bool]:
|
|
"""
|
|
Returns a `Response` object if an error occurred, as well as a boolean to know
|
|
whether we should inject headers or not.
|
|
Used in our WSGI middleware, it only supports synchronous exception_handler.
|
|
This will fallback on _rate_limit_exceeded_handler otherwise.
|
|
"""
|
|
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
|
|
if not exception_handler or not exc:
|
|
return None, _bool
|
|
|
|
# cannot execute asynchronous code in a synchronous middleware,
|
|
# -> fallback on default exception handler
|
|
if inspect.iscoroutinefunction(exception_handler):
|
|
exception_handler = _rate_limit_exceeded_handler
|
|
|
|
return exception_handler(request, exc), _bool # type: ignore
|
|
|
|
|
|
async def async_check_limits(
|
|
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
|
) -> Tuple[Optional[Response], bool]:
|
|
"""
|
|
Returns a `Response` object if an error occurred, as well as a boolean to know
|
|
whether we should inject headers or not.
|
|
Used in our ASGI middleware, this support both synchronous or asynchronous exception handlers.
|
|
"""
|
|
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
|
|
if not exception_handler:
|
|
return None, _bool
|
|
|
|
if inspect.iscoroutinefunction(exception_handler):
|
|
return await exception_handler(request, exc), _bool
|
|
else:
|
|
return exception_handler(request, exc), _bool
|
|
|
|
|
|
def _should_exempt(limiter: Limiter, handler: Optional[Callable]) -> bool:
|
|
# if we can't find the route handler
|
|
if handler is None:
|
|
return True
|
|
|
|
name = _get_route_name(handler)
|
|
|
|
# if exempt no need to check
|
|
if name in limiter._exempt_routes:
|
|
return True
|
|
|
|
# there is a decorator for this route we let the decorator handle it
|
|
if name in limiter._route_limits:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class SlowAPIMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
app: Starlette = request.app
|
|
limiter: Limiter = app.state.limiter
|
|
|
|
if not limiter.enabled:
|
|
return await call_next(request)
|
|
|
|
handler = _find_route_handler(app.routes, request.scope)
|
|
if _should_exempt(limiter, handler):
|
|
return await call_next(request)
|
|
|
|
error_response, should_inject_headers = sync_check_limits(
|
|
limiter, request, handler, app
|
|
)
|
|
if error_response is not None:
|
|
return error_response
|
|
|
|
response = await call_next(request)
|
|
if should_inject_headers:
|
|
response = limiter._inject_headers(response, request.state.view_rate_limit)
|
|
return response
|
|
|
|
|
|
class SlowAPIASGIMiddleware:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http":
|
|
return await self.app(scope, receive, send)
|
|
|
|
await _ASGIMiddlewareResponder(self.app)(scope, receive, send)
|
|
|
|
|
|
class _ASGIMiddlewareResponder:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
self.error_response: Optional[Response] = None
|
|
self.initial_message: Message = {}
|
|
self.inject_headers = False
|
|
|
|
async def send_wrapper(self, message: Message) -> None:
|
|
if message["type"] == "http.response.start":
|
|
# do not send the http.response.start message now, so that we can edit the headers
|
|
# before sending it, based on what happens in the http.response.body message.
|
|
self.initial_message = message
|
|
|
|
elif message["type"] == "http.response.body":
|
|
if self.error_response:
|
|
self.initial_message["status"] = self.error_response.status_code
|
|
|
|
if self.inject_headers:
|
|
headers = MutableHeaders(raw=self.initial_message["headers"])
|
|
headers = self.limiter._inject_asgi_headers(
|
|
headers, self.request.state.view_rate_limit
|
|
)
|
|
|
|
# send the http.response.start message just before the http.response.body one,
|
|
# now that the headers are updated
|
|
await self.send(self.initial_message)
|
|
await self.send(message)
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
self.send = send
|
|
|
|
_app: Starlette = scope["app"]
|
|
limiter: Limiter = _app.state.limiter
|
|
|
|
if not limiter.enabled:
|
|
return await self.app(scope, receive, self.send)
|
|
|
|
handler = _find_route_handler(_app.routes, scope)
|
|
request = Request(scope, receive=receive, send=self.send)
|
|
if _should_exempt(limiter, handler):
|
|
return await self.app(scope, receive, self.send)
|
|
|
|
error_response, should_inject_headers = await async_check_limits(
|
|
limiter, request, handler, _app
|
|
)
|
|
if error_response is not None:
|
|
return await error_response(scope, receive, self.send_wrapper)
|
|
|
|
if should_inject_headers:
|
|
self.inject_headers = True
|
|
self.limiter = limiter
|
|
self.request = request
|
|
|
|
return await self.app(scope, receive, self.send_wrapper)
|