Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,80 @@
"""
Implementations of storage backends to be used with
:class:`limits.strategies.RateLimiter` strategies
"""
from __future__ import annotations
import urllib
import limits # noqa
from ..errors import ConfigurationError
from ..typing import TypeAlias, cast
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from .memcached import MemcachedStorage
from .memory import MemoryStorage
from .mongodb import MongoDBStorage, MongoDBStorageBase
from .redis import RedisStorage
from .redis_cluster import RedisClusterStorage
from .redis_sentinel import RedisSentinelStorage
from .registry import SCHEMES
StorageTypes: TypeAlias = "Storage | limits.aio.storage.Storage"
def storage_from_string(
storage_string: str, **options: float | str | bool
) -> StorageTypes:
"""
Factory function to get an instance of the storage class based
on the uri of the storage. In most cases using it should be sufficient
instead of directly instantiating the storage classes. for example::
from limits.storage import storage_from_string
memory = storage_from_string("memory://")
memcached = storage_from_string("memcached://localhost:11211")
redis = storage_from_string("redis://localhost:6379")
The same function can be used to construct the :ref:`storage:async storage`
variants, for example::
from limits.storage import storage_from_string
memory = storage_from_string("async+memory://")
memcached = storage_from_string("async+memcached://localhost:11211")
redis = storage_from_string("async+redis://localhost:6379")
:param storage_string: a string of the form ``scheme://host:port``.
More details about supported storage schemes can be found at
:ref:`storage:storage scheme`
:param options: all remaining keyword arguments are passed to the
constructor matched by :paramref:`storage_string`.
:raises ConfigurationError: when the :attr:`storage_string` cannot be
mapped to a registered :class:`limits.storage.Storage`
or :class:`limits.aio.storage.Storage` instance.
"""
scheme = urllib.parse.urlparse(storage_string).scheme
if scheme not in SCHEMES:
raise ConfigurationError(f"unknown storage scheme : {storage_string}")
return cast(StorageTypes, SCHEMES[scheme](storage_string, **options))
__all__ = [
"MemcachedStorage",
"MemoryStorage",
"MongoDBStorage",
"MongoDBStorageBase",
"MovingWindowSupport",
"RedisClusterStorage",
"RedisSentinelStorage",
"RedisStorage",
"SlidingWindowCounterSupport",
"Storage",
"storage_from_string",
]

View File

