Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,184 @@
from __future__ import annotations
import time
from math import floor
from deprecated.sphinx import versionadded, versionchanged
from packaging.version import Version
from limits.aio.storage import SlidingWindowCounterSupport, Storage
from limits.aio.storage.memcached.bridge import MemcachedBridge
from limits.aio.storage.memcached.emcache import EmcacheBridge
from limits.aio.storage.memcached.memcachio import MemcachioBridge
from limits.storage.base import TimestampedSlidingWindow
from limits.typing import Literal
@versionadded(version="2.1")
@versionchanged(
version="5.0",
reason="Switched default implementation to :pypi:`memcachio`",
)
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
"""
Rate limit storage with memcached as backend.
Depends on :pypi:`memcachio`
"""
STORAGE_SCHEME = ["async+memcached"]
"""The storage scheme for memcached to be used in an async context"""
DEPENDENCIES = {
"memcachio": Version("0.3"),
"emcache": Version("0.0"),
}
bridge: MemcachedBridge
storage_exceptions: tuple[Exception, ...]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["memcachio", "emcache"] = "memcachio",
**options: float | str | bool,
) -> None:
"""
:param uri: memcached location of the form
``async+memcached://host:port,host:port``
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``memcachio``: :class:`memcachio.Client`
- ``emcache``: :class:`emcache.Client`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`memcachio.Client`
:raise ConfigurationError: when :pypi:`memcachio` is not available
"""
if implementation == "emcache":
self.bridge = EmcacheBridge(
uri, self.dependencies["emcache"].module, **options
)
else:
self.bridge = MemcachioBridge(
uri, self.dependencies["memcachio"].module, **options
)
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.bridge.base_exceptions
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return await self.bridge.get(key)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
await self.bridge.clear(key)
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
window every hit.
:param amount: the number to increment by
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
"""
return await self.bridge.incr(
key, expiry, amount, set_expiration_key=set_expiration_key
)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return await self.bridge.get_expiry(key)
async def reset(self) -> int | None:
raise NotImplementedError
async def check(self) -> bool:
return await self.bridge.check()
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
t0 = time.time()
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
# We don't need the expiration key as it is estimated with the timestamps directly.
current_count = await self.incr(
current_key, 2 * expiry, amount=amount, set_expiration_key=False
)
t1 = time.time()
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
weighted_count = (
previous_count * actualised_previous_ttl / expiry + current_count
)
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the increment and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
await self.bridge.decr(current_key, amount, noreply=True)
return False
return True
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return await self._get_sliding_window_info(
previous_key, current_key, expiry, now
)
async def _get_sliding_window_info(
self, previous_key: str, current_key: str, expiry: int, now: float
) -> tuple[int, float, int, float]:
result = await self.bridge.get_many([previous_key, current_key])
previous_count = result.get(previous_key.encode("utf-8"), 0)
current_count = result.get(current_key.encode("utf-8"), 0)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl

View File

@ -0,0 +1,73 @@
from __future__ import annotations
import urllib
from abc import ABC, abstractmethod
from types import ModuleType
from limits.typing import Iterable
class MemcachedBridge(ABC):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
self.uri = uri
self.parsed_uri = urllib.parse.urlparse(self.uri)
self.dependency = dependency
self.hosts = []
self.options = options
sep = self.parsed_uri.netloc.strip().find("@") + 1
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
host, port = loc.split(":")
self.hosts.append((host, int(port)))
if self.parsed_uri.username:
self.options["username"] = self.parsed_uri.username
if self.parsed_uri.password:
self.options["password"] = self.parsed_uri.password
def _expiration_key(self, key: str) -> str:
"""
Return the expiration key for the given counter key.
Memcached doesn't natively return the expiration time or TTL for a given key,
so we implement the expiration time on a separate key.
"""
return key + "/expires"
@property
@abstractmethod
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: ...
@abstractmethod
async def get(self, key: str) -> int: ...
@abstractmethod
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
@abstractmethod
async def clear(self, key: str) -> None: ...
@abstractmethod
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
@abstractmethod
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int: ...
@abstractmethod
async def get_expiry(self, key: str) -> float: ...
@abstractmethod
async def check(self) -> bool: ...

View File

@ -0,0 +1,112 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from limits.typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import emcache
class EmcacheBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage = None
async def get_storage(self) -> emcache.Client:
if not self._storage:
self._storage = await self.dependency.create_client(
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
item = await (await self.get_storage()).get(key.encode("utf-8"))
return item and int(item.value) or 0
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
results = await (await self.get_storage()).get_many(
[k.encode("utf-8") for k in keys]
)
return {k: int(item.value) if item else 0 for k, item in results.items()}
async def clear(self, key: str) -> None:
try:
await (await self.get_storage()).delete(key.encode("utf-8"))
except self.dependency.NotFoundCommandError:
pass
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
try:
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
except self.dependency.NotFoundCommandError:
value = 0
return value
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
try:
return await storage.increment(limit_key, amount) or amount
except self.dependency.NotFoundCommandError:
storage = await self.get_storage()
try:
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
exptime=ceil(expiry),
noreply=False,
)
value = amount
except self.dependency.NotStoredStorageCommandError:
# Coult not add the key, probably because a concurrent call has added it
storage = await self.get_storage()
value = await storage.increment(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
item = await storage.get(self._expiration_key(key).encode("utf-8"))
return item and float(item.value) or time.time()
pass
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return (
self.dependency.ClusterNoAvailableNodes,
self.dependency.CommandError,
)
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False

View File

@ -0,0 +1,104 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import memcachio
class MemcachioBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage: memcachio.Client[bytes] | None = None
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]:
return (
self.dependency.errors.NoAvailableNodes,
self.dependency.errors.MemcachioConnectionError,
)
async def get_storage(self) -> memcachio.Client[bytes]:
if not self._storage:
self._storage = self.dependency.Client(
[(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
"""
Return multiple counters at once
:param keys: the keys to get the counter values for
"""
results = await (await self.get_storage()).get(
*[k.encode("utf-8") for k in keys]
)
return {k: int(v.value) for k, v in results.items()}
async def clear(self, key: str) -> None:
await (await self.get_storage()).delete(key.encode("utf-8"))
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
return await storage.decr(limit_key, amount, noreply=noreply) or 0
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
if (value := (await storage.incr(limit_key, amount))) is None:
storage = await self.get_storage()
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
expiry=ceil(expiry),
noreply=False,
)
return amount
else:
storage = await self.get_storage()
return await storage.incr(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
expiration_key = self._expiration_key(key).encode("utf-8")
item = (await storage.get(expiration_key)).get(expiration_key, None)
return item and float(item.value) or time.time()
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False