Update 2025-04-24_11:44:19
This commit is contained in:
8
venv/lib/python3.11/site-packages/limits/aio/__init__.py
Normal file
8
venv/lib/python3.11/site-packages/limits/aio/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import storage, strategies
|
||||
|
||||
__all__ = [
|
||||
"storage",
|
||||
"strategies",
|
||||
]
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.aio.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage
|
||||
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
220
venv/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
220
venv/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, Awaitable[R]]:
|
||||
@functools.wraps(fn)
|
||||
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing an async storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
|
||||
super().__init_subclass__(**kwargs)
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.memcached.bridge import MemcachedBridge
|
||||
from limits.aio.storage.memcached.emcache import EmcacheBridge
|
||||
from limits.aio.storage.memcached.memcachio import MemcachioBridge
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="5.0",
|
||||
reason="Switched default implementation to :pypi:`memcachio`",
|
||||
)
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`memcachio`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memcached"]
|
||||
"""The storage scheme for memcached to be used in an async context"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"memcachio": Version("0.3"),
|
||||
"emcache": Version("0.0"),
|
||||
}
|
||||
|
||||
bridge: MemcachedBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["memcachio", "emcache"] = "memcachio",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``async+memcached://host:port,host:port``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``memcachio``: :class:`memcachio.Client`
|
||||
- ``emcache``: :class:`emcache.Client`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`memcachio.Client`
|
||||
:raise ConfigurationError: when :pypi:`memcachio` is not available
|
||||
"""
|
||||
if implementation == "emcache":
|
||||
self.bridge = EmcacheBridge(
|
||||
uri, self.dependencies["emcache"].module, **options
|
||||
)
|
||||
else:
|
||||
self.bridge = MemcachioBridge(
|
||||
uri, self.dependencies["memcachio"].module, **options
|
||||
)
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await self.bridge.clear(key)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
|
||||
"""
|
||||
return await self.bridge.incr(
|
||||
key, expiry, amount, set_expiration_key=set_expiration_key
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def check(self) -> bool:
|
||||
return await self.bridge.check()
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
t0 = time.time()
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = await self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
t1 = time.time()
|
||||
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the increment and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.bridge.decr(current_key, amount, noreply=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = await self.bridge.get_many([previous_key, current_key])
|
||||
|
||||
previous_count = result.get(previous_key.encode("utf-8"), 0)
|
||||
current_count = result.get(current_key.encode("utf-8"), 0)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import Iterable
|
||||
|
||||
|
||||
class MemcachedBridge(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.hosts = []
|
||||
self.options = options
|
||||
|
||||
sep = self.parsed_uri.netloc.strip().find("@") + 1
|
||||
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
|
||||
if self.parsed_uri.username:
|
||||
self.options["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.options["password"] = self.parsed_uri.password
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import emcache
|
||||
|
||||
|
||||
class EmcacheBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage = None
|
||||
|
||||
async def get_storage(self) -> emcache.Client:
|
||||
if not self._storage:
|
||||
self._storage = await self.dependency.create_client(
|
||||
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
||||
return item and int(item.value) or 0
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
results = await (await self.get_storage()).get_many(
|
||||
[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(item.value) if item else 0 for k, item in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
try:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
except self.dependency.NotFoundCommandError:
|
||||
pass
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
try:
|
||||
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
|
||||
except self.dependency.NotFoundCommandError:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
try:
|
||||
return await storage.increment(limit_key, amount) or amount
|
||||
except self.dependency.NotFoundCommandError:
|
||||
storage = await self.get_storage()
|
||||
try:
|
||||
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
exptime=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
value = amount
|
||||
except self.dependency.NotStoredStorageCommandError:
|
||||
# Coult not add the key, probably because a concurrent call has added it
|
||||
storage = await self.get_storage()
|
||||
value = await storage.increment(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
item = await storage.get(self._expiration_key(key).encode("utf-8"))
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return (
|
||||
self.dependency.ClusterNoAvailableNodes,
|
||||
self.dependency.CommandError,
|
||||
)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import memcachio
|
||||
|
||||
|
||||
class MemcachioBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage: memcachio.Client[bytes] | None = None
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (
|
||||
self.dependency.errors.NoAvailableNodes,
|
||||
self.dependency.errors.MemcachioConnectionError,
|
||||
)
|
||||
|
||||
async def get_storage(self) -> memcachio.Client[bytes]:
|
||||
if not self._storage:
|
||||
self._storage = self.dependency.Client(
|
||||
[(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
"""
|
||||
results = await (await self.get_storage()).get(
|
||||
*[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(v.value) for k, v in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
return await storage.decr(limit_key, amount, noreply=noreply) or 0
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
if (value := (await storage.incr(limit_key, amount))) is None:
|
||||
storage = await self.get_storage()
|
||||
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
expiry=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
return amount
|
||||
else:
|
||||
storage = await self.get_storage()
|
||||
return await storage.incr(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
expiration_key = self._expiration_key(key).encode("utf-8")
|
||||
item = (await storage.get(expiration_key)).get(expiration_key, None)
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
281
venv/lib/python3.11/site-packages/limits/aio/storage/memory.py
Normal file
281
venv/lib/python3.11/site-packages/limits/aio/storage/memory.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import bisect
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
import limits.typing
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: int) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed & sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memory"]
|
||||
"""
|
||||
The storage scheme for in process memory storage for use in an
|
||||
async context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, uri: str | None = None, wrap_exceptions: bool = False, **_: str
|
||||
) -> None:
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: asyncio.Task[None] | None = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.timer = None
|
||||
self.locks = defaultdict(asyncio.Lock)
|
||||
asyncio.ensure_future(self.__schedule_expiry())
|
||||
|
||||
async def __expire_events(self) -> None:
|
||||
try:
|
||||
now = time.time()
|
||||
for key in list(self.events.keys()):
|
||||
cutoff = await asyncio.to_thread(
|
||||
lambda evts: bisect.bisect_left(
|
||||
evts, -now, key=lambda event: -event.expiry
|
||||
),
|
||||
self.events[key],
|
||||
)
|
||||
async with self.locks[key]:
|
||||
if self.events.get(key, []):
|
||||
self.events[key] = self.events[key][:cutoff]
|
||||
if not self.events.get(key, None):
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
async def __schedule_expiry(self) -> None:
|
||||
if not self.timer or self.timer.done():
|
||||
self.timer = asyncio.create_task(self.__expire_events())
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key. 0 is the minimum allowed value.
|
||||
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry: Entry | None = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = await self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = await self.get(previous_key)
|
||||
current_count = await self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
|
||||
return num_items
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self.timer and not self.timer.done():
|
||||
self.timer.cancel()
|
||||
except RuntimeError: # noqa
|
||||
pass
|
517
venv/lib/python3.11/site-packages/limits/aio/storage/mongodb.py
Normal file
517
venv/lib/python3.11/site-packages/limits/aio/storage/mongodb.py
Normal file
@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.typing import (
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`motor`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
|
||||
"""
|
||||
The storage scheme for MongoDB for use in an async context
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
|
||||
not available
|
||||
"""
|
||||
|
||||
uri = uri.replace("async+mongodb", "mongodb", 1)
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
self.dependency = self.dependencies["motor.motor_asyncio"]
|
||||
self.proxy_dependency = self.dependencies["pymongo"]
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
|
||||
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
|
||||
# TODO: Fix this hack. It was noticed when running a benchmark
|
||||
# with FastAPI - however - doesn't appear in unit tests or in an isolated
|
||||
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
|
||||
self.storage.get_io_loop = asyncio.get_running_loop
|
||||
|
||||
self.__database_name = database_name
|
||||
self.__collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.__indices_created = False
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
@property
|
||||
def database(self): # type: ignore
|
||||
return self.storage.get_database(self.__database_name)
|
||||
|
||||
async def create_indices(self) -> None:
|
||||
if not self.__indices_created:
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
)
|
||||
self.__indices_created = True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = sum(
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].count_documents(
|
||||
{}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].count_documents({}),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].drop(),
|
||||
self.database[self.__collection_mapping["windows"]].drop(),
|
||||
)
|
||||
|
||||
return cast(int, num_keys)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{"_id": key}
|
||||
)
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
response = await self.database[
|
||||
self.__collection_mapping["counters"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
return int(response["count"])
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling
|
||||
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
|
||||
"""
|
||||
try:
|
||||
await self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param str key: rate limit key
|
||||
:param int expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if (
|
||||
result := await self.database[self.__collection_mapping["windows"]]
|
||||
.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
.to_list(length=1)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
await self.database[self.__collection_mapping["windows"]].update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.proxy_dependency.module.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
await self.create_indices()
|
||||
expiry_ms = expiry * 1000
|
||||
result = await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.storage and self.storage.close()
|
@ -0,0 +1,400 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.aio.storage.redis.coredis import CoredisBridge
|
||||
from limits.aio.storage.redis.redispy import RedispyBridge
|
||||
from limits.aio.storage.redis.valkey import ValkeyBridge
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`redis`"
|
||||
" through :paramref:`implementation`"
|
||||
),
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis",
|
||||
"async+rediss",
|
||||
"async+redis+unix",
|
||||
"async+valkey",
|
||||
"async+valkeys",
|
||||
"async+valkey+unix",
|
||||
]
|
||||
"""
|
||||
The storage schemes for redis to be used in an async context
|
||||
"""
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
|
||||
bridge: RedisBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form:
|
||||
|
||||
- ``async+redis://[:password]@host:port``
|
||||
- ``async+redis://[:password]@host:port/db``
|
||||
- ``async+rediss://[:password]@host:port``
|
||||
- ``async+redis+unix:///path/to/sock?db=0`` etc...
|
||||
|
||||
This uri is passed directly to :meth:`coredis.Redis.from_url` or
|
||||
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
|
||||
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
|
||||
|
||||
If the uri scheme is ``async+valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.Redis`
|
||||
- ``redispy``: :class:`redis.asyncio.client.Redis`
|
||||
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
|
||||
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
"""
|
||||
uri = uri.removeprefix("async+")
|
||||
self.target_server = "redis" if uri.startswith("redis") else "valkey"
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
self.options = options
|
||||
if self.target_server == "valkey" or implementation == "valkey":
|
||||
self.bridge = ValkeyBridge(uri, self.dependencies["valkey"].module)
|
||||
else:
|
||||
if implementation == "redispy":
|
||||
self.bridge = RedispyBridge(uri, self.dependencies["redis"].module)
|
||||
else:
|
||||
self.bridge = CoredisBridge(uri, self.dependencies["coredis"].module)
|
||||
self.configure_bridge()
|
||||
self.bridge.register_scripts()
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_basic(**self.options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
|
||||
return await self.bridge.incr(key, expiry, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
|
||||
return await self.bridge.clear(key)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
|
||||
return await self.bridge.acquire_entry(key, limit, expiry, amount)
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
return await self.bridge.get_moving_window(key, limit, expiry)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
current_key = self._current_window_key(key)
|
||||
previous_key = self._previous_window_key(key)
|
||||
return await self.bridge.acquire_sliding_window_entry(
|
||||
previous_key, current_key, limit, expiry, amount
|
||||
)
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling ``PING``
|
||||
"""
|
||||
|
||||
return await self.bridge.check()
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
``self.PREFIX`` in blocks of 5000.
|
||||
|
||||
.. warning:: This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
"""
|
||||
|
||||
return await self.bridge.lua_reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+cluster`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
|
||||
"""
|
||||
The storage schemes for redis cluster to be used in an async context
|
||||
"""
|
||||
|
||||
MODE = "CLUSTER"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.RedisCluster`
|
||||
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
|
||||
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.RedisCluster` or
|
||||
:class:`redis.asyncio.RedisCluster`
|
||||
:raise ConfigurationError: when the redis library is not
|
||||
available or if the redis host cannot be pinged.
|
||||
"""
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_cluster(**self.options)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
|
||||
one at a time.
|
||||
|
||||
.. warning:: This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets
|
||||
"""
|
||||
|
||||
return await self.bridge.reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+sentinel`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis+sentinel",
|
||||
"async+valkey+sentinel",
|
||||
]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
MODE = "SENTINEL"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"coredis.sentinel": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.sentinel.Sentinel`
|
||||
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
|
||||
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
|
||||
:param service_name: sentinel service name (if not provided in `uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: optional arguments to pass as
|
||||
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.Sentinel`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis primary host cannot be pinged.
|
||||
"""
|
||||
|
||||
self.service_name = service_name
|
||||
self.use_replicas = use_replicas
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_sentinel(
|
||||
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.util import get_package_data
|
||||
|
||||
|
||||
class RedisBridge(ABC):
|
||||
PREFIX = "LIMITS"
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.parsed_auth = {}
|
||||
if self.parsed_uri.username:
|
||||
self.parsed_auth["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.parsed_auth["password"] = self.parsed_uri.password
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.PREFIX}:{key}"
|
||||
|
||||
@abstractmethod
|
||||
def register_scripts(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_basic(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_cluster(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def lua_reset(self) -> int | None: ...
|
@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncCoRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
|
||||
|
||||
class CoredisBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`coredis.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.exceptions.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.sentinel.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.primary_for(service_name)
|
||||
self.storage_replica = self.sentinel.replica_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts: list[dict[str, int | str]] = []
|
||||
cluster_hosts.extend(
|
||||
{"host": host, "port": int(port)}
|
||||
for loc in self.parsed_uri.netloc[sep:].split(",")
|
||||
if loc
|
||||
for host, port in [loc.split(":")]
|
||||
)
|
||||
self.storage = self.dependency.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_moving_window: coredis.commands.Script[bytes]
|
||||
lua_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_clear_keys: coredis.commands.Script[bytes]
|
||||
lua_incr_expire: coredis.commands.Script[bytes]
|
||||
connection_getter: Callable[[bool], AsyncCoRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
if (value := await self.get_connection().incrby(key, amount)) == amount:
|
||||
await self.get_connection().expire(key, expiry)
|
||||
return value
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete([key])
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window.execute(
|
||||
[key], [timestamp - expiry, limit]
|
||||
)
|
||||
if window:
|
||||
return float(window[0]), window[1] # type: ignore
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
|
||||
if window := await self.lua_sliding_window.execute(
|
||||
[previous_key, current_key], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0), # type: ignore
|
||||
max(0, float(window[1] or 0)) / 1000, # type: ignore
|
||||
int(window[2] or 0), # type: ignore
|
||||
max(0, float(window[3] or 0)) / 1000, # type: ignore
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window.execute(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window.execute(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(prefix)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete([key])
|
||||
return count
|
@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis.commands
|
||||
|
||||
|
||||
class RedispyBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`redis.asyncio.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.asyncio.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.master_for(service_name)
|
||||
self.storage_replica = self.sentinel.slave_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.asyncio.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append(
|
||||
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
|
||||
)
|
||||
|
||||
self.storage = self.dependency.asyncio.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
lua_clear_keys: redis.commands.core.Script
|
||||
lua_incr_expire: redis.commands.core.Script
|
||||
connection_getter: Callable[[bool], AsyncRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
# Redis-py uses a slightly different script registration
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete(key)
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
|
||||
if window:
|
||||
return float(window[0]), window[1]
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
if window := await self.lua_sliding_window(
|
||||
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(
|
||||
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
|
||||
)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete(key)
|
||||
return count
|
@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .redispy import RedispyBridge
|
||||
|
||||
|
||||
class ValkeyBridge(RedispyBridge):
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.ValkeyError,)
|
310
venv/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
310
venv/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
@ -0,0 +1,310 @@
|
||||
"""
|
||||
Asynchronous rate limiting strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from math import floor, inf
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from ..limits import RateLimitItem
|
||||
from ..storage import StorageTypes
|
||||
from ..typing import cast
|
||||
from ..util import WindowStats
|
||||
from .storage import MovingWindowSupport, Storage
|
||||
from .storage.base import SlidingWindowCounterSupport
|
||||
|
||||
|
||||
class RateLimiter(ABC):
|
||||
def __init__(self, storage: StorageTypes):
|
||||
assert isinstance(storage, Storage)
|
||||
self.storage: Storage = storage
|
||||
|
||||
@abstractmethod
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining))
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return await self.storage.clear(item.key_for(*identifiers))
|
||||
|
||||
|
||||
class MovingWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:moving window`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes) -> None:
|
||||
if not (
|
||||
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"MovingWindowRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
|
||||
return await cast(MovingWindowSupport, self.storage).acquire_entry(
|
||||
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
)
|
||||
amount = res[1]
|
||||
|
||||
return amount <= item.amount - cost
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
returns the number of requests remaining within this limit.
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
window_start, window_items = await cast(
|
||||
MovingWindowSupport, self.storage
|
||||
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
|
||||
reset = window_start + item.get_expiry()
|
||||
|
||||
return WindowStats(reset, item.amount - window_items)
|
||||
|
||||
|
||||
class FixedWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:fixed window`
|
||||
"""
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.incr(
|
||||
item.key_for(*identifiers),
|
||||
item.get_expiry(),
|
||||
amount=cost,
|
||||
)
|
||||
<= item.amount
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: reset time, remaining
|
||||
"""
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount - await self.storage.get(item.key_for(*identifiers)),
|
||||
)
|
||||
reset = await self.storage.get_expiry(item.key_for(*identifiers))
|
||||
|
||||
return WindowStats(reset, remaining)
|
||||
|
||||
|
||||
@versionadded(version="4.1")
|
||||
class SlidingWindowCounterRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:sliding window counter`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
||||
storage, "acquire_sliding_window_entry"
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def _weighted_count(
|
||||
self,
|
||||
item: RateLimitItem,
|
||||
previous_count: int,
|
||||
previous_expires_in: float,
|
||||
current_count: int,
|
||||
) -> float:
|
||||
"""
|
||||
Return the approximated by weighting the previous window count and adding the current window count.
|
||||
"""
|
||||
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
return await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).acquire_sliding_window_entry(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
cost,
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
previous_count, previous_expires_in, current_count, _ = await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
return (
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
< item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
|
||||
(
|
||||
previous_count,
|
||||
previous_expires_in,
|
||||
current_count,
|
||||
current_expires_in,
|
||||
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
|
||||
item.key_for(*identifiers), item.get_expiry()
|
||||
)
|
||||
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount
|
||||
- floor(
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if not (previous_count or current_count):
|
||||
return WindowStats(now, remaining)
|
||||
|
||||
expiry = item.get_expiry()
|
||||
|
||||
previous_reset_in, current_reset_in = inf, inf
|
||||
if previous_count:
|
||||
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
||||
if current_count:
|
||||
current_reset_in = current_expires_in % expiry
|
||||
|
||||
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
||||
|
||||
|
||||
STRATEGIES = {
|
||||
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
||||
"fixed-window": FixedWindowRateLimiter,
|
||||
"moving-window": MovingWindowRateLimiter,
|
||||
}
|
Reference in New Issue
Block a user