@ -0,0 +1,232 @@
from __future__ import annotations
import functools
from abc import ABC, abstractmethod
from limits import errors
from limits.storage.registry import StorageRegistry
from limits.typing import (
Any,
Callable,
P,
R,
cast,
)
from limits.util import LazyDependency
def _wrap_errors(
fn: Callable[P, R],
) -> Callable[P, R]:
@functools.wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
instance = cast(Storage, args[0])
try:
return fn(*args, **kwargs)
except instance.base_exceptions as exc:
if instance.wrap_exceptions:
raise errors.StorageError(exc) from exc
raise
return inner
class Storage(LazyDependency, metaclass=StorageRegistry):
"""
Base class to extend when implementing a storage backend.
"""
STORAGE_SCHEME: list[str] | None
"""The storage schemes to register against this implementation"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"incr",
"get",
"get_expiry",
"check",
"reset",
"clear",
}:
setattr(cls, method, _wrap_errors(getattr(cls, method)))
super().__init_subclass__(**kwargs)
def __init__(
self,
uri: str | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
):
"""
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
"""
super().__init__()
self.wrap_exceptions = wrap_exceptions
@property
@abstractmethod
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
raise NotImplementedError
@abstractmethod
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
raise NotImplementedError
@abstractmethod
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
raise NotImplementedError
@abstractmethod
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
raise NotImplementedError
@abstractmethod
def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError
@abstractmethod
def reset(self) -> int | None:
"""
reset storage to clear limits
"""
raise NotImplementedError
@abstractmethod
def clear(self, key: str) -> None:
"""
resets the rate limit key
:param key: the key to clear rate limits for
"""
raise NotImplementedError
class MovingWindowSupport(ABC):
"""
Abstract base class for storages that support
the :ref:`strategies:moving window` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_entry",
"get_moving_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
raise NotImplementedError
class SlidingWindowCounterSupport(ABC):
"""
Abstract base class for storages that support
the :ref:`strategies:sliding window counter` strategy.
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
Acquire an entry if the weighted count of the current and previous
windows is less than or equal to the limit
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
"""
Return the previous and current window information.
:param key: the rate limit key
:param expiry: the rate limit expiry, needed to compute the key in some implementations
:return: a tuple of (int, float, int, float) with the following information:
- previous window counter
- previous window TTL
- current window counter
- current window TTL
"""
raise NotImplementedError
class TimestampedSlidingWindow:
"""Helper class for storage that support the sliding window counter, with timestamp based keys."""
@classmethod
def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
"""
returns the previous and the current window's keys.
:param key: the key to get the window's keys from
:param expiry: the expiry of the limit item, in seconds
:param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
Returns a tuple with the previous and the current key: (previous, current).
Example:
- key = "mykey"
- expiry = 60
- at = 1738576292.6631825
The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
"""
return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"

View File

@ -0,0 +1,299 @@
from __future__ import annotations
import inspect
import threading
import time
import urllib.parse
from collections.abc import Iterable
from math import ceil, floor
from types import ModuleType
from limits.errors import ConfigurationError
from limits.storage.base import (
SlidingWindowCounterSupport,
Storage,
TimestampedSlidingWindow,
)
from limits.typing import (
Any,
Callable,
MemcachedClientP,
P,
R,
cast,
)
from limits.util import get_dependency
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
"""
Rate limit storage with memcached as backend.
Depends on :pypi:`pymemcache`.
"""
STORAGE_SCHEME = ["memcached"]
"""The storage scheme for memcached"""
DEPENDENCIES = ["pymemcache"]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
**options: str | Callable[[], MemcachedClientP],
) -> None:
"""
:param uri: memcached location of the form
``memcached://host:port,host:port``,
``memcached:///var/tmp/path/to/sock``
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`pymemcache.client.base.PooledClient`
or :class:`pymemcache.client.hash.HashClient` (if there are more than
one hosts specified)
:raise ConfigurationError: when :pypi:`pymemcache` is not available
"""
parsed = urllib.parse.urlparse(uri)
self.hosts = []
for loc in parsed.netloc.strip().split(","):
if not loc:
continue
host, port = loc.split(":")
self.hosts.append((host, int(port)))
else:
# filesystem path to UDS
if parsed.path and not parsed.netloc and not parsed.port:
self.hosts = [parsed.path] # type: ignore
self.dependency = self.dependencies["pymemcache"].module
self.library = str(options.pop("library", "pymemcache.client"))
self.cluster_library = str(
options.pop("cluster_library", "pymemcache.client.hash")
)
self.client_getter = cast(
Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
options.pop("client_getter", self.get_client),
)
self.options = options
if not get_dependency(self.library):
raise ConfigurationError(
f"memcached prerequisite not available. please install {self.library}"
) # pragma: no cover
self.local_storage = threading.local()
self.local_storage.storage = None
super().__init__(uri, wrap_exceptions=wrap_exceptions)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.dependency.MemcacheError # type: ignore[no-any-return]
def get_client(
self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
) -> MemcachedClientP:
"""
returns a memcached client.
:param module: the memcached module
:param hosts: list of memcached hosts
"""
return cast(
MemcachedClientP,
(
module.HashClient(hosts, **kwargs)
if len(hosts) > 1
else module.PooledClient(*hosts, **kwargs)
),
)
def call_memcached_func(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
if "noreply" in kwargs:
argspec = inspect.getfullargspec(func)
if not ("noreply" in argspec.args or argspec.varkw):
kwargs.pop("noreply")
return func(*args, **kwargs)
@property
def storage(self) -> MemcachedClientP:
"""
lazily creates a memcached client instance using a thread local
"""
if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
dependency = get_dependency(
self.cluster_library if len(self.hosts) > 1 else self.library
)[0]
if not dependency:
raise ConfigurationError(f"Unable to import {self.cluster_library}")
self.local_storage.storage = self.client_getter(
dependency, self.hosts, **self.options
)
return cast(MemcachedClientP, self.local_storage.storage)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return int(self.storage.get(key, "0"))
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
"""
Return multiple counters at once
:param keys: the keys to get the counter values for
:meta private:
"""
return self.storage.get_many(keys)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.delete(key)
def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
window every hit.
:param amount: the number to increment by
:param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
"""
if (
value := self.call_memcached_func(
self.storage.incr, key, amount, noreply=False
)
) is not None:
return value
else:
if not self.call_memcached_func(
self.storage.add, key, amount, ceil(expiry), noreply=False
):
return self.storage.incr(key, amount) or amount
else:
if set_expiration_key:
self.call_memcached_func(
self.storage.set,
self._expiration_key(key),
expiry + time.time(),
expire=ceil(expiry),
noreply=False,
)
return amount
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return float(self.storage.get(self._expiration_key(key)) or time.time())
def _expiration_key(self, key: str) -> str:
"""
Return the expiration key for the given counter key.
Memcached doesn't natively return the expiration time or TTL for a given key,
so we implement the expiration time on a separate key.
"""
return key + "/expires"
def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
self.call_memcached_func(self.storage.get, "limiter-check")
return True
except: # noqa
return False
def reset(self) -> int | None:
raise NotImplementedError
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
previous_key, current_key, expiry, now=now
)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
# We don't need the expiration key as it is estimated with the timestamps directly.
current_count = self.incr(
current_key, 2 * expiry, amount=amount, set_expiration_key=False
)
actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
weighted_count = (
previous_count * actualised_previous_ttl / expiry + current_count
)
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
self.call_memcached_func(
self.storage.decr,
current_key,
amount,
noreply=True,
)
return False
return True
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
def _get_sliding_window_info(
self, previous_key: str, current_key: str, expiry: int, now: float
) -> tuple[int, float, int, float]:
result = self.get_many([previous_key, current_key])
previous_count, current_count = (
int(result.get(previous_key, 0)),
int(result.get(current_key, 0)),
)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl

