Update 2025-04-24_11:44:19
This commit is contained in:
@ -0,0 +1,400 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.aio.storage.redis.coredis import CoredisBridge
|
||||
from limits.aio.storage.redis.redispy import RedispyBridge
|
||||
from limits.aio.storage.redis.valkey import ValkeyBridge
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`redis`"
|
||||
" through :paramref:`implementation`"
|
||||
),
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis",
|
||||
"async+rediss",
|
||||
"async+redis+unix",
|
||||
"async+valkey",
|
||||
"async+valkeys",
|
||||
"async+valkey+unix",
|
||||
]
|
||||
"""
|
||||
The storage schemes for redis to be used in an async context
|
||||
"""
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
|
||||
bridge: RedisBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form:
|
||||
|
||||
- ``async+redis://[:password]@host:port``
|
||||
- ``async+redis://[:password]@host:port/db``
|
||||
- ``async+rediss://[:password]@host:port``
|
||||
- ``async+redis+unix:///path/to/sock?db=0`` etc...
|
||||
|
||||
This uri is passed directly to :meth:`coredis.Redis.from_url` or
|
||||
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
|
||||
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
|
||||
|
||||
If the uri scheme is ``async+valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.Redis`
|
||||
- ``redispy``: :class:`redis.asyncio.client.Redis`
|
||||
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
|
||||
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
"""
|
||||
uri = uri.removeprefix("async+")
|
||||
self.target_server = "redis" if uri.startswith("redis") else "valkey"
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
self.options = options
|
||||
if self.target_server == "valkey" or implementation == "valkey":
|
||||
self.bridge = ValkeyBridge(uri, self.dependencies["valkey"].module)
|
||||
else:
|
||||
if implementation == "redispy":
|
||||
self.bridge = RedispyBridge(uri, self.dependencies["redis"].module)
|
||||
else:
|
||||
self.bridge = CoredisBridge(uri, self.dependencies["coredis"].module)
|
||||
self.configure_bridge()
|
||||
self.bridge.register_scripts()
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_basic(**self.options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
|
||||
return await self.bridge.incr(key, expiry, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
|
||||
return await self.bridge.clear(key)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
|
||||
return await self.bridge.acquire_entry(key, limit, expiry, amount)
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
return await self.bridge.get_moving_window(key, limit, expiry)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
current_key = self._current_window_key(key)
|
||||
previous_key = self._previous_window_key(key)
|
||||
return await self.bridge.acquire_sliding_window_entry(
|
||||
previous_key, current_key, limit, expiry, amount
|
||||
)
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling ``PING``
|
||||
"""
|
||||
|
||||
return await self.bridge.check()
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
``self.PREFIX`` in blocks of 5000.
|
||||
|
||||
.. warning:: This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
"""
|
||||
|
||||
return await self.bridge.lua_reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+cluster`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
|
||||
"""
|
||||
The storage schemes for redis cluster to be used in an async context
|
||||
"""
|
||||
|
||||
MODE = "CLUSTER"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.RedisCluster`
|
||||
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
|
||||
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.RedisCluster` or
|
||||
:class:`redis.asyncio.RedisCluster`
|
||||
:raise ConfigurationError: when the redis library is not
|
||||
available or if the redis host cannot be pinged.
|
||||
"""
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_cluster(**self.options)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with ``self.PREFIX`` and calls delete on them,
|
||||
one at a time.
|
||||
|
||||
.. warning:: This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets
|
||||
"""
|
||||
|
||||
return await self.bridge.reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+sentinel`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis+sentinel",
|
||||
"async+valkey+sentinel",
|
||||
]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
MODE = "SENTINEL"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"coredis.sentinel": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.sentinel.Sentinel`
|
||||
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
|
||||
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
|
||||
:param service_name: sentinel service name (if not provided in `uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: optional arguments to pass as
|
||||
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.Sentinel`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis primary host cannot be pinged.
|
||||
"""
|
||||
|
||||
self.service_name = service_name
|
||||
self.use_replicas = use_replicas
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_sentinel(
|
||||
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.util import get_package_data
|
||||
|
||||
|
||||
class RedisBridge(ABC):
|
||||
PREFIX = "LIMITS"
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.parsed_auth = {}
|
||||
if self.parsed_uri.username:
|
||||
self.parsed_auth["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.parsed_auth["password"] = self.parsed_uri.password
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.PREFIX}:{key}"
|
||||
|
||||
@abstractmethod
|
||||
def register_scripts(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_basic(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_cluster(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def lua_reset(self) -> int | None: ...
|
@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncCoRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
|
||||
|
||||
class CoredisBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`coredis.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.exceptions.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.sentinel.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.primary_for(service_name)
|
||||
self.storage_replica = self.sentinel.replica_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts: list[dict[str, int | str]] = []
|
||||
cluster_hosts.extend(
|
||||
{"host": host, "port": int(port)}
|
||||
for loc in self.parsed_uri.netloc[sep:].split(",")
|
||||
if loc
|
||||
for host, port in [loc.split(":")]
|
||||
)
|
||||
self.storage = self.dependency.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_moving_window: coredis.commands.Script[bytes]
|
||||
lua_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_clear_keys: coredis.commands.Script[bytes]
|
||||
lua_incr_expire: coredis.commands.Script[bytes]
|
||||
connection_getter: Callable[[bool], AsyncCoRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
if (value := await self.get_connection().incrby(key, amount)) == amount:
|
||||
await self.get_connection().expire(key, expiry)
|
||||
return value
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete([key])
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window.execute(
|
||||
[key], [timestamp - expiry, limit]
|
||||
)
|
||||
if window:
|
||||
return float(window[0]), window[1] # type: ignore
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
|
||||
if window := await self.lua_sliding_window.execute(
|
||||
[previous_key, current_key], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0), # type: ignore
|
||||
max(0, float(window[1] or 0)) / 1000, # type: ignore
|
||||
int(window[2] or 0), # type: ignore
|
||||
max(0, float(window[3] or 0)) / 1000, # type: ignore
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window.execute(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window.execute(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(prefix)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete([key])
|
||||
return count
|
@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis.commands
|
||||
|
||||
|
||||
class RedispyBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`redis.asyncio.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.asyncio.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.master_for(service_name)
|
||||
self.storage_replica = self.sentinel.slave_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.asyncio.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append(
|
||||
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
|
||||
)
|
||||
|
||||
self.storage = self.dependency.asyncio.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
lua_clear_keys: redis.commands.core.Script
|
||||
lua_incr_expire: redis.commands.core.Script
|
||||
connection_getter: Callable[[bool], AsyncRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
# Redis-py uses a slightly different script registration
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete(key)
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
|
||||
if window:
|
||||
return float(window[0]), window[1]
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
if window := await self.lua_sliding_window(
|
||||
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(
|
||||
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
|
||||
)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete(key)
|
||||
return count
|
@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .redispy import RedispyBridge
|
||||
|
||||
|
||||
class ValkeyBridge(RedispyBridge):
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.ValkeyError,)
|
Reference in New Issue
Block a user