254 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import bisect
 | |
| import threading
 | |
| import time
 | |
| from collections import Counter, defaultdict
 | |
| from math import floor
 | |
| 
 | |
| import limits.typing
 | |
| from limits.storage.base import (
 | |
|     MovingWindowSupport,
 | |
|     SlidingWindowCounterSupport,
 | |
|     Storage,
 | |
|     TimestampedSlidingWindow,
 | |
| )
 | |
| 
 | |
| 
 | |
| class Entry:
 | |
|     def __init__(self, expiry: float) -> None:
 | |
|         self.atime = time.time()
 | |
|         self.expiry = self.atime + expiry
 | |
| 
 | |
| 
 | |
| class MemoryStorage(
 | |
|     Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
 | |
| ):
 | |
|     """
 | |
|     rate limit storage using :class:`collections.Counter`
 | |
|     as an in memory storage for fixed and sliding window strategies,
 | |
|     and a simple list to implement moving window strategy.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     STORAGE_SCHEME = ["memory"]
 | |
| 
 | |
|     def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
 | |
|         self.storage: limits.typing.Counter[str] = Counter()
 | |
|         self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
 | |
|         self.expirations: dict[str, float] = {}
 | |
|         self.events: dict[str, list[Entry]] = {}
 | |
|         self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
 | |
|         self.timer.start()
 | |
|         super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
 | |
| 
 | |
|     def __getstate__(self) -> dict[str, limits.typing.Any]:  # type: ignore[explicit-any]
 | |
|         state = self.__dict__.copy()
 | |
|         del state["timer"]
 | |
|         del state["locks"]
 | |
|         return state
 | |
| 
 | |
|     def __setstate__(self, state: dict[str, limits.typing.Any]) -> None:  # type: ignore[explicit-any]
 | |
|         self.__dict__.update(state)
 | |
|         self.locks = defaultdict(threading.RLock)
 | |
|         self.timer = threading.Timer(0.01, self.__expire_events)
 | |
|         self.timer.start()
 | |
| 
 | |
|     def __expire_events(self) -> None:
 | |
|         for key in list(self.events.keys()):
 | |
|             with self.locks[key]:
 | |
|                 if events := self.events.get(key, []):
 | |
|                     oldest = bisect.bisect_left(
 | |
|                         events, -time.time(), key=lambda event: -event.expiry
 | |
|                     )
 | |
|                     self.events[key] = self.events[key][:oldest]
 | |
|                 if not self.events.get(key, None):
 | |
|                     self.locks.pop(key, None)
 | |
|         for key in list(self.expirations.keys()):
 | |
|             if self.expirations[key] <= time.time():
 | |
|                 self.storage.pop(key, None)
 | |
|                 self.expirations.pop(key, None)
 | |
|                 self.locks.pop(key, None)
 | |
| 
 | |
|     def __schedule_expiry(self) -> None:
 | |
|         if not self.timer.is_alive():
 | |
|             self.timer = threading.Timer(0.01, self.__expire_events)
 | |
|             self.timer.start()
 | |
| 
 | |
|     @property
 | |
|     def base_exceptions(
 | |
|         self,
 | |
|     ) -> type[Exception] | tuple[type[Exception], ...]:  # pragma: no cover
 | |
|         return ValueError
 | |
| 
 | |
|     def incr(self, key: str, expiry: float, amount: int = 1) -> int:
 | |
|         """
 | |
|         increments the counter for a given rate limit key
 | |
| 
 | |
|         :param key: the key to increment
 | |
|         :param expiry: amount in seconds for the key to expire in
 | |
|         :param amount: the number to increment by
 | |
|         """
 | |
|         self.get(key)
 | |
|         self.__schedule_expiry()
 | |
|         with self.locks[key]:
 | |
|             self.storage[key] += amount
 | |
|             if self.storage[key] == amount:
 | |
|                 self.expirations[key] = time.time() + expiry
 | |
|         return self.storage.get(key, 0)
 | |
| 
 | |
|     def decr(self, key: str, amount: int = 1) -> int:
 | |
|         """
 | |
|         decrements the counter for a given rate limit key
 | |
| 
 | |
|         :param key: the key to decrement
 | |
|         :param amount: the number to decrement by
 | |
|         """
 | |
|         self.get(key)
 | |
|         self.__schedule_expiry()
 | |
|         with self.locks[key]:
 | |
|             self.storage[key] = max(self.storage[key] - amount, 0)
 | |
| 
 | |
|         return self.storage.get(key, 0)
 | |
| 
 | |
|     def get(self, key: str) -> int:
 | |
|         """
 | |
|         :param key: the key to get the counter value for
 | |
|         """
 | |
| 
 | |
|         if self.expirations.get(key, 0) <= time.time():
 | |
|             self.storage.pop(key, None)
 | |
|             self.expirations.pop(key, None)
 | |
|             self.locks.pop(key, None)
 | |
| 
 | |
|         return self.storage.get(key, 0)
 | |
| 
 | |
|     def clear(self, key: str) -> None:
 | |
|         """
 | |