View File

@ -0,0 +1,253 @@
from __future__ import annotations
import bisect
import threading
import time
from collections import Counter, defaultdict
from math import floor
import limits.typing
from limits.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
TimestampedSlidingWindow,
)
class Entry:
def __init__(self, expiry: float) -> None:
self.atime = time.time()
self.expiry = self.atime + expiry
class MemoryStorage(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
):
"""
rate limit storage using :class:`collections.Counter`
as an in memory storage for fixed and sliding window strategies,
and a simple list to implement moving window strategy.
"""
STORAGE_SCHEME = ["memory"]
def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
self.storage: limits.typing.Counter[str] = Counter()
self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
self.expirations: dict[str, float] = {}
self.events: dict[str, list[Entry]] = {}
self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
state = self.__dict__.copy()
del state["timer"]
del state["locks"]
return state
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
self.__dict__.update(state)
self.locks = defaultdict(threading.RLock)
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
def __expire_events(self) -> None:
for key in list(self.events.keys()):
with self.locks[key]:
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -time.time(), key=lambda event: -event.expiry
)
self.events[key] = self.events[key][:oldest]
if not self.events.get(key, None):
self.locks.pop(key, None)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
def __schedule_expiry(self) -> None:
if not self.timer.is_alive():
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ValueError
def incr(self, key: str, expiry: float, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
self.get(key)
self.__schedule_expiry()
with self.locks[key]:
self.storage[key] += amount
if self.storage[key] == amount:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, 0)
def decr(self, key: str, amount: int = 1) -> int:
"""
decrements the counter for a given rate limit key
:param key: the key to decrement
:param amount: the number to decrement by
"""
self.get(key)
self.__schedule_expiry()
with self.locks[key]:
self.storage[key] = max(self.storage[key] - amount, 0)
return self.storage.get(key, 0)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
return self.storage.get(key, 0)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)
self.locks.pop(key, None)
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
self.__schedule_expiry()
with self.locks[key]:
self.events.setdefault(key, [])
timestamp = time.time()
try:
entry = self.events[key][limit - amount]
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
return False
else:
self.events[key][:0] = [Entry(expiry)] * amount
return True
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return self.expirations.get(key, time.time())
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -(timestamp - expiry), key=lambda entry: -entry.atime
)
return events[oldest - 1].atime, oldest
return timestamp, 0
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
current_count = self.incr(current_key, 2 * expiry, amount=amount)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
self.decr(current_key, amount)
return False
return True
def _get_sliding_window_info(
self,
previous_key: str,
current_key: str,
expiry: int,
now: float,
) -> tuple[int, float, int, float]:
previous_count = self.get(previous_key)
current_count = self.get(current_key)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
def check(self) -> bool:
"""
check if storage is healthy
"""
return True
def reset(self) -> int | None:
num_items = max(len(self.storage), len(self.events))
self.storage.clear()
self.expirations.clear()
self.events.clear()
self.locks.clear()
return num_items

