490 lines
18 KiB
Python
490 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
from deprecated.sphinx import versionadded, versionchanged
|
|
|
|
from limits.typing import (
|
|
MongoClient,
|
|
MongoCollection,
|
|
MongoDatabase,
|
|
cast,
|
|
)
|
|
|
|
from ..util import get_dependency
|
|
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
|
|
|
|
|
class MongoDBStorageBase(
|
|
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
|
|
):
|
|
"""
|
|
Rate limit storage with MongoDB as backend.
|
|
|
|
Depends on :pypi:`pymongo`.
|
|
"""
|
|
|
|
DEPENDENCIES = ["pymongo"]
|
|
|
|
def __init__(
|
|
self,
|
|
uri: str,
|
|
database_name: str = "limits",
|
|
counter_collection_name: str = "counters",
|
|
window_collection_name: str = "windows",
|
|
wrap_exceptions: bool = False,
|
|
**options: int | str | bool,
|
|
) -> None:
|
|
"""
|
|
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
|
|
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
|
|
:param database_name: The database to use for storing the rate limit
|
|
collections.
|
|
:param counter_collection_name: The collection name to use for individual counters
|
|
used in fixed window strategies
|
|
:param window_collection_name: The collection name to use for sliding & moving window
|
|
storage
|
|
:param wrap_exceptions: Whether to wrap storage exceptions in
|
|
:exc:`limits.errors.StorageError` before raising it.
|
|
:param options: all remaining keyword arguments are passed to the
|
|
constructor of :class:`~pymongo.mongo_client.MongoClient`
|
|
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
|
|
"""
|
|
|
|
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
|
self._database_name = database_name
|
|
self._collection_mapping = {
|
|
"counters": counter_collection_name,
|
|
"windows": window_collection_name,
|
|
}
|
|
self.lib = self.dependencies["pymongo"].module
|
|
self.lib_errors, _ = get_dependency("pymongo.errors")
|
|
self._storage_uri = uri
|
|
self._storage_options = options
|
|
self._storage: MongoClient | None = None
|
|
|
|
@property
|
|
def storage(self) -> MongoClient:
|
|
if self._storage is None:
|
|
self._storage = self._init_mongo_client(
|
|
self._storage_uri, **self._storage_options
|
|
)
|
|
self.__initialize_database()
|
|
return self._storage
|
|
|
|
@property
|
|
def _database(self) -> MongoDatabase:
|
|
return self.storage[self._database_name]
|
|
|
|
@property
|
|
def counters(self) -> MongoCollection:
|
|
return self._database[self._collection_mapping["counters"]]
|
|
|
|
@property
|
|
def windows(self) -> MongoCollection:
|
|
return self._database[self._collection_mapping["windows"]]
|
|
|
|
@abstractmethod
|
|
def _init_mongo_client(
|
|
self, uri: str | None, **options: int | str | bool
|
|
) -> MongoClient:
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def base_exceptions(
|
|
self,
|
|
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
|
return self.lib_errors.PyMongoError # type: ignore
|
|
|
|
def __initialize_database(self) -> None:
|
|
self.counters.create_index("expireAt", expireAfterSeconds=0)
|
|
self.windows.create_index("expireAt", expireAfterSeconds=0)
|
|
|
|
def reset(self) -> int | None:
|
|
"""
|
|
Delete all rate limit keys in the rate limit collections (counters, windows)
|
|
"""
|
|
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
|
|
self.counters.drop()
|
|
self.windows.drop()
|
|
|
|
return int(num_keys)
|
|
|
|
def clear(self, key: str) -> None:
|
|
"""
|
|
:param key: the key to clear rate limits for
|
|
"""
|
|
self.counters.find_one_and_delete({"_id": key})
|
|
self.windows.find_one_and_delete({"_id": key})
|
|
|
|
def get_expiry(self, key: str) -> float:
|
|
"""
|
|
:param key: the key to get the expiry for
|
|
"""
|
|
counter = self.counters.find_one({"_id": key})
|
|
return (
|
|
(counter["expireAt"] if counter else datetime.datetime.now())
|
|
.replace(tzinfo=datetime.timezone.utc)
|
|
.timestamp()
|
|
)
|
|
|
|
def get(self, key: str) -> int:
|
|
"""
|
|
:param key: the key to get the counter value for
|
|
"""
|
|
counter = self.counters.find_one(
|
|
{
|
|
"_id": key,
|
|
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
|
},
|
|
projection=["count"],
|
|
)
|
|
|
|
return counter and counter["count"] or 0
|
|
|
|
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
|
"""
|
|
increments the counter for a given rate limit key
|
|
|
|
:param key: the key to increment
|
|
:param expiry: amount in seconds for the key to expire in
|
|
:param amount: the number to increment by
|
|
"""
|
|
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
|
seconds=expiry
|
|
)
|
|
|
|
return int(
|
|
self.counters.find_one_and_update(
|
|
{"_id": key},
|
|
[
|
|
{
|
|
"$set": {
|
|
"count": {
|
|
"$cond": {
|
|
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
|
"then": amount,
|
|
"else": {"$add": ["$count", amount]},
|
|
}
|
|
},
|
|
"expireAt": {
|
|
"$cond": {
|
|
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
|
"then": expiration,
|
|
"else": "$expireAt",
|
|
}
|
|
},
|
|
}
|
|
},
|
|
],
|
|
upsert=True,
|
|
projection=["count"],
|
|
return_document=self.lib.ReturnDocument.AFTER,
|
|
)["count"]
|
|
)
|
|
|
|
def check(self) -> bool:
|
|
"""
|
|
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
|
|
"""
|
|
try:
|
|
self.storage.server_info()
|
|
|
|
return True
|
|
except: # noqa: E722
|
|
return False
|
|
|
|
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
|
"""
|
|
returns the starting point and the number of entries in the moving
|
|
window
|
|
|
|
:param key: rate limit key
|
|
:param expiry: expiry of entry
|
|
:return: (start of window, number of acquired entries)
|
|
"""
|
|
timestamp = time.time()
|
|
if result := list(
|
|
self.windows.aggregate(
|
|
[
|
|
{"$match": {"_id": key}},
|
|
{
|
|
"$project": {
|
|
"filteredEntries": {
|
|
"$filter": {
|
|
"input": "$entries",
|
|
"as": "entry",
|
|
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"$project": {
|
|
"min": {"$min": "$filteredEntries"},
|
|
"count": {"$size": "$filteredEntries"},
|
|
}
|
|
},
|
|
]
|
|
)
|
|
):
|
|
return result[0]["min"], result[0]["count"]
|
|
return timestamp, 0
|
|
|
|
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
|
"""
|
|
:param key: rate limit key to acquire an entry in
|
|
:param limit: amount of entries allowed
|
|
:param expiry: expiry of the entry
|
|
:param amount: the number of entries to acquire
|
|
"""
|
|
if amount > limit:
|
|
return False
|
|
|
|
timestamp = time.time()
|
|
try:
|
|
updates: dict[
|
|
str,
|
|
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
|
] = {
|
|
"$push": {
|
|
"entries": {
|
|
"$each": [timestamp] * amount,
|
|
"$position": 0,
|
|
"$slice": limit,
|
|
}
|
|
},
|
|
"$set": {
|
|
"expireAt": (
|
|
datetime.datetime.now(datetime.timezone.utc)
|
|
+ datetime.timedelta(seconds=expiry)
|
|
)
|
|
},
|
|
}
|
|
|
|
self.windows.update_one(
|
|
{
|
|
"_id": key,
|
|
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
|
},
|
|
updates,
|
|
upsert=True,
|
|
)
|
|
|
|
return True
|
|
except self.lib.errors.DuplicateKeyError:
|
|
return False
|
|
|
|
def get_sliding_window(
|
|
self, key: str, expiry: int
|
|
) -> tuple[int, float, int, float]:
|
|
expiry_ms = expiry * 1000
|
|
if result := self.windows.find_one_and_update(
|
|
{"_id": key},
|
|
[
|
|
{
|
|
"$set": {
|
|
"previousCount": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": {"$ifNull": ["$currentCount", 0]},
|
|
"else": {"$ifNull": ["$previousCount", 0]},
|
|
}
|
|
},
|
|
"currentCount": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": 0,
|
|
"else": {"$ifNull": ["$currentCount", 0]},
|
|
}
|
|
},
|
|
"expireAt": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": {
|
|
"$add": ["$expireAt", expiry_ms],
|
|
},
|
|
"else": "$expireAt",
|
|
}
|
|
},
|
|
}
|
|
}
|
|
],
|
|
return_document=self.lib.ReturnDocument.AFTER,
|
|
projection=["currentCount", "previousCount", "expireAt"],
|
|
):
|
|
expires_at = (
|
|
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
|
if result.get("expireAt")
|
|
else time.time()
|
|
)
|
|
current_ttl = max(0, expires_at - time.time())
|
|
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
|
|
|
return (
|
|
result["previousCount"],
|
|
prev_ttl,
|
|
result["currentCount"],
|
|
current_ttl,
|
|
)
|
|
return 0, 0.0, 0, 0.0
|
|
|
|
def acquire_sliding_window_entry(
|
|
self, key: str, limit: int, expiry: int, amount: int = 1
|
|
) -> bool:
|
|
expiry_ms = expiry * 1000
|
|
result = self.windows.find_one_and_update(
|
|
{"_id": key},
|
|
[
|
|
{
|
|
"$set": {
|
|
"previousCount": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": {"$ifNull": ["$currentCount", 0]},
|
|
"else": {"$ifNull": ["$previousCount", 0]},
|
|
}
|
|
},
|
|
}
|
|
},
|
|
{
|
|
"$set": {
|
|
"currentCount": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": 0,
|
|
"else": {"$ifNull": ["$currentCount", 0]},
|
|
}
|
|
},
|
|
"expireAt": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$subtract": ["$expireAt", "$$NOW"]},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
"then": {
|
|
"$cond": {
|
|
"if": {"$gt": ["$expireAt", 0]},
|
|
"then": {"$add": ["$expireAt", expiry_ms]},
|
|
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
|
}
|
|
},
|
|
"else": "$expireAt",
|
|
}
|
|
},
|
|
}
|
|
},
|
|
{
|
|
"$set": {
|
|
"curWeightedCount": {
|
|
"$floor": {
|
|
"$add": [
|
|
{
|
|
"$multiply": [
|
|
"$previousCount",
|
|
{
|
|
"$divide": [
|
|
{
|
|
"$max": [
|
|
0,
|
|
{
|
|
"$subtract": [
|
|
"$expireAt",
|
|
{
|
|
"$add": [
|
|
"$$NOW",
|
|
expiry_ms,
|
|
]
|
|
},
|
|
]
|
|
},
|
|
]
|
|
},
|
|
expiry_ms,
|
|
]
|
|
},
|
|
]
|
|
},
|
|
"$currentCount",
|
|
]
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"$set": {
|
|
"currentCount": {
|
|
"$cond": {
|
|
"if": {
|
|
"$lte": [
|
|
{"$add": ["$curWeightedCount", amount]},
|
|
limit,
|
|
]
|
|
},
|
|
"then": {"$add": ["$currentCount", amount]},
|
|
"else": "$currentCount",
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"$set": {
|
|
"_acquired": {
|
|
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
|
}
|
|
}
|
|
},
|
|
{"$unset": ["curWeightedCount"]},
|
|
],
|
|
return_document=self.lib.ReturnDocument.AFTER,
|
|
upsert=True,
|
|
)
|
|
return cast(bool, result["_acquired"])
|
|
|
|
def __del__(self) -> None:
|
|
if self.storage:
|
|
self.storage.close()
|
|
|
|
|
|
@versionadded(version="2.1")
|
|
@versionchanged(
|
|
version="3.14.0",
|
|
reason="Added option to select custom collection names for windows & counters",
|
|
)
|
|
class MongoDBStorage(MongoDBStorageBase):
|
|
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
|
|
|
|
def _init_mongo_client(
|
|
self, uri: str | None, **options: int | str | bool
|
|
) -> MongoClient:
|
|
return cast(MongoClient, self.lib.MongoClient(uri, **options))
|