Update 2025-04-24_11:44:19
This commit is contained in:
3
venv/lib/python3.11/site-packages/slowapi/__init__.py
Normal file
3
venv/lib/python3.11/site-packages/slowapi/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .extension import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
__all__ = ["Limiter", "_rate_limit_exceeded_handler"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
27
venv/lib/python3.11/site-packages/slowapi/errors.py
Normal file
27
venv/lib/python3.11/site-packages/slowapi/errors.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
errors and exceptions
|
||||
"""
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from .wrappers import Limit
|
||||
|
||||
|
||||
class RateLimitExceeded(HTTPException):
|
||||
"""
|
||||
exception raised when a rate limit is hit.
|
||||
"""
|
||||
|
||||
limit = None
|
||||
|
||||
def __init__(self, limit: Limit) -> None:
|
||||
self.limit = limit
|
||||
if limit.error_message:
|
||||
description: str = (
|
||||
limit.error_message
|
||||
if not callable(limit.error_message)
|
||||
else limit.error_message()
|
||||
)
|
||||
else:
|
||||
description = str(limit.limit)
|
||||
super(RateLimitExceeded, self).__init__(status_code=429, detail=description)
|
885
venv/lib/python3.11/site-packages/slowapi/extension.py
Normal file
885
venv/lib/python3.11/site-packages/slowapi/extension.py
Normal file
@ -0,0 +1,885 @@
|
||||
"""
|
||||
The starlette extension to rate-limit requests
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from email.utils import formatdate, parsedate_to_datetime
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from limits import RateLimitItem # type: ignore
|
||||
from limits.errors import ConfigurationError # type: ignore
|
||||
from limits.storage import MemoryStorage, storage_from_string # type: ignore
|
||||
from limits.strategies import STRATEGIES, RateLimiter # type: ignore
|
||||
from starlette.config import Config
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .errors import RateLimitExceeded
|
||||
from .wrappers import Limit, LimitGroup
|
||||
|
||||
# used to annotate get_app_config method
|
||||
T = TypeVar("T")
|
||||
# Define an alias for the most commonly used type
|
||||
StrOrCallableStr = Union[str, Callable[..., str]]
|
||||
|
||||
|
||||
class C:
|
||||
ENABLED = "RATELIMIT_ENABLED"
|
||||
HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
|
||||
STORAGE_URL = "RATELIMIT_STORAGE_URL"
|
||||
STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
|
||||
STRATEGY = "RATELIMIT_STRATEGY"
|
||||
GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
|
||||
DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
|
||||
APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
|
||||
HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
|
||||
HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
|
||||
HEADER_RESET = "RATELIMIT_HEADER_RESET"
|
||||
SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
|
||||
IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
|
||||
IN_MEMORY_FALLBACK_ENABLED = "RATELIMIT_IN_MEMORY_FALLBACK_ENABLED"
|
||||
HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
|
||||
HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
|
||||
KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
|
||||
|
||||
|
||||
class HEADERS:
|
||||
RESET = 1
|
||||
REMAINING = 2
|
||||
LIMIT = 3
|
||||
RETRY_AFTER = 4
|
||||
|
||||
|
||||
MAX_BACKEND_CHECKS = 5
|
||||
|
||||
|
||||
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> Response:
|
||||
"""
|
||||
Build a simple JSON response that includes the details of the rate limit
|
||||
that was hit. If no limit is hit, the countdown is added to headers.
|
||||
"""
|
||||
response = JSONResponse(
|
||||
{"error": f"Rate limit exceeded: {exc.detail}"}, status_code=429
|
||||
)
|
||||
response = request.app.state.limiter._inject_headers(
|
||||
response, request.state.view_rate_limit
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class Limiter:
|
||||
"""
|
||||
Initializes the slowapi rate limiter.
|
||||
|
||||
** parameter **
|
||||
|
||||
* **app**: `Starlette/FastAPI` instance to initialize the extension
|
||||
with.
|
||||
|
||||
* **default_limits**: a variable list of strings or callables returning strings denoting global
|
||||
limits to apply to all routes. `ratelimit-string` for more details.
|
||||
|
||||
* **application_limits**: a variable list of strings or callables returning strings for limits that
|
||||
are applied to the entire application (i.e a shared limit for all routes)
|
||||
|
||||
* **key_func**: a callable that returns the domain to rate limit by.
|
||||
|
||||
* **headers_enabled**: whether ``X-RateLimit`` response headers are written.
|
||||
|
||||
* **strategy:** the strategy to use. refer to `ratelimit-strategy`
|
||||
|
||||
* **storage_uri**: the storage location. refer to `ratelimit-conf`
|
||||
|
||||
* **storage_options**: kwargs to pass to the storage implementation upon
|
||||
instantiation.
|
||||
* **auto_check**: whether to automatically check the rate limit in the before_request
|
||||
chain of the application. default ``True``
|
||||
* **swallow_errors**: whether to swallow errors when hitting a rate limit.
|
||||
An exception will still be logged. default ``False``
|
||||
* **in_memory_fallback**: a variable list of strings or callables returning strings denoting fallback
|
||||
limits to apply when the storage is down.
|
||||
* **in_memory_fallback_enabled**: simply falls back to in memory storage
|
||||
when the main storage is down and inherits the original limits.
|
||||
* **key_prefix**: prefix prepended to rate limiter keys.
|
||||
* **enabled**: set to False to deactivate the limiter (default: True)
|
||||
* **config_filename**: name of the config file for Starlette from which to load settings
|
||||
for the rate limiter. Defaults to ".env".
|
||||
* **key_style**: set to "url" to use the url, "endpoint" to use the view_func
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# app: Starlette = None,
|
||||
key_func: Callable[..., str],
|
||||
default_limits: List[StrOrCallableStr] = [],
|
||||
application_limits: List[StrOrCallableStr] = [],
|
||||
headers_enabled: bool = False,
|
||||
strategy: Optional[str] = None,
|
||||
storage_uri: Optional[str] = None,
|
||||
storage_options: Dict[str, str] = {},
|
||||
auto_check: bool = True,
|
||||
swallow_errors: bool = False,
|
||||
in_memory_fallback: List[StrOrCallableStr] = [],
|
||||
in_memory_fallback_enabled: bool = False,
|
||||
retry_after: Optional[str] = None,
|
||||
key_prefix: str = "",
|
||||
enabled: bool = True,
|
||||
config_filename: Optional[str] = None,
|
||||
key_style: Literal["endpoint", "url"] = "url",
|
||||
) -> None:
|
||||
"""
|
||||
Configure the rate limiter at app level
|
||||
"""
|
||||
# assert app is not None, "Passing the app instance to the limiter is required"
|
||||
# self.app = app
|
||||
# app.state.limiter = self
|
||||
|
||||
self.logger = logging.getLogger("slowapi")
|
||||
|
||||
dotenv_file_exists = os.path.isfile(".env")
|
||||
self.app_config = Config(
|
||||
".env"
|
||||
if dotenv_file_exists and config_filename is None
|
||||
else config_filename
|
||||
)
|
||||
|
||||
self.enabled = enabled
|
||||
self._default_limits = []
|
||||
self._application_limits = []
|
||||
self._in_memory_fallback: List[LimitGroup] = []
|
||||
self._in_memory_fallback_enabled = (
|
||||
in_memory_fallback_enabled or len(in_memory_fallback) > 0
|
||||
)
|
||||
self._exempt_routes: Set[str] = set()
|
||||
self._request_filters: List[Callable[..., bool]] = []
|
||||
self._headers_enabled = headers_enabled
|
||||
self._header_mapping: Dict[int, str] = {}
|
||||
self._retry_after: Optional[str] = retry_after
|
||||
self._strategy = strategy
|
||||
self._storage_uri = storage_uri
|
||||
self._storage_options = storage_options
|
||||
self._auto_check = auto_check
|
||||
self._swallow_errors = swallow_errors
|
||||
|
||||
self._key_func = key_func
|
||||
self._key_prefix = key_prefix
|
||||
self._key_style = key_style
|
||||
|
||||
for limit in set(default_limits):
|
||||
self._default_limits.extend(
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, 1, False
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in application_limits:
|
||||
self._application_limits.extend(
|
||||
[
|
||||
LimitGroup(
|
||||
limit,
|
||||
self._key_func,
|
||||
"global",
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
False,
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in in_memory_fallback:
|
||||
self._in_memory_fallback.extend(
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, 1, False
|
||||
)
|
||||
]
|
||||
)
|
||||
self._route_limits: Dict[str, List[Limit]] = {}
|
||||
self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
|
||||
# a flag to note if the storage backend is dead (not available)
|
||||
self._storage_dead: bool = False
|
||||
self._fallback_limiter = None
|
||||
self.__check_backend_count = 0
|
||||
self.__last_check_backend = time.time()
|
||||
self.__marked_for_limiting: Dict[str, List[Callable]] = {}
|
||||
|
||||
class BlackHoleHandler(logging.StreamHandler):
|
||||
def emit(*_):
|
||||
return
|
||||
|
||||
self.logger.addHandler(BlackHoleHandler())
|
||||
|
||||
self.enabled = self.get_app_config(C.ENABLED, self.enabled)
|
||||
self._swallow_errors = self.get_app_config(
|
||||
C.SWALLOW_ERRORS, self._swallow_errors
|
||||
)
|
||||
self._headers_enabled = self._headers_enabled or self.get_app_config(
|
||||
C.HEADERS_ENABLED, False
|
||||
)
|
||||
self._storage_options.update(self.get_app_config(C.STORAGE_OPTIONS, {}))
|
||||
self._storage = storage_from_string(
|
||||
self._storage_uri or self.get_app_config(C.STORAGE_URL, "memory://"),
|
||||
**self._storage_options,
|
||||
)
|
||||
strategy = self._strategy or self.get_app_config(C.STRATEGY, "fixed-window")
|
||||
if strategy not in STRATEGIES:
|
||||
raise ConfigurationError("Invalid rate limiting strategy %s" % strategy)
|
||||
self._limiter: RateLimiter = STRATEGIES[strategy](self._storage)
|
||||
self._header_mapping.update(
|
||||
{
|
||||
HEADERS.RESET: self._header_mapping.get(
|
||||
HEADERS.RESET,
|
||||
self.get_app_config(C.HEADER_RESET, "X-RateLimit-Reset"),
|
||||
),
|
||||
HEADERS.REMAINING: self._header_mapping.get(
|
||||
HEADERS.REMAINING,
|
||||
self.get_app_config(C.HEADER_REMAINING, "X-RateLimit-Remaining"),
|
||||
),
|
||||
HEADERS.LIMIT: self._header_mapping.get(
|
||||
HEADERS.LIMIT,
|
||||
self.get_app_config(C.HEADER_LIMIT, "X-RateLimit-Limit"),
|
||||
),
|
||||
HEADERS.RETRY_AFTER: self._header_mapping.get(
|
||||
HEADERS.RETRY_AFTER,
|
||||
self.get_app_config(C.HEADER_RETRY_AFTER, "Retry-After"),
|
||||
),
|
||||
}
|
||||
)
|
||||
self._retry_after = self._retry_after or self.get_app_config(
|
||||
C.HEADER_RETRY_AFTER_VALUE
|
||||
)
|
||||
self._key_prefix = self._key_prefix or self.get_app_config(C.KEY_PREFIX)
|
||||
app_limits: Optional[StrOrCallableStr] = self.get_app_config(
|
||||
C.APPLICATION_LIMITS, None
|
||||
)
|
||||
if not self._application_limits and app_limits:
|
||||
self._application_limits = [
|
||||
LimitGroup(
|
||||
app_limits,
|
||||
self._key_func,
|
||||
"global",
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
False,
|
||||
)
|
||||
]
|
||||
|
||||
conf_limits: Optional[StrOrCallableStr] = self.get_app_config(
|
||||
C.DEFAULT_LIMITS, None
|
||||
)
|
||||
if not self._default_limits and conf_limits:
|
||||
self._default_limits = [
|
||||
LimitGroup(
|
||||
conf_limits, self._key_func, None, False, None, None, None, 1, False
|
||||
)
|
||||
]
|
||||
fallback_enabled = self.get_app_config(C.IN_MEMORY_FALLBACK_ENABLED, False)
|
||||
fallback_limits: Optional[StrOrCallableStr] = self.get_app_config(
|
||||
C.IN_MEMORY_FALLBACK, None
|
||||
)
|
||||
if not self._in_memory_fallback and fallback_limits:
|
||||
self._in_memory_fallback = [
|
||||
LimitGroup(
|
||||
fallback_limits,
|
||||
self._key_func,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
False,
|
||||
)
|
||||
]
|
||||
if not self._in_memory_fallback_enabled:
|
||||
self._in_memory_fallback_enabled = (
|
||||
fallback_enabled or len(self._in_memory_fallback) > 0
|
||||
)
|
||||
|
||||
if self._in_memory_fallback_enabled:
|
||||
self._fallback_storage = MemoryStorage()
|
||||
self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage)
|
||||
|
||||
def slowapi_startup(self) -> None:
|
||||
"""
|
||||
Starlette startup event handler that links the app with the Limiter instance.
|
||||
"""
|
||||
app.state.limiter = self # type: ignore
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore
|
||||
|
||||
def get_app_config(self, key: str, default_value: T = None) -> T:
|
||||
"""
|
||||
Place holder until we find a better way to load config from app
|
||||
"""
|
||||
return (
|
||||
self.app_config(key, default=default_value, cast=type(default_value))
|
||||
if default_value
|
||||
else self.app_config(key, default=default_value)
|
||||
)
|
||||
|
||||
def __should_check_backend(self) -> bool:
|
||||
if self.__check_backend_count > MAX_BACKEND_CHECKS:
|
||||
self.__check_backend_count = 0
|
||||
if time.time() - self.__last_check_backend > pow(2, self.__check_backend_count):
|
||||
self.__last_check_backend = time.time()
|
||||
self.__check_backend_count += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
resets the storage if it supports being reset
|
||||
"""
|
||||
try:
|
||||
self._storage.reset()
|
||||
self.logger.info("Storage has been reset and all limits cleared")
|
||||
except NotImplementedError:
|
||||
self.logger.warning("This storage type does not support being reset")
|
||||
|
||||
@property
|
||||
def limiter(self) -> RateLimiter:
|
||||
"""
|
||||
The backend that keeps track of consumption of endpoints vs limits
|
||||
"""
|
||||
if self._storage_dead and self._in_memory_fallback_enabled:
|
||||
assert (
|
||||
self._fallback_limiter
|
||||
), "Fallback limiter is needed when in memory fallback is enabled"
|
||||
return self._fallback_limiter
|
||||
else:
|
||||
return self._limiter
|
||||
|
||||
def _inject_headers(
|
||||
self, response: Response, current_limit: Tuple[RateLimitItem, List[str]]
|
||||
) -> Response:
|
||||
if self.enabled and self._headers_enabled and current_limit is not None:
|
||||
if not isinstance(response, Response):
|
||||
raise Exception(
|
||||
"parameter `response` must be an instance of starlette.responses.Response"
|
||||
)
|
||||
try:
|
||||
window_stats: Tuple[int, int] = self.limiter.get_window_stats(
|
||||
current_limit[0], *current_limit[1]
|
||||
)
|
||||
reset_in = 1 + window_stats[0]
|
||||
response.headers.append(
|
||||
self._header_mapping[HEADERS.LIMIT], str(current_limit[0].amount)
|
||||
)
|
||||
response.headers.append(
|
||||
self._header_mapping[HEADERS.REMAINING], str(window_stats[1])
|
||||
)
|
||||
response.headers.append(
|
||||
self._header_mapping[HEADERS.RESET], str(reset_in)
|
||||
)
|
||||
|
||||
# response may have an existing retry after
|
||||
existing_retry_after_header = response.headers.get("Retry-After")
|
||||
|
||||
if existing_retry_after_header is not None:
|
||||
reset_in = max(
|
||||
self._determine_retry_time(existing_retry_after_header),
|
||||
reset_in,
|
||||
)
|
||||
|
||||
response.headers[self._header_mapping[HEADERS.RETRY_AFTER]] = (
|
||||
formatdate(reset_in)
|
||||
if self._retry_after == "http-date"
|
||||
else str(int(reset_in - time.time()))
|
||||
)
|
||||
except:
|
||||
if self._in_memory_fallback and not self._storage_dead:
|
||||
self.logger.warning(
|
||||
"Rate limit storage unreachable - falling back to"
|
||||
" in-memory storage"
|
||||
)
|
||||
self._storage_dead = True
|
||||
response = self._inject_headers(response, current_limit)
|
||||
if self._swallow_errors:
|
||||
self.logger.exception(
|
||||
"Failed to update rate limit headers. Swallowing error"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
return response
|
||||
|
||||
def _inject_asgi_headers(
|
||||
self, headers: MutableHeaders, current_limit: Tuple[RateLimitItem, List[str]]
|
||||
) -> MutableHeaders:
|
||||
"""
|
||||
Injects 'X-RateLimit-Reset', 'X-RateLimit-Remaining', 'X-RateLimit-Limit'
|
||||
and 'Retry-After' headers into :headers parameter if needed.
|
||||
|
||||
Basically the same as _inject_headers, but without access to the Response object.
|
||||
-> supports ASGI Middlewares.
|
||||
"""
|
||||
if self.enabled and self._headers_enabled and current_limit is not None:
|
||||
try:
|
||||
window_stats: Tuple[int, int] = self.limiter.get_window_stats(
|
||||
current_limit[0], *current_limit[1]
|
||||
)
|
||||
reset_in = 1 + window_stats[0]
|
||||
headers[self._header_mapping[HEADERS.LIMIT]] = str(
|
||||
current_limit[0].amount
|
||||
)
|
||||
headers[self._header_mapping[HEADERS.REMAINING]] = str(window_stats[1])
|
||||
headers[self._header_mapping[HEADERS.RESET]] = str(reset_in)
|
||||
|
||||
# response may have an existing retry after
|
||||
existing_retry_after_header = headers.get("Retry-After")
|
||||
|
||||
if existing_retry_after_header is not None:
|
||||
reset_in = max(
|
||||
self._determine_retry_time(existing_retry_after_header),
|
||||
reset_in,
|
||||
)
|
||||
|
||||
headers[self._header_mapping[HEADERS.RETRY_AFTER]] = (
|
||||
formatdate(reset_in)
|
||||
if self._retry_after == "http-date"
|
||||
else str(int(reset_in - time.time()))
|
||||
)
|
||||
except Exception:
|
||||
if self._in_memory_fallback and not self._storage_dead:
|
||||
self.logger.warning(
|
||||
"Rate limit storage unreachable - falling back to"
|
||||
" in-memory storage"
|
||||
)
|
||||
self._storage_dead = True
|
||||
headers = self._inject_asgi_headers(headers, current_limit)
|
||||
if self._swallow_errors:
|
||||
self.logger.exception(
|
||||
"Failed to update rate limit headers. Swallowing error"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
return headers
|
||||
|
||||
def __evaluate_limits(
|
||||
self, request: Request, endpoint: str, limits: List[Limit]
|
||||
) -> None:
|
||||
failed_limit = None
|
||||
limit_for_header = None
|
||||
for lim in limits:
|
||||
limit_scope = lim.scope or endpoint
|
||||
if lim.is_exempt:
|
||||
continue
|
||||
if lim.methods is not None and request.method.lower() not in lim.methods:
|
||||
continue
|
||||
if lim.per_method:
|
||||
limit_scope += ":%s" % request.method
|
||||
|
||||
if "request" in inspect.signature(lim.key_func).parameters.keys():
|
||||
limit_key = lim.key_func(request)
|
||||
else:
|
||||
limit_key = lim.key_func()
|
||||
|
||||
args = [limit_key, limit_scope]
|
||||
if all(args):
|
||||
if self._key_prefix:
|
||||
args = [self._key_prefix] + args
|
||||
if not limit_for_header or lim.limit < limit_for_header[0]:
|
||||
limit_for_header = (lim.limit, args)
|
||||
|
||||
cost = lim.cost(request) if callable(lim.cost) else lim.cost
|
||||
if not self.limiter.hit(lim.limit, *args, cost=cost):
|
||||
self.logger.warning(
|
||||
"ratelimit %s (%s) exceeded at endpoint: %s",
|
||||
lim.limit,
|
||||
limit_key,
|
||||
limit_scope,
|
||||
)
|
||||
failed_limit = lim
|
||||
limit_for_header = (lim.limit, args)
|
||||
break
|
||||
else:
|
||||
self.logger.error(
|
||||
"Skipping limit: %s. Empty value found in parameters.", lim.limit
|
||||
)
|
||||
continue
|
||||
# keep track of which limit was hit, to be picked up for the response header
|
||||
request.state.view_rate_limit = limit_for_header
|
||||
|
||||
if failed_limit:
|
||||
raise RateLimitExceeded(failed_limit)
|
||||
|
||||
def _determine_retry_time(self, retry_header_value) -> int:
|
||||
try:
|
||||
retry_after_date: Optional[datetime] = parsedate_to_datetime(
|
||||
retry_header_value
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
retry_after_date = None
|
||||
|
||||
if retry_after_date is not None:
|
||||
return int(time.mktime(retry_after_date.timetuple()))
|
||||
|
||||
try:
|
||||
retry_after_int: int = int(retry_header_value)
|
||||
except TypeError:
|
||||
raise ValueError(
|
||||
"Retry-After Header does not meet RFC2616 - value is not of http-date or int type."
|
||||
)
|
||||
|
||||
return int(time.time() + retry_after_int)
|
||||
|
||||
def _check_request_limit(
|
||||
self,
|
||||
request: Request,
|
||||
endpoint_func: Optional[Callable[..., Any]],
|
||||
in_middleware: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Determine if the request is within limits
|
||||
"""
|
||||
endpoint_url = request["path"] or ""
|
||||
view_func = endpoint_func
|
||||
|
||||
endpoint_func_name = (
|
||||
f"{view_func.__module__}.{view_func.__name__}" if view_func else ""
|
||||
)
|
||||
_endpoint_key = endpoint_url if self._key_style == "url" else endpoint_func_name
|
||||
# cases where we don't need to check the limits
|
||||
if (
|
||||
not _endpoint_key
|
||||
or not self.enabled
|
||||
# or we are sending a static file
|
||||
# or view_func == current_app.send_static_file
|
||||
or endpoint_func_name in self._exempt_routes
|
||||
or any(fn() for fn in self._request_filters)
|
||||
):
|
||||
return
|
||||
limits: List[Limit] = []
|
||||
dynamic_limits: List[Limit] = []
|
||||
|
||||
if not in_middleware:
|
||||
limits = (
|
||||
self._route_limits[endpoint_func_name]
|
||||
if endpoint_func_name in self._route_limits
|
||||
else []
|
||||
)
|
||||
dynamic_limits = []
|
||||
if endpoint_func_name in self._dynamic_route_limits:
|
||||
for lim in self._dynamic_route_limits[endpoint_func_name]:
|
||||
try:
|
||||
dynamic_limits.extend(list(lim.with_request(request)))
|
||||
except ValueError as e:
|
||||
self.logger.error(
|
||||
"failed to load ratelimit for view function %s (%s)",
|
||||
endpoint_func_name,
|
||||
e,
|
||||
)
|
||||
|
||||
try:
|
||||
all_limits: List[Limit] = []
|
||||
if self._storage_dead and self._fallback_limiter:
|
||||
if in_middleware and endpoint_func_name in self.__marked_for_limiting:
|
||||
pass
|
||||
else:
|
||||
if self.__should_check_backend() and self._storage.check():
|
||||
self.logger.info("Rate limit storage recovered")
|
||||
self._storage_dead = False
|
||||
self.__check_backend_count = 0
|
||||
else:
|
||||
all_limits = list(itertools.chain(*self._in_memory_fallback))
|
||||
if not all_limits:
|
||||
route_limits: List[Limit] = limits + dynamic_limits
|
||||
all_limits = (
|
||||
list(itertools.chain(*self._application_limits))
|
||||
if in_middleware
|
||||
else []
|
||||
)
|
||||
all_limits += route_limits
|
||||
combined_defaults = all(
|
||||
not limit.override_defaults for limit in route_limits
|
||||
)
|
||||
if (
|
||||
not route_limits
|
||||
and not (
|
||||
in_middleware
|
||||
and endpoint_func_name in self.__marked_for_limiting
|
||||
)
|
||||
or combined_defaults
|
||||
):
|
||||
all_limits += list(itertools.chain(*self._default_limits))
|
||||
# actually check the limits, so far we've only computed the list of limits to check
|
||||
self.__evaluate_limits(request, _endpoint_key, all_limits)
|
||||
except Exception as e: # no qa
|
||||
if isinstance(e, RateLimitExceeded):
|
||||
raise
|
||||
if self._in_memory_fallback_enabled and not self._storage_dead:
|
||||
self.logger.warn(
|
||||
"Rate limit storage unreachable - falling back to"
|
||||
" in-memory storage"
|
||||
)
|
||||
self._storage_dead = True
|
||||
self._check_request_limit(request, endpoint_func, in_middleware)
|
||||
else:
|
||||
if self._swallow_errors:
|
||||
self.logger.exception("Failed to rate limit. Swallowing error")
|
||||
else:
|
||||
raise
|
||||
|
||||
def __limit_decorator(
|
||||
self,
|
||||
limit_value: StrOrCallableStr,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
shared: bool = False,
|
||||
scope: Optional[StrOrCallableStr] = None,
|
||||
per_method: bool = False,
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
cost: Union[int, Callable[..., int]] = 1,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable[..., Any]:
|
||||
_scope = scope if shared else None
|
||||
|
||||
def decorator(func: Callable[..., Response]):
|
||||
keyfunc = key_func or self._key_func
|
||||
name = f"{func.__module__}.{func.__name__}"
|
||||
dynamic_limit = None
|
||||
static_limits: List[Limit] = []
|
||||
if callable(limit_value):
|
||||
dynamic_limit = LimitGroup(
|
||||
limit_value,
|
||||
keyfunc,
|
||||
_scope,
|
||||
per_method,
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
cost,
|
||||
override_defaults,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
static_limits = list(
|
||||
LimitGroup(
|
||||
limit_value,
|
||||
keyfunc,
|
||||
_scope,
|
||||
per_method,
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
cost,
|
||||
override_defaults,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
self.logger.error(
|
||||
"Failed to configure throttling for %s (%s)",
|
||||
name,
|
||||
e,
|
||||
)
|
||||
self.__marked_for_limiting.setdefault(name, []).append(func)
|
||||
if dynamic_limit:
|
||||
self._dynamic_route_limits.setdefault(name, []).append(dynamic_limit)
|
||||
else:
|
||||
self._route_limits.setdefault(name, []).extend(static_limits)
|
||||
|
||||
connection_type: Optional[str] = None
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
connection_type = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
|
||||
# get the request object from the decorated endpoint function
|
||||
if self.enabled:
|
||||
request = kwargs.get("request", args[idx] if args else None)
|
||||
if not isinstance(request, Request):
|
||||
raise Exception(
|
||||
"parameter `request` must be an instance of starlette.requests.Request"
|
||||
)
|
||||
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = await func(*args, **kwargs) # type: ignore
|
||||
if self.enabled:
|
||||
if not isinstance(response, Response):
|
||||
# get the response object from the decorated endpoint function
|
||||
self._inject_headers(
|
||||
kwargs.get("response"), request.state.view_rate_limit # type: ignore
|
||||
)
|
||||
else:
|
||||
self._inject_headers(
|
||||
response, request.state.view_rate_limit
|
||||
)
|
||||
return response
|
||||
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
|
||||
# get the request object from the decorated endpoint function
|
||||
if self.enabled:
|
||||
request = kwargs.get("request", args[idx] if args else None)
|
||||
if not isinstance(request, Request):
|
||||
raise Exception(
|
||||
"parameter `request` must be an instance of starlette.requests.Request"
|
||||
)
|
||||
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = func(*args, **kwargs)
|
||||
if self.enabled:
|
||||
if not isinstance(response, Response):
|
||||
# get the response object from the decorated endpoint function
|
||||
self._inject_headers(
|
||||
kwargs.get("response"), request.state.view_rate_limit # type: ignore
|
||||
)
|
||||
else:
|
||||
self._inject_headers(
|
||||
response, request.state.view_rate_limit
|
||||
)
|
||||
return response
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def limit(
|
||||
self,
|
||||
limit_value: StrOrCallableStr,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
per_method: bool = False,
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
cost: Union[int, Callable[..., int]] = 1,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to be used for rate limiting individual routes.
|
||||
|
||||
* **limit_value**: rate limit string or a callable that returns a string.
|
||||
:ref:`ratelimit-string` for more details.
|
||||
* **key_func**: function/lambda to extract the unique identifier for
|
||||
the rate limit. defaults to remote address of the request.
|
||||
* **per_method**: whether the limit is sub categorized into the http
|
||||
method of the request.
|
||||
* **methods**: if specified, only the methods in this list will be rate
|
||||
limited (default: None).
|
||||
* **error_message**: string (or callable that returns one) to override the
|
||||
error message used in the response.
|
||||
* **exempt_when**: function returning a boolean indicating whether to exempt
|
||||
the route from the limit
|
||||
* **cost**: integer (or callable that returns one) which is the cost of a hit
|
||||
* **override_defaults**: whether to override the default limits (default: True)
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
limit_value,
|
||||
key_func,
|
||||
per_method=per_method,
|
||||
methods=methods,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
cost=cost,
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def shared_limit(
|
||||
self,
|
||||
limit_value: StrOrCallableStr,
|
||||
scope: StrOrCallableStr,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
cost: Union[int, Callable[..., int]] = 1,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to be applied to multiple routes sharing the same rate limit.
|
||||
|
||||
* **limit_value**: rate limit string or a callable that returns a string.
|
||||
:ref:`ratelimit-string` for more details.
|
||||
* **scope**: a string or callable that returns a string
|
||||
for defining the rate limiting scope.
|
||||
* **key_func**: function/lambda to extract the unique identifier for
|
||||
the rate limit. defaults to remote address of the request.
|
||||
* **per_method**: whether the limit is sub categorized into the http
|
||||
method of the request.
|
||||
* **methods**: if specified, only the methods in this list will be rate
|
||||
limited (default: None).
|
||||
* **error_message**: string (or callable that returns one) to override the
|
||||
error message used in the response.
|
||||
* **exempt_when**: function returning a boolean indicating whether to exempt
|
||||
the route from the limit
|
||||
* **cost**: integer (or callable that returns one) which is the cost of a hit
|
||||
* **override_defaults**: whether to override the default limits (default: True)
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
limit_value,
|
||||
key_func,
|
||||
True,
|
||||
scope,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
cost=cost,
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def exempt(self, obj):
|
||||
"""
|
||||
Decorator to mark a view as exempt from rate limits.
|
||||
"""
|
||||
name = "%s.%s" % (obj.__module__, obj.__name__)
|
||||
|
||||
self._exempt_routes.add(name)
|
||||
|
||||
if asyncio.iscoroutinefunction(obj):
|
||||
|
||||
@wraps(obj)
|
||||
async def __async_inner(*a, **k):
|
||||
return await obj(*a, **k)
|
||||
|
||||
return __async_inner
|
||||
else:
|
||||
|
||||
@wraps(obj)
|
||||
def __inner(*a, **k):
|
||||
return obj(*a, **k)
|
||||
|
||||
return __inner
|
206
venv/lib/python3.11/site-packages/slowapi/middleware.py
Normal file
206
venv/lib/python3.11/site-packages/slowapi/middleware.py
Normal file
@ -0,0 +1,206 @@
|
||||
import inspect
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.middleware.base import (
|
||||
BaseHTTPMiddleware,
|
||||
RequestResponseEndpoint,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Match
|
||||
from starlette.types import ASGIApp, Message, Scope, Receive, Send
|
||||
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
|
||||
def _find_route_handler(
|
||||
routes: Iterable[BaseRoute], scope: Scope
|
||||
) -> Optional[Callable]:
|
||||
handler = None
|
||||
for route in routes:
|
||||
match, _ = route.matches(scope)
|
||||
if match == Match.FULL and hasattr(route, "endpoint"):
|
||||
handler = route.endpoint # type: ignore
|
||||
return handler
|
||||
|
||||
|
||||
def _get_route_name(handler: Callable):
|
||||
return f"{handler.__module__}.{handler.__name__}"
|
||||
|
||||
|
||||
def _check_limits(
|
||||
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
||||
) -> Tuple[Optional[Callable], bool, Optional[Exception]]:
|
||||
"""
|
||||
Utils to check (if needed) current requests limit.
|
||||
It returns a tuple of size 3:
|
||||
1. The exception handler to run, if needed
|
||||
2. a bool, True if we need to inject some headers, False otherwise
|
||||
3. the exception that happened, if any
|
||||
"""
|
||||
if limiter._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
try:
|
||||
limiter._check_request_limit(request, handler, True)
|
||||
except Exception as e:
|
||||
# handle the exception since the global exception handler won't pick it up if we call_next
|
||||
exception_handler = app.exception_handlers.get(
|
||||
type(e), _rate_limit_exceeded_handler
|
||||
)
|
||||
return exception_handler, False, e
|
||||
|
||||
return None, True, None
|
||||
return None, False, None
|
||||
|
||||
|
||||
def sync_check_limits(
|
||||
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
||||
) -> Tuple[Optional[Response], bool]:
|
||||
"""
|
||||
Returns a `Response` object if an error occurred, as well as a boolean to know
|
||||
whether we should inject headers or not.
|
||||
Used in our WSGI middleware, it only supports synchronous exception_handler.
|
||||
This will fallback on _rate_limit_exceeded_handler otherwise.
|
||||
"""
|
||||
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
|
||||
if not exception_handler or not exc:
|
||||
return None, _bool
|
||||
|
||||
# cannot execute asynchronous code in a synchronous middleware,
|
||||
# -> fallback on default exception handler
|
||||
if inspect.iscoroutinefunction(exception_handler):
|
||||
exception_handler = _rate_limit_exceeded_handler
|
||||
|
||||
return exception_handler(request, exc), _bool # type: ignore
|
||||
|
||||
|
||||
async def async_check_limits(
|
||||
limiter: Limiter, request: Request, handler: Optional[Callable], app: Starlette
|
||||
) -> Tuple[Optional[Response], bool]:
|
||||
"""
|
||||
Returns a `Response` object if an error occurred, as well as a boolean to know
|
||||
whether we should inject headers or not.
|
||||
Used in our ASGI middleware, this support both synchronous or asynchronous exception handlers.
|
||||
"""
|
||||
exception_handler, _bool, exc = _check_limits(limiter, request, handler, app)
|
||||
if not exception_handler:
|
||||
return None, _bool
|
||||
|
||||
if inspect.iscoroutinefunction(exception_handler):
|
||||
return await exception_handler(request, exc), _bool
|
||||
else:
|
||||
return exception_handler(request, exc), _bool
|
||||
|
||||
|
||||
def _should_exempt(limiter: Limiter, handler: Optional[Callable]) -> bool:
|
||||
# if we can't find the route handler
|
||||
if handler is None:
|
||||
return True
|
||||
|
||||
name = _get_route_name(handler)
|
||||
|
||||
# if exempt no need to check
|
||||
if name in limiter._exempt_routes:
|
||||
return True
|
||||
|
||||
# there is a decorator for this route we let the decorator handle it
|
||||
if name in limiter._route_limits:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class SlowAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
app: Starlette = request.app
|
||||
limiter: Limiter = app.state.limiter
|
||||
|
||||
if not limiter.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
handler = _find_route_handler(app.routes, request.scope)
|
||||
if _should_exempt(limiter, handler):
|
||||
return await call_next(request)
|
||||
|
||||
error_response, should_inject_headers = sync_check_limits(
|
||||
limiter, request, handler, app
|
||||
)
|
||||
if error_response is not None:
|
||||
return error_response
|
||||
|
||||
response = await call_next(request)
|
||||
if should_inject_headers:
|
||||
response = limiter._inject_headers(response, request.state.view_rate_limit)
|
||||
return response
|
||||
|
||||
|
||||
class SlowAPIASGIMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
await _ASGIMiddlewareResponder(self.app)(scope, receive, send)
|
||||
|
||||
|
||||
class _ASGIMiddlewareResponder:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
self.error_response: Optional[Response] = None
|
||||
self.initial_message: Message = {}
|
||||
self.inject_headers = False
|
||||
|
||||
async def send_wrapper(self, message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# do not send the http.response.start message now, so that we can edit the headers
|
||||
# before sending it, based on what happens in the http.response.body message.
|
||||
self.initial_message = message
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
if self.error_response:
|
||||
self.initial_message["status"] = self.error_response.status_code
|
||||
|
||||
if self.inject_headers:
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers = self.limiter._inject_asgi_headers(
|
||||
headers, self.request.state.view_rate_limit
|
||||
)
|
||||
|
||||
# send the http.response.start message just before the http.response.body one,
|
||||
# now that the headers are updated
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
|
||||
_app: Starlette = scope["app"]
|
||||
limiter: Limiter = _app.state.limiter
|
||||
|
||||
if not limiter.enabled:
|
||||
return await self.app(scope, receive, self.send)
|
||||
|
||||
handler = _find_route_handler(_app.routes, scope)
|
||||
request = Request(scope, receive=receive, send=self.send)
|
||||
if _should_exempt(limiter, handler):
|
||||
return await self.app(scope, receive, self.send)
|
||||
|
||||
error_response, should_inject_headers = await async_check_limits(
|
||||
limiter, request, handler, _app
|
||||
)
|
||||
if error_response is not None:
|
||||
return await error_response(scope, receive, self.send_wrapper)
|
||||
|
||||
if should_inject_headers:
|
||||
self.inject_headers = True
|
||||
self.limiter = limiter
|
||||
self.request = request
|
||||
|
||||
return await self.app(scope, receive, self.send_wrapper)
|
0
venv/lib/python3.11/site-packages/slowapi/py.typed
Normal file
0
venv/lib/python3.11/site-packages/slowapi/py.typed
Normal file
27
venv/lib/python3.11/site-packages/slowapi/util.py
Normal file
27
venv/lib/python3.11/site-packages/slowapi/util.py
Normal file
@ -0,0 +1,27 @@
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
def get_ipaddr(request: Request) -> str:
|
||||
"""
|
||||
Returns the ip address for the current request (or 127.0.0.1 if none found)
|
||||
based on the X-Forwarded-For headers.
|
||||
Note that a more robust method for determining IP address of the client is
|
||||
provided by uvicorn's ProxyHeadersMiddleware.
|
||||
"""
|
||||
if "X_FORWARDED_FOR" in request.headers:
|
||||
return request.headers["X_FORWARDED_FOR"]
|
||||
else:
|
||||
if not request.client or not request.client.host:
|
||||
return "127.0.0.1"
|
||||
|
||||
return request.client.host
|
||||
|
||||
|
||||
def get_remote_address(request: Request) -> str:
|
||||
"""
|
||||
Returns the ip address for the current request (or 127.0.0.1 if none found)
|
||||
"""
|
||||
if not request.client or not request.client.host:
|
||||
return "127.0.0.1"
|
||||
|
||||
return request.client.host
|
113
venv/lib/python3.11/site-packages/slowapi/wrappers.py
Normal file
113
venv/lib/python3.11/site-packages/slowapi/wrappers.py
Normal file
@ -0,0 +1,113 @@
|
||||
import inspect
|
||||
from typing import Callable, Iterator, List, Optional, Union
|
||||
|
||||
from limits import RateLimitItem, parse_many # type: ignore
|
||||
|
||||
|
||||
class Limit(object):
|
||||
"""
|
||||
simple wrapper to encapsulate limits and their context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: RateLimitItem,
|
||||
key_func: Callable[..., str],
|
||||
scope: Optional[Union[str, Callable[..., str]]],
|
||||
per_method: bool,
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
cost: Union[int, Callable[..., int]],
|
||||
override_defaults: bool,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.key_func = key_func
|
||||
self.__scope = scope
|
||||
self.per_method = per_method
|
||||
self.methods = methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.cost = cost
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
@property
|
||||
def is_exempt(self) -> bool:
|
||||
"""
|
||||
Check if the limit is exempt.
|
||||
Return True to exempt the route from the limit.
|
||||
"""
|
||||
return self.exempt_when() if self.exempt_when is not None else False
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
# flack.request.endpoint is the name of the function for the endpoint
|
||||
# FIXME: how to get the request here?
|
||||
if self.__scope is None:
|
||||
return ""
|
||||
else:
|
||||
return (
|
||||
self.__scope(request.endpoint) # type: ignore
|
||||
if callable(self.__scope)
|
||||
else self.__scope
|
||||
)
|
||||
|
||||
|
||||
class LimitGroup(object):
|
||||
"""
|
||||
represents a group of related limits either from a string or a callable that returns one
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit_provider: Union[str, Callable[..., str]],
|
||||
key_function: Callable[..., str],
|
||||
scope: Optional[Union[str, Callable[..., str]]],
|
||||
per_method: bool,
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
cost: Union[int, Callable[..., int]],
|
||||
override_defaults: bool,
|
||||
):
|
||||
self.__limit_provider = limit_provider
|
||||
self.__scope = scope
|
||||
self.key_function = key_function
|
||||
self.per_method = per_method
|
||||
self.methods = methods and [m.lower() for m in methods] or methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.cost = cost
|
||||
self.override_defaults = override_defaults
|
||||
self.request = None
|
||||
|
||||
def __iter__(self) -> Iterator[Limit]:
|
||||
if callable(self.__limit_provider):
|
||||
if "key" in inspect.signature(self.__limit_provider).parameters.keys():
|
||||
assert (
|
||||
"request" in inspect.signature(self.key_function).parameters.keys()
|
||||
), f"Limit provider function {self.key_function.__name__} needs a `request` argument"
|
||||
if self.request is None:
|
||||
raise Exception("`request` object can't be None")
|
||||
limit_raw = self.__limit_provider(self.key_function(self.request))
|
||||
else:
|
||||
limit_raw = self.__limit_provider()
|
||||
else:
|
||||
limit_raw = self.__limit_provider
|
||||
limit_items: List[RateLimitItem] = parse_many(limit_raw)
|
||||
for limit in limit_items:
|
||||
yield Limit(
|
||||
limit,
|
||||
self.key_function,
|
||||
self.__scope,
|
||||
self.per_method,
|
||||
self.methods,
|
||||
self.error_message,
|
||||
self.exempt_when,
|
||||
self.cost,
|
||||
self.override_defaults,
|
||||
)
|
||||
|
||||
def with_request(self, request):
|
||||
self.request = request
|
||||
return self
|
Reference in New Issue
Block a user