|         :param key: the key to clear rate limits for
 | |
|         """
 | |
|         self.storage.pop(key, None)
 | |
|         self.expirations.pop(key, None)
 | |
|         self.events.pop(key, None)
 | |
|         self.locks.pop(key, None)
 | |
| 
 | |
|     def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
 | |
|         """
 | |
|         :param key: rate limit key to acquire an entry in
 | |
|         :param limit: amount of entries allowed
 | |
|         :param expiry: expiry of the entry
 | |
|         :param amount: the number of entries to acquire
 | |
|         """
 | |
|         if amount > limit:
 | |
|             return False
 | |
| 
 | |
|         self.__schedule_expiry()
 | |
|         with self.locks[key]:
 | |
|             self.events.setdefault(key, [])
 | |
|             timestamp = time.time()
 | |
|             try:
 | |
|                 entry = self.events[key][limit - amount]
 | |
|             except IndexError:
 | |
|                 entry = None
 | |
| 
 | |
|             if entry and entry.atime >= timestamp - expiry:
 | |
|                 return False
 | |
|             else:
 | |
|                 self.events[key][:0] = [Entry(expiry)] * amount
 | |
|                 return True
 | |
| 
 | |
|     def get_expiry(self, key: str) -> float:
 | |
|         """
 | |
|         :param key: the key to get the expiry for
 | |
|         """
 | |
| 
 | |
|         return self.expirations.get(key, time.time())
 | |
| 
 | |
|     def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
 | |
|         """
 | |
|         returns the starting point and the number of entries in the moving
 | |
|         window
 | |
| 
 | |
|         :param key: rate limit key
 | |
|         :param expiry: expiry of entry
 | |
|         :return: (start of window, number of acquired entries)
 | |
|         """
 | |
|         timestamp = time.time()
 | |
|         if events := self.events.get(key, []):
 | |
|             oldest = bisect.bisect_left(
 | |
|                 events, -(timestamp - expiry), key=lambda entry: -entry.atime
 | |
|             )
 | |
|             return events[oldest - 1].atime, oldest
 | |
|         return timestamp, 0
 | |
| 
 | |
|     def acquire_sliding_window_entry(
 | |
|         self,
 | |
|         key: str,
 | |
|         limit: int,
 | |
|         expiry: int,
 | |
|         amount: int = 1,
 | |
|     ) -> bool:
 | |
|         if amount > limit:
 | |
|             return False
 | |
|         now = time.time()
 | |
|         previous_key, current_key = self.sliding_window_keys(key, expiry, now)
 | |
|         (
 | |
|             previous_count,
 | |
|             previous_ttl,
 | |
|             current_count,
 | |
|             _,
 | |
|         ) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
 | |
|         weighted_count = previous_count * previous_ttl / expiry + current_count
 | |
|         if floor(weighted_count) + amount > limit:
 | |
|             return False
 | |
|         else:
 | |
|             # Hit, increase the current counter.
 | |
|             # If the counter doesn't exist yet, set twice the theorical expiry.
 | |
|             current_count = self.incr(current_key, 2 * expiry, amount=amount)
 | |
|             weighted_count = previous_count * previous_ttl / expiry + current_count
 | |
|             if floor(weighted_count) > limit:
 | |
|                 # Another hit won the race condition: revert the incrementation and refuse this hit
 | |
|                 # Limitation: during high concurrency at the end of the window,
 | |
|                 # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
 | |
|                 self.decr(current_key, amount)
 | |
|                 return False
 | |
|             return True
 | |
| 
 | |
|     def _get_sliding_window_info(
 | |
|         self,
 | |
|         previous_key: str,
 | |
|         current_key: str,
 | |
|         expiry: int,
 | |
|         now: float,
 | |
|     ) -> tuple[int, float, int, float]:
 | |
|         previous_count = self.get(previous_key)
 | |
|         current_count = self.get(current_key)
 | |
|         if previous_count == 0:
 | |
|             previous_ttl = float(0)
 | |
|         else:
 | |
|             previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
 | |
|         current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
 | |
|         return previous_count, previous_ttl, current_count, current_ttl
 | |
| 
 | |
|     def get_sliding_window(
 | |
|         self, key: str, expiry: int
 | |
|     ) -> tuple[int, float, int, float]:
 | |
|         now = time.time()
 | |
|         previous_key, current_key = self.sliding_window_keys(key, expiry, now)
 | |
|         return self._get_sliding_window_info(previous_key, current_key, expiry, now)
 | |
| 
 | |
|     def check(self) -> bool:
 | |
|         """
 | |
|         check if storage is healthy
 | |
|         """
 | |
| 
 | |
|         return True
 | |
| 
 | |
|     def reset(self) -> int | None:
 | |
|         num_items = max(len(self.storage), len(self.events))
 | |
|         self.storage.clear()
 | |
|         self.expirations.clear()
 | |
|         self.events.clear()
 | |
|         self.locks.clear()
 | |
|         return num_items
 | 
