Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,35 @@
"""
Rate limiting with commonly used storage backends
"""
from __future__ import annotations
from . import _version, aio, storage, strategies
from .limits import (
RateLimitItem,
RateLimitItemPerDay,
RateLimitItemPerHour,
RateLimitItemPerMinute,
RateLimitItemPerMonth,
RateLimitItemPerSecond,
RateLimitItemPerYear,
)
from .util import WindowStats, parse, parse_many
__all__ = [
"RateLimitItem",
"RateLimitItemPerDay",
"RateLimitItemPerHour",
"RateLimitItemPerMinute",
"RateLimitItemPerMonth",
"RateLimitItemPerSecond",
"RateLimitItemPerYear",
"WindowStats",
"aio",
"parse",
"parse_many",
"storage",
"strategies",
]
__version__ = _version.get_versions()["version"]

View File

@ -0,0 +1,21 @@
# This file was generated by 'versioneer.py' (0.29) from
# revision-control system data, or from the parent directory name of an
# unpacked source archive. Distribution tarballs contain a pre-generated copy
# of this file.
import json
version_json = '''
{
"date": "2025-04-15T17:28:21-0700",
"dirty": false,
"error": null,
"full-revisionid": "eeb02fd85c146292dd70fb798ac90a486ba163bd",
"version": "5.0.0"
}
''' # END VERSION_JSON
def get_versions():
return json.loads(version_json)

View File

@ -0,0 +1,8 @@
from __future__ import annotations
from . import storage, strategies
__all__ = [
"storage",
"strategies",
]

View File

@ -0,0 +1,24 @@
"""
Implementations of storage backends to be used with
:class:`limits.aio.strategies.RateLimiter` strategies
"""
from __future__ import annotations
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from .memcached import MemcachedStorage
from .memory import MemoryStorage
from .mongodb import MongoDBStorage
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
__all__ = [
"MemcachedStorage",
"MemoryStorage",
"MongoDBStorage",
"MovingWindowSupport",
"RedisClusterStorage",
"RedisSentinelStorage",
"RedisStorage",
"SlidingWindowCounterSupport",
"Storage",
]

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import functools
from abc import ABC, abstractmethod
from deprecated.sphinx import versionadded
from limits import errors
from limits.storage.registry import StorageRegistry
from limits.typing import (
Any,
Awaitable,
Callable,
P,
R,
cast,
)
from limits.util import LazyDependency
def _wrap_errors(
fn: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]:
@functools.wraps(fn)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
instance = cast(Storage, args[0])
try:
return await fn(*args, **kwargs)
except instance.base_exceptions as exc:
if instance.wrap_exceptions:
raise errors.StorageError(exc) from exc
raise
return inner
@versionadded(version="2.1")
class Storage(LazyDependency, metaclass=StorageRegistry):
"""
Base class to extend when implementing an async storage backend.
"""
STORAGE_SCHEME: list[str] | None
"""The storage schemes to register against this implementation"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
super().__init_subclass__(**kwargs)
for method in {
"incr",
"get",
"get_expiry",
"check",
"reset",
"clear",
}:
setattr(cls, method, _wrap_errors(getattr(cls, method)))
super().__init_subclass__(**kwargs)
def __init__(
self,
uri: str | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
"""
super().__init__()
self.wrap_exceptions = wrap_exceptions
@property
@abstractmethod
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
raise NotImplementedError
@abstractmethod
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
raise NotImplementedError
@abstractmethod
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
raise NotImplementedError
@abstractmethod
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
raise NotImplementedError
@abstractmethod
async def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError
@abstractmethod
async def reset(self) -> int | None:
"""
reset storage to clear limits
"""
raise NotImplementedError
@abstractmethod
async def clear(self, key: str) -> None:
"""
resets the rate limit key
:param key: the key to clear rate limits for
"""
raise NotImplementedError
class MovingWindowSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:moving window` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_entry",
"get_moving_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
raise NotImplementedError
class SlidingWindowCounterSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:sliding window counter` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
Acquire an entry if the weighted count of the current and previous
windows is less than or equal to the limit
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
"""
Return the previous and current window information.
:param key: the rate limit key
:param expiry: the rate limit expiry, needed to compute the key in some implementations
:return: a tuple of (int, float, int, float) with the following information:
- previous window counter
- previous window TTL
- current window counter
- current window TTL
"""
raise NotImplementedError

View File

