Update 2025-04-24_11:44:19
This commit is contained in:
35
venv/lib/python3.11/site-packages/limits/__init__.py
Normal file
35
venv/lib/python3.11/site-packages/limits/__init__.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Rate limiting with commonly used storage backends
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import _version, aio, storage, strategies
|
||||
from .limits import (
|
||||
RateLimitItem,
|
||||
RateLimitItemPerDay,
|
||||
RateLimitItemPerHour,
|
||||
RateLimitItemPerMinute,
|
||||
RateLimitItemPerMonth,
|
||||
RateLimitItemPerSecond,
|
||||
RateLimitItemPerYear,
|
||||
)
|
||||
from .util import WindowStats, parse, parse_many
|
||||
|
||||
__all__ = [
|
||||
"RateLimitItem",
|
||||
"RateLimitItemPerDay",
|
||||
"RateLimitItemPerHour",
|
||||
"RateLimitItemPerMinute",
|
||||
"RateLimitItemPerMonth",
|
||||
"RateLimitItemPerSecond",
|
||||
"RateLimitItemPerYear",
|
||||
"WindowStats",
|
||||
"aio",
|
||||
"parse",
|
||||
"parse_many",
|
||||
"storage",
|
||||
"strategies",
|
||||
]
|
||||
|
||||
__version__ = _version.get_versions()["version"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21
venv/lib/python3.11/site-packages/limits/_version.py
Normal file
21
venv/lib/python3.11/site-packages/limits/_version.py
Normal file
@ -0,0 +1,21 @@
|
||||
|
||||
# This file was generated by 'versioneer.py' (0.29) from
|
||||
# revision-control system data, or from the parent directory name of an
|
||||
# unpacked source archive. Distribution tarballs contain a pre-generated copy
|
||||
# of this file.
|
||||
|
||||
import json
|
||||
|
||||
version_json = '''
|
||||
{
|
||||
"date": "2025-04-15T17:28:21-0700",
|
||||
"dirty": false,
|
||||
"error": null,
|
||||
"full-revisionid": "eeb02fd85c146292dd70fb798ac90a486ba163bd",
|
||||
"version": "5.0.0"
|
||||
}
|
||||
''' # END VERSION_JSON
|
||||
|
||||
|
||||
def get_versions():
|
||||
return json.loads(version_json)
|
8
venv/lib/python3.11/site-packages/limits/aio/__init__.py
Normal file
8
venv/lib/python3.11/site-packages/limits/aio/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import storage, strategies
|
||||
|
||||
__all__ = [
|
||||
"storage",
|
||||
"strategies",
|
||||
]
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.aio.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage
|
||||
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
220
venv/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
220
venv/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, Awaitable[R]]:
|
||||
@functools.wraps(fn)
|
||||
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing an async storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
|
||||
super().__init_subclass__(**kwargs)
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.memcached.bridge import MemcachedBridge
|
||||
from limits.aio.storage.memcached.emcache import EmcacheBridge
|
||||
from limits.aio.storage.memcached.memcachio import MemcachioBridge
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="5.0",
|
||||
reason="Switched default implementation to :pypi:`memcachio`",
|
||||
)
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`memcachio`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memcached"]
|
||||
"""The storage scheme for memcached to be used in an async context"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"memcachio": Version("0.3"),
|
||||
"emcache": Version("0.0"),
|
||||
}
|
||||
|
||||
bridge: MemcachedBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["memcachio", "emcache"] = "memcachio",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``async+memcached://host:port,host:port``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``memcachio``: :class:`memcachio.Client`
|
||||
- ``emcache``: :class:`emcache.Client`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`memcachio.Client`
|
||||
:raise ConfigurationError: when :pypi:`memcachio` is not available
|
||||
"""
|
||||
if implementation == "emcache":
|
||||
self.bridge = EmcacheBridge(
|
||||
uri, self.dependencies["emcache"].module, **options
|
||||
)
|
||||
else:
|
||||
self.bridge = MemcachioBridge(
|
||||
uri, self.dependencies["memcachio"].module, **options
|
||||
)
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await self.bridge.clear(key)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
|
||||
"""
|
||||
return await self.bridge.incr(
|
||||
key, expiry, amount, set_expiration_key=set_expiration_key
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def check(self) -> bool:
|
||||
return await self.bridge.check()
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
t0 = time.time()
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = await self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
t1 = time.time()
|
||||
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the increment and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.bridge.decr(current_key, amount, noreply=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = await self.bridge.get_many([previous_key, current_key])
|
||||
|
||||
previous_count = result.get(previous_key.encode("utf-8"), 0)
|
||||
current_count = result.get(current_key.encode("utf-8"), 0)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import Iterable
|
||||
|
||||
|
||||
class MemcachedBridge(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.hosts = []
|
||||
self.options = options
|
||||
|
||||
sep = self.parsed_uri.netloc.strip().find("@") + 1
|
||||
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
|
||||
if self.parsed_uri.username:
|
||||
self.options["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.options["password"] = self.parsed_uri.password
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import emcache
|
||||
|
||||
|
||||
class EmcacheBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage = None
|
||||
|
||||
async def get_storage(self) -> emcache.Client:
|
||||
if not self._storage:
|
||||
self._storage = await self.dependency.create_client(
|
||||
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
||||
return item and int(item.value) or 0
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
results = await (await self.get_storage()).get_many(
|
||||
[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(item.value) if item else 0 for k, item in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
try:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
except self.dependency.NotFoundCommandError:
|
||||
pass
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
try:
|
||||
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
|
||||
except self.dependency.NotFoundCommandError:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
try:
|
||||
return await storage.increment(limit_key, amount) or amount
|
||||
except self.dependency.NotFoundCommandError:
|
||||
storage = await self.get_storage()
|
||||
try:
|
||||
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
exptime=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
value = amount
|
||||
except self.dependency.NotStoredStorageCommandError:
|
||||
# Coult not add the key, probably because a concurrent call has added it
|
||||
storage = await self.get_storage()
|
||||
value = await storage.increment(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
item = await storage.get(self._expiration_key(key).encode("utf-8"))
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return (
|
||||
self.dependency.ClusterNoAvailableNodes,
|
||||
self.dependency.CommandError,
|
||||
)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import memcachio
|
||||
|
||||
|
||||
class MemcachioBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage: memcachio.Client[bytes] | None = None
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (
|
||||
self.dependency.errors.NoAvailableNodes,
|
||||
self.dependency.errors.MemcachioConnectionError,
|
||||
)
|
||||
|
||||
async def get_storage(self) -> memcachio.Client[bytes]:
|
||||
if not self._storage:
|
||||
self._storage = self.dependency.Client(
|
||||
[(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
"""
|
||||
results = await (await self.get_storage()).get(
|
||||
*[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(v.value) for k, v in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
return await storage.decr(limit_key, amount, noreply=noreply) or 0
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
if (value := (await storage.incr(limit_key, amount))) is None:
|
||||
storage = await self.get_storage()
|
||||
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
expiry=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
return amount
|
||||
else:
|
||||
storage = await self.get_storage()
|
||||
return await storage.incr(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
expiration_key = self._expiration_key(key).encode("utf-8")
|
||||
item = (await storage.get(expiration_key)).get(expiration_key, None)
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
281
venv/lib/python3.11/site-packages/limits/aio/storage/memory.py
Normal file
281
venv/lib/python3.11/site-packages/limits/aio/storage/memory.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import bisect
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
import limits.typing
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: int) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed & sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memory"]
|
||||
"""
|
||||
The storage scheme for in process memory storage for use in an
|
||||
async context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, uri: str | None = None, wrap_exceptions: bool = False, **_: str
|
||||
) -> None:
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: asyncio.Task[None] | None = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.timer = None
|
||||
self.locks = defaultdict(asyncio.Lock)
|
||||
asyncio.ensure_future(self.__schedule_expiry())
|
||||
|
||||
async def __expire_events(self) -> None:
|
||||
try:
|
||||
now = time.time()
|
||||
for key in list(self.events.keys()):
|
||||
cutoff = await asyncio.to_thread(
|
||||
lambda evts: bisect.bisect_left(
|
||||
evts, -now, key=lambda event: -event.expiry
|
||||
),
|
||||
self.events[key],
|
||||
)
|
||||
async with self.locks[key]:
|
||||
if self.events.get(key, []):
|
||||
self.events[key] = self.events[key][:cutoff]
|
||||
if not self.events.get(key, None):
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
async def __schedule_expiry(self) -> None:
|
||||
if not self.timer or self.timer.done():
|
||||
self.timer = asyncio.create_task(self.__expire_events())
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key. 0 is the minimum allowed value.
|
||||
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry: Entry | None = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = await self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = await self.get(previous_key)
|
||||
current_count = await self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
|
||||
return num_items
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self.timer and not self.timer.done():
|
||||
self.timer.cancel()
|
||||
except RuntimeError: # noqa
|
||||
pass
|
517
venv/lib/python3.11/site-packages/limits/aio/storage/mongodb.py
Normal file
517
venv/lib/python3.11/site-packages/limits/aio/storage/mongodb.py
Normal file
@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.typing import (
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`motor`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
|
||||
"""
|
||||
The storage scheme for MongoDB for use in an async context
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
|
||||
not available
|
||||
"""
|
||||
|
||||
uri = uri.replace("async+mongodb", "mongodb", 1)
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
self.dependency = self.dependencies["motor.motor_asyncio"]
|
||||
self.proxy_dependency = self.dependencies["pymongo"]
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
|
||||
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
|
||||
# TODO: Fix this hack. It was noticed when running a benchmark
|
||||
# with FastAPI - however - doesn't appear in unit tests or in an isolated
|
||||
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
|
||||
self.storage.get_io_loop = asyncio.get_running_loop
|
||||
|
||||
self.__database_name = database_name
|
||||
self.__collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.__indices_created = False
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
@property
|
||||
def database(self): # type: ignore
|
||||
return self.storage.get_database(self.__database_name)
|
||||
|
||||
async def create_indices(self) -> None:
|
||||
if not self.__indices_created:
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
)
|
||||
self.__indices_created = True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = sum(
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].count_documents(
|
||||
{}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].count_documents({}),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].drop(),
|
||||
self.database[self.__collection_mapping["windows"]].drop(),
|
||||
)
|
||||
|
||||
return cast(int, num_keys)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{"_id": key}
|
||||
)
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
response = await self.database[
|
||||
self.__collection_mapping["counters"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
return int(response["count"])
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling
|
||||
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
|
||||
"""
|
||||
try:
|
||||
await self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param str key: rate limit key
|
||||
:param int expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if (
|
||||
result := await self.database[self.__collection_mapping["windows"]]
|
||||
.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
.to_list(length=1)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
await self.database[self.__collection_mapping["windows"]].update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.proxy_dependency.module.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
await self.create_indices()
|
||||
expiry_ms = expiry * 1000
|
||||
result = await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.storage and self.storage.close()
|
@ -0,0 +1,400 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.aio.storage.redis.coredis import CoredisBridge
|
||||
from limits.aio.storage.redis.redispy import RedispyBridge
|
||||
from limits.aio.storage.redis.valkey import ValkeyBridge
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`redis`"
|
||||
" through :paramref:`implementation`"
|
||||
),
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis",
|
||||
"async+rediss",
|
||||
"async+redis+unix",
|
||||
"async+valkey",
|
||||
"async+valkeys",
|
||||
"async+valkey+unix",
|
||||
]
|
||||
"""
|
||||
The storage schemes for redis to be used in an async context
|
||||
"""
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
|
||||
bridge: RedisBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form:
|
||||
|
||||
- ``async+redis://[:password]@host:port``
|
||||
- ``async+redis://[:password]@host:port/db``
|
||||
- ``async+rediss://[:password]@host:port``
|
||||
- ``async+redis+unix:///path/to/sock?db=0`` etc...
|
||||
|
||||
This uri is passed directly to :meth:`coredis.Redis.from_url` or
|
||||
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
|
||||
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
|
||||
|
||||
If the uri scheme is ``async+valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.Redis`
|
||||
- ``redispy``: :class:`redis.asyncio.client.Redis`
|
||||
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
|
||||
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
"""
|
||||
uri = uri.removeprefix("async+")
|
||||
self.target_server = "redis" if uri.startswith("redis") else "valkey"
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
self.options = options
|
||||
if self.target_server == "valkey" or implementation == "valkey":
|
||||
self.bridge = ValkeyBridge(uri, self.dependencies["valkey"].module)
|
||||
else:
|
||||
if implementation == "redispy":
|
||||
self.bridge = RedispyBridge(uri, self.dependencies["redis"].module)
|
||||
else:
|
||||
self.bridge = CoredisBridge(uri, self.dependencies["coredis"].module)
|
||||
self.configure_bridge()
|
||||
self.bridge.register_scripts()
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_basic(**self.options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
|
||||
return await self.bridge.incr(key, expiry, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
|
||||
return await self.bridge.clear(key)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
|
||||
return await self.bridge.acquire_entry(key, limit, expiry, amount)
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
return await self.bridge.get_moving_window(key, limit, expiry)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
current_key = self._current_window_key(key)
|
||||
previous_key = self._previous_window_key(key)
|
||||
return await self.bridge.acquire_sliding_window_entry(
|
||||
previous_key, current_key, limit, expiry, amount
|
||||
)
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling ``PING``
|
||||
"""
|
||||
|
||||
return await self.bridge.check()
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
``self.PREFIX`` in blocks of 5000.
|
||||
|
||||
.. warning:: This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
"""
|
||||
|
||||
return await self.bridge.lua_reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+cluster`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
|
||||
"""
|
||||
The storage schemes for redis cluster to be used in an async context
|
||||
"""
|
||||
|
||||
MODE = "CLUSTER"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.RedisCluster`
|
||||
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
|
||||
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.RedisCluster` or
|
||||
:class:`redis.asyncio.RedisCluster`
|
||||
:raise ConfigurationError: when the redis library is not
|
||||
available or if the redis host cannot be pinged.
|
||||
"""
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_cluster(**self.options)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
|
||||
one at a time.
|
||||
|
||||
.. warning:: This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets
|
||||
"""
|
||||
|
||||
return await self.bridge.reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+sentinel`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis+sentinel",
|
||||
"async+valkey+sentinel",
|
||||
]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
MODE = "SENTINEL"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"coredis.sentinel": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.sentinel.Sentinel`
|
||||
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
|
||||
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
|
||||
:param service_name: sentinel service name (if not provided in `uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: optional arguments to pass as
|
||||
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.Sentinel`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis primary host cannot be pinged.
|
||||
"""
|
||||
|
||||
self.service_name = service_name
|
||||
self.use_replicas = use_replicas
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_sentinel(
|
||||
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.util import get_package_data
|
||||
|
||||
|
||||
class RedisBridge(ABC):
|
||||
PREFIX = "LIMITS"
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.parsed_auth = {}
|
||||
if self.parsed_uri.username:
|
||||
self.parsed_auth["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.parsed_auth["password"] = self.parsed_uri.password
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.PREFIX}:{key}"
|
||||
|
||||
@abstractmethod
|
||||
def register_scripts(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_basic(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_cluster(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def lua_reset(self) -> int | None: ...
|
@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncCoRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
|
||||
|
||||
class CoredisBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`coredis.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.exceptions.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.sentinel.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.primary_for(service_name)
|
||||
self.storage_replica = self.sentinel.replica_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts: list[dict[str, int | str]] = []
|
||||
cluster_hosts.extend(
|
||||
{"host": host, "port": int(port)}
|
||||
for loc in self.parsed_uri.netloc[sep:].split(",")
|
||||
if loc
|
||||
for host, port in [loc.split(":")]
|
||||
)
|
||||
self.storage = self.dependency.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_moving_window: coredis.commands.Script[bytes]
|
||||
lua_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_clear_keys: coredis.commands.Script[bytes]
|
||||
lua_incr_expire: coredis.commands.Script[bytes]
|
||||
connection_getter: Callable[[bool], AsyncCoRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
if (value := await self.get_connection().incrby(key, amount)) == amount:
|
||||
await self.get_connection().expire(key, expiry)
|
||||
return value
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete([key])
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window.execute(
|
||||
[key], [timestamp - expiry, limit]
|
||||
)
|
||||
if window:
|
||||
return float(window[0]), window[1] # type: ignore
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
|
||||
if window := await self.lua_sliding_window.execute(
|
||||
[previous_key, current_key], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0), # type: ignore
|
||||
max(0, float(window[1] or 0)) / 1000, # type: ignore
|
||||
int(window[2] or 0), # type: ignore
|
||||
max(0, float(window[3] or 0)) / 1000, # type: ignore
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window.execute(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window.execute(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(prefix)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete([key])
|
||||
return count
|
@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis.commands
|
||||
|
||||
|
||||
class RedispyBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`redis.asyncio.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.asyncio.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.master_for(service_name)
|
||||
self.storage_replica = self.sentinel.slave_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.asyncio.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append(
|
||||
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
|
||||
)
|
||||
|
||||
self.storage = self.dependency.asyncio.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
lua_clear_keys: redis.commands.core.Script
|
||||
lua_incr_expire: redis.commands.core.Script
|
||||
connection_getter: Callable[[bool], AsyncRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
# Redis-py uses a slightly different script registration
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete(key)
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
|
||||
if window:
|
||||
return float(window[0]), window[1]
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
if window := await self.lua_sliding_window(
|
||||
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(
|
||||
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
|
||||
)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete(key)
|
||||
return count
|
@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .redispy import RedispyBridge
|
||||
|
||||
|
||||
class ValkeyBridge(RedispyBridge):
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.ValkeyError,)
|
310
venv/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
310
venv/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
@ -0,0 +1,310 @@
|
||||
"""
|
||||
Asynchronous rate limiting strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from math import floor, inf
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from ..limits import RateLimitItem
|
||||
from ..storage import StorageTypes
|
||||
from ..typing import cast
|
||||
from ..util import WindowStats
|
||||
from .storage import MovingWindowSupport, Storage
|
||||
from .storage.base import SlidingWindowCounterSupport
|
||||
|
||||
|
||||
class RateLimiter(ABC):
|
||||
def __init__(self, storage: StorageTypes):
|
||||
assert isinstance(storage, Storage)
|
||||
self.storage: Storage = storage
|
||||
|
||||
@abstractmethod
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining))
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return await self.storage.clear(item.key_for(*identifiers))
|
||||
|
||||
|
||||
class MovingWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:moving window`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes) -> None:
|
||||
if not (
|
||||
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"MovingWindowRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
|
||||
return await cast(MovingWindowSupport, self.storage).acquire_entry(
|
||||
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
)
|
||||
amount = res[1]
|
||||
|
||||
return amount <= item.amount - cost
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
returns the number of requests remaining within this limit.
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
window_start, window_items = await cast(
|
||||
MovingWindowSupport, self.storage
|
||||
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
|
||||
reset = window_start + item.get_expiry()
|
||||
|
||||
return WindowStats(reset, item.amount - window_items)
|
||||
|
||||
|
||||
class FixedWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:fixed window`
|
||||
"""
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.incr(
|
||||
item.key_for(*identifiers),
|
||||
item.get_expiry(),
|
||||
amount=cost,
|
||||
)
|
||||
<= item.amount
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: reset time, remaining
|
||||
"""
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount - await self.storage.get(item.key_for(*identifiers)),
|
||||
)
|
||||
reset = await self.storage.get_expiry(item.key_for(*identifiers))
|
||||
|
||||
return WindowStats(reset, remaining)
|
||||
|
||||
|
||||
@versionadded(version="4.1")
|
||||
class SlidingWindowCounterRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:sliding window counter`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
||||
storage, "acquire_sliding_window_entry"
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def _weighted_count(
|
||||
self,
|
||||
item: RateLimitItem,
|
||||
previous_count: int,
|
||||
previous_expires_in: float,
|
||||
current_count: int,
|
||||
) -> float:
|
||||
"""
|
||||
Return the approximated by weighting the previous window count and adding the current window count.
|
||||
"""
|
||||
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
return await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).acquire_sliding_window_entry(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
cost,
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
previous_count, previous_expires_in, current_count, _ = await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
return (
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
< item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
|
||||
(
|
||||
previous_count,
|
||||
previous_expires_in,
|
||||
current_count,
|
||||
current_expires_in,
|
||||
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
|
||||
item.key_for(*identifiers), item.get_expiry()
|
||||
)
|
||||
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount
|
||||
- floor(
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if not (previous_count or current_count):
|
||||
return WindowStats(now, remaining)
|
||||
|
||||
expiry = item.get_expiry()
|
||||
|
||||
previous_reset_in, current_reset_in = inf, inf
|
||||
if previous_count:
|
||||
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
||||
if current_count:
|
||||
current_reset_in = current_expires_in % expiry
|
||||
|
||||
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
||||
|
||||
|
||||
STRATEGIES = {
|
||||
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
||||
"fixed-window": FixedWindowRateLimiter,
|
||||
"moving-window": MovingWindowRateLimiter,
|
||||
}
|
30
venv/lib/python3.11/site-packages/limits/errors.py
Normal file
30
venv/lib/python3.11/site-packages/limits/errors.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
errors and exceptions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""
|
||||
Error raised when a configuration problem is encountered
|
||||
"""
|
||||
|
||||
|
||||
class ConcurrentUpdateError(Exception):
|
||||
"""
|
||||
Error raised when an update to limit fails due to concurrent
|
||||
updates
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, attempts: int) -> None:
|
||||
super().__init__(f"Unable to update {key} after {attempts} retries")
|
||||
|
||||
|
||||
class StorageError(Exception):
|
||||
"""
|
||||
Error raised when an error is encountered in a storage
|
||||
"""
|
||||
|
||||
def __init__(self, storage_error: Exception) -> None:
|
||||
self.storage_error = storage_error
|
193
venv/lib/python3.11/site-packages/limits/limits.py
Normal file
193
venv/lib/python3.11/site-packages/limits/limits.py
Normal file
@ -0,0 +1,193 @@
|
||||
""" """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import total_ordering
|
||||
|
||||
from limits.typing import ClassVar, NamedTuple, cast
|
||||
|
||||
|
||||
def safe_string(value: bytes | str | int | float) -> str:
|
||||
"""
|
||||
normalize a byte/str/int or float to a str
|
||||
"""
|
||||
|
||||
if isinstance(value, bytes):
|
||||
return value.decode()
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
class Granularity(NamedTuple):
|
||||
seconds: int
|
||||
name: str
|
||||
|
||||
|
||||
TIME_TYPES = dict(
|
||||
day=Granularity(60 * 60 * 24, "day"),
|
||||
month=Granularity(60 * 60 * 24 * 30, "month"),
|
||||
year=Granularity(60 * 60 * 24 * 30 * 12, "year"),
|
||||
hour=Granularity(60 * 60, "hour"),
|
||||
minute=Granularity(60, "minute"),
|
||||
second=Granularity(1, "second"),
|
||||
)
|
||||
|
||||
GRANULARITIES: dict[str, type[RateLimitItem]] = {}
|
||||
|
||||
|
||||
class RateLimitItemMeta(type):
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
parents: tuple[type, ...],
|
||||
dct: dict[str, Granularity | list[str]],
|
||||
) -> RateLimitItemMeta:
|
||||
if "__slots__" not in dct:
|
||||
dct["__slots__"] = []
|
||||
granularity = super().__new__(cls, name, parents, dct)
|
||||
|
||||
if "GRANULARITY" in dct:
|
||||
GRANULARITIES[dct["GRANULARITY"][1]] = cast(
|
||||
type[RateLimitItem], granularity
|
||||
)
|
||||
|
||||
return granularity
|
||||
|
||||
|
||||
# pylint: disable=no-member
|
||||
@total_ordering
|
||||
class RateLimitItem(metaclass=RateLimitItemMeta):
|
||||
"""
|
||||
defines a Rate limited resource which contains the characteristic
|
||||
namespace, amount and granularity multiples of the rate limiting window.
|
||||
|
||||
:param amount: the rate limit amount
|
||||
:param multiples: multiple of the 'per' :attr:`GRANULARITY`
|
||||
(e.g. 'n' per 'm' seconds)
|
||||
:param namespace: category for the specific rate limit
|
||||
"""
|
||||
|
||||
__slots__ = ["namespace", "amount", "multiples"]
|
||||
|
||||
GRANULARITY: ClassVar[Granularity]
|
||||
"""
|
||||
A tuple describing the granularity of this limit as
|
||||
(number of seconds, name)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, amount: int, multiples: int | None = 1, namespace: str = "LIMITER"
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.amount = int(amount)
|
||||
self.multiples = int(multiples or 1)
|
||||
|
||||
@classmethod
|
||||
def check_granularity_string(cls, granularity_string: str) -> bool:
|
||||
"""
|
||||
Checks if this instance matches a *granularity_string*
|
||||
of type ``n per hour``, ``n per minute`` etc,
|
||||
by comparing with :attr:`GRANULARITY`
|
||||
|
||||
"""
|
||||
|
||||
return granularity_string.lower() in cls.GRANULARITY.name
|
||||
|
||||
def get_expiry(self) -> int:
|
||||
"""
|
||||
:return: the duration the limit is enforced for in seconds.
|
||||
"""
|
||||
|
||||
return self.GRANULARITY.seconds * self.multiples
|
||||
|
||||
def key_for(self, *identifiers: bytes | str | int | float) -> str:
|
||||
"""
|
||||
Constructs a key for the current limit and any additional
|
||||
identifiers provided.
|
||||
|
||||
:param identifiers: a list of strings to append to the key
|
||||
:return: a string key identifying this resource with
|
||||
each identifier separated with a '/' delimiter.
|
||||
"""
|
||||
remainder = "/".join(
|
||||
[safe_string(k) for k in identifiers]
|
||||
+ [
|
||||
safe_string(self.amount),
|
||||
safe_string(self.multiples),
|
||||
self.GRANULARITY.name,
|
||||
]
|
||||
)
|
||||
|
||||
return f"{self.namespace}/{remainder}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, RateLimitItem):
|
||||
return (
|
||||
self.amount == other.amount
|
||||
and self.GRANULARITY == other.GRANULARITY
|
||||
and self.multiples == other.multiples
|
||||
)
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.amount} per {self.multiples} {self.GRANULARITY.name}"
|
||||
|
||||
def __lt__(self, other: RateLimitItem) -> bool:
|
||||
return self.GRANULARITY.seconds < other.GRANULARITY.seconds
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.namespace, self.amount, self.multiples, self.GRANULARITY))
|
||||
|
||||
|
||||
class RateLimitItemPerYear(RateLimitItem):
|
||||
"""
|
||||
per year rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["year"]
|
||||
"""A year"""
|
||||
|
||||
|
||||
class RateLimitItemPerMonth(RateLimitItem):
|
||||
"""
|
||||
per month rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["month"]
|
||||
"""A month"""
|
||||
|
||||
|
||||
class RateLimitItemPerDay(RateLimitItem):
|
||||
"""
|
||||
per day rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["day"]
|
||||
"""A day"""
|
||||
|
||||
|
||||
class RateLimitItemPerHour(RateLimitItem):
|
||||
"""
|
||||
per hour rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["hour"]
|
||||
"""An hour"""
|
||||
|
||||
|
||||
class RateLimitItemPerMinute(RateLimitItem):
|
||||
"""
|
||||
per minute rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["minute"]
|
||||
"""A minute"""
|
||||
|
||||
|
||||
class RateLimitItemPerSecond(RateLimitItem):
|
||||
"""
|
||||
per second rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["second"]
|
||||
"""A second"""
|
0
venv/lib/python3.11/site-packages/limits/py.typed
Normal file
0
venv/lib/python3.11/site-packages/limits/py.typed
Normal file
@ -0,0 +1,26 @@
|
||||
local timestamp = tonumber(ARGV[1])
|
||||
local limit = tonumber(ARGV[2])
|
||||
local expiry = tonumber(ARGV[3])
|
||||
local amount = tonumber(ARGV[4])
|
||||
|
||||
if amount > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
local entry = redis.call('lindex', KEYS[1], limit - amount)
|
||||
|
||||
if entry and tonumber(entry) >= timestamp - expiry then
|
||||
return false
|
||||
end
|
||||
local entries = {}
|
||||
for i = 1, amount do
|
||||
entries[i] = timestamp
|
||||
end
|
||||
|
||||
for i=1,#entries,5000 do
|
||||
redis.call('lpush', KEYS[1], unpack(entries, i, math.min(i+4999, #entries)))
|
||||
end
|
||||
redis.call('ltrim', KEYS[1], 0, limit - 1)
|
||||
redis.call('expire', KEYS[1], expiry)
|
||||
|
||||
return true
|
@ -0,0 +1,45 @@
|
||||
-- Time is in milliseconds in this script: TTL, expiry...
|
||||
|
||||
local limit = tonumber(ARGV[1])
|
||||
local expiry = tonumber(ARGV[2]) * 1000
|
||||
local amount = tonumber(ARGV[3])
|
||||
|
||||
if amount > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
local current_ttl = tonumber(redis.call('pttl', KEYS[2]))
|
||||
|
||||
if current_ttl > 0 and current_ttl < expiry then
|
||||
-- Current window expired, shift it to the previous window
|
||||
redis.call('rename', KEYS[2], KEYS[1])
|
||||
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
||||
end
|
||||
|
||||
local previous_count = tonumber(redis.call('get', KEYS[1])) or 0
|
||||
local previous_ttl = tonumber(redis.call('pttl', KEYS[1])) or 0
|
||||
local current_count = tonumber(redis.call('get', KEYS[2])) or 0
|
||||
current_ttl = tonumber(redis.call('pttl', KEYS[2])) or 0
|
||||
|
||||
-- If the values don't exist yet, consider the TTL is 0
|
||||
if previous_ttl <= 0 then
|
||||
previous_ttl = 0
|
||||
end
|
||||
if current_ttl <= 0 then
|
||||
current_ttl = 0
|
||||
end
|
||||
local weighted_count = math.floor(previous_count * previous_ttl / expiry) + current_count
|
||||
|
||||
if (weighted_count + amount) > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
-- If the current counter exists, increase its value
|
||||
if redis.call('exists', KEYS[2]) == 1 then
|
||||
redis.call('incrby', KEYS[2], amount)
|
||||
else
|
||||
-- Otherwise, set the value with twice the expiry time
|
||||
redis.call('set', KEYS[2], amount, 'PX', expiry * 2)
|
||||
end
|
||||
|
||||
return true
|
@ -0,0 +1,10 @@
|
||||
local keys = redis.call('keys', KEYS[1])
|
||||
local res = 0
|
||||
|
||||
for i=1,#keys,5000 do
|
||||
res = res + redis.call(
|
||||
'del', unpack(keys, i, math.min(i+4999, #keys))
|
||||
)
|
||||
end
|
||||
|
||||
return res
|
@ -0,0 +1,9 @@
|
||||
local current
|
||||
local amount = tonumber(ARGV[2])
|
||||
current = redis.call("incrby", KEYS[1], amount)
|
||||
|
||||
if tonumber(current) == amount then
|
||||
redis.call("expire", KEYS[1], ARGV[1])
|
||||
end
|
||||
|
||||
return current
|
@ -0,0 +1,30 @@
|
||||
local len = tonumber(ARGV[2])
|
||||
local expiry = tonumber(ARGV[1])
|
||||
|
||||
-- Binary search to find the oldest valid entry in the window
|
||||
local function oldest_entry(high, target)
|
||||
local low = 0
|
||||
local result = nil
|
||||
|
||||
while low <= high do
|
||||
local mid = math.floor((low + high) / 2)
|
||||
local val = tonumber(redis.call('lindex', KEYS[1], mid))
|
||||
|
||||
if val and val >= target then
|
||||
result = mid
|
||||
low = mid + 1
|
||||
else
|
||||
high = mid - 1
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
local index = oldest_entry(len - 1, expiry)
|
||||
|
||||
if index then
|
||||
local count = index + 1
|
||||
local oldest = tonumber(redis.call('lindex', KEYS[1], index))
|
||||
return {tostring(oldest), count}
|
||||
end
|
@ -0,0 +1,17 @@
|
||||
local expiry = tonumber(ARGV[1]) * 1000
|
||||
local previous_count = redis.call('get', KEYS[1])
|
||||
local previous_ttl = redis.call('pttl', KEYS[1])
|
||||
local current_count = redis.call('get', KEYS[2])
|
||||
local current_ttl = redis.call('pttl', KEYS[2])
|
||||
|
||||
if current_ttl > 0 and current_ttl < expiry then
|
||||
-- Current window expired, shift it to the previous window
|
||||
redis.call('rename', KEYS[2], KEYS[1])
|
||||
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
||||
previous_count = redis.call('get', KEYS[1])
|
||||
previous_ttl = redis.call('pttl', KEYS[1])
|
||||
current_count = redis.call('get', KEYS[2])
|
||||
current_ttl = redis.call('pttl', KEYS[2])
|
||||
end
|
||||
|
||||
return {previous_count, previous_ttl, current_count, current_ttl}
|
80
venv/lib/python3.11/site-packages/limits/storage/__init__.py
Normal file
80
venv/lib/python3.11/site-packages/limits/storage/__init__.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
import limits # noqa
|
||||
|
||||
from ..errors import ConfigurationError
|
||||
from ..typing import TypeAlias, cast
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage, MongoDBStorageBase
|
||||
from .redis import RedisStorage
|
||||
from .redis_cluster import RedisClusterStorage
|
||||
from .redis_sentinel import RedisSentinelStorage
|
||||
from .registry import SCHEMES
|
||||
|
||||
StorageTypes: TypeAlias = "Storage | limits.aio.storage.Storage"
|
||||
|
||||
|
||||
def storage_from_string(
|
||||
storage_string: str, **options: float | str | bool
|
||||
) -> StorageTypes:
|
||||
"""
|
||||
Factory function to get an instance of the storage class based
|
||||
on the uri of the storage. In most cases using it should be sufficient
|
||||
instead of directly instantiating the storage classes. for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("memory://")
|
||||
memcached = storage_from_string("memcached://localhost:11211")
|
||||
redis = storage_from_string("redis://localhost:6379")
|
||||
|
||||
The same function can be used to construct the :ref:`storage:async storage`
|
||||
variants, for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("async+memory://")
|
||||
memcached = storage_from_string("async+memcached://localhost:11211")
|
||||
redis = storage_from_string("async+redis://localhost:6379")
|
||||
|
||||
:param storage_string: a string of the form ``scheme://host:port``.
|
||||
More details about supported storage schemes can be found at
|
||||
:ref:`storage:storage scheme`
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor matched by :paramref:`storage_string`.
|
||||
:raises ConfigurationError: when the :attr:`storage_string` cannot be
|
||||
mapped to a registered :class:`limits.storage.Storage`
|
||||
or :class:`limits.aio.storage.Storage` instance.
|
||||
|
||||
|
||||
"""
|
||||
scheme = urllib.parse.urlparse(storage_string).scheme
|
||||
|
||||
if scheme not in SCHEMES:
|
||||
raise ConfigurationError(f"unknown storage scheme : {storage_string}")
|
||||
|
||||
return cast(StorageTypes, SCHEMES[scheme](storage_string, **options))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MongoDBStorageBase",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
"storage_from_string",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
232
venv/lib/python3.11/site-packages/limits/storage/base.py
Normal file
232
venv/lib/python3.11/site-packages/limits/storage/base.py
Normal file
@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, R],
|
||||
) -> Callable[P, R]:
|
||||
@functools.wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing a storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TimestampedSlidingWindow:
|
||||
"""Helper class for storage that support the sliding window counter, with timestamp based keys."""
|
||||
|
||||
@classmethod
|
||||
def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
|
||||
"""
|
||||
returns the previous and the current window's keys.
|
||||
|
||||
:param key: the key to get the window's keys from
|
||||
:param expiry: the expiry of the limit item, in seconds
|
||||
:param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
|
||||
|
||||
Returns a tuple with the previous and the current key: (previous, current).
|
||||
|
||||
Example:
|
||||
- key = "mykey"
|
||||
- expiry = 60
|
||||
- at = 1738576292.6631825
|
||||
|
||||
The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
|
||||
"""
|
||||
return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"
|
299
venv/lib/python3.11/site-packages/limits/storage/memcached.py
Normal file
299
venv/lib/python3.11/site-packages/limits/storage/memcached.py
Normal file
@ -0,0 +1,299 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable
|
||||
from math import ceil, floor
|
||||
from types import ModuleType
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.base import (
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
MemcachedClientP,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`pymemcache`.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memcached"]
|
||||
"""The storage scheme for memcached"""
|
||||
DEPENDENCIES = ["pymemcache"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: str | Callable[[], MemcachedClientP],
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``memcached://host:port,host:port``,
|
||||
``memcached:///var/tmp/path/to/sock``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`pymemcache.client.base.PooledClient`
|
||||
or :class:`pymemcache.client.hash.HashClient` (if there are more than
|
||||
one hosts specified)
|
||||
:raise ConfigurationError: when :pypi:`pymemcache` is not available
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
self.hosts = []
|
||||
|
||||
for loc in parsed.netloc.strip().split(","):
|
||||
if not loc:
|
||||
continue
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
else:
|
||||
# filesystem path to UDS
|
||||
|
||||
if parsed.path and not parsed.netloc and not parsed.port:
|
||||
self.hosts = [parsed.path] # type: ignore
|
||||
|
||||
self.dependency = self.dependencies["pymemcache"].module
|
||||
self.library = str(options.pop("library", "pymemcache.client"))
|
||||
self.cluster_library = str(
|
||||
options.pop("cluster_library", "pymemcache.client.hash")
|
||||
)
|
||||
self.client_getter = cast(
|
||||
Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
|
||||
options.pop("client_getter", self.get_client),
|
||||
)
|
||||
self.options = options
|
||||
|
||||
if not get_dependency(self.library):
|
||||
raise ConfigurationError(
|
||||
f"memcached prerequisite not available. please install {self.library}"
|
||||
) # pragma: no cover
|
||||
self.local_storage = threading.local()
|
||||
self.local_storage.storage = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.dependency.MemcacheError # type: ignore[no-any-return]
|
||||
|
||||
def get_client(
|
||||
self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
|
||||
) -> MemcachedClientP:
|
||||
"""
|
||||
returns a memcached client.
|
||||
|
||||
:param module: the memcached module
|
||||
:param hosts: list of memcached hosts
|
||||
"""
|
||||
|
||||
return cast(
|
||||
MemcachedClientP,
|
||||
(
|
||||
module.HashClient(hosts, **kwargs)
|
||||
if len(hosts) > 1
|
||||
else module.PooledClient(*hosts, **kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
def call_memcached_func(
|
||||
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> R:
|
||||
if "noreply" in kwargs:
|
||||
argspec = inspect.getfullargspec(func)
|
||||
|
||||
if not ("noreply" in argspec.args or argspec.varkw):
|
||||
kwargs.pop("noreply")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def storage(self) -> MemcachedClientP:
|
||||
"""
|
||||
lazily creates a memcached client instance using a thread local
|
||||
"""
|
||||
|
||||
if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
|
||||
dependency = get_dependency(
|
||||
self.cluster_library if len(self.hosts) > 1 else self.library
|
||||
)[0]
|
||||
|
||||
if not dependency:
|
||||
raise ConfigurationError(f"Unable to import {self.cluster_library}")
|
||||
self.local_storage.storage = self.client_getter(
|
||||
dependency, self.hosts, **self.options
|
||||
)
|
||||
|
||||
return cast(MemcachedClientP, self.local_storage.storage)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return int(self.storage.get(key, "0"))
|
||||
|
||||
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.storage.get_many(keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.delete(key)
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
|
||||
"""
|
||||
if (
|
||||
value := self.call_memcached_func(
|
||||
self.storage.incr, key, amount, noreply=False
|
||||
)
|
||||
) is not None:
|
||||
return value
|
||||
else:
|
||||
if not self.call_memcached_func(
|
||||
self.storage.add, key, amount, ceil(expiry), noreply=False
|
||||
):
|
||||
return self.storage.incr(key, amount) or amount
|
||||
else:
|
||||
if set_expiration_key:
|
||||
self.call_memcached_func(
|
||||
self.storage.set,
|
||||
self._expiration_key(key),
|
||||
expiry + time.time(),
|
||||
expire=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
|
||||
return amount
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return float(self.storage.get(self._expiration_key(key)) or time.time())
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
self.call_memcached_func(self.storage.get, "limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now=now
|
||||
)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.call_memcached_func(
|
||||
self.storage.decr,
|
||||
current_key,
|
||||
amount,
|
||||
noreply=True,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = self.get_many([previous_key, current_key])
|
||||
previous_count, current_count = (
|
||||
int(result.get(previous_key, 0)),
|
||||
int(result.get(current_key, 0)),
|
||||
)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
253
venv/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
253
venv/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
import limits.typing
|
||||
from limits.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: float) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed and sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memory"]
|
||||
|
||||
def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.locks = defaultdict(threading.RLock)
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
def __expire_events(self) -> None:
|
||||
for key in list(self.events.keys()):
|
||||
with self.locks[key]:
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -time.time(), key=lambda event: -event.expiry
|
||||
)
|
||||
self.events[key] = self.events[key][:oldest]
|
||||
if not self.events.get(key, None):
|
||||
self.locks.pop(key, None)
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def __schedule_expiry(self) -> None:
|
||||
if not self.timer.is_alive():
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key
|
||||
|
||||
:param key: the key to decrement
|
||||
:param amount: the number to decrement by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = self.get(previous_key)
|
||||
current_count = self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
return num_items
|
489
venv/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
489
venv/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
@ -0,0 +1,489 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.typing import (
|
||||
MongoClient,
|
||||
MongoCollection,
|
||||
MongoDatabase,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..util import get_dependency
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
|
||||
class MongoDBStorageBase(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
|
||||
):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`pymongo`.
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: int | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor of :class:`~pymongo.mongo_client.MongoClient`
|
||||
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
|
||||
"""
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self._database_name = database_name
|
||||
self._collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.lib = self.dependencies["pymongo"].module
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
self._storage_uri = uri
|
||||
self._storage_options = options
|
||||
self._storage: MongoClient | None = None
|
||||
|
||||
@property
|
||||
def storage(self) -> MongoClient:
|
||||
if self._storage is None:
|
||||
self._storage = self._init_mongo_client(
|
||||
self._storage_uri, **self._storage_options
|
||||
)
|
||||
self.__initialize_database()
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def _database(self) -> MongoDatabase:
|
||||
return self.storage[self._database_name]
|
||||
|
||||
@property
|
||||
def counters(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["counters"]]
|
||||
|
||||
@property
|
||||
def windows(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["windows"]]
|
||||
|
||||
@abstractmethod
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
def __initialize_database(self) -> None:
|
||||
self.counters.create_index("expireAt", expireAfterSeconds=0)
|
||||
self.windows.create_index("expireAt", expireAfterSeconds=0)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
|
||||
self.counters.drop()
|
||||
self.windows.drop()
|
||||
|
||||
return int(num_keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.counters.find_one_and_delete({"_id": key})
|
||||
self.windows.find_one_and_delete({"_id": key})
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = self.counters.find_one({"_id": key})
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = self.counters.find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
return int(
|
||||
self.counters.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
)["count"]
|
||||
)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
|
||||
"""
|
||||
try:
|
||||
self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if result := list(
|
||||
self.windows.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
self.windows.update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.lib.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$add": ["$expireAt", expiry_ms],
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
expiry_ms = expiry * 1000
|
||||
result = self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.close()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(MongoDBStorageBase):
|
||||
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
|
||||
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
return cast(MongoClient, self.lib.MongoClient(uri, **options))
|
308
venv/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
308
venv/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.typing import Literal, RedisClient
|
||||
|
||||
from ..util import get_package_data
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"redis",
|
||||
"rediss",
|
||||
"redis+unix",
|
||||
"valkey",
|
||||
"valkeys",
|
||||
"valkey+unix",
|
||||
]
|
||||
"""The storage scheme for redis"""
|
||||
|
||||
DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}
|
||||
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
|
||||
PREFIX = "LIMITS"
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
connection_pool: redis.connection.ConnectionPool | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``redis://[:password]@host:port``,
|
||||
``redis://[:password]@host:port/db``,
|
||||
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
|
||||
This uri is passed directly to :func:`redis.from_url` except for the
|
||||
case of ``redis+unix://`` where it is replaced with ``unix://``.
|
||||
|
||||
If the uri scheme is ``valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.Redis`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not available
|
||||
"""
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
if not connection_pool:
|
||||
self.storage = self.dependency.from_url(uri, **options)
|
||||
else:
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Valkey(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependency.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependency.ValkeyError
|
||||
)
|
||||
|
||||
def initialize_storage(self, _uri: str) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return cast(RedisClient, self.storage)
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.PREFIX}:{key}"
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
|
||||
return float(window[0]), window[1]
|
||||
|
||||
return timestamp, 0
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.get_connection(True).get(key) or 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
self.get_connection().delete(key)
|
||||
|
||||
def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry. Shift the current window to the previous window if it expired.
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
acquired = self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(self.get_connection(True).ttl(key), 0) + time.time()
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
return self.get_connection().ping()
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
``self.PREFIX`` in blocks of 5000.
|
||||
|
||||
.. warning::
|
||||
This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
|
||||
"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
return int(self.lua_clear_keys([prefix]))
|
@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.storage.redis import RedisStorage
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="""
|
||||
Dropped support for the :pypi:`redis-py-cluster` library
|
||||
which has been abandoned/deprecated.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="2.5.0",
|
||||
reason="""
|
||||
Cluster support was provided by the :pypi:`redis-py-cluster` library
|
||||
which has been absorbed into the official :pypi:`redis` client. By
|
||||
default the :class:`redis.cluster.RedisCluster` client will be used
|
||||
however if the version of the package is lower than ``4.2.0`` the implementation
|
||||
will fallback to trying to use :class:`rediscluster.RedisCluster`.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+cluster://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri`
|
||||
starts with ``valkey+cluster://``).
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+cluster", "valkey+cluster"]
|
||||
"""The storage scheme for redis cluster"""
|
||||
|
||||
DEFAULT_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to the :class:`~redis.cluster.RedisCluster`"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("4.2.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.cluster.RedisCluster`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not
|
||||
available or if the redis cluster cannot be reached.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append((host, int(port)))
|
||||
|
||||
self.storage = None
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
merged_options = {**self.DEFAULT_OPTIONS, **parsed_auth, **options}
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
startup_nodes = [self.dependency.cluster.ClusterNode(*c) for c in cluster_hosts]
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.cluster.RedisCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.cluster.ValkeyCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
|
||||
assert self.storage
|
||||
self.initialize_storage(uri)
|
||||
super(RedisStorage, self).__init__(uri, wrap_exceptions, **options)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
|
||||
one at a time.
|
||||
|
||||
.. warning::
|
||||
This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
count = 0
|
||||
for primary in self.storage.get_primaries():
|
||||
node = self.storage.get_redis_connection(primary)
|
||||
keys = node.keys(prefix)
|
||||
count += sum([node.delete(k.decode("utf-8")) for k in keys])
|
||||
return count
|
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.redis import RedisStorage
|
||||
from limits.typing import RedisClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+sentinel://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`redis` package (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey+sentinel://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+sentinel", "valkey+sentinel"]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("3.0"),
|
||||
"redis.sentinel": Version("3.0"),
|
||||
"valkey": Version("6.0"),
|
||||
"valkey.sentinel": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri scheme is ``valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param service_name: sentinel service name
|
||||
(if not provided in :attr:`uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: kwargs to pass as
|
||||
:attr:`sentinel_kwargs` to :class:`redis.sentinel.Sentinel`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis master host cannot be pinged.
|
||||
"""
|
||||
|
||||
super(RedisStorage, self).__init__(
|
||||
uri, wrap_exceptions=wrap_exceptions, **options
|
||||
)
|
||||
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
sentinel_configuration = []
|
||||
sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
|
||||
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
self.service_name = (
|
||||
parsed.path.replace("/", "") if parsed.path else service_name
|
||||
)
|
||||
|
||||
if self.service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
sentinel_dep = self.dependencies[f"{self.target_server}.sentinel"].module
|
||||
self.sentinel = sentinel_dep.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**parsed_auth, **sentinel_options},
|
||||
**{**parsed_auth, **options},
|
||||
)
|
||||
self.storage: RedisClient = self.sentinel.master_for(self.service_name)
|
||||
self.storage_slave: RedisClient = self.sentinel.slave_for(self.service_name)
|
||||
self.use_replicas = use_replicas
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependencies["redis"].module.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependencies["valkey"].module.ValkeyError
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return self.storage_slave if (readonly and self.use_replicas) else self.storage
|
24
venv/lib/python3.11/site-packages/limits/storage/registry.py
Normal file
24
venv/lib/python3.11/site-packages/limits/storage/registry.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
SCHEMES: dict[str, StorageRegistry] = {}
|
||||
|
||||
|
||||
class StorageRegistry(ABCMeta):
|
||||
def __new__(
|
||||
mcs, name: str, bases: tuple[type, ...], dct: dict[str, str | list[str]]
|
||||
) -> StorageRegistry:
|
||||
storage_scheme = dct.get("STORAGE_SCHEME", None)
|
||||
cls = super().__new__(mcs, name, bases, dct)
|
||||
|
||||
if storage_scheme:
|
||||
if isinstance(storage_scheme, str): # noqa
|
||||
schemes = [storage_scheme]
|
||||
else:
|
||||
schemes = storage_scheme
|
||||
|
||||
for scheme in schemes:
|
||||
SCHEMES[scheme] = cls
|
||||
|
||||
return cls
|
298
venv/lib/python3.11/site-packages/limits/strategies.py
Normal file
298
venv/lib/python3.11/site-packages/limits/strategies.py
Normal file
@ -0,0 +1,298 @@
|
||||
"""
|
||||
Rate limiting strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from math import floor, inf
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from limits.storage.base import SlidingWindowCounterSupport
|
||||
|
||||
from .limits import RateLimitItem
|
||||
from .storage import MovingWindowSupport, Storage, StorageTypes
|
||||
from .typing import cast
|
||||
from .util import WindowStats
|
||||
|
||||
|
||||
class RateLimiter(metaclass=ABCMeta):
|
||||
def __init__(self, storage: StorageTypes):
|
||||
assert isinstance(storage, Storage)
|
||||
self.storage: Storage = storage
|
||||
|
||||
@abstractmethod
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check the rate limit without consuming from it.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return self.storage.clear(item.key_for(*identifiers))
|
||||
|
||||
|
||||
class MovingWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:moving window`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not (
|
||||
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"MovingWindowRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
|
||||
return cast(MovingWindowSupport, self.storage).acquire_entry(
|
||||
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
cast(MovingWindowSupport, self.storage).get_moving_window(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
)[1]
|
||||
<= item.amount - cost
|
||||
)
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
returns the number of requests remaining within this limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: tuple (reset time, remaining)
|
||||
"""
|
||||
window_start, window_items = cast(
|
||||
MovingWindowSupport, self.storage
|
||||
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
|
||||
reset = window_start + item.get_expiry()
|
||||
|
||||
return WindowStats(reset, item.amount - window_items)
|
||||
|
||||
|
||||
class FixedWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:fixed window`
|
||||
"""
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
|
||||
return (
|
||||
self.storage.incr(
|
||||
item.key_for(*identifiers),
|
||||
item.get_expiry(),
|
||||
amount=cost,
|
||||
)
|
||||
<= item.amount
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
|
||||
return self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
remaining = max(0, item.amount - self.storage.get(item.key_for(*identifiers)))
|
||||
reset = self.storage.get_expiry(item.key_for(*identifiers))
|
||||
|
||||
return WindowStats(reset, remaining)
|
||||
|
||||
|
||||
@versionadded(version="4.1")
|
||||
class SlidingWindowCounterRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:sliding window counter`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
||||
storage, "acquire_sliding_window_entry"
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def _weighted_count(
|
||||
self,
|
||||
item: RateLimitItem,
|
||||
previous_count: int,
|
||||
previous_expires_in: float,
|
||||
current_count: int,
|
||||
) -> float:
|
||||
"""
|
||||
Return the approximated by weighting the previous window count and adding the current window count.
|
||||
"""
|
||||
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
"""
|
||||
return cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).acquire_sliding_window_entry(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
cost,
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
"""
|
||||
previous_count, previous_expires_in, current_count, _ = cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
return (
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
< item.amount - cost + 1
|
||||
)
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: WindowStats(reset time, remaining)
|
||||
"""
|
||||
previous_count, previous_expires_in, current_count, current_expires_in = cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount
|
||||
- floor(
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if not (previous_count or current_count):
|
||||
return WindowStats(now, remaining)
|
||||
|
||||
expiry = item.get_expiry()
|
||||
|
||||
previous_reset_in, current_reset_in = inf, inf
|
||||
if previous_count:
|
||||
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
||||
if current_count:
|
||||
current_reset_in = current_expires_in % expiry
|
||||
|
||||
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
||||
|
||||
|
||||
KnownStrategy = (
|
||||
type[SlidingWindowCounterRateLimiter]
|
||||
| type[FixedWindowRateLimiter]
|
||||
| type[MovingWindowRateLimiter]
|
||||
)
|
||||
|
||||
STRATEGIES: dict[str, KnownStrategy] = {
|
||||
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
||||
"fixed-window": FixedWindowRateLimiter,
|
||||
"moving-window": MovingWindowRateLimiter,
|
||||
}
|
127
venv/lib/python3.11/site-packages/limits/typing.py
Normal file
127
venv/lib/python3.11/site-packages/limits/typing.py
Normal file
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
Serializable = int | str | float
|
||||
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
import pymongo.collection
|
||||
import pymongo.database
|
||||
import pymongo.mongo_client
|
||||
import redis
|
||||
|
||||
|
||||
class MemcachedClientP(Protocol):
|
||||
def add(
|
||||
self,
|
||||
key: str,
|
||||
value: Serializable,
|
||||
expire: int | None = 0,
|
||||
noreply: bool | None = None,
|
||||
flags: int | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
def get(self, key: str, default: str | None = None) -> bytes: ...
|
||||
|
||||
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: ... # type:ignore[explicit-any]
|
||||
|
||||
def incr(
|
||||
self, key: str, value: int, noreply: bool | None = False
|
||||
) -> int | None: ...
|
||||
|
||||
def decr(
|
||||
self,
|
||||
key: str,
|
||||
value: int,
|
||||
noreply: bool | None = False,
|
||||
) -> int | None: ...
|
||||
|
||||
def delete(self, key: str, noreply: bool | None = None) -> bool | None: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Serializable,
|
||||
expire: int = 0,
|
||||
noreply: bool | None = None,
|
||||
flags: int | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
def touch(
|
||||
self, key: str, expire: int | None = 0, noreply: bool | None = None
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class RedisClientP(Protocol):
|
||||
def incrby(self, key: str, amount: int) -> int: ...
|
||||
def get(self, key: str) -> bytes | None: ...
|
||||
def delete(self, key: str) -> int: ...
|
||||
def ttl(self, key: str) -> int: ...
|
||||
def expire(self, key: str, seconds: int) -> bool: ...
|
||||
def ping(self) -> bool: ...
|
||||
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
|
||||
|
||||
|
||||
class AsyncRedisClientP(Protocol):
|
||||
async def incrby(self, key: str, amount: int) -> int: ...
|
||||
async def get(self, key: str) -> bytes | None: ...
|
||||
async def delete(self, key: str) -> int: ...
|
||||
async def ttl(self, key: str) -> int: ...
|
||||
async def expire(self, key: str, seconds: int) -> bool: ...
|
||||
async def ping(self) -> bool: ...
|
||||
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
|
||||
|
||||
|
||||
RedisClient: TypeAlias = RedisClientP
|
||||
AsyncRedisClient: TypeAlias = AsyncRedisClientP
|
||||
AsyncCoRedisClient: TypeAlias = "coredis.Redis[bytes] | coredis.RedisCluster[bytes]"
|
||||
|
||||
MongoClient: TypeAlias = "pymongo.mongo_client.MongoClient[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
MongoDatabase: TypeAlias = "pymongo.database.Database[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
MongoCollection: TypeAlias = "pymongo.collection.Collection[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
|
||||
__all__ = [
|
||||
"TYPE_CHECKING",
|
||||
"Any",
|
||||
"AsyncRedisClient",
|
||||
"Awaitable",
|
||||
"Callable",
|
||||
"ClassVar",
|
||||
"Counter",
|
||||
"Iterable",
|
||||
"Literal",
|
||||
"MemcachedClientP",
|
||||
"MongoClient",
|
||||
"MongoCollection",
|
||||
"MongoDatabase",
|
||||
"NamedTuple",
|
||||
"P",
|
||||
"ParamSpec",
|
||||
"Protocol",
|
||||
"R",
|
||||
"R_co",
|
||||
"RedisClient",
|
||||
"Serializable",
|
||||
"TypeAlias",
|
||||
"TypeVar",
|
||||
"cast",
|
||||
]
|
207
venv/lib/python3.11/site-packages/limits/util.py
Normal file
207
venv/lib/python3.11/site-packages/limits/util.py
Normal file
@ -0,0 +1,207 @@
|
||||
""" """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import re
|
||||
import sys
|
||||
from collections import UserDict
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.typing import NamedTuple
|
||||
|
||||
from .errors import ConfigurationError
|
||||
from .limits import GRANULARITIES, RateLimitItem
|
||||
|
||||
SEPARATORS = re.compile(r"[,;|]{1}")
|
||||
SINGLE_EXPR = re.compile(
|
||||
r"""
|
||||
\s*([0-9]+)
|
||||
\s*(/|\s*per\s*)
|
||||
\s*([0-9]+)
|
||||
*\s*(hour|minute|second|day|month|year)s?\s*""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
EXPR = re.compile(
|
||||
rf"^{SINGLE_EXPR.pattern}(:?{SEPARATORS.pattern}{SINGLE_EXPR.pattern})*$",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
class WindowStats(NamedTuple):
|
||||
"""
|
||||
tuple to describe a rate limited window
|
||||
"""
|
||||
|
||||
#: Time as seconds since the Epoch when this window will be reset
|
||||
reset_time: float
|
||||
#: Quantity remaining in this window
|
||||
remaining: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Dependency:
|
||||
name: str
|
||||
version_required: Version | None
|
||||
version_found: Version | None
|
||||
module: ModuleType
|
||||
|
||||
|
||||
MissingModule = ModuleType("Missing")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_UserDict = UserDict[str, Dependency]
|
||||
else:
|
||||
_UserDict = UserDict
|
||||
|
||||
|
||||
class DependencyDict(_UserDict):
|
||||
def __getitem__(self, key: str) -> Dependency:
|
||||
dependency = super().__getitem__(key)
|
||||
|
||||
if dependency.module is MissingModule:
|
||||
message = f"'{dependency.name}' prerequisite not available."
|
||||
if dependency.version_required:
|
||||
message += (
|
||||
f" A minimum version of {dependency.version_required} is required."
|
||||
if dependency.version_required
|
||||
else ""
|
||||
)
|
||||
message += (
|
||||
" See https://limits.readthedocs.io/en/stable/storage.html#supported-versions"
|
||||
" for more details."
|
||||
)
|
||||
raise ConfigurationError(message)
|
||||
elif dependency.version_required and (
|
||||
not dependency.version_found
|
||||
or dependency.version_found < dependency.version_required
|
||||
):
|
||||
raise ConfigurationError(
|
||||
f"The minimum version of {dependency.version_required}"
|
||||
f" for '{dependency.name}' could not be found. Found version: {dependency.version_found}"
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
class LazyDependency:
|
||||
"""
|
||||
Simple utility that provides an :attr:`dependency`
|
||||
to the child class to fetch any dependencies
|
||||
without having to import them explicitly.
|
||||
"""
|
||||
|
||||
DEPENDENCIES: dict[str, Version | None] | list[str] = []
|
||||
"""
|
||||
The python modules this class has a dependency on.
|
||||
Used to lazily populate the :attr:`dependencies`
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dependencies: DependencyDict = DependencyDict()
|
||||
|
||||
@property
|
||||
def dependencies(self) -> DependencyDict:
|
||||
"""
|
||||
Cached mapping of the modules this storage depends on.
|
||||
This is done so that the module is only imported lazily
|
||||
when the storage is instantiated.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
if not getattr(self, "_dependencies", None):
|
||||
dependencies = DependencyDict()
|
||||
mapping: dict[str, Version | None]
|
||||
|
||||
if isinstance(self.DEPENDENCIES, list):
|
||||
mapping = {dependency: None for dependency in self.DEPENDENCIES}
|
||||
else:
|
||||
mapping = self.DEPENDENCIES
|
||||
|
||||
for name, minimum_version in mapping.items():
|
||||
dependency, version = get_dependency(name)
|
||||
|
||||
dependencies[name] = Dependency(
|
||||
name, minimum_version, version, dependency
|
||||
)
|
||||
self._dependencies = dependencies
|
||||
|
||||
return self._dependencies
|
||||
|
||||
|
||||
def get_dependency(module_path: str) -> tuple[ModuleType, Version | None]:
|
||||
"""
|
||||
safe function to import a module at runtime
|
||||
"""
|
||||
try:
|
||||
if module_path not in sys.modules:
|
||||
__import__(module_path)
|
||||
root = module_path.split(".")[0]
|
||||
version = getattr(sys.modules[root], "__version__", "0.0.0")
|
||||
|
||||
return sys.modules[module_path], Version(version)
|
||||
except ImportError: # pragma: no cover
|
||||
return MissingModule, None
|
||||
|
||||
|
||||
def get_package_data(path: str) -> bytes:
|
||||
return importlib.resources.files("limits").joinpath(path).read_bytes()
|
||||
|
||||
|
||||
def parse_many(limit_string: str) -> list[RateLimitItem]:
|
||||
"""
|
||||
parses rate limits in string notation containing multiple rate limits
|
||||
(e.g. ``1/second; 5/minute``)
|
||||
|
||||
:param limit_string: rate limit string using :ref:`ratelimit-string`
|
||||
:raise ValueError: if the string notation is invalid.
|
||||
|
||||
"""
|
||||
|
||||
if not (isinstance(limit_string, str) and EXPR.match(limit_string)):
|
||||
raise ValueError(f"couldn't parse rate limit string '{limit_string}'")
|
||||
limits = []
|
||||
|
||||
for limit in SEPARATORS.split(limit_string):
|
||||
match = SINGLE_EXPR.match(limit)
|
||||
|
||||
if match:
|
||||
amount, _, multiples, granularity_string = match.groups()
|
||||
granularity = granularity_from_string(granularity_string)
|
||||
limits.append(
|
||||
granularity(int(amount), multiples and int(multiples) or None)
|
||||
)
|
||||
|
||||
return limits
|
||||
|
||||
|
||||
def parse(limit_string: str) -> RateLimitItem:
|
||||
"""
|
||||
parses a single rate limit in string notation
|
||||
(e.g. ``1/second`` or ``1 per second``)
|
||||
|
||||
:param limit_string: rate limit string using :ref:`ratelimit-string`
|
||||
:raise ValueError: if the string notation is invalid.
|
||||
|
||||
"""
|
||||
|
||||
return list(parse_many(limit_string))[0]
|
||||
|
||||
|
||||
def granularity_from_string(granularity_string: str) -> type[RateLimitItem]:
|
||||
"""
|
||||
|
||||
:param granularity_string:
|
||||
:raise ValueError:
|
||||
"""
|
||||
|
||||
for granularity in GRANULARITIES.values():
|
||||
if granularity.check_granularity_string(granularity_string):
|
||||
return granularity
|
||||
raise ValueError(f"no granularity matched for {granularity_string}")
|
3
venv/lib/python3.11/site-packages/limits/version.py
Normal file
3
venv/lib/python3.11/site-packages/limits/version.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
empty file to be updated by versioneer
|
||||
"""
|
Reference in New Issue
Block a user