Files
2025-04-24 11:44:23 +02:00

221 lines
6.1 KiB
Python

from __future__ import annotations
import functools
from abc import ABC, abstractmethod
from deprecated.sphinx import versionadded
from limits import errors
from limits.storage.registry import StorageRegistry
from limits.typing import (
Any,
Awaitable,
Callable,
P,
R,
cast,
)
from limits.util import LazyDependency
def _wrap_errors(
fn: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]:
@functools.wraps(fn)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
instance = cast(Storage, args[0])
try:
return await fn(*args, **kwargs)
except instance.base_exceptions as exc:
if instance.wrap_exceptions:
raise errors.StorageError(exc) from exc
raise
return inner
@versionadded(version="2.1")
class Storage(LazyDependency, metaclass=StorageRegistry):
"""
Base class to extend when implementing an async storage backend.
"""
STORAGE_SCHEME: list[str] | None
"""The storage schemes to register against this implementation"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
super().__init_subclass__(**kwargs)
for method in {
"incr",
"get",
"get_expiry",
"check",
"reset",
"clear",
}:
setattr(cls, method, _wrap_errors(getattr(cls, method)))
super().__init_subclass__(**kwargs)
def __init__(
self,
uri: str | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
"""
super().__init__()
self.wrap_exceptions = wrap_exceptions
@property
@abstractmethod
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
raise NotImplementedError
@abstractmethod
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
raise NotImplementedError
@abstractmethod
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
raise NotImplementedError
@abstractmethod
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
raise NotImplementedError
@abstractmethod
async def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError
@abstractmethod
async def reset(self) -> int | None:
"""
reset storage to clear limits
"""
raise NotImplementedError
@abstractmethod
async def clear(self, key: str) -> None:
"""
resets the rate limit key
:param key: the key to clear rate limits for
"""
raise NotImplementedError
class MovingWindowSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:moving window` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_entry",
"get_moving_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
raise NotImplementedError
class SlidingWindowCounterSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:sliding window counter` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
Acquire an entry if the weighted count of the current and previous
windows is less than or equal to the limit
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
"""
Return the previous and current window information.
:param key: the rate limit key
:param expiry: the rate limit expiry, needed to compute the key in some implementations
:return: a tuple of (int, float, int, float) with the following information:
- previous window counter
- previous window TTL
- current window counter
- current window TTL
"""
raise NotImplementedError