@ -0,0 +1,184 @@
from __future__ import annotations
import time
from math import floor
from deprecated.sphinx import versionadded, versionchanged
from packaging.version import Version
from limits.aio.storage import SlidingWindowCounterSupport, Storage
from limits.aio.storage.memcached.bridge import MemcachedBridge
from limits.aio.storage.memcached.emcache import EmcacheBridge
from limits.aio.storage.memcached.memcachio import MemcachioBridge
from limits.storage.base import TimestampedSlidingWindow
from limits.typing import Literal
@versionadded(version="2.1")
@versionchanged(
version="5.0",
reason="Switched default implementation to :pypi:`memcachio`",
)
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
"""
Rate limit storage with memcached as backend.
Depends on :pypi:`memcachio`
"""
STORAGE_SCHEME = ["async+memcached"]
"""The storage scheme for memcached to be used in an async context"""
DEPENDENCIES = {
"memcachio": Version("0.3"),
"emcache": Version("0.0"),
}
bridge: MemcachedBridge
storage_exceptions: tuple[Exception, ...]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["memcachio", "emcache"] = "memcachio",
**options: float | str | bool,
) -> None:
"""
:param uri: memcached location of the form
``async+memcached://host:port,host:port``
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``memcachio``: :class:`memcachio.Client`
- ``emcache``: :class:`emcache.Client`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`memcachio.Client`
:raise ConfigurationError: when :pypi:`memcachio` is not available
"""
if implementation == "emcache":
self.bridge = EmcacheBridge(
uri, self.dependencies["emcache"].module, **options
)
else:
self.bridge = MemcachioBridge(
uri, self.dependencies["memcachio"].module, **options
)
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.bridge.base_exceptions
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return await self.bridge.get(key)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
await self.bridge.clear(key)
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
window every hit.
:param amount: the number to increment by
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
"""
return await self.bridge.incr(
key, expiry, amount, set_expiration_key=set_expiration_key
)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return await self.bridge.get_expiry(key)
async def reset(self) -> int | None:
raise NotImplementedError
async def check(self) -> bool:
return await self.bridge.check()
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
t0 = time.time()
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
# We don't need the expiration key as it is estimated with the timestamps directly.
current_count = await self.incr(
current_key, 2 * expiry, amount=amount, set_expiration_key=False
)
t1 = time.time()
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
weighted_count = (
previous_count * actualised_previous_ttl / expiry + current_count
)
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the increment and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
await self.bridge.decr(current_key, amount, noreply=True)
return False
return True
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return await self._get_sliding_window_info(
previous_key, current_key, expiry, now
)
async def _get_sliding_window_info(
self, previous_key: str, current_key: str, expiry: int, now: float
) -> tuple[int, float, int, float]:
result = await self.bridge.get_many([previous_key, current_key])
previous_count = result.get(previous_key.encode("utf-8"), 0)
current_count = result.get(current_key.encode("utf-8"), 0)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl

View File

@ -0,0 +1,73 @@
from __future__ import annotations
import urllib
from abc import ABC, abstractmethod
from types import ModuleType
from limits.typing import Iterable
class MemcachedBridge(ABC):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
self.uri = uri
self.parsed_uri = urllib.parse.urlparse(self.uri)
self.dependency = dependency
self.hosts = []
self.options = options
sep = self.parsed_uri.netloc.strip().find("@") + 1
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
host, port = loc.split(":")
self.hosts.append((host, int(port)))
if self.parsed_uri.username:
self.options["username"] = self.parsed_uri.username
if self.parsed_uri.password:
self.options["password"] = self.parsed_uri.password
def _expiration_key(self, key: str) -> str:
"""
Return the expiration key for the given counter key.
Memcached doesn't natively return the expiration time or TTL for a given key,
so we implement the expiration time on a separate key.
"""
return key + "/expires"
@property
@abstractmethod
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: ...
@abstractmethod
async def get(self, key: str) -> int: ...
@abstractmethod
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
@abstractmethod
async def clear(self, key: str) -> None: ...
@abstractmethod
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
@abstractmethod
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int: ...
@abstractmethod
async def get_expiry(self, key: str) -> float: ...
@abstractmethod
async def check(self) -> bool: ...

View File

@ -0,0 +1,112 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from limits.typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import emcache
class EmcacheBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage = None
async def get_storage(self) -> emcache.Client:
if not self._storage:
self._storage = await self.dependency.create_client(
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
item = await (await self.get_storage()).get(key.encode("utf-8"))
return item and int(item.value) or 0
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
results = await (await self.get_storage()).get_many(
[k.encode("utf-8") for k in keys]
)
return {k: int(item.value) if item else 0 for k, item in results.items()}
async def clear(self, key: str) -> None:
try:
await (await self.get_storage()).delete(key.encode("utf-8"))
except self.dependency.NotFoundCommandError:
pass
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
try:
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
except self.dependency.NotFoundCommandError:
value = 0
return value
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
try:
return await storage.increment(limit_key, amount) or amount
except self.dependency.NotFoundCommandError:
storage = await self.get_storage()
try:
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
exptime=ceil(expiry),
noreply=False,
)
value = amount
except self.dependency.NotStoredStorageCommandError:
# Coult not add the key, probably because a concurrent call has added it
storage = await self.get_storage()
value = await storage.increment(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
item = await storage.get(self._expiration_key(key).encode("utf-8"))
return item and float(item.value) or time.time()
pass
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return (
self.dependency.ClusterNoAvailableNodes,
self.dependency.CommandError,
)
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False

View File

@ -0,0 +1,104 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import memcachio
class MemcachioBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage: memcachio.Client[bytes] | None = None
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]:
return (
self.dependency.errors.NoAvailableNodes,
self.dependency.errors.MemcachioConnectionError,
)
async def get_storage(self) -> memcachio.Client[bytes]:
if not self._storage:
self._storage = self.dependency.Client(
[(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
"""
Return multiple counters at once
:param keys: the keys to get the counter values for
"""
results = await (await self.get_storage()).get(
*[k.encode("utf-8") for k in keys]
)
return {k: int(v.value) for k, v in results.items()}
async def clear(self, key: str) -> None:
await (await self.get_storage()).delete(key.encode("utf-8"))
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
return await storage.decr(limit_key, amount, noreply=noreply) or 0
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
if (value := (await storage.incr(limit_key, amount))) is None:
storage = await self.get_storage()
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
expiry=ceil(expiry),
noreply=False,
)
return amount
else:
storage = await self.get_storage()
return await storage.incr(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
expiration_key = self._expiration_key(key).encode("utf-8")
item = (await storage.get(expiration_key)).get(expiration_key, None)
return item and float(item.value) or time.time()
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False

View File

@ -0,0 +1,281 @@
from __future__ import annotations
import asyncio
import bisect
import time
from collections import Counter, defaultdict
from math import floor
from deprecated.sphinx import versionadded
import limits.typing
from limits.aio.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
)
from limits.storage.base import TimestampedSlidingWindow
class Entry:
def __init__(self, expiry: int) -> None:
self.atime = time.time()
self.expiry = self.atime + expiry
@versionadded(version="2.1")
class MemoryStorage(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
):
"""
rate limit storage using :class:`collections.Counter`
as an in memory storage for fixed & sliding window strategies,
and a simple list to implement moving window strategy.
"""
STORAGE_SCHEME = ["async+memory"]
"""
The storage scheme for in process memory storage for use in an
async context
"""
def __init__(
self, uri: str | None = None, wrap_exceptions: bool = False, **_: str
) -> None:
self.storage: limits.typing.Counter[str] = Counter()
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.expirations: dict[str, float] = {}
self.events: dict[str, list[Entry]] = {}
self.timer: asyncio.Task[None] | None = None
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
state = self.__dict__.copy()
del state["timer"]
del state["locks"]
return state
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
self.__dict__.update(state)
self.timer = None
self.locks = defaultdict(asyncio.Lock)
asyncio.ensure_future(self.__schedule_expiry())
async def __expire_events(self) -> None:
try:
now = time.time()
for key in list(self.events.keys()):
cutoff = await asyncio.to_thread(
lambda evts: bisect.bisect_left(
evts, -now, key=lambda event: -event.expiry
),
self.events[key],
)
async with self.locks[key]:
if self.events.get(key, []):
self.events[key] = self.events[key][:cutoff]
if not self.events.get(key, None):
self.events.pop(key, None)
self.locks.pop(key, None)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
except asyncio.CancelledError:
return
async def __schedule_expiry(self) -> None:
if not self.timer or self.timer.done():
self.timer = asyncio.create_task(self.__expire_events())
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ValueError
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
await self.get(key)
await self.__schedule_expiry()
async with self.locks[key]:
self.storage[key] += amount
if self.storage[key] == amount:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, amount)
async def decr(self, key: str, amount: int = 1) -> int:
"""
decrements the counter for a given rate limit key. 0 is the minimum allowed value.
:param amount: the number to increment by
"""
await self.get(key)
await self.__schedule_expiry()
async with self.locks[key]:
self.storage[key] = max(self.storage[key] - amount, 0)
return self.storage.get(key, amount)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
return self.storage.get(key, 0)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)
self.locks.pop(key, None)
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
await self.__schedule_expiry()
async with self.locks[key]:
self.events.setdefault(key, [])
timestamp = time.time()
try:
entry: Entry | None = self.events[key][limit - amount]
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
return False
else:
self.events[key][:0] = [Entry(expiry)] * amount
return True
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return self.expirations.get(key, time.time())
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -(timestamp - expiry), key=lambda entry: -entry.atime
)
return events[oldest - 1].atime, oldest
return timestamp, 0
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
current_count = await self.incr(current_key, 2 * expiry, amount=amount)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
await self.decr(current_key, amount)
return False
return True
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return await self._get_sliding_window_info(
previous_key, current_key, expiry, now
)
async def _get_sliding_window_info(
self,
previous_key: str,
current_key: str,
expiry: int,
now: float,
) -> tuple[int, float, int, float]:
previous_count = await self.get(previous_key)
current_count = await self.get(current_key)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl
async def check(self) -> bool:
"""
check if storage is healthy
"""
return True
async def reset(self) -> int | None:
num_items = max(len(self.storage), len(self.events))
self.storage.clear()
self.expirations.clear()
self.events.clear()
self.locks.clear()
return num_items
def __del__(self) -> None:
try:
if self.timer and not self.timer.done():
self.timer.cancel()
except RuntimeError: # noqa
pass

View File

@ -0,0 +1,517 @@
from __future__ import annotations
import asyncio
import datetime
import time
from deprecated.sphinx import versionadded, versionchanged
from limits.aio.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
)
from limits.typing import (
ParamSpec,
TypeVar,
cast,
)
from limits.util import get_dependency
P = ParamSpec("P")
R = TypeVar("R")
@versionadded(version="2.1")
@versionchanged(
version="3.14.0",
reason="Added option to select custom collection names for windows & counters",
)
class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with MongoDB as backend.
Depends on :pypi:`motor`
"""
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
"""
The storage scheme for MongoDB for use in an async context
"""
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
def __init__(
self,
uri: str,
database_name: str = "limits",
counter_collection_name: str = "counters",
window_collection_name: str = "windows",
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
:param database_name: The database to use for storing the rate limit
collections.
:param counter_collection_name: The collection name to use for individual counters
used in fixed window strategies
:param window_collection_name: The collection name to use for sliding & moving window
storage
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
not available
"""
uri = uri.replace("async+mongodb", "mongodb", 1)
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self.dependency = self.dependencies["motor.motor_asyncio"]
self.proxy_dependency = self.dependencies["pymongo"]
self.lib_errors, _ = get_dependency("pymongo.errors")
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
# TODO: Fix this hack. It was noticed when running a benchmark
# with FastAPI - however - doesn't appear in unit tests or in an isolated
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
self.storage.get_io_loop = asyncio.get_running_loop
self.__database_name = database_name
self.__collection_mapping = {
"counters": counter_collection_name,
"windows": window_collection_name,
}
self.__indices_created = False
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.lib_errors.PyMongoError # type: ignore
@property
def database(self): # type: ignore
return self.storage.get_database(self.__database_name)
async def create_indices(self) -> None:
if not self.__indices_created:
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].create_index(
"expireAt", expireAfterSeconds=0
),
self.database[self.__collection_mapping["windows"]].create_index(
"expireAt", expireAfterSeconds=0
),
)
self.__indices_created = True
async def reset(self) -> int | None:
"""
Delete all rate limit keys in the rate limit collections (counters, windows)
"""
num_keys = sum(
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].count_documents(
{}
),
self.database[self.__collection_mapping["windows"]].count_documents({}),
)
)
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].drop(),
self.database[self.__collection_mapping["windows"]].drop(),
)
return cast(int, num_keys)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].find_one_and_delete(
{"_id": key}
),
self.database[self.__collection_mapping["windows"]].find_one_and_delete(
{"_id": key}
),
)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = await self.database[self.__collection_mapping["counters"]].find_one(
{"_id": key}
)
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
counter = await self.database[self.__collection_mapping["counters"]].find_one(
{
"_id": key,
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
},
projection=["count"],
)
return counter and counter["count"] or 0
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
await self.create_indices()
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=expiry
)
response = await self.database[
self.__collection_mapping["counters"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"count": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": amount,
"else": {"$add": ["$count", amount]},
}
},
"expireAt": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": expiration,
"else": "$expireAt",
}
},
}
},
],
upsert=True,
projection=["count"],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
)
return int(response["count"])
async def check(self) -> bool:
"""
Check if storage is healthy by calling
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
"""
try:
await self.storage.server_info()
return True
except: # noqa: E722
return False
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param str key: rate limit key
:param int expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if (
result := await self.database[self.__collection_mapping["windows"]]
.aggregate(
[
{"$match": {"_id": key}},
{
"$project": {
"filteredEntries": {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$gte": ["$$entry", timestamp - expiry]},
}
}
}
},
{
"$project": {
"min": {"$min": "$filteredEntries"},
"count": {"$size": "$filteredEntries"},
}
},
]
)
.to_list(length=1)
):
return result[0]["min"], result[0]["count"]
return timestamp, 0
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
await self.create_indices()
if amount > limit:
return False
timestamp = time.time()
try:
updates: dict[
str,
dict[str, datetime.datetime | dict[str, list[float] | int]],
] = {
"$push": {
"entries": {
"$each": [timestamp] * amount,
"$position": 0,
"$slice": limit,
}
},
"$set": {
"expireAt": (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expiry)
)
},
}
await self.database[self.__collection_mapping["windows"]].update_one(
{
"_id": key,
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
},
updates,
upsert=True,
)
return True
except self.proxy_dependency.module.errors.DuplicateKeyError:
return False
async def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
await self.create_indices()
expiry_ms = expiry * 1000
result = await self.database[
self.__collection_mapping["windows"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$cond": {
"if": {"$gt": ["$expireAt", 0]},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
}
},
"else": "$expireAt",
}
},
}
},
{
"$set": {
"curWeightedCount": {
"$floor": {
"$add": [
{
"$multiply": [
"$previousCount",
{
"$divide": [
{
"$max": [
0,
{
"$subtract": [
"$expireAt",
{
"$add": [
"$$NOW",
expiry_ms,
]
},
]
},
]
},
expiry_ms,
]
},
]
},
"$currentCount",
]
}
}
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$add": ["$curWeightedCount", amount]},
limit,
]
},
"then": {"$add": ["$currentCount", amount]},
"else": "$currentCount",
}
}
}
},
{
"$set": {
"_acquired": {
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
}
}
},
{"$unset": ["curWeightedCount"]},
],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
upsert=True,
)
return cast(bool, result["_acquired"])
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
expiry_ms = expiry * 1000
if result := await self.database[
self.__collection_mapping["windows"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": "$expireAt",
}
},
}
}
],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
projection=["currentCount", "previousCount", "expireAt"],
):
expires_at = (
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
if result.get("expireAt")
else time.time()
)
current_ttl = max(0, expires_at - time.time())
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
return (
result["previousCount"],
prev_ttl,
result["currentCount"],
current_ttl,
)
return 0, 0.0, 0, 0.0
def __del__(self) -> None:
self.storage and self.storage.close()

View File

@ -0,0 +1,400 @@
from __future__ import annotations
from deprecated.sphinx import versionadded, versionchanged
from packaging.version import Version
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from limits.aio.storage.redis.bridge import RedisBridge
from limits.aio.storage.redis.coredis import CoredisBridge
from limits.aio.storage.redis.redispy import RedispyBridge
from limits.aio.storage.redis.valkey import ValkeyBridge
from limits.typing import Literal
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason=(
"Added support for using the asyncio redis client from :pypi:`redis`"
" through :paramref:`implementation`"
),
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey`` schema"
),
)
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with redis as backend.
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = [
"async+redis",
"async+rediss",
"async+redis+unix",
"async+valkey",
"async+valkeys",
"async+valkey+unix",
]
"""
The storage schemes for redis to be used in an async context
"""
DEPENDENCIES = {
"redis": Version("5.2.0"),
"coredis": Version("3.4.0"),
"valkey": Version("6.0"),
}
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
bridge: RedisBridge
storage_exceptions: tuple[Exception, ...]
target_server: Literal["redis", "valkey"]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form:
- ``async+redis://[:password]@host:port``
- ``async+redis://[:password]@host:port/db``
- ``async+rediss://[:password]@host:port``
- ``async+redis+unix:///path/to/sock?db=0`` etc...
This uri is passed directly to :meth:`coredis.Redis.from_url` or
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
If the uri scheme is ``async+valkey`` the implementation used will be from
:pypi:`valkey`.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.Redis`
- ``redispy``: :class:`redis.asyncio.client.Redis`
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
:raise ConfigurationError: when the redis library is not available
"""
uri = uri.removeprefix("async+")
self.target_server = "redis" if uri.startswith("redis") else "valkey"
uri = uri.replace(f"{self.target_server}+unix", "unix")
super().__init__(uri, wrap_exceptions=wrap_exceptions)
self.options = options
if self.target_server == "valkey" or implementation == "valkey":
self.bridge = ValkeyBridge(uri, self.dependencies["valkey"].module)
else:
if implementation == "redispy":
self.bridge = RedispyBridge(uri, self.dependencies["redis"].module)
else:
self.bridge = CoredisBridge(uri, self.dependencies["coredis"].module)
self.configure_bridge()
self.bridge.register_scripts()
def _current_window_key(self, key: str) -> str:
"""
Return the current window's storage key (Sliding window strategy)
Contrary to other strategies that have one key per rate limit item,
this strategy has two keys per rate limit item than must be on the same machine.
To keep the current key and the previous key on the same Redis cluster node,
curly braces are added.
Eg: "{constructed_key}"
"""
return f"{{{key}}}"
def _previous_window_key(self, key: str) -> str:
"""
Return the previous window's storage key (Sliding window strategy).
Curvy braces are added on the common pattern with the current window's key,
so the current and the previous key are stored on the same Redis cluster node.
Eg: "{constructed_key}/-1"
"""
return f"{self._current_window_key(key)}/-1"
def configure_bridge(self) -> None:
self.bridge.use_basic(**self.options)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.bridge.base_exceptions
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
return await self.bridge.incr(key, expiry, amount)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return await self.bridge.get(key)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
return await self.bridge.clear(key)
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
return await self.bridge.acquire_entry(key, limit, expiry, amount)
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (previous count, previous TTL, current count, current TTL)
"""
return await self.bridge.get_moving_window(key, limit, expiry)
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
current_key = self._current_window_key(key)
previous_key = self._previous_window_key(key)
return await self.bridge.acquire_sliding_window_entry(
previous_key, current_key, limit, expiry, amount
)
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self._previous_window_key(key)
current_key = self._current_window_key(key)
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return await self.bridge.get_expiry(key)
async def check(self) -> bool:
"""
Check if storage is healthy by calling ``PING``
"""
return await self.bridge.check()
async def reset(self) -> int | None:
"""
This function calls a Lua Script to delete keys prefixed with
``self.PREFIX`` in blocks of 5000.
.. warning:: This operation was designed to be fast, but was not tested
on a large production based system. Be careful with its usage as it
could be slow on very large data sets.
"""
return await self.bridge.lua_reset()
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey+cluster`` schema"
),
)
class RedisClusterStorage(RedisStorage):
"""
Rate limit storage with redis cluster as backend
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
"""
The storage schemes for redis cluster to be used in an async context
"""
MODE = "CLUSTER"
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``async+redis+cluster://[:password]@host:port,host:port``
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.RedisCluster`
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.RedisCluster` or
:class:`redis.asyncio.RedisCluster`
:raise ConfigurationError: when the redis library is not
available or if the redis host cannot be pinged.
"""
super().__init__(
uri,
wrap_exceptions=wrap_exceptions,
implementation=implementation,
**options,
)
def configure_bridge(self) -> None:
self.bridge.use_cluster(**self.options)
async def reset(self) -> int | None:
"""
Redis Clusters are sharded and deleting across shards
can't be done atomically. Because of this, this reset loops over all
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
one at a time.
.. warning:: This operation was not tested with extremely large data sets.
On a large production based system, care should be taken with its
usage as it could be slow on very large data sets
"""
return await self.bridge.reset()
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey+sentinel`` schema"
),
)
class RedisSentinelStorage(RedisStorage):
"""
Rate limit storage with redis sentinel as backend
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = [
"async+redis+sentinel",
"async+valkey+sentinel",
]
"""The storage scheme for redis accessed via a redis sentinel installation"""
MODE = "SENTINEL"
DEPENDENCIES = {
"redis": Version("5.2.0"),
"coredis": Version("3.4.0"),
"coredis.sentinel": Version("3.4.0"),
"valkey": Version("6.0"),
}
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
service_name: str | None = None,
use_replicas: bool = True,
sentinel_kwargs: dict[str, float | str | bool] | None = None,
**options: float | str | bool,
):
"""
:param uri: url of the form
``async+redis+sentinel://host:port,host:port/service_name``
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.sentinel.Sentinel`
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
:param service_name: sentinel service name (if not provided in `uri`)
:param use_replicas: Whether to use replicas for read only operations
:param sentinel_kwargs: optional arguments to pass as
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
:class:`redis.asyncio.Sentinel`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
:class:`redis.asyncio.sentinel.Sentinel`
:raise ConfigurationError: when the redis library is not available
or if the redis primary host cannot be pinged.
"""
self.service_name = service_name
self.use_replicas = use_replicas
self.sentinel_kwargs = sentinel_kwargs
super().__init__(
uri,
wrap_exceptions=wrap_exceptions,
implementation=implementation,
**options,
)
def configure_bridge(self) -> None:
self.bridge.use_sentinel(
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
)

