Update 2025-04-24_11:44:19
This commit is contained in:
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.memcached.bridge import MemcachedBridge
|
||||
from limits.aio.storage.memcached.emcache import EmcacheBridge
|
||||
from limits.aio.storage.memcached.memcachio import MemcachioBridge
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="5.0",
|
||||
reason="Switched default implementation to :pypi:`memcachio`",
|
||||
)
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`memcachio`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memcached"]
|
||||
"""The storage scheme for memcached to be used in an async context"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"memcachio": Version("0.3"),
|
||||
"emcache": Version("0.0"),
|
||||
}
|
||||
|
||||
bridge: MemcachedBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["memcachio", "emcache"] = "memcachio",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``async+memcached://host:port,host:port``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``memcachio``: :class:`memcachio.Client`
|
||||
- ``emcache``: :class:`emcache.Client`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`memcachio.Client`
|
||||
:raise ConfigurationError: when :pypi:`memcachio` is not available
|
||||
"""
|
||||
if implementation == "emcache":
|
||||
self.bridge = EmcacheBridge(
|
||||
uri, self.dependencies["emcache"].module, **options
|
||||
)
|
||||
else:
|
||||
self.bridge = MemcachioBridge(
|
||||
uri, self.dependencies["memcachio"].module, **options
|
||||
)
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await self.bridge.clear(key)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
|
||||
"""
|
||||
return await self.bridge.incr(
|
||||
key, expiry, amount, set_expiration_key=set_expiration_key
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def check(self) -> bool:
|
||||
return await self.bridge.check()
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
t0 = time.time()
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = await self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
t1 = time.time()
|
||||
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the increment and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.bridge.decr(current_key, amount, noreply=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = await self.bridge.get_many([previous_key, current_key])
|
||||
|
||||
previous_count = result.get(previous_key.encode("utf-8"), 0)
|
||||
current_count = result.get(current_key.encode("utf-8"), 0)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import Iterable
|
||||
|
||||
|
||||
class MemcachedBridge(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.hosts = []
|
||||
self.options = options
|
||||
|
||||
sep = self.parsed_uri.netloc.strip().find("@") + 1
|
||||
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
|
||||
if self.parsed_uri.username:
|
||||
self.options["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.options["password"] = self.parsed_uri.password
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import emcache
|
||||
|
||||
|
||||
class EmcacheBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage = None
|
||||
|
||||
async def get_storage(self) -> emcache.Client:
|
||||
if not self._storage:
|
||||
self._storage = await self.dependency.create_client(
|
||||
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
||||
return item and int(item.value) or 0
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
results = await (await self.get_storage()).get_many(
|
||||
[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(item.value) if item else 0 for k, item in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
try:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
except self.dependency.NotFoundCommandError:
|
||||
pass
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
try:
|
||||
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
|
||||
except self.dependency.NotFoundCommandError:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
try:
|
||||
return await storage.increment(limit_key, amount) or amount
|
||||
except self.dependency.NotFoundCommandError:
|
||||
storage = await self.get_storage()
|
||||
try:
|
||||
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
exptime=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
value = amount
|
||||
except self.dependency.NotStoredStorageCommandError:
|
||||
# Coult not add the key, probably because a concurrent call has added it
|
||||
storage = await self.get_storage()
|
||||
value = await storage.increment(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
item = await storage.get(self._expiration_key(key).encode("utf-8"))
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return (
|
||||
self.dependency.ClusterNoAvailableNodes,
|
||||
self.dependency.CommandError,
|
||||
)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import memcachio
|
||||
|
||||
|
||||
class MemcachioBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage: memcachio.Client[bytes] | None = None
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (
|
||||
self.dependency.errors.NoAvailableNodes,
|
||||
self.dependency.errors.MemcachioConnectionError,
|
||||
)
|
||||
|
||||
async def get_storage(self) -> memcachio.Client[bytes]:
|
||||
if not self._storage:
|
||||
self._storage = self.dependency.Client(
|
||||
[(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
"""
|
||||
results = await (await self.get_storage()).get(
|
||||
*[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(v.value) for k, v in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
return await storage.decr(limit_key, amount, noreply=noreply) or 0
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
if (value := (await storage.incr(limit_key, amount))) is None:
|
||||
storage = await self.get_storage()
|
||||
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
expiry=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
return amount
|
||||
else:
|
||||
storage = await self.get_storage()
|
||||
return await storage.incr(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
expiration_key = self._expiration_key(key).encode("utf-8")
|
||||
item = (await storage.get(expiration_key)).get(expiration_key, None)
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
Reference in New Issue
Block a user