View File

@ -0,0 +1,489 @@
from __future__ import annotations
import datetime
import time
from abc import ABC, abstractmethod
from deprecated.sphinx import versionadded, versionchanged
from limits.typing import (
MongoClient,
MongoCollection,
MongoDatabase,
cast,
)
from ..util import get_dependency
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
class MongoDBStorageBase(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
):
"""
Rate limit storage with MongoDB as backend.
Depends on :pypi:`pymongo`.
"""
DEPENDENCIES = ["pymongo"]
def __init__(
self,
uri: str,
database_name: str = "limits",
counter_collection_name: str = "counters",
window_collection_name: str = "windows",
wrap_exceptions: bool = False,
**options: int | str | bool,
) -> None:
"""
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
:param database_name: The database to use for storing the rate limit
collections.
:param counter_collection_name: The collection name to use for individual counters
used in fixed window strategies
:param window_collection_name: The collection name to use for sliding & moving window
storage
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed to the
constructor of :class:`~pymongo.mongo_client.MongoClient`
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
"""
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self._database_name = database_name
self._collection_mapping = {
"counters": counter_collection_name,
"windows": window_collection_name,
}
self.lib = self.dependencies["pymongo"].module
self.lib_errors, _ = get_dependency("pymongo.errors")
self._storage_uri = uri
self._storage_options = options
self._storage: MongoClient | None = None
@property
def storage(self) -> MongoClient:
if self._storage is None:
self._storage = self._init_mongo_client(
self._storage_uri, **self._storage_options
)
self.__initialize_database()
return self._storage
@property
def _database(self) -> MongoDatabase:
return self.storage[self._database_name]
@property
def counters(self) -> MongoCollection:
return self._database[self._collection_mapping["counters"]]
@property
def windows(self) -> MongoCollection:
return self._database[self._collection_mapping["windows"]]
@abstractmethod
def _init_mongo_client(
self, uri: str | None, **options: int | str | bool
) -> MongoClient:
raise NotImplementedError()
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.lib_errors.PyMongoError # type: ignore
def __initialize_database(self) -> None:
self.counters.create_index("expireAt", expireAfterSeconds=0)
self.windows.create_index("expireAt", expireAfterSeconds=0)
def reset(self) -> int | None:
"""
Delete all rate limit keys in the rate limit collections (counters, windows)
"""
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
self.counters.drop()
self.windows.drop()
return int(num_keys)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.counters.find_one_and_delete({"_id": key})
self.windows.find_one_and_delete({"_id": key})
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = self.counters.find_one({"_id": key})
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
counter = self.counters.find_one(
{
"_id": key,
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
},
projection=["count"],
)
return counter and counter["count"] or 0
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=expiry
)
return int(
self.counters.find_one_and_update(
{"_id": key},
[
{
"$set": {
"count": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": amount,
"else": {"$add": ["$count", amount]},
}
},
"expireAt": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": expiration,
"else": "$expireAt",
}
},
}
},
],
upsert=True,
projection=["count"],
return_document=self.lib.ReturnDocument.AFTER,
)["count"]
)
def check(self) -> bool:
"""
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
"""
try:
self.storage.server_info()
return True
except: # noqa: E722
return False
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if result := list(
self.windows.aggregate(
[
{"$match": {"_id": key}},
{
"$project": {
"filteredEntries": {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$gte": ["$$entry", timestamp - expiry]},
}
}
}
},
{
"$project": {
"min": {"$min": "$filteredEntries"},
"count": {"$size": "$filteredEntries"},
}
},
]
)
):
return result[0]["min"], result[0]["count"]
return timestamp, 0
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
timestamp = time.time()
try:
updates: dict[
str,
dict[str, datetime.datetime | dict[str, list[float] | int]],
] = {
"$push": {
"entries": {
"$each": [timestamp] * amount,
"$position": 0,
"$slice": limit,
}
},
"$set": {
"expireAt": (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expiry)
)
},
}
self.windows.update_one(
{
"_id": key,
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
},
updates,
upsert=True,
)
return True
except self.lib.errors.DuplicateKeyError:
return False
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
expiry_ms = expiry * 1000
if result := self.windows.find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$add": ["$expireAt", expiry_ms],
},
"else": "$expireAt",
}
},
}
}
],
return_document=self.lib.ReturnDocument.AFTER,
projection=["currentCount", "previousCount", "expireAt"],
):
expires_at = (
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
if result.get("expireAt")
else time.time()
)
current_ttl = max(0, expires_at - time.time())
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
return (
result["previousCount"],
prev_ttl,
result["currentCount"],
current_ttl,
)
return 0, 0.0, 0, 0.0
def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
expiry_ms = expiry * 1000
result = self.windows.find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$cond": {
"if": {"$gt": ["$expireAt", 0]},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
}
},
"else": "$expireAt",
}
},
}
},
{
"$set": {
"curWeightedCount": {
"$floor": {
"$add": [
{
"$multiply": [
"$previousCount",
{
"$divide": [
{
"$max": [
0,
{
"$subtract": [
"$expireAt",
{
"$add": [
"$$NOW",
expiry_ms,
]
},
]
},
]
},
expiry_ms,
]
},
]
},
"$currentCount",
]
}
}
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$add": ["$curWeightedCount", amount]},
limit,
]
},
"then": {"$add": ["$currentCount", amount]},
"else": "$currentCount",
}
}
}
},
{
"$set": {
"_acquired": {
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
}
}
},
{"$unset": ["curWeightedCount"]},
],
return_document=self.lib.ReturnDocument.AFTER,
upsert=True,
)
return cast(bool, result["_acquired"])
def __del__(self) -> None:
if self.storage:
self.storage.close()
@versionadded(version="2.1")
@versionchanged(
version="3.14.0",
reason="Added option to select custom collection names for windows & counters",
)
class MongoDBStorage(MongoDBStorageBase):
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
def _init_mongo_client(
self, uri: str | None, **options: int | str | bool
) -> MongoClient:
return cast(MongoClient, self.lib.MongoClient(uri, **options))