View File

@ -0,0 +1,119 @@
from __future__ import annotations
import urllib
from abc import ABC, abstractmethod
from types import ModuleType
from limits.util import get_package_data
class RedisBridge(ABC):
PREFIX = "LIMITS"
RES_DIR = "resources/redis/lua_scripts"
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_moving_window.lua"
)
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_sliding_window.lua"
)
def __init__(
self,
uri: str,
dependency: ModuleType,
) -> None:
self.uri = uri
self.parsed_uri = urllib.parse.urlparse(self.uri)
self.dependency = dependency
self.parsed_auth = {}
if self.parsed_uri.username:
self.parsed_auth["username"] = self.parsed_uri.username
if self.parsed_uri.password:
self.parsed_auth["password"] = self.parsed_uri.password
def prefixed_key(self, key: str) -> str:
return f"{self.PREFIX}:{key}"
@abstractmethod
def register_scripts(self) -> None: ...
@abstractmethod
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None: ...
@abstractmethod
def use_basic(self, **options: str | float | bool) -> None: ...
@abstractmethod
def use_cluster(self, **options: str | float | bool) -> None: ...
@property
@abstractmethod
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: ...
@abstractmethod
async def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int: ...
@abstractmethod
async def get(self, key: str) -> int: ...
@abstractmethod
async def clear(self, key: str) -> None: ...
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]: ...
@abstractmethod
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]: ...
@abstractmethod
async def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool: ...
@abstractmethod
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool: ...
@abstractmethod
async def get_expiry(self, key: str) -> float: ...
@abstractmethod
async def check(self) -> bool: ...
@abstractmethod
async def reset(self) -> int | None: ...
@abstractmethod
async def lua_reset(self) -> int | None: ...

View File

