221 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			221 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import functools
 | |
| from abc import ABC, abstractmethod
 | |
| 
 | |
| from deprecated.sphinx import versionadded
 | |
| 
 | |
| from limits import errors
 | |
| from limits.storage.registry import StorageRegistry
 | |
| from limits.typing import (
 | |
|     Any,
 | |
|     Awaitable,
 | |
|     Callable,
 | |
|     P,
 | |
|     R,
 | |
|     cast,
 | |
| )
 | |
| from limits.util import LazyDependency
 | |
| 
 | |
| 
 | |
| def _wrap_errors(
 | |
|     fn: Callable[P, Awaitable[R]],
 | |
| ) -> Callable[P, Awaitable[R]]:
 | |
|     @functools.wraps(fn)
 | |
|     async def inner(*args: P.args, **kwargs: P.kwargs) -> R:  # type: ignore[misc]
 | |
|         instance = cast(Storage, args[0])
 | |
|         try:
 | |
|             return await fn(*args, **kwargs)
 | |
|         except instance.base_exceptions as exc:
 | |
|             if instance.wrap_exceptions:
 | |
|                 raise errors.StorageError(exc) from exc
 | |
|             raise
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| @versionadded(version="2.1")
 | |
| class Storage(LazyDependency, metaclass=StorageRegistry):
 | |
|     """
 | |
|     Base class to extend when implementing an async storage backend.
 | |
|     """
 | |
| 
 | |
|     STORAGE_SCHEME: list[str] | None
 | |
|     """The storage schemes to register against this implementation"""
 | |
| 
 | |
|     def __init_subclass__(cls, **kwargs: Any) -> None:  # type:ignore[explicit-any]
 | |
|         super().__init_subclass__(**kwargs)
 | |
|         for method in {
 | |
|             "incr",
 | |
|             "get",
 | |
|             "get_expiry",
 | |
|             "check",
 | |
|             "reset",
 | |
|             "clear",
 | |
|         }:
 | |
|             setattr(cls, method, _wrap_errors(getattr(cls, method)))
 | |
|         super().__init_subclass__(**kwargs)
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         uri: str | None = None,
 | |
|         wrap_exceptions: bool = False,
 | |
|         **options: float | str | bool,
 | |
|     ) -> None:
 | |
|         """
 | |
|         :param wrap_exceptions: Whether to wrap storage exceptions in
 | |
|          :exc:`limits.errors.StorageError` before raising it.
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.wrap_exceptions = wrap_exceptions
 | |
| 
 | |
|     @property
 | |
|     @abstractmethod
 | |
|     def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
 | |
|         """
 | |
|         increments the counter for a given rate limit key
 | |
| 
 | |
|         :param key: the key to increment
 | |
|         :param expiry: amount in seconds for the key to expire in
 | |
|         :param amount: the number to increment by
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def get(self, key: str) -> int:
 | |
|         """
 | |
|         :param key: the key to get the counter value for
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def get_expiry(self, key: str) -> float:
 | |
|         """
 | |
|         :param key: the key to get the expiry for
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def check(self) -> bool:
 | |
|         """
 | |
|         check if storage is healthy
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def reset(self) -> int | None:
 | |
|         """
 | |
|         reset storage to clear limits
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def clear(self, key: str) -> None:
 | |
|         """
 | |
|         resets the rate limit key
 | |
| 
 | |
|         :param key: the key to clear rate limits for
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class MovingWindowSupport(ABC):
 | |
|     """
 | |
|     Abstract base class for async storages that support
 | |
|     the :ref:`strategies:moving window` strategy
 | |
|     """
 | |
| 
 | |
|     def __init_subclass__(cls, **kwargs: Any) -> None:  # type: ignore[explicit-any]
 | |
|         for method in {
 | |
|             "acquire_entry",
 | |
|             "get_moving_window",
 | |
|         }:
 | |
|             setattr(
 | |
|                 cls,
 | |
|                 method,
 | |
|                 _wrap_errors(getattr(cls, method)),
 | |
|             )
 | |
|         super().__init_subclass__(**kwargs)
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def acquire_entry(
 | |
|         self, key: str, limit: int, expiry: int, amount: int = 1
 | |
|     ) -> bool:
 | |
|         """
 | |
|         :param key: rate limit key to acquire an entry in
 | |
|         :param limit: amount of entries allowed
 | |
|         :param expiry: expiry of the entry
 | |
|         :param amount: the number of entries to acquire
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def get_moving_window(
 | |
|         self, key: str, limit: int, expiry: int
 | |
|     ) -> tuple[float, int]:
 | |
|         """
 | |
|         returns the starting point and the number of entries in the moving
 | |
|         window
 | |
| 
 | |
|         :param key: rate limit key
 | |
|         :param expiry: expiry of entry
 | |
|         :return: (start of window, number of acquired entries)
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class SlidingWindowCounterSupport(ABC):
 | |
|     """
 | |
|     Abstract base class for async storages that support
 | |
|     the :ref:`strategies:sliding window counter` strategy
 | |
|     """
 | |
| 
 | |
|     def __init_subclass__(cls, **kwargs: Any) -> None:  # type: ignore[explicit-any]
 | |
|         for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
 | |
|             setattr(
 | |
|                 cls,
 | |
|                 method,
 | |
|                 _wrap_errors(getattr(cls, method)),
 | |
|             )
 | |
|         super().__init_subclass__(**kwargs)
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def acquire_sliding_window_entry(
 | |
|         self,
 | |
|         key: str,
 | |
|         limit: int,
 | |
|         expiry: int,
 | |
|         amount: int = 1,
 | |
|     ) -> bool:
 | |
|         """
 | |
|         Acquire an entry if the weighted count of the current and previous
 | |
|         windows is less than or equal to the limit
 | |
| 
 | |
|         :param key: rate limit key to acquire an entry in
 | |
|         :param limit: amount of entries allowed
 | |
|         :param expiry: expiry of the entry
 | |
|         :param amount: the number of entries to acquire
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def get_sliding_window(
 | |
|         self, key: str, expiry: int
 | |
|     ) -> tuple[int, float, int, float]:
 | |
|         """
 | |
|         Return the previous and current window information.
 | |
| 
 | |
|         :param key: the rate limit key
 | |
|         :param expiry: the rate limit expiry, needed to compute the key in some implementations
 | |
|         :return: a tuple of (int, float, int, float) with the following information:
 | |
|           - previous window counter
 | |
|           - previous window TTL
 | |
|           - current window counter
 | |
|           - current window TTL
 | |
|         """
 | |
|         raise NotImplementedError
 | 