View File

@ -0,0 +1,308 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.typing import Literal, RedisClient
from ..util import get_package_data
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
if TYPE_CHECKING:
import redis
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey://`` schema"
),
)
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with redis as backend.
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
``valkey://``)
"""
STORAGE_SCHEME = [
"redis",
"rediss",
"redis+unix",
"valkey",
"valkeys",
"valkey+unix",
]
"""The storage scheme for redis"""
DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}
RES_DIR = "resources/redis/lua_scripts"
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_moving_window.lua"
)
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_sliding_window.lua"
)
lua_moving_window: redis.commands.core.Script
lua_acquire_moving_window: redis.commands.core.Script
lua_sliding_window: redis.commands.core.Script
lua_acquire_sliding_window: redis.commands.core.Script
PREFIX = "LIMITS"
target_server: Literal["redis", "valkey"]
def __init__(
self,
uri: str,
connection_pool: redis.connection.ConnectionPool | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form ``redis://[:password]@host:port``,
``redis://[:password]@host:port/db``,
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
This uri is passed directly to :func:`redis.from_url` except for the
case of ``redis+unix://`` where it is replaced with ``unix://``.
If the uri scheme is ``valkey`` the implementation used will be from
:pypi:`valkey`.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.Redis`
:raise ConfigurationError: when the :pypi:`redis` library is not available
"""
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
self.dependency = self.dependencies[self.target_server].module
uri = uri.replace(f"{self.target_server}+unix", "unix")
if not connection_pool:
self.storage = self.dependency.from_url(uri, **options)
else:
if self.target_server == "redis":
self.storage = self.dependency.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.Valkey(
connection_pool=connection_pool, **options
)
self.initialize_storage(uri)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ( # type: ignore[no-any-return]
self.dependency.RedisError
if self.target_server == "redis"
else self.dependency.ValkeyError
)
def initialize_storage(self, _uri: str) -> None:
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
def get_connection(self, readonly: bool = False) -> RedisClient:
return cast(RedisClient, self.storage)
def _current_window_key(self, key: str) -> str:
"""
Return the current window's storage key (Sliding window strategy)
Contrary to other strategies that have one key per rate limit item,
this strategy has two keys per rate limit item than must be on the same machine.
To keep the current key and the previous key on the same Redis cluster node,
curly braces are added.
Eg: "{constructed_key}"
"""
return f"{{{key}}}"
def _previous_window_key(self, key: str) -> str:
"""
Return the previous window's storage key (Sliding window strategy).
Curvy braces are added on the common pattern with the current window's key,
so the current and the previous key are stored on the same Redis cluster node.
Eg: "{constructed_key}/-1"
"""
return f"{self._current_window_key(key)}/-1"
def prefixed_key(self, key: str) -> str:
return f"{self.PREFIX}:{key}"
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
key = self.prefixed_key(key)
timestamp = time.time()
if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
return float(window[0]), window[1]
return timestamp, 0
def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self.prefixed_key(self._previous_window_key(key))
current_key = self.prefixed_key(self._current_window_key(key))
if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
return (
int(window[0] or 0),
max(0, float(window[1] or 0)) / 1000,
int(window[2] or 0),
max(0, float(window[3] or 0)) / 1000,
)
return 0, 0.0, 0, 0.0
def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
key = self.prefixed_key(key)
return int(self.lua_incr_expire([key], [expiry, amount]))
def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
key = self.prefixed_key(key)
return int(self.get_connection(True).get(key) or 0)
def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
key = self.prefixed_key(key)
self.get_connection().delete(key)
def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
key = self.prefixed_key(key)
timestamp = time.time()
acquired = self.lua_acquire_moving_window(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
Acquire an entry. Shift the current window to the previous window if it expired.
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
previous_key = self.prefixed_key(self._previous_window_key(key))
current_key = self.prefixed_key(self._current_window_key(key))
acquired = self.lua_acquire_sliding_window(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
key = self.prefixed_key(key)
return max(self.get_connection(True).ttl(key), 0) + time.time()
def check(self) -> bool:
"""
check if storage is healthy
"""
try:
return self.get_connection().ping()
except: # noqa
return False
def reset(self) -> int | None:
"""
This function calls a Lua Script to delete keys prefixed with
``self.PREFIX`` in blocks of 5000.
.. warning::
This operation was designed to be fast, but was not tested
on a large production based system. Be careful with its usage as it
could be slow on very large data sets.
"""
prefix = self.prefixed_key("*")
return int(self.lua_clear_keys([prefix]))

View File

@ -0,0 +1,125 @@
from __future__ import annotations
import urllib
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.storage.redis import RedisStorage
@versionchanged(
version="3.14.0",
reason="""
Dropped support for the :pypi:`redis-py-cluster` library
which has been abandoned/deprecated.
""",
)
@versionchanged(
version="2.5.0",
reason="""
Cluster support was provided by the :pypi:`redis-py-cluster` library
which has been absorbed into the official :pypi:`redis` client. By
default the :class:`redis.cluster.RedisCluster` client will be used
however if the version of the package is lower than ``4.2.0`` the implementation
will fallback to trying to use :class:`rediscluster.RedisCluster`.
""",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey+cluster://`` schema"
),
)
class RedisClusterStorage(RedisStorage):
"""
Rate limit storage with redis cluster as backend
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri`
starts with ``valkey+cluster://``).
"""
STORAGE_SCHEME = ["redis+cluster", "valkey+cluster"]
"""The storage scheme for redis cluster"""
DEFAULT_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to the :class:`~redis.cluster.RedisCluster`"
DEPENDENCIES = {
"redis": Version("4.2.0"),
"valkey": Version("6.0"),
}
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``redis+cluster://[:password]@host:port,host:port``
If the uri scheme is ``valkey+cluster`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.cluster.RedisCluster`
:raise ConfigurationError: when the :pypi:`redis` library is not
available or if the redis cluster cannot be reached.
"""
parsed = urllib.parse.urlparse(uri)
parsed_auth: dict[str, float | str | bool] = {}
if parsed.username:
parsed_auth["username"] = parsed.username
if parsed.password:
parsed_auth["password"] = parsed.password
sep = parsed.netloc.find("@") + 1
cluster_hosts = []
for loc in parsed.netloc[sep:].split(","):
host, port = loc.split(":")
cluster_hosts.append((host, int(port)))
self.storage = None
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
merged_options = {**self.DEFAULT_OPTIONS, **parsed_auth, **options}
self.dependency = self.dependencies[self.target_server].module
startup_nodes = [self.dependency.cluster.ClusterNode(*c) for c in cluster_hosts]
if self.target_server == "redis":
self.storage = self.dependency.cluster.RedisCluster(
startup_nodes=startup_nodes, **merged_options
)
else:
self.storage = self.dependency.cluster.ValkeyCluster(
startup_nodes=startup_nodes, **merged_options
)
assert self.storage
self.initialize_storage(uri)
super(RedisStorage, self).__init__(uri, wrap_exceptions, **options)
def reset(self) -> int | None:
"""
Redis Clusters are sharded and deleting across shards
can't be done atomically. Because of this, this reset loops over all
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
one at a time.
.. warning::
This operation was not tested with extremely large data sets.
On a large production based system, care should be taken with its
usage as it could be slow on very large data sets"""
prefix = self.prefixed_key("*")
count = 0
for primary in self.storage.get_primaries():
node = self.storage.get_redis_connection(primary)
keys = node.keys(prefix)
count += sum([node.delete(k.decode("utf-8")) for k in keys])
return count

View File

@ -0,0 +1,120 @@
from __future__ import annotations
import urllib.parse
from typing import TYPE_CHECKING
from deprecated.sphinx import versionchanged
from packaging.version import Version
from limits.errors import ConfigurationError
from limits.storage.redis import RedisStorage
from limits.typing import RedisClient
if TYPE_CHECKING:
pass
@versionchanged(
version="4.3",
reason=(
"Added support for using the redis client from :pypi:`valkey`"
" if :paramref:`uri` has the ``valkey+sentinel://`` schema"
),
)
class RedisSentinelStorage(RedisStorage):
"""
Rate limit storage with redis sentinel as backend
Depends on :pypi:`redis` package (or :pypi:`valkey` if :paramref:`uri` starts with
``valkey+sentinel://``)
"""
STORAGE_SCHEME = ["redis+sentinel", "valkey+sentinel"]
"""The storage scheme for redis accessed via a redis sentinel installation"""
DEPENDENCIES = {
"redis": Version("3.0"),
"redis.sentinel": Version("3.0"),
"valkey": Version("6.0"),
"valkey.sentinel": Version("6.0"),
}
def __init__(
self,
uri: str,
service_name: str | None = None,
use_replicas: bool = True,
sentinel_kwargs: dict[str, float | str | bool] | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``redis+sentinel://host:port,host:port/service_name``
If the uri scheme is ``valkey+sentinel`` the implementation used will be from
:pypi:`valkey`.
:param service_name: sentinel service name
(if not provided in :attr:`uri`)
:param use_replicas: Whether to use replicas for read only operations
:param sentinel_kwargs: kwargs to pass as
:attr:`sentinel_kwargs` to :class:`redis.sentinel.Sentinel`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.sentinel.Sentinel`
:raise ConfigurationError: when the redis library is not available
or if the redis master host cannot be pinged.
"""
super(RedisStorage, self).__init__(
uri, wrap_exceptions=wrap_exceptions, **options
)
parsed = urllib.parse.urlparse(uri)
sentinel_configuration = []
sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
parsed_auth: dict[str, float | str | bool] = {}
if parsed.username:
parsed_auth["username"] = parsed.username
if parsed.password:
parsed_auth["password"] = parsed.password
sep = parsed.netloc.find("@") + 1
for loc in parsed.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
self.service_name = (
parsed.path.replace("/", "") if parsed.path else service_name
)
if self.service_name is None:
raise ConfigurationError("'service_name' not provided")
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
sentinel_dep = self.dependencies[f"{self.target_server}.sentinel"].module
self.sentinel = sentinel_dep.Sentinel(
sentinel_configuration,
sentinel_kwargs={**parsed_auth, **sentinel_options},
**{**parsed_auth, **options},
)
self.storage: RedisClient = self.sentinel.master_for(self.service_name)
self.storage_slave: RedisClient = self.sentinel.slave_for(self.service_name)
self.use_replicas = use_replicas
self.initialize_storage(uri)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ( # type: ignore[no-any-return]
self.dependencies["redis"].module.RedisError
if self.target_server == "redis"
else self.dependencies["valkey"].module.ValkeyError
)
def get_connection(self, readonly: bool = False) -> RedisClient:
return self.storage_slave if (readonly and self.use_replicas) else self.storage

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABCMeta
SCHEMES: dict[str, StorageRegistry] = {}
class StorageRegistry(ABCMeta):
def __new__(
mcs, name: str, bases: tuple[type, ...], dct: dict[str, str | list[str]]
) -> StorageRegistry:
storage_scheme = dct.get("STORAGE_SCHEME", None)
cls = super().__new__(mcs, name, bases, dct)
if storage_scheme:
if isinstance(storage_scheme, str): # noqa
schemes = [storage_scheme]
else:
schemes = storage_scheme
for scheme in schemes:
SCHEMES[scheme] = cls
return cls