@ -0,0 +1,205 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from limits.aio.storage.redis.bridge import RedisBridge
from limits.errors import ConfigurationError
from limits.typing import AsyncCoRedisClient, Callable
if TYPE_CHECKING:
import coredis
class CoredisBridge(RedisBridge):
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to :class:`coredis.RedisCluster`"
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.exceptions.RedisError,)
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None:
sentinel_configuration = []
connection_options = options.copy()
sep = self.parsed_uri.netloc.find("@") + 1
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
service_name = (
self.parsed_uri.path.replace("/", "")
if self.parsed_uri.path
else service_name
)
if service_name is None:
raise ConfigurationError("'service_name' not provided")
self.sentinel = self.dependency.sentinel.Sentinel(
sentinel_configuration,
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
**{**self.parsed_auth, **connection_options},
)
self.storage = self.sentinel.primary_for(service_name)
self.storage_replica = self.sentinel.replica_for(service_name)
self.connection_getter = lambda readonly: (
self.storage_replica if readonly and use_replicas else self.storage
)
def use_basic(self, **options: str | float | bool) -> None:
if connection_pool := options.pop("connection_pool", None):
self.storage = self.dependency.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.Redis.from_url(self.uri, **options)
self.connection_getter = lambda _: self.storage
def use_cluster(self, **options: str | float | bool) -> None:
sep = self.parsed_uri.netloc.find("@") + 1
cluster_hosts: list[dict[str, int | str]] = []
cluster_hosts.extend(
{"host": host, "port": int(port)}
for loc in self.parsed_uri.netloc[sep:].split(",")
if loc
for host, port in [loc.split(":")]
)
self.storage = self.dependency.RedisCluster(
startup_nodes=cluster_hosts,
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
)
self.connection_getter = lambda _: self.storage
lua_moving_window: coredis.commands.Script[bytes]
lua_acquire_moving_window: coredis.commands.Script[bytes]
lua_sliding_window: coredis.commands.Script[bytes]
lua_acquire_sliding_window: coredis.commands.Script[bytes]
lua_clear_keys: coredis.commands.Script[bytes]
lua_incr_expire: coredis.commands.Script[bytes]
connection_getter: Callable[[bool], AsyncCoRedisClient]
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
return self.connection_getter(readonly)
def register_scripts(self) -> None:
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
key = self.prefixed_key(key)
if (value := await self.get_connection().incrby(key, amount)) == amount:
await self.get_connection().expire(key, expiry)
return value
async def get(self, key: str) -> int:
key = self.prefixed_key(key)
return int(await self.get_connection(readonly=True).get(key) or 0)
async def clear(self, key: str) -> None:
key = self.prefixed_key(key)
await self.get_connection().delete([key])
async def lua_reset(self) -> int | None:
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
key = self.prefixed_key(key)
timestamp = time.time()
window = await self.lua_moving_window.execute(
[key], [timestamp - expiry, limit]
)
if window:
return float(window[0]), window[1] # type: ignore
return timestamp, 0
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
if window := await self.lua_sliding_window.execute(
[previous_key, current_key], [expiry]
):
return (
int(window[0] or 0), # type: ignore
max(0, float(window[1] or 0)) / 1000, # type: ignore
int(window[2] or 0), # type: ignore
max(0, float(window[3] or 0)) / 1000, # type: ignore
)
return 0, 0.0, 0, 0.0
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
key = self.prefixed_key(key)
timestamp = time.time()
acquired = await self.lua_acquire_moving_window.execute(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
acquired = await self.lua_acquire_sliding_window.execute(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
async def get_expiry(self, key: str) -> float:
key = self.prefixed_key(key)
return max(await self.get_connection().ttl(key), 0) + time.time()
async def check(self) -> bool:
try:
await self.get_connection().ping()
return True
except: # noqa
return False
async def reset(self) -> int | None:
prefix = self.prefixed_key("*")
keys = await self.storage.keys(prefix)
count = 0
for key in keys:
count += await self.storage.delete([key])
return count

View File

@ -0,0 +1,250 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from limits.aio.storage.redis.bridge import RedisBridge
from limits.errors import ConfigurationError
from limits.typing import AsyncRedisClient, Callable
if TYPE_CHECKING:
import redis.commands
class RedispyBridge(RedisBridge):
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to :class:`redis.asyncio.RedisCluster`"
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.RedisError,)
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None:
sentinel_configuration = []
connection_options = options.copy()
sep = self.parsed_uri.netloc.find("@") + 1
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
service_name = (
self.parsed_uri.path.replace("/", "")
if self.parsed_uri.path
else service_name
)
if service_name is None:
raise ConfigurationError("'service_name' not provided")
self.sentinel = self.dependency.asyncio.Sentinel(
sentinel_configuration,
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
**{**self.parsed_auth, **connection_options},
)
self.storage = self.sentinel.master_for(service_name)
self.storage_replica = self.sentinel.slave_for(service_name)
self.connection_getter = lambda readonly: (
self.storage_replica if readonly and use_replicas else self.storage
)
def use_basic(self, **options: str | float | bool) -> None:
if connection_pool := options.pop("connection_pool", None):
self.storage = self.dependency.asyncio.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
self.connection_getter = lambda _: self.storage
def use_cluster(self, **options: str | float | bool) -> None:
sep = self.parsed_uri.netloc.find("@") + 1
cluster_hosts = []
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
cluster_hosts.append(
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
)
self.storage = self.dependency.asyncio.RedisCluster(
startup_nodes=cluster_hosts,
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
)
self.connection_getter = lambda _: self.storage
lua_moving_window: redis.commands.core.Script
lua_acquire_moving_window: redis.commands.core.Script
lua_sliding_window: redis.commands.core.Script
lua_acquire_sliding_window: redis.commands.core.Script
lua_clear_keys: redis.commands.core.Script
lua_incr_expire: redis.commands.core.Script
connection_getter: Callable[[bool], AsyncRedisClient]
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
return self.connection_getter(readonly)
def register_scripts(self) -> None:
# Redis-py uses a slightly different script registration
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
async def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
key = self.prefixed_key(key)
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
key = self.prefixed_key(key)
return int(await self.get_connection(readonly=True).get(key) or 0)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
key = self.prefixed_key(key)
await self.get_connection().delete(key)
async def lua_reset(self) -> int | None:
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (previous count, previous TTL, current count, current TTL)
"""
key = self.prefixed_key(key)
timestamp = time.time()
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
if window:
return float(window[0]), window[1]
return timestamp, 0
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]:
if window := await self.lua_sliding_window(
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
):
return (
int(window[0] or 0),
max(0, float(window[1] or 0)) / 1000,
int(window[2] or 0),
max(0, float(window[3] or 0)) / 1000,
)
return 0, 0.0, 0, 0.0
async def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
"""
key = self.prefixed_key(key)
timestamp = time.time()
acquired = await self.lua_acquire_moving_window(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
acquired = await self.lua_acquire_sliding_window(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
key = self.prefixed_key(key)
return max(await self.get_connection().ttl(key), 0) + time.time()
async def check(self) -> bool:
"""
check if storage is healthy
"""
try:
await self.get_connection().ping()
return True
except: # noqa
return False
async def reset(self) -> int | None:
prefix = self.prefixed_key("*")
keys = await self.storage.keys(
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
)
count = 0
for key in keys:
count += await self.storage.delete(key)
return count

View File

@ -0,0 +1,9 @@
from __future__ import annotations
from .redispy import RedispyBridge
class ValkeyBridge(RedispyBridge):
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.ValkeyError,)

View File

@ -0,0 +1,310 @@
"""
Asynchronous rate limiting strategies
"""
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from math import floor, inf
from deprecated.sphinx import versionadded
from ..limits import RateLimitItem
from ..storage import StorageTypes
from ..typing import cast
from ..util import WindowStats
from .storage import MovingWindowSupport, Storage
from .storage.base import SlidingWindowCounterSupport
class RateLimiter(ABC):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = storage
@abstractmethod
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
raise NotImplementedError
@abstractmethod
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
"""
raise NotImplementedError
@abstractmethod
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining))
"""
raise NotImplementedError
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return await self.storage.clear(item.key_for(*identifiers))
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes) -> None:
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
return await cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
"""
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)
amount = res[1]
return amount <= item.amount - cost
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
returns the number of requests remaining within this limit.
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining)
"""
window_start, window_items = await cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return WindowStats(reset, item.amount - window_items)
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
return (
await self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
amount=cost,
)
<= item.amount
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
"""
return (
await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
)
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: reset time, remaining
"""
remaining = max(
0,
item.amount - await self.storage.get(item.key_for(*identifiers)),
)
reset = await self.storage.get_expiry(item.key_for(*identifiers))
return WindowStats(reset, remaining)
@versionadded(version="4.1")
class SlidingWindowCounterRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:sliding window counter`
"""
def __init__(self, storage: StorageTypes):
if not hasattr(storage, "get_sliding_window") or not hasattr(
storage, "acquire_sliding_window_entry"
):
raise NotImplementedError(
"SlidingWindowCounterRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
def _weighted_count(
self,
item: RateLimitItem,
previous_count: int,
previous_expires_in: float,
current_count: int,
) -> float:
"""
Return the approximated by weighting the previous window count and adding the current window count.
"""
return previous_count * previous_expires_in / item.get_expiry() + current_count
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return await cast(
SlidingWindowCounterSupport, self.storage
).acquire_sliding_window_entry(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
cost,
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
"""
previous_count, previous_expires_in, current_count, _ = await cast(
SlidingWindowCounterSupport, self.storage
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
return (
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
< item.amount - cost + 1
)
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
(
previous_count,
previous_expires_in,
current_count,
current_expires_in,
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
item.key_for(*identifiers), item.get_expiry()
)
remaining = max(
0,
item.amount
- floor(
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
),
)
now = time.time()
if not (previous_count or current_count):
return WindowStats(now, remaining)
expiry = item.get_expiry()
previous_reset_in, current_reset_in = inf, inf
if previous_count:
previous_reset_in = previous_expires_in % (expiry / previous_count)
if current_count:
current_reset_in = current_expires_in % expiry
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
STRATEGIES = {
"sliding-window-counter": SlidingWindowCounterRateLimiter,
"fixed-window": FixedWindowRateLimiter,
"moving-window": MovingWindowRateLimiter,
}

View File

@ -0,0 +1,30 @@
"""
errors and exceptions
"""
from __future__ import annotations
class ConfigurationError(Exception):
"""
Error raised when a configuration problem is encountered
"""
class ConcurrentUpdateError(Exception):
"""
Error raised when an update to limit fails due to concurrent
updates
"""
def __init__(self, key: str, attempts: int) -> None:
super().__init__(f"Unable to update {key} after {attempts} retries")
class StorageError(Exception):
"""
Error raised when an error is encountered in a storage
"""
def __init__(self, storage_error: Exception) -> None:
self.storage_error = storage_error

View File

@ -0,0 +1,193 @@
""" """
from __future__ import annotations
from functools import total_ordering
from limits.typing import ClassVar, NamedTuple, cast
def safe_string(value: bytes | str | int | float) -> str:
"""
normalize a byte/str/int or float to a str
"""
if isinstance(value, bytes):
return value.decode()
return str(value)
class Granularity(NamedTuple):
seconds: int
name: str
TIME_TYPES = dict(
day=Granularity(60 * 60 * 24, "day"),
month=Granularity(60 * 60 * 24 * 30, "month"),
year=Granularity(60 * 60 * 24 * 30 * 12, "year"),
hour=Granularity(60 * 60, "hour"),
minute=Granularity(60, "minute"),
second=Granularity(1, "second"),
)
GRANULARITIES: dict[str, type[RateLimitItem]] = {}
class RateLimitItemMeta(type):
def __new__(
cls,
name: str,
parents: tuple[type, ...],
dct: dict[str, Granularity | list[str]],
) -> RateLimitItemMeta:
if "__slots__" not in dct:
dct["__slots__"] = []
granularity = super().__new__(cls, name, parents, dct)
if "GRANULARITY" in dct:
GRANULARITIES[dct["GRANULARITY"][1]] = cast(
type[RateLimitItem], granularity
)
return granularity
# pylint: disable=no-member
@total_ordering
class RateLimitItem(metaclass=RateLimitItemMeta):
"""
defines a Rate limited resource which contains the characteristic
namespace, amount and granularity multiples of the rate limiting window.
:param amount: the rate limit amount
:param multiples: multiple of the 'per' :attr:`GRANULARITY`
(e.g. 'n' per 'm' seconds)
:param namespace: category for the specific rate limit
"""
__slots__ = ["namespace", "amount", "multiples"]
GRANULARITY: ClassVar[Granularity]
"""
A tuple describing the granularity of this limit as
(number of seconds, name)
"""
def __init__(
self, amount: int, multiples: int | None = 1, namespace: str = "LIMITER"
):
self.namespace = namespace
self.amount = int(amount)
self.multiples = int(multiples or 1)
@classmethod
def check_granularity_string(cls, granularity_string: str) -> bool:
"""
Checks if this instance matches a *granularity_string*
of type ``n per hour``, ``n per minute`` etc,
by comparing with :attr:`GRANULARITY`
"""
return granularity_string.lower() in cls.GRANULARITY.name
def get_expiry(self) -> int:
"""
:return: the duration the limit is enforced for in seconds.
"""
return self.GRANULARITY.seconds * self.multiples
def key_for(self, *identifiers: bytes | str | int | float) -> str:
"""
Constructs a key for the current limit and any additional
identifiers provided.
:param identifiers: a list of strings to append to the key
:return: a string key identifying this resource with
each identifier separated with a '/' delimiter.
"""
remainder = "/".join(
[safe_string(k) for k in identifiers]
+ [
safe_string(self.amount),
safe_string(self.multiples),
self.GRANULARITY.name,
]
)
return f"{self.namespace}/{remainder}"
def __eq__(self, other: object) -> bool:
if isinstance(other, RateLimitItem):
return (
self.amount == other.amount
and self.GRANULARITY == other.GRANULARITY
and self.multiples == other.multiples
)
return False
def __repr__(self) -> str:
return f"{self.amount} per {self.multiples} {self.GRANULARITY.name}"
def __lt__(self, other: RateLimitItem) -> bool:
return self.GRANULARITY.seconds < other.GRANULARITY.seconds
def __hash__(self) -> int:
return hash((self.namespace, self.amount, self.multiples, self.GRANULARITY))
class RateLimitItemPerYear(RateLimitItem):
"""
per year rate limited resource.
"""
GRANULARITY = TIME_TYPES["year"]
"""A year"""
class RateLimitItemPerMonth(RateLimitItem):
"""
per month rate limited resource.
"""
GRANULARITY = TIME_TYPES["month"]
"""A month"""
class RateLimitItemPerDay(RateLimitItem):
"""
per day rate limited resource.
"""
GRANULARITY = TIME_TYPES["day"]
"""A day"""
class RateLimitItemPerHour(RateLimitItem):
"""
per hour rate limited resource.
"""
GRANULARITY = TIME_TYPES["hour"]
"""An hour"""
class RateLimitItemPerMinute(RateLimitItem):
"""
per minute rate limited resource.
"""
GRANULARITY = TIME_TYPES["minute"]
"""A minute"""
class RateLimitItemPerSecond(RateLimitItem):
"""
per second rate limited resource.
"""
GRANULARITY = TIME_TYPES["second"]
"""A second"""

View File

@ -0,0 +1,26 @@
local timestamp = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local expiry = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
if amount > limit then
return false
end
local entry = redis.call('lindex', KEYS[1], limit - amount)
if entry and tonumber(entry) >= timestamp - expiry then
return false
end
local entries = {}
for i = 1, amount do
entries[i] = timestamp
end
for i=1,#entries,5000 do
redis.call('lpush', KEYS[1], unpack(entries, i, math.min(i+4999, #entries)))
end
redis.call('ltrim', KEYS[1], 0, limit - 1)
redis.call('expire', KEYS[1], expiry)
return true

View File

@ -0,0 +1,45 @@
-- Time is in milliseconds in this script: TTL, expiry...
local limit = tonumber(ARGV[1])
local expiry = tonumber(ARGV[2]) * 1000
local amount = tonumber(ARGV[3])
if amount > limit then
return false
end
local current_ttl = tonumber(redis.call('pttl', KEYS[2]))
if current_ttl > 0 and current_ttl < expiry then
-- Current window expired, shift it to the previous window
redis.call('rename', KEYS[2], KEYS[1])
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
end
local previous_count = tonumber(redis.call('get', KEYS[1])) or 0
local previous_ttl = tonumber(redis.call('pttl', KEYS[1])) or 0
local current_count = tonumber(redis.call('get', KEYS[2])) or 0
current_ttl = tonumber(redis.call('pttl', KEYS[2])) or 0
-- If the values don't exist yet, consider the TTL is 0
if previous_ttl <= 0 then
previous_ttl = 0
end
if current_ttl <= 0 then
current_ttl = 0
end
local weighted_count = math.floor(previous_count * previous_ttl / expiry) + current_count
if (weighted_count + amount) > limit then
return false
end
-- If the current counter exists, increase its value
if redis.call('exists', KEYS[2]) == 1 then
redis.call('incrby', KEYS[2], amount)
else
-- Otherwise, set the value with twice the expiry time
redis.call('set', KEYS[2], amount, 'PX', expiry * 2)
end
return true

View File

@ -0,0 +1,10 @@
local keys = redis.call('keys', KEYS[1])
local res = 0
for i=1,#keys,5000 do
res = res + redis.call(
'del', unpack(keys, i, math.min(i+4999, #keys))
)
end
return res

View File

@ -0,0 +1,9 @@
local current
local amount = tonumber(ARGV[2])
current = redis.call("incrby", KEYS[1], amount)
if tonumber(current) == amount then
redis.call("expire", KEYS[1], ARGV[1])
end
return current

View File

@ -0,0 +1,30 @@
local len = tonumber(ARGV[2])
local expiry = tonumber(ARGV[1])
-- Binary search to find the oldest valid entry in the window
local function oldest_entry(high, target)
local low = 0
local result = nil
while low <= high do
local mid = math.floor((low + high) / 2)
local val = tonumber(redis.call('lindex', KEYS[1], mid))
if val and val >= target then
result = mid
low = mid + 1
else
high = mid - 1
end
end
return result
end
local index = oldest_entry(len - 1, expiry)
if index then
local count = index + 1
local oldest = tonumber(redis.call('lindex', KEYS[1], index))
return {tostring(oldest), count}
end

View File

@ -0,0 +1,17 @@
local expiry = tonumber(ARGV[1]) * 1000
local previous_count = redis.call('get', KEYS[1])
local previous_ttl = redis.call('pttl', KEYS[1])
local current_count = redis.call('get', KEYS[2])
local current_ttl = redis.call('pttl', KEYS[2])
if current_ttl > 0 and current_ttl < expiry then
-- Current window expired, shift it to the previous window
redis.call('rename', KEYS[2], KEYS[1])
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
previous_count = redis.call('get', KEYS[1])
previous_ttl = redis.call('pttl', KEYS[1])
current_count = redis.call('get', KEYS[2])
current_ttl = redis.call('pttl', KEYS[2])
end
return {previous_count, previous_ttl, current_count, current_ttl}

View File

@ -0,0 +1,80 @@
"""
Implementations of storage backends to be used with
:class:`limits.strategies.RateLimiter` strategies
"""
from __future__ import annotations
import urllib
import limits # noqa
from ..errors import ConfigurationError
from ..typing import TypeAlias, cast
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from .memcached import MemcachedStorage
from .memory import MemoryStorage
from .mongodb import MongoDBStorage, MongoDBStorageBase
from .redis import RedisStorage
from .redis_cluster import RedisClusterStorage
from .redis_sentinel import RedisSentinelStorage
from .registry import SCHEMES
StorageTypes: TypeAlias = "Storage | limits.aio.storage.Storage"
def storage_from_string(
storage_string: str, **options: float | str | bool
) -> StorageTypes:
"""
Factory function to get an instance of the storage class based
on the uri of the storage. In most cases using it should be sufficient
instead of directly instantiating the storage classes. for example::
from limits.storage import storage_from_string
memory = storage_from_string("memory://")
memcached = storage_from_string("memcached://localhost:11211")
redis = storage_from_string("redis://localhost:6379")
The same function can be used to construct the :ref:`storage:async storage`
variants, for example::
from limits.storage import storage_from_string
memory = storage_from_string("async+memory://")
memcached = storage_from_string("async+memcached://localhost:11211")
redis = storage_from_string("async+redis://localhost:6379")
:param storage_string: a string of the form ``scheme://host:port``.
More details about supported storage schemes can be found at
:ref:`storage:storage scheme`
:param options: all remaining keyword arguments are passed to the
constructor matched by :paramref:`storage_string`.
:raises ConfigurationError: when the :attr:`storage_string` cannot be
mapped to a registered :class:`limits.storage.Storage`
or :class:`limits.aio.storage.Storage` instance.
"""
scheme = urllib.parse.urlparse(storage_string).scheme
if scheme not in SCHEMES:
raise ConfigurationError(f"unknown storage scheme : {storage_string}")
return cast(StorageTypes, SCHEMES[scheme](storage_string, **options))
__all__ = [
"MemcachedStorage",
"MemoryStorage",
"MongoDBStorage",
"MongoDBStorageBase",
"MovingWindowSupport",
"RedisClusterStorage",
"RedisSentinelStorage",
"RedisStorage",
"SlidingWindowCounterSupport",
"Storage",
"storage_from_string",
]

View File

@ -0,0 +1,232 @@
from __future__ import annotations
import functools
from abc import ABC, abstractmethod
from limits import errors
from limits.storage.registry import StorageRegistry
from limits.typing import (
Any,
Callable,
P,
R,
cast,
)
from limits.util import LazyDependency
def _wrap_errors(
fn: Callable[P, R],
) -> Callable[P, R]:
@functools.wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
instance = cast(Storage, args[0])
try:
return fn(*args, **kwargs)
except instance.base_exceptions as exc:
if instance.wrap_exceptions:
raise errors.StorageError(exc) from exc
raise
return inner
class Storage(LazyDependency, metaclass=StorageRegistry):
"""
Base class to extend when implementing a storage backend.
"""
STORAGE_SCHEME: list[str] | None
"""The storage schemes to register against this implementation"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"incr",
"get",
"get_expiry",
"check",
"reset",
"clear",
}:
setattr(cls, method, _wrap_errors(getattr(cls, method)))
super().__init_subclass__(**kwargs)
def __init__(
self,
uri: str | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
):
"""
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
"""
super().__init__()
self.wrap_exceptions = wrap_exceptions
@property
@abstractmethod
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
raise NotImplementedError
@abstractmethod
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
raise NotImplementedError
@abstractmethod
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
raise NotImplementedError
@abstractmethod
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
raise NotImplementedError
@abstractmethod
def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError
@abstractmethod
def reset(self) -> int | None:
"""
reset storage to clear limits
"""
raise NotImplementedError
@abstractmethod
def clear(self, key: str) -> None:
"""
resets the rate limit key
:param key: the key to clear rate limits for
"""
raise NotImplementedError
class MovingWindowSupport(ABC):
"""
Abstract base class for storages that support
the :ref:`strategies:moving window` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_entry",
"get_moving_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
raise NotImplementedError
class SlidingWindowCounterSupport(ABC):
"""
Abstract base class for storages that support
the :ref:`strategies:sliding window counter` strategy.
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
Acquire an entry if the weighted count of the current and previous
windows is less than or equal to the limit
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
"""
Return the previous and current window information.
:param key: the rate limit key
:param expiry: the rate limit expiry, needed to compute the key in some implementations
:return: a tuple of (int, float, int, float) with the following information:
- previous window counter
- previous window TTL
- current window counter
- current window TTL
"""
raise NotImplementedError
class TimestampedSlidingWindow:
"""Helper class for storage that support the sliding window counter, with timestamp based keys."""
@classmethod
def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
"""
returns the previous and the current window's keys.
:param key: the key to get the window's keys from
:param expiry: the expiry of the limit item, in seconds
:param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
Returns a tuple with the previous and the current key: (previous, current).
Example:
- key = "mykey"
- expiry = 60
- at = 1738576292.6631825
The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
"""
return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"

View File

@ -0,0 +1,299 @@
from __future__ import annotations
import inspect
import threading
import time
import urllib.parse
from collections.abc import Iterable
from math import ceil, floor
from types import ModuleType
from limits.errors import ConfigurationError
from limits.storage.base import (
SlidingWindowCounterSupport,
Storage,
TimestampedSlidingWindow,
)
from limits.typing import (
Any,
Callable,
MemcachedClientP,
P,
R,
cast,
)
from limits.util import get_dependency
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
"""
Rate limit storage with memcached as backend.
Depends on :pypi:`pymemcache`.
"""
STORAGE_SCHEME = ["memcached"]
"""The storage scheme for memcached"""
DEPENDENCIES = ["pymemcache"]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
**options: str | Callable[[], MemcachedClientP],
) -> None:
"""
:param uri: memcached location of the form
``memcached://host:port,host:port``,
``memcached:///var/tmp/path/to/sock``
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`pymemcache.client.base.PooledClient`
or :class:`pymemcache.client.hash.HashClient` (if there are more than
one hosts specified)
:raise ConfigurationError: when :pypi:`pymemcache` is not available
"""
parsed = urllib.parse.urlparse(uri)
self.hosts = []
for loc in parsed.netloc.strip().split(","):
if not loc:
continue
host, port = loc.split(":")
self.hosts.append((host, int(port)))
else:
# filesystem path to UDS
if parsed.path and not parsed.netloc and not parsed.port:
self.hosts = [parsed.path] # type: ignore
self.dependency = self.dependencies["pymemcache"].module
self.library = str(options.pop("library", "pymemcache.client"))
self.cluster_library = str(
options.pop("cluster_library", "pymemcache.client.hash")
)
self.client_getter = cast(
Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
options.pop("client_getter", self.get_client),
)
self.options = options
if not get_dependency(self.library):
raise ConfigurationError(
f"memcached prerequisite not available. please install {self.library}"
) # pragma: no cover
self.local_storage = threading.local()
self.local_storage.storage = None
super().__init__(uri, wrap_exceptions=wrap_exceptions)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.dependency.MemcacheError # type: ignore[no-any-return]
def get_client(
self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
) -> MemcachedClientP:
"""
returns a memcached client.
:param module: the memcached module
:param hosts: list of memcached hosts
"""
return cast(
MemcachedClientP,
(
module.HashClient(hosts, **kwargs)
if len(hosts) > 1
else module.PooledClient(*hosts, **kwargs)
),
)
def call_memcached_func(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
if "noreply" in kwargs:
argspec = inspect.getfullargspec(func)
if not ("noreply" in argspec.args or argspec.varkw):
kwargs.pop("noreply")
return func(*args, **kwargs)
@property
def storage(self) -> MemcachedClientP:
"""
lazily creates a memcached client instance using a thread local
"""
if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
dependency = get_dependency(
self.cluster_library if len(self.hosts) > 1 else self.library
)[0]
if not dependency:
raise ConfigurationError(f"Unable to import {self.cluster_library}")
self.local_storage.storage = self.client_getter(
dependency, self.hosts, **self.options
)
return cast(MemcachedClientP, self.local_storage.storage)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return int(self.storage.get(key, "0"))
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
"""
Return multiple counters at once
:param keys: the keys to get the counter values for
:meta private:
"""
return self.storage.get_many(keys)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.delete(key)
def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
window every hit.
:param amount: the number to increment by
:param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
"""
if (
value := self.call_memcached_func(
self.storage.incr, key, amount, noreply=False
)
) is not None:
return value
else:
if not self.call_memcached_func(
self.storage.add, key, amount, ceil(expiry), noreply=False
):
return self.storage.incr(key, amount) or amount
else:
if set_expiration_key:
self.call_memcached_func(
self.storage.set,
self._expiration_key(key),
expiry + time.time(),
expire=ceil(expiry),
noreply=False,
)
return amount
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return float(self.storage.get(self._expiration_key(key)) or time.time())
def _expiration_key(self, key: str) -> str:
"""
Return the expiration key for the given counter key.
Memcached doesn't natively return the expiration time or TTL for a given key,
so we implement the expiration time on a separate key.
"""
return key + "/expires"
def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
self.call_memcached_func(self.storage.get, "limiter-check")
return True
except: # noqa
return False
def reset(self) -> int | None:
raise NotImplementedError
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
previous_key, current_key, expiry, now=now
)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
# We don't need the expiration key as it is estimated with the timestamps directly.
current_count = self.incr(
current_key, 2 * expiry, amount=amount, set_expiration_key=False
)
actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
weighted_count = (
previous_count * actualised_previous_ttl / expiry + current_count
)
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
self.call_memcached_func(
self.storage.decr,
current_key,
amount,
noreply=True,
)
return False
return True
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
def _get_sliding_window_info(
self, previous_key: str, current_key: str, expiry: int, now: float
) -> tuple[int, float, int, float]:
result = self.get_many([previous_key, current_key])
previous_count, current_count = (
int(result.get(previous_key, 0)),
int(result.get(current_key, 0)),
)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl

View File

@ -0,0 +1,253 @@
from __future__ import annotations
import bisect
import threading
import time
from collections import Counter, defaultdict
from math import floor
import limits.typing
from limits.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
TimestampedSlidingWindow,
)
class Entry:
def __init__(self, expiry: float) -> None:
self.atime = time.time()
self.expiry = self.atime + expiry
class MemoryStorage(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
):
"""
rate limit storage using :class:`collections.Counter`
as an in memory storage for fixed and sliding window strategies,
and a simple list to implement moving window strategy.
"""
STORAGE_SCHEME = ["memory"]
def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
self.storage: limits.typing.Counter[str] = Counter()
self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
self.expirations: dict[str, float] = {}
self.events: dict[str, list[Entry]] = {}
self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
state = self.__dict__.copy()
del state["timer"]
del state["locks"]
return state
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
self.__dict__.update(state)
self.locks = defaultdict(threading.RLock)
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
def __expire_events(self) -> None:
for key in list(self.events.keys()):
with self.locks[key]:
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -time.time(), key=lambda event: -event.expiry
)
self.events[key] = self.events[key][:oldest]
if not self.events.get(key, None):
self.locks.pop(key, None)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
def __schedule_expiry(self) -> None:
if not self.timer.is_alive():
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ValueError
def incr(self, key: str, expiry: float, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
self.get(key)
self.__schedule_expiry()
with self.locks[key]:
self.storage[key] += amount
if self.storage[key] == amount:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, 0)
def decr(self, key: str, amount: int = 1) -> int:
"""
decrements the counter for a given rate limit key
:param key: the key to decrement
:param amount: the number to decrement by
"""
self.get(key)
self.__schedule_expiry()
with self.locks[key]:
self.storage[key] = max(self.storage[key] - amount, 0)
return self.storage.get(key, 0)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
return self.storage.get(key, 0)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)
self.locks.pop(key, None)
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
self.__schedule_expiry()
with self.locks[key]:
self.events.setdefault(key, [])
timestamp = time.time()
try:
entry = self.events[key][limit - amount]
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
return False
else:
self.events[key][:0] = [Entry(expiry)] * amount
return True
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return self.expirations.get(key, time.time())
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -(timestamp - expiry), key=lambda entry: -entry.atime
)
return events[oldest - 1].atime, oldest
return timestamp, 0
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
current_count = self.incr(current_key, 2 * expiry, amount=amount)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
self.decr(current_key, amount)
return False
return True
def _get_sliding_window_info(
self,
previous_key: str,
current_key: str,
expiry: int,
now: float,
) -> tuple[int, float, int, float]:
previous_count = self.get(previous_key)
current_count = self.get(current_key)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
def check(self) -> bool:
"""
check if storage is healthy
"""
return True
def reset(self) -> int | None:
num_items = max(len(self.storage), len(self.events))
self.storage.clear()
self.expirations.clear()
self.events.clear()
self.locks.clear()
return num_items

View File

@ -0,0 +1,489 @@
from __future__ import annotations
import datetime
import time
from abc import ABC, abstractmethod
from deprecated.sphinx import versionadded, versionchanged
from limits.typing import (
MongoClient,
MongoCollection,
MongoDatabase,
cast,
)
from ..util import get_dependency
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
class MongoDBStorageBase(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
):
"""
Rate limit storage with MongoDB as backend.
Depends on :pypi:`pymongo`.
"""
DEPENDENCIES = ["pymongo"]
def __init__(
self,
uri: str,
database_name: str = "limits",
counter_collection_name: str = "counters",
window_collection_name: str = "windows",
wrap_exceptions: bool = False,
**options: int | str | bool,
) -> None:
"""
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
:param database_name: The database to use for storing the rate limit
collections.
:param counter_collection_name: The collection name to use for individual counters
used in fixed window strategies
:param window_collection_name: The collection name to use for sliding & moving window
storage
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed to the
constructor of :class:`~pymongo.mongo_client.MongoClient`
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
"""
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self._database_name = database_name
self._collection_mapping = {
"counters": counter_collection_name,
"windows": window_collection_name,
}
self.lib = self.dependencies["pymongo"].module
self.lib_errors, _ = get_dependency("pymongo.errors")
self._storage_uri = uri
self._storage_options = options
self._storage: MongoClient | None = None
@property
def storage(self) -> MongoClient:
if self._storage is None:
self._storage = self._init_mongo_client(
self._storage_uri, **self._storage_options
)
self.__initialize_database()
return self._storage
@property
def _database(self) -> MongoDatabase:
return self.storage[self._database_name]
@property
def counters(self) -> MongoCollection:
return self._database[self._collection_mapping["counters"]]
@property
def windows(self) -> MongoCollection:
return self._database[self._collection_mapping["windows"]]
@abstractmethod
def _init_mongo_client(
self, uri: str | None, **options: int | str | bool
) -> MongoClient:
raise NotImplementedError()
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.lib_errors.PyMongoError # type: ignore
def __initialize_database(self) -> None:
self.counters.create_index("expireAt", expireAfterSeconds=0)
self.windows.create_index("expireAt", expireAfterSeconds=0)
def reset(self) -> int | None:
"""
Delete all rate limit keys in the rate limit collections (counters, windows)
"""
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
self.counters.drop()
self.windows.drop()
return int(num_keys)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.counters.find_one_and_delete({"_id": key})
self.windows.find_one_and_delete({"_id": key})
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = self.counters.find_one({"_id": key})
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
counter = self.counters.find_one(
{
"_id": key,
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
},
projection=["count"],
)
return counter and counter["count"] or 0
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=expiry
)
return int(
self.counters.find_one_and_update(
{"_id": key},
[
{
"$set": {
"count": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": amount,
"else": {"$add": ["$count", amount]},
}
},
"expireAt": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": expiration,
"else": "$expireAt",
}
},
}
},
],
upsert=True,
projection=["count"],
return_document=self.lib.ReturnDocument.AFTER,
)["count"]
)
def check(self) -> bool:
"""
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
"""
try:
self.storage.server_info()
return True
except: # noqa: E722
return False
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if result := list(
self.windows.aggregate(
[
{"$match": {"_id": key}},
{
"$project": {
"filteredEntries": {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$gte": ["$$entry", timestamp - expiry]},
}
}
}
},
{
"$project": {
"min": {"$min": "$filteredEntries"},
"count": {"$size": "$filteredEntries"},
}
},
]
)
):
return result[0]["min"], result[0]["count"]
return timestamp, 0
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
timestamp = time.time()
try:
updates: dict[
str,
dict[str, datetime.datetime | dict[str, list[float] | int]],
] = {
"$push": {
"entries": {
"$each": [timestamp] * amount,
"$position": 0,
"$slice": limit,
}
},
"$set": {
"expireAt": (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expiry)
)
},
}
self.windows.update_one(
{
"_id": key,
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
},
updates,
upsert=True,
)
return True
except self.lib.errors.DuplicateKeyError:
return False
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
expiry_ms = expiry * 1000
if result := self.windows.find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$add": ["$expireAt", expiry_ms],
},
"else": "$expireAt",
}
},
}
}
],
return_document=self.lib.ReturnDocument.AFTER,
projection=["currentCount", "previousCount", "expireAt"],
):
expires_at = (
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
if result.get("expireAt")
else time.time()
)
current_ttl = max(0, expires_at - time.time())
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
return (
result["previousCount"],
prev_ttl,
result["currentCount"],
current_ttl,
)
return 0, 0.0, 0, 0.0
def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
expiry_ms = expiry * 1000
result = self.windows.find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$cond": {
"if": {"$gt": ["$expireAt", 0]},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
}
},
"else": "$expireAt",
}
},
}
},
{
"$set": {
"curWeightedCount": {
"$floor": {
"$add": [
{
"$multiply": [
"$previousCount",
{
"$divide": [
{
"$max": [
0,
{
"$subtract": [
"$expireAt",
{
"$add": [
"$$NOW",
expiry_ms,
]
},
]
},
]
},
expiry_ms,
]
},
]
},
"$currentCount",
]
}
}
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$add": ["$curWeightedCount", amount]},
limit,
]
},
"then": {"$add": ["$currentCount", amount]},
"else": "$currentCount",
}
}
}
},
{
"$set": {
"_acquired": {
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
}
}
},
{"$unset": ["curWeightedCount"]},
],
return_document=self.lib.ReturnDocument.AFTER,
upsert=True,
)
return cast(bool, result["_acquired"])
def __del__(self) -> None:
if self.storage:
self.storage.close()
@versionadded(version="2.1")
@versionchanged(
version="3.14.0",
reason="Added option to select custom collection names for windows & counters",
)
class MongoDBStorage(MongoDBStorageBase):
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
def _init_mongo_client(
self, uri: str | None, **options: int | str | bool
) -> MongoClient:
return cast(MongoClient, self.lib.MongoClient(uri, **options))

View File

@ -0,0 +1,308 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.typing import Literal, RedisClient
from ..util import get_package_data
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
if TYPE_CHECKING:
import redis
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey://`` schema"
),
)
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with redis as backend.
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
``valkey://``)
"""
STORAGE_SCHEME = [
"redis",
"rediss",
"redis+unix",
"valkey",
"valkeys",
"valkey+unix",
]
"""The storage scheme for redis"""
DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}
RES_DIR = "resources/redis/lua_scripts"
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_moving_window.lua"
)
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_sliding_window.lua"
)
lua_moving_window: redis.commands.core.Script
lua_acquire_moving_window: redis.commands.core.Script
lua_sliding_window: redis.commands.core.Script
lua_acquire_sliding_window: redis.commands.core.Script
PREFIX = "LIMITS"
target_server: Literal["redis", "valkey"]
def __init__(
self,
uri: str,
connection_pool: redis.connection.ConnectionPool | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form ``redis://[:password]@host:port``,
``redis://[:password]@host:port/db``,
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
This uri is passed directly to :func:`redis.from_url` except for the
case of ``redis+unix://`` where it is replaced with ``unix://``.
If the uri scheme is ``valkey`` the implementation used will be from
:pypi:`valkey`.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.Redis`
:raise ConfigurationError: when the :pypi:`redis` library is not available
"""
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
self.dependency = self.dependencies[self.target_server].module
uri = uri.replace(f"{self.target_server}+unix", "unix")
if not connection_pool:
self.storage = self.dependency.from_url(uri, **options)
else:
if self.target_server == "redis":
self.storage = self.dependency.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.Valkey(
connection_pool=connection_pool, **options
)
self.initialize_storage(uri)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ( # type: ignore[no-any-return]
self.dependency.RedisError
if self.target_server == "redis"
else self.dependency.ValkeyError
)
def initialize_storage(self, _uri: str) -> None:
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
def get_connection(self, readonly: bool = False) -> RedisClient:
return cast(RedisClient, self.storage)
def _current_window_key(self, key: str) -> str:
"""
Return the current window's storage key (Sliding window strategy)
Contrary to other strategies that have one key per rate limit item,
this strategy has two keys per rate limit item than must be on the same machine.
To keep the current key and the previous key on the same Redis cluster node,
curly braces are added.
Eg: "{constructed_key}"
"""
return f"{{{key}}}"
def _previous_window_key(self, key: str) -> str:
"""
Return the previous window's storage key (Sliding window strategy).
Curvy braces are added on the common pattern with the current window's key,
so the current and the previous key are stored on the same Redis cluster node.
Eg: "{constructed_key}/-1"
"""
return f"{self._current_window_key(key)}/-1"
def prefixed_key(self, key: str) -> str:
return f"{self.PREFIX}:{key}"
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
key = self.prefixed_key(key)
timestamp = time.time()
if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
return float(window[0]), window[1]
return timestamp, 0
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self.prefixed_key(self._previous_window_key(key))
current_key = self.prefixed_key(self._current_window_key(key))
if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
return (
int(window[0] or 0),
max(0, float(window[1] or 0)) / 1000,
int(window[2] or 0),
max(0, float(window[3] or 0)) / 1000,
)
return 0, 0.0, 0, 0.0
def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
key = self.prefixed_key(key)
return int(self.lua_incr_expire([key], [expiry, amount]))
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
key = self.prefixed_key(key)
return int(self.get_connection(True).get(key) or 0)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
key = self.prefixed_key(key)
self.get_connection().delete(key)
def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
key = self.prefixed_key(key)
timestamp = time.time()
acquired = self.lua_acquire_moving_window(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
Acquire an entry. Shift the current window to the previous window if it expired.
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
previous_key = self.prefixed_key(self._previous_window_key(key))
current_key = self.prefixed_key(self._current_window_key(key))
acquired = self.lua_acquire_sliding_window(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
key = self.prefixed_key(key)
return max(self.get_connection(True).ttl(key), 0) + time.time()
def check(self) -> bool:
"""
check if storage is healthy
"""
try:
return self.get_connection().ping()
except: # noqa
return False
def reset(self) -> int | None:
"""
This function calls a Lua Script to delete keys prefixed with
``self.PREFIX`` in blocks of 5000.
.. warning::
This operation was designed to be fast, but was not tested
on a large production based system. Be careful with its usage as it
could be slow on very large data sets.
"""
prefix = self.prefixed_key("*")
return int(self.lua_clear_keys([prefix]))

View File

@ -0,0 +1,125 @@
from __future__ import annotations
import urllib
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.storage.redis import RedisStorage
@versionchanged(
version="3.14.0",
reason="""
Dropped support for the :pypi:`redis-py-cluster` library
which has been abandoned/deprecated.
""",
)
@versionchanged(
version="2.5.0",
reason="""
Cluster support was provided by the :pypi:`redis-py-cluster` library
which has been absorbed into the official :pypi:`redis` client. By
default the :class:`redis.cluster.RedisCluster` client will be used
however if the version of the package is lower than ``4.2.0`` the implementation
will fallback to trying to use :class:`rediscluster.RedisCluster`.
""",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey+cluster://`` schema"
),
)
class RedisClusterStorage(RedisStorage):
"""
Rate limit storage with redis cluster as backend
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri`
starts with ``valkey+cluster://``).
"""
STORAGE_SCHEME = ["redis+cluster", "valkey+cluster"]
"""The storage scheme for redis cluster"""
DEFAULT_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to the :class:`~redis.cluster.RedisCluster`"
DEPENDENCIES = {
"redis": Version("4.2.0"),
"valkey": Version("6.0"),
}
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``redis+cluster://[:password]@host:port,host:port``
If the uri scheme is ``valkey+cluster`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.cluster.RedisCluster`
:raise ConfigurationError: when the :pypi:`redis` library is not
available or if the redis cluster cannot be reached.
"""
parsed = urllib.parse.urlparse(uri)
parsed_auth: dict[str, float | str | bool] = {}
if parsed.username:
parsed_auth["username"] = parsed.username
if parsed.password:
parsed_auth["password"] = parsed.password
sep = parsed.netloc.find("@") + 1
cluster_hosts = []
for loc in parsed.netloc[sep:].split(","):
host, port = loc.split(":")
cluster_hosts.append((host, int(port)))
self.storage = None
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
merged_options = {**self.DEFAULT_OPTIONS, **parsed_auth, **options}
self.dependency = self.dependencies[self.target_server].module
startup_nodes = [self.dependency.cluster.ClusterNode(*c) for c in cluster_hosts]
if self.target_server == "redis":
self.storage = self.dependency.cluster.RedisCluster(
startup_nodes=startup_nodes, **merged_options
)
else:
self.storage = self.dependency.cluster.ValkeyCluster(
startup_nodes=startup_nodes, **merged_options
)
assert self.storage
self.initialize_storage(uri)
super(RedisStorage, self).__init__(uri, wrap_exceptions, **options)
def reset(self) -> int | None:
"""
Redis Clusters are sharded and deleting across shards
can't be done atomically. Because of this, this reset loops over all
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
one at a time.
.. warning::
This operation was not tested with extremely large data sets.
On a large production based system, care should be taken with its
usage as it could be slow on very large data sets"""
prefix = self.prefixed_key("*")
count = 0
for primary in self.storage.get_primaries():
node = self.storage.get_redis_connection(primary)
keys = node.keys(prefix)
count += sum([node.delete(k.decode("utf-8")) for k in keys])
return count

View File

@ -0,0 +1,120 @@
from __future__ import annotations
import urllib.parse
from typing import TYPE_CHECKING
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.errors import ConfigurationError
from limits.storage.redis import RedisStorage
from limits.typing import RedisClient
if TYPE_CHECKING:
pass
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey+sentinel://`` schema"
),
)
class RedisSentinelStorage(RedisStorage):
"""
Rate limit storage with redis sentinel as backend
Depends on :pypi:`redis` package (or :pypi:`valkey` if :paramref:`uri` starts with
``valkey+sentinel://``)
"""
STORAGE_SCHEME = ["redis+sentinel", "valkey+sentinel"]
"""The storage scheme for redis accessed via a redis sentinel installation"""
DEPENDENCIES = {
"redis": Version("3.0"),
"redis.sentinel": Version("3.0"),
"valkey": Version("6.0"),
"valkey.sentinel": Version("6.0"),
}
def __init__(
self,
uri: str,
service_name: str | None = None,
use_replicas: bool = True,
sentinel_kwargs: dict[str, float | str | bool] | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``redis+sentinel://host:port,host:port/service_name``
If the uri scheme is ``valkey+sentinel`` the implementation used will be from
:pypi:`valkey`.
:param service_name: sentinel service name
(if not provided in :attr:`uri`)
:param use_replicas: Whether to use replicas for read only operations
:param sentinel_kwargs: kwargs to pass as
:attr:`sentinel_kwargs` to :class:`redis.sentinel.Sentinel`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.sentinel.Sentinel`
:raise ConfigurationError: when the redis library is not available
or if the redis master host cannot be pinged.
"""
super(RedisStorage, self).__init__(
uri, wrap_exceptions=wrap_exceptions, **options
)
parsed = urllib.parse.urlparse(uri)
sentinel_configuration = []
sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
parsed_auth: dict[str, float | str | bool] = {}
if parsed.username:
parsed_auth["username"] = parsed.username
if parsed.password:
parsed_auth["password"] = parsed.password
sep = parsed.netloc.find("@") + 1
for loc in parsed.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
self.service_name = (
parsed.path.replace("/", "") if parsed.path else service_name
)
if self.service_name is None:
raise ConfigurationError("'service_name' not provided")
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
sentinel_dep = self.dependencies[f"{self.target_server}.sentinel"].module
self.sentinel = sentinel_dep.Sentinel(
sentinel_configuration,
sentinel_kwargs={**parsed_auth, **sentinel_options},
**{**parsed_auth, **options},
)
self.storage: RedisClient = self.sentinel.master_for(self.service_name)
self.storage_slave: RedisClient = self.sentinel.slave_for(self.service_name)
self.use_replicas = use_replicas
self.initialize_storage(uri)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ( # type: ignore[no-any-return]
self.dependencies["redis"].module.RedisError
if self.target_server == "redis"
else self.dependencies["valkey"].module.ValkeyError
)
def get_connection(self, readonly: bool = False) -> RedisClient:
return self.storage_slave if (readonly and self.use_replicas) else self.storage

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABCMeta
SCHEMES: dict[str, StorageRegistry] = {}
class StorageRegistry(ABCMeta):
def __new__(
mcs, name: str, bases: tuple[type, ...], dct: dict[str, str | list[str]]
) -> StorageRegistry:
storage_scheme = dct.get("STORAGE_SCHEME", None)
cls = super().__new__(mcs, name, bases, dct)
if storage_scheme:
if isinstance(storage_scheme, str): # noqa
schemes = [storage_scheme]
else:
schemes = storage_scheme
for scheme in schemes:
SCHEMES[scheme] = cls
return cls

View File

@ -0,0 +1,298 @@
"""
Rate limiting strategies
"""
from __future__ import annotations
import time
from abc import ABCMeta, abstractmethod
from math import floor, inf
from deprecated.sphinx import versionadded
from limits.storage.base import SlidingWindowCounterSupport
from .limits import RateLimitItem
from .storage import MovingWindowSupport, Storage, StorageTypes
from .typing import cast
from .util import WindowStats
class RateLimiter(metaclass=ABCMeta):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = storage
@abstractmethod
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
raise NotImplementedError
@abstractmethod
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check the rate limit without consuming from it.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
"""
raise NotImplementedError
@abstractmethod
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
raise NotImplementedError
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return self.storage.clear(item.key_for(*identifiers))
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes):
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
:return: (reset time, remaining)
"""
return cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
"""
return (
cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)[1]
<= item.amount - cost
)
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
returns the number of requests remaining within this limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: tuple (reset time, remaining)
"""
window_start, window_items = cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return WindowStats(reset, item.amount - window_items)
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return (
self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
amount=cost,
)
<= item.amount
)
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
"""
return self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
remaining = max(0, item.amount - self.storage.get(item.key_for(*identifiers)))
reset = self.storage.get_expiry(item.key_for(*identifiers))
return WindowStats(reset, remaining)
@versionadded(version="4.1")
class SlidingWindowCounterRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:sliding window counter`
"""
def __init__(self, storage: StorageTypes):
if not hasattr(storage, "get_sliding_window") or not hasattr(
storage, "acquire_sliding_window_entry"
):
raise NotImplementedError(
"SlidingWindowCounterRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
def _weighted_count(
self,
item: RateLimitItem,
previous_count: int,
previous_expires_in: float,
current_count: int,
) -> float:
"""
Return the approximated by weighting the previous window count and adding the current window count.
"""
return previous_count * previous_expires_in / item.get_expiry() + current_count
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return cast(
SlidingWindowCounterSupport, self.storage
).acquire_sliding_window_entry(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
cost,
)
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
"""
previous_count, previous_expires_in, current_count, _ = cast(
SlidingWindowCounterSupport, self.storage
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
return (
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
< item.amount - cost + 1
)
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
Query the reset time and remaining amount for the limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: WindowStats(reset time, remaining)
"""
previous_count, previous_expires_in, current_count, current_expires_in = cast(
SlidingWindowCounterSupport, self.storage
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
remaining = max(
0,
item.amount
- floor(
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
),
)
now = time.time()
if not (previous_count or current_count):
return WindowStats(now, remaining)
expiry = item.get_expiry()
previous_reset_in, current_reset_in = inf, inf
if previous_count:
previous_reset_in = previous_expires_in % (expiry / previous_count)
if current_count:
current_reset_in = current_expires_in % expiry
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
KnownStrategy = (
type[SlidingWindowCounterRateLimiter]
| type[FixedWindowRateLimiter]
| type[MovingWindowRateLimiter]
)
STRATEGIES: dict[str, KnownStrategy] = {
"sliding-window-counter": SlidingWindowCounterRateLimiter,
"fixed-window": FixedWindowRateLimiter,
"moving-window": MovingWindowRateLimiter,
}

View File

@ -0,0 +1,127 @@
from __future__ import annotations
from collections import Counter
from collections.abc import Awaitable, Callable, Iterable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
NamedTuple,
ParamSpec,
Protocol,
TypeAlias,
TypeVar,
cast,
)
Serializable = int | str | float
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
P = ParamSpec("P")
if TYPE_CHECKING:
import coredis
import pymongo.collection
import pymongo.database
import pymongo.mongo_client
import redis
class MemcachedClientP(Protocol):
def add(
self,
key: str,
value: Serializable,
expire: int | None = 0,
noreply: bool | None = None,
flags: int | None = None,
) -> bool: ...
def get(self, key: str, default: str | None = None) -> bytes: ...
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: ... # type:ignore[explicit-any]
def incr(
self, key: str, value: int, noreply: bool | None = False
) -> int | None: ...
def decr(
self,
key: str,
value: int,
noreply: bool | None = False,
) -> int | None: ...
def delete(self, key: str, noreply: bool | None = None) -> bool | None: ...
def set(
self,
key: str,
value: Serializable,
expire: int = 0,
noreply: bool | None = None,
flags: int | None = None,
) -> bool: ...
def touch(
self, key: str, expire: int | None = 0, noreply: bool | None = None
) -> bool: ...
class RedisClientP(Protocol):
def incrby(self, key: str, amount: int) -> int: ...
def get(self, key: str) -> bytes | None: ...
def delete(self, key: str) -> int: ...
def ttl(self, key: str) -> int: ...
def expire(self, key: str, seconds: int) -> bool: ...
def ping(self) -> bool: ...
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
class AsyncRedisClientP(Protocol):
async def incrby(self, key: str, amount: int) -> int: ...
async def get(self, key: str) -> bytes | None: ...
async def delete(self, key: str) -> int: ...
async def ttl(self, key: str) -> int: ...
async def expire(self, key: str, seconds: int) -> bool: ...
async def ping(self) -> bool: ...
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
RedisClient: TypeAlias = RedisClientP
AsyncRedisClient: TypeAlias = AsyncRedisClientP
AsyncCoRedisClient: TypeAlias = "coredis.Redis[bytes] | coredis.RedisCluster[bytes]"
MongoClient: TypeAlias = "pymongo.mongo_client.MongoClient[dict[str, Any]]" # type:ignore[explicit-any]
MongoDatabase: TypeAlias = "pymongo.database.Database[dict[str, Any]]" # type:ignore[explicit-any]
MongoCollection: TypeAlias = "pymongo.collection.Collection[dict[str, Any]]" # type:ignore[explicit-any]
__all__ = [
"TYPE_CHECKING",
"Any",
"AsyncRedisClient",
"Awaitable",
"Callable",
"ClassVar",
"Counter",
"Iterable",
"Literal",
"MemcachedClientP",
"MongoClient",
"MongoCollection",
"MongoDatabase",
"NamedTuple",
"P",
"ParamSpec",
"Protocol",
"R",
"R_co",
"RedisClient",
"Serializable",
"TypeAlias",
"TypeVar",
"cast",
]

View File

@ -0,0 +1,207 @@
""" """
from __future__ import annotations
import dataclasses
import importlib.resources
import re
import sys
from collections import UserDict
from types import ModuleType
from typing import TYPE_CHECKING
from packaging.version import Version
from limits.typing import NamedTuple
from .errors import ConfigurationError
from .limits import GRANULARITIES, RateLimitItem
SEPARATORS = re.compile(r"[,;|]{1}")
SINGLE_EXPR = re.compile(
r"""
\s*([0-9]+)
\s*(/|\s*per\s*)
\s*([0-9]+)
*\s*(hour|minute|second|day|month|year)s?\s*""",
re.IGNORECASE | re.VERBOSE,
)
EXPR = re.compile(
rf"^{SINGLE_EXPR.pattern}(:?{SEPARATORS.pattern}{SINGLE_EXPR.pattern})*$",
re.IGNORECASE | re.VERBOSE,
)
class WindowStats(NamedTuple):
"""
tuple to describe a rate limited window
"""
#: Time as seconds since the Epoch when this window will be reset
reset_time: float
#: Quantity remaining in this window
remaining: int
@dataclasses.dataclass
class Dependency:
name: str
version_required: Version | None
version_found: Version | None
module: ModuleType
MissingModule = ModuleType("Missing")
if TYPE_CHECKING:
_UserDict = UserDict[str, Dependency]
else:
_UserDict = UserDict
class DependencyDict(_UserDict):
def __getitem__(self, key: str) -> Dependency:
dependency = super().__getitem__(key)
if dependency.module is MissingModule:
message = f"'{dependency.name}' prerequisite not available."
if dependency.version_required:
message += (
f" A minimum version of {dependency.version_required} is required."
if dependency.version_required
else ""
)
message += (
" See https://limits.readthedocs.io/en/stable/storage.html#supported-versions"
" for more details."
)
raise ConfigurationError(message)
elif dependency.version_required and (
not dependency.version_found
or dependency.version_found < dependency.version_required
):
raise ConfigurationError(
f"The minimum version of {dependency.version_required}"
f" for '{dependency.name}' could not be found. Found version: {dependency.version_found}"
)
return dependency
class LazyDependency:
"""
Simple utility that provides an :attr:`dependency`
to the child class to fetch any dependencies
without having to import them explicitly.
"""
DEPENDENCIES: dict[str, Version | None] | list[str] = []
"""
The python modules this class has a dependency on.
Used to lazily populate the :attr:`dependencies`
"""
def __init__(self) -> None:
self._dependencies: DependencyDict = DependencyDict()
@property
def dependencies(self) -> DependencyDict:
"""
Cached mapping of the modules this storage depends on.
This is done so that the module is only imported lazily
when the storage is instantiated.
:meta private:
"""
if not getattr(self, "_dependencies", None):
dependencies = DependencyDict()
mapping: dict[str, Version | None]
if isinstance(self.DEPENDENCIES, list):
mapping = {dependency: None for dependency in self.DEPENDENCIES}
else:
mapping = self.DEPENDENCIES
for name, minimum_version in mapping.items():
dependency, version = get_dependency(name)
dependencies[name] = Dependency(
name, minimum_version, version, dependency
)
self._dependencies = dependencies
return self._dependencies
def get_dependency(module_path: str) -> tuple[ModuleType, Version | None]:
"""
safe function to import a module at runtime
"""
try:
if module_path not in sys.modules:
__import__(module_path)
root = module_path.split(".")[0]
version = getattr(sys.modules[root], "__version__", "0.0.0")
return sys.modules[module_path], Version(version)
except ImportError: # pragma: no cover
return MissingModule, None
def get_package_data(path: str) -> bytes:
return importlib.resources.files("limits").joinpath(path).read_bytes()
def parse_many(limit_string: str) -> list[RateLimitItem]:
"""
parses rate limits in string notation containing multiple rate limits
(e.g. ``1/second; 5/minute``)
:param limit_string: rate limit string using :ref:`ratelimit-string`
:raise ValueError: if the string notation is invalid.
"""
if not (isinstance(limit_string, str) and EXPR.match(limit_string)):
raise ValueError(f"couldn't parse rate limit string '{limit_string}'")
limits = []
for limit in SEPARATORS.split(limit_string):
match = SINGLE_EXPR.match(limit)
if match:
amount, _, multiples, granularity_string = match.groups()
granularity = granularity_from_string(granularity_string)
limits.append(
granularity(int(amount), multiples and int(multiples) or None)
)
return limits
def parse(limit_string: str) -> RateLimitItem:
"""
parses a single rate limit in string notation
(e.g. ``1/second`` or ``1 per second``)
:param limit_string: rate limit string using :ref:`ratelimit-string`
:raise ValueError: if the string notation is invalid.
"""
return list(parse_many(limit_string))[0]
def granularity_from_string(granularity_string: str) -> type[RateLimitItem]:
"""
:param granularity_string:
:raise ValueError:
"""
for granularity in GRANULARITIES.values():
if granularity.check_granularity_string(granularity_string):
return granularity
raise ValueError(f"no granularity matched for {granularity_string}")

View File

@ -0,0 +1,3 @@
"""
empty file to be updated by versioneer
"""