Update 2025-04-24_11:44:19
This commit is contained in:
80
venv/lib/python3.11/site-packages/limits/storage/__init__.py
Normal file
80
venv/lib/python3.11/site-packages/limits/storage/__init__.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
import limits # noqa
|
||||
|
||||
from ..errors import ConfigurationError
|
||||
from ..typing import TypeAlias, cast
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage, MongoDBStorageBase
|
||||
from .redis import RedisStorage
|
||||
from .redis_cluster import RedisClusterStorage
|
||||
from .redis_sentinel import RedisSentinelStorage
|
||||
from .registry import SCHEMES
|
||||
|
||||
StorageTypes: TypeAlias = "Storage | limits.aio.storage.Storage"
|
||||
|
||||
|
||||
def storage_from_string(
|
||||
storage_string: str, **options: float | str | bool
|
||||
) -> StorageTypes:
|
||||
"""
|
||||
Factory function to get an instance of the storage class based
|
||||
on the uri of the storage. In most cases using it should be sufficient
|
||||
instead of directly instantiating the storage classes. for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("memory://")
|
||||
memcached = storage_from_string("memcached://localhost:11211")
|
||||
redis = storage_from_string("redis://localhost:6379")
|
||||
|
||||
The same function can be used to construct the :ref:`storage:async storage`
|
||||
variants, for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("async+memory://")
|
||||
memcached = storage_from_string("async+memcached://localhost:11211")
|
||||
redis = storage_from_string("async+redis://localhost:6379")
|
||||
|
||||
:param storage_string: a string of the form ``scheme://host:port``.
|
||||
More details about supported storage schemes can be found at
|
||||
:ref:`storage:storage scheme`
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor matched by :paramref:`storage_string`.
|
||||
:raises ConfigurationError: when the :attr:`storage_string` cannot be
|
||||
mapped to a registered :class:`limits.storage.Storage`
|
||||
or :class:`limits.aio.storage.Storage` instance.
|
||||
|
||||
|
||||
"""
|
||||
scheme = urllib.parse.urlparse(storage_string).scheme
|
||||
|
||||
if scheme not in SCHEMES:
|
||||
raise ConfigurationError(f"unknown storage scheme : {storage_string}")
|
||||
|
||||
return cast(StorageTypes, SCHEMES[scheme](storage_string, **options))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MongoDBStorageBase",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
"storage_from_string",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
232
venv/lib/python3.11/site-packages/limits/storage/base.py
Normal file
232
venv/lib/python3.11/site-packages/limits/storage/base.py
Normal file
@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, R],
|
||||
) -> Callable[P, R]:
|
||||
@functools.wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing a storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TimestampedSlidingWindow:
|
||||
"""Helper class for storage that support the sliding window counter, with timestamp based keys."""
|
||||
|
||||
@classmethod
|
||||
def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
|
||||
"""
|
||||
returns the previous and the current window's keys.
|
||||
|
||||
:param key: the key to get the window's keys from
|
||||
:param expiry: the expiry of the limit item, in seconds
|
||||
:param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
|
||||
|
||||
Returns a tuple with the previous and the current key: (previous, current).
|
||||
|
||||
Example:
|
||||
- key = "mykey"
|
||||
- expiry = 60
|
||||
- at = 1738576292.6631825
|
||||
|
||||
The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
|
||||
"""
|
||||
return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"
|
299
venv/lib/python3.11/site-packages/limits/storage/memcached.py
Normal file
299
venv/lib/python3.11/site-packages/limits/storage/memcached.py
Normal file
@ -0,0 +1,299 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable
|
||||
from math import ceil, floor
|
||||
from types import ModuleType
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.base import (
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
MemcachedClientP,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`pymemcache`.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memcached"]
|
||||
"""The storage scheme for memcached"""
|
||||
DEPENDENCIES = ["pymemcache"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: str | Callable[[], MemcachedClientP],
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``memcached://host:port,host:port``,
|
||||
``memcached:///var/tmp/path/to/sock``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`pymemcache.client.base.PooledClient`
|
||||
or :class:`pymemcache.client.hash.HashClient` (if there are more than
|
||||
one hosts specified)
|
||||
:raise ConfigurationError: when :pypi:`pymemcache` is not available
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
self.hosts = []
|
||||
|
||||
for loc in parsed.netloc.strip().split(","):
|
||||
if not loc:
|
||||
continue
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
else:
|
||||
# filesystem path to UDS
|
||||
|
||||
if parsed.path and not parsed.netloc and not parsed.port:
|
||||
self.hosts = [parsed.path] # type: ignore
|
||||
|
||||
self.dependency = self.dependencies["pymemcache"].module
|
||||
self.library = str(options.pop("library", "pymemcache.client"))
|
||||
self.cluster_library = str(
|
||||
options.pop("cluster_library", "pymemcache.client.hash")
|
||||
)
|
||||
self.client_getter = cast(
|
||||
Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
|
||||
options.pop("client_getter", self.get_client),
|
||||
)
|
||||
self.options = options
|
||||
|
||||
if not get_dependency(self.library):
|
||||
raise ConfigurationError(
|
||||
f"memcached prerequisite not available. please install {self.library}"
|
||||
) # pragma: no cover
|
||||
self.local_storage = threading.local()
|
||||
self.local_storage.storage = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.dependency.MemcacheError # type: ignore[no-any-return]
|
||||
|
||||
def get_client(
|
||||
self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
|
||||
) -> MemcachedClientP:
|
||||
"""
|
||||
returns a memcached client.
|
||||
|
||||
:param module: the memcached module
|
||||
:param hosts: list of memcached hosts
|
||||
"""
|
||||
|
||||
return cast(
|
||||
MemcachedClientP,
|
||||
(
|
||||
module.HashClient(hosts, **kwargs)
|
||||
if len(hosts) > 1
|
||||
else module.PooledClient(*hosts, **kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
def call_memcached_func(
|
||||
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> R:
|
||||
if "noreply" in kwargs:
|
||||
argspec = inspect.getfullargspec(func)
|
||||
|
||||
if not ("noreply" in argspec.args or argspec.varkw):
|
||||
kwargs.pop("noreply")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def storage(self) -> MemcachedClientP:
|
||||
"""
|
||||
lazily creates a memcached client instance using a thread local
|
||||
"""
|
||||
|
||||
if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
|
||||
dependency = get_dependency(
|
||||
self.cluster_library if len(self.hosts) > 1 else self.library
|
||||
)[0]
|
||||
|
||||
if not dependency:
|
||||
raise ConfigurationError(f"Unable to import {self.cluster_library}")
|
||||
self.local_storage.storage = self.client_getter(
|
||||
dependency, self.hosts, **self.options
|
||||
)
|
||||
|
||||
return cast(MemcachedClientP, self.local_storage.storage)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return int(self.storage.get(key, "0"))
|
||||
|
||||
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.storage.get_many(keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.delete(key)
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
|
||||
"""
|
||||
if (
|
||||
value := self.call_memcached_func(
|
||||
self.storage.incr, key, amount, noreply=False
|
||||
)
|
||||
) is not None:
|
||||
return value
|
||||
else:
|
||||
if not self.call_memcached_func(
|
||||
self.storage.add, key, amount, ceil(expiry), noreply=False
|
||||
):
|
||||
return self.storage.incr(key, amount) or amount
|
||||
else:
|
||||
if set_expiration_key:
|
||||
self.call_memcached_func(
|
||||
self.storage.set,
|
||||
self._expiration_key(key),
|
||||
expiry + time.time(),
|
||||
expire=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
|
||||
return amount
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return float(self.storage.get(self._expiration_key(key)) or time.time())
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
self.call_memcached_func(self.storage.get, "limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now=now
|
||||
)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.call_memcached_func(
|
||||
self.storage.decr,
|
||||
current_key,
|
||||
amount,
|
||||
noreply=True,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = self.get_many([previous_key, current_key])
|
||||
previous_count, current_count = (
|
||||
int(result.get(previous_key, 0)),
|
||||
int(result.get(current_key, 0)),
|
||||
)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
253
venv/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
253
venv/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
import limits.typing
|
||||
from limits.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: float) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed and sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memory"]
|
||||
|
||||
def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.locks = defaultdict(threading.RLock)
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
def __expire_events(self) -> None:
|
||||
for key in list(self.events.keys()):
|
||||
with self.locks[key]:
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -time.time(), key=lambda event: -event.expiry
|
||||
)
|
||||
self.events[key] = self.events[key][:oldest]
|
||||
if not self.events.get(key, None):
|
||||
self.locks.pop(key, None)
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def __schedule_expiry(self) -> None:
|
||||
if not self.timer.is_alive():
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key
|
||||
|
||||
:param key: the key to decrement
|
||||
:param amount: the number to decrement by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = self.get(previous_key)
|
||||
current_count = self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
return num_items
|
489
venv/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
489
venv/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
@ -0,0 +1,489 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.typing import (
|
||||
MongoClient,
|
||||
MongoCollection,
|
||||
MongoDatabase,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..util import get_dependency
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
|
||||
class MongoDBStorageBase(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
|
||||
):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`pymongo`.
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: int | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor of :class:`~pymongo.mongo_client.MongoClient`
|
||||
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
|
||||
"""
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self._database_name = database_name
|
||||
self._collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.lib = self.dependencies["pymongo"].module
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
self._storage_uri = uri
|
||||
self._storage_options = options
|
||||
self._storage: MongoClient | None = None
|
||||
|
||||
@property
|
||||
def storage(self) -> MongoClient:
|
||||
if self._storage is None:
|
||||
self._storage = self._init_mongo_client(
|
||||
self._storage_uri, **self._storage_options
|
||||
)
|
||||
self.__initialize_database()
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def _database(self) -> MongoDatabase:
|
||||
return self.storage[self._database_name]
|
||||
|
||||
@property
|
||||
def counters(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["counters"]]
|
||||
|
||||
@property
|
||||
def windows(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["windows"]]
|
||||
|
||||
@abstractmethod
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
def __initialize_database(self) -> None:
|
||||
self.counters.create_index("expireAt", expireAfterSeconds=0)
|
||||
self.windows.create_index("expireAt", expireAfterSeconds=0)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
|
||||
self.counters.drop()
|
||||
self.windows.drop()
|
||||
|
||||
return int(num_keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.counters.find_one_and_delete({"_id": key})
|
||||
self.windows.find_one_and_delete({"_id": key})
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = self.counters.find_one({"_id": key})
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = self.counters.find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
return int(
|
||||
self.counters.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
)["count"]
|
||||
)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
|
||||
"""
|
||||
try:
|
||||
self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if result := list(
|
||||
self.windows.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
self.windows.update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.lib.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$add": ["$expireAt", expiry_ms],
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
expiry_ms = expiry * 1000
|
||||
result = self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.close()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(MongoDBStorageBase):
|
||||
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
|
||||
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
return cast(MongoClient, self.lib.MongoClient(uri, **options))
|
308
venv/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
308
venv/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.typing import Literal, RedisClient
|
||||
|
||||
from ..util import get_package_data
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"redis",
|
||||
"rediss",
|
||||
"redis+unix",
|
||||
"valkey",
|
||||
"valkeys",
|
||||
"valkey+unix",
|
||||
]
|
||||
"""The storage scheme for redis"""
|
||||
|
||||
DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}
|
||||
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
|
||||
PREFIX = "LIMITS"
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
connection_pool: redis.connection.ConnectionPool | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``redis://[:password]@host:port``,
|
||||
``redis://[:password]@host:port/db``,
|
||||
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
|
||||
This uri is passed directly to :func:`redis.from_url` except for the
|
||||
case of ``redis+unix://`` where it is replaced with ``unix://``.
|
||||
|
||||
If the uri scheme is ``valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.Redis`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not available
|
||||
"""
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
if not connection_pool:
|
||||
self.storage = self.dependency.from_url(uri, **options)
|
||||
else:
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Valkey(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependency.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependency.ValkeyError
|
||||
)
|
||||
|
||||
def initialize_storage(self, _uri: str) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return cast(RedisClient, self.storage)
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.PREFIX}:{key}"
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
|
||||
return float(window[0]), window[1]
|
||||
|
||||
return timestamp, 0
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.get_connection(True).get(key) or 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
self.get_connection().delete(key)
|
||||
|
||||
def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry. Shift the current window to the previous window if it expired.
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
acquired = self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(self.get_connection(True).ttl(key), 0) + time.time()
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
return self.get_connection().ping()
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
``self.PREFIX`` in blocks of 5000.
|
||||
|
||||
.. warning::
|
||||
This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
|
||||
"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
return int(self.lua_clear_keys([prefix]))
|
@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.storage.redis import RedisStorage
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="""
|
||||
Dropped support for the :pypi:`redis-py-cluster` library
|
||||
which has been abandoned/deprecated.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="2.5.0",
|
||||
reason="""
|
||||
Cluster support was provided by the :pypi:`redis-py-cluster` library
|
||||
which has been absorbed into the official :pypi:`redis` client. By
|
||||
default the :class:`redis.cluster.RedisCluster` client will be used
|
||||
however if the version of the package is lower than ``4.2.0`` the implementation
|
||||
will fallback to trying to use :class:`rediscluster.RedisCluster`.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+cluster://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri`
|
||||
starts with ``valkey+cluster://``).
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+cluster", "valkey+cluster"]
|
||||
"""The storage scheme for redis cluster"""
|
||||
|
||||
DEFAULT_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to the :class:`~redis.cluster.RedisCluster`"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("4.2.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.cluster.RedisCluster`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not
|
||||
available or if the redis cluster cannot be reached.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append((host, int(port)))
|
||||
|
||||
self.storage = None
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
merged_options = {**self.DEFAULT_OPTIONS, **parsed_auth, **options}
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
startup_nodes = [self.dependency.cluster.ClusterNode(*c) for c in cluster_hosts]
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.cluster.RedisCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.cluster.ValkeyCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
|
||||
assert self.storage
|
||||
self.initialize_storage(uri)
|
||||
super(RedisStorage, self).__init__(uri, wrap_exceptions, **options)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
|
||||
one at a time.
|
||||
|
||||
.. warning::
|
||||
This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
count = 0
|
||||
for primary in self.storage.get_primaries():
|
||||
node = self.storage.get_redis_connection(primary)
|
||||
keys = node.keys(prefix)
|
||||
count += sum([node.delete(k.decode("utf-8")) for k in keys])
|
||||
return count
|
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.redis import RedisStorage
|
||||
from limits.typing import RedisClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+sentinel://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`redis` package (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey+sentinel://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+sentinel", "valkey+sentinel"]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("3.0"),
|
||||
"redis.sentinel": Version("3.0"),
|
||||
"valkey": Version("6.0"),
|
||||
"valkey.sentinel": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri scheme is ``valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param service_name: sentinel service name
|
||||
(if not provided in :attr:`uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: kwargs to pass as
|
||||
:attr:`sentinel_kwargs` to :class:`redis.sentinel.Sentinel`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis master host cannot be pinged.
|
||||
"""
|
||||
|
||||
super(RedisStorage, self).__init__(
|
||||
uri, wrap_exceptions=wrap_exceptions, **options
|
||||
)
|
||||
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
sentinel_configuration = []
|
||||
sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
|
||||
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
self.service_name = (
|
||||
parsed.path.replace("/", "") if parsed.path else service_name
|
||||
)
|
||||
|
||||
if self.service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
sentinel_dep = self.dependencies[f"{self.target_server}.sentinel"].module
|
||||
self.sentinel = sentinel_dep.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**parsed_auth, **sentinel_options},
|
||||
**{**parsed_auth, **options},
|
||||
)
|
||||
self.storage: RedisClient = self.sentinel.master_for(self.service_name)
|
||||
self.storage_slave: RedisClient = self.sentinel.slave_for(self.service_name)
|
||||
self.use_replicas = use_replicas
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependencies["redis"].module.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependencies["valkey"].module.ValkeyError
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return self.storage_slave if (readonly and self.use_replicas) else self.storage
|
24
venv/lib/python3.11/site-packages/limits/storage/registry.py
Normal file
24
venv/lib/python3.11/site-packages/limits/storage/registry.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
SCHEMES: dict[str, StorageRegistry] = {}
|
||||
|
||||
|
||||
class StorageRegistry(ABCMeta):
|
||||
def __new__(
|
||||
mcs, name: str, bases: tuple[type, ...], dct: dict[str, str | list[str]]
|
||||
) -> StorageRegistry:
|
||||
storage_scheme = dct.get("STORAGE_SCHEME", None)
|
||||
cls = super().__new__(mcs, name, bases, dct)
|
||||
|
||||
if storage_scheme:
|
||||
if isinstance(storage_scheme, str): # noqa
|
||||
schemes = [storage_scheme]
|
||||
else:
|
||||
schemes = storage_scheme
|
||||
|
||||
for scheme in schemes:
|
||||
SCHEMES[scheme] = cls
|
||||
|
||||
return cls
|
Reference in New Issue
Block a user