Added github integration
This commit is contained in:
35
buffteks/lib/python3.11/site-packages/limits/__init__.py
Normal file
35
buffteks/lib/python3.11/site-packages/limits/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Rate limiting with commonly used storage backends
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import _version, aio, storage, strategies
|
||||
from .limits import (
|
||||
RateLimitItem,
|
||||
RateLimitItemPerDay,
|
||||
RateLimitItemPerHour,
|
||||
RateLimitItemPerMinute,
|
||||
RateLimitItemPerMonth,
|
||||
RateLimitItemPerSecond,
|
||||
RateLimitItemPerYear,
|
||||
)
|
||||
from .util import WindowStats, parse, parse_many
|
||||
|
||||
__all__ = [
|
||||
"RateLimitItem",
|
||||
"RateLimitItemPerDay",
|
||||
"RateLimitItemPerHour",
|
||||
"RateLimitItemPerMinute",
|
||||
"RateLimitItemPerMonth",
|
||||
"RateLimitItemPerSecond",
|
||||
"RateLimitItemPerYear",
|
||||
"WindowStats",
|
||||
"aio",
|
||||
"parse",
|
||||
"parse_many",
|
||||
"storage",
|
||||
"strategies",
|
||||
]
|
||||
|
||||
__version__ = _version.__version__
|
||||
34
buffteks/lib/python3.11/site-packages/limits/_version.py
Normal file
34
buffteks/lib/python3.11/site-packages/limits/_version.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# file generated by setuptools-scm
|
||||
# don't change, don't track in version control
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__version_tuple__",
|
||||
"version",
|
||||
"version_tuple",
|
||||
"__commit_id__",
|
||||
"commit_id",
|
||||
]
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
COMMIT_ID = Union[str, None]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
COMMIT_ID = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '5.6.0'
|
||||
__version_tuple__ = version_tuple = (5, 6, 0)
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
@@ -0,0 +1 @@
|
||||
__version__: str
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import storage, strategies
|
||||
|
||||
__all__ = [
|
||||
"storage",
|
||||
"strategies",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.aio.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage
|
||||
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
]
|
||||
234
buffteks/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
234
buffteks/lib/python3.11/site-packages/limits/aio/storage/base.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, Awaitable[R]]:
|
||||
@functools.wraps(fn)
|
||||
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing an async storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
|
||||
super().__init_subclass__(**kwargs)
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for async storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_sliding_window_entry",
|
||||
"get_sliding_window",
|
||||
"clear_sliding_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
"""
|
||||
Resets the rate limit key(s) for the sliding window
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implemenations
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.memcached.bridge import MemcachedBridge
|
||||
from limits.aio.storage.memcached.emcache import EmcacheBridge
|
||||
from limits.aio.storage.memcached.memcachio import MemcachioBridge
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="5.0",
|
||||
reason="Switched default implementation to :pypi:`memcachio`",
|
||||
)
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`memcachio`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memcached"]
|
||||
"""The storage scheme for memcached to be used in an async context"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"memcachio": Version("0.3"),
|
||||
"emcache": Version("0.0"),
|
||||
}
|
||||
|
||||
bridge: MemcachedBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["memcachio", "emcache"] = "memcachio",
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``async+memcached://host:port,host:port``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``memcachio``: :class:`memcachio.Client`
|
||||
- ``emcache``: :class:`emcache.Client`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`memcachio.Client`
|
||||
:raise ConfigurationError: when :pypi:`memcachio` is not available
|
||||
"""
|
||||
if implementation == "emcache":
|
||||
self.bridge = EmcacheBridge(
|
||||
uri, self.dependencies["emcache"].module, **options
|
||||
)
|
||||
else:
|
||||
self.bridge = MemcachioBridge(
|
||||
uri, self.dependencies["memcachio"].module, **options
|
||||
)
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await self.bridge.clear(key)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
|
||||
"""
|
||||
return await self.bridge.incr(
|
||||
key, expiry, amount, set_expiration_key=set_expiration_key
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def check(self) -> bool:
|
||||
return await self.bridge.check()
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
t0 = time.time()
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = await self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
t1 = time.time()
|
||||
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the increment and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.bridge.decr(current_key, amount, noreply=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
await asyncio.gather(self.clear(previous_key), self.clear(current_key))
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = await self.bridge.get_many([previous_key, current_key])
|
||||
|
||||
previous_count = result.get(previous_key.encode("utf-8"), 0)
|
||||
current_count = result.get(current_key.encode("utf-8"), 0)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import Iterable
|
||||
|
||||
|
||||
class MemcachedBridge(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.hosts = []
|
||||
self.options = options
|
||||
|
||||
sep = self.parsed_uri.netloc.strip().find("@") + 1
|
||||
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
|
||||
if self.parsed_uri.username:
|
||||
self.options["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.options["password"] = self.parsed_uri.password
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
|
||||
from limits.typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import emcache
|
||||
|
||||
|
||||
class EmcacheBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage = None
|
||||
|
||||
async def get_storage(self) -> emcache.Client:
|
||||
if not self._storage:
|
||||
self._storage = await self.dependency.create_client(
|
||||
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
||||
return item and int(item.value) or 0
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
results = await (await self.get_storage()).get_many(
|
||||
[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(item.value) if item else 0 for k, item in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
try:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
except self.dependency.NotFoundCommandError:
|
||||
pass
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
try:
|
||||
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
|
||||
except self.dependency.NotFoundCommandError:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
try:
|
||||
return await storage.increment(limit_key, amount) or amount
|
||||
except self.dependency.NotFoundCommandError:
|
||||
storage = await self.get_storage()
|
||||
try:
|
||||
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
exptime=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
value = amount
|
||||
except self.dependency.NotStoredStorageCommandError:
|
||||
# Coult not add the key, probably because a concurrent call has added it
|
||||
storage = await self.get_storage()
|
||||
value = await storage.increment(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
item = await storage.get(self._expiration_key(key).encode("utf-8"))
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return (
|
||||
self.dependency.ClusterNoAvailableNodes,
|
||||
self.dependency.CommandError,
|
||||
)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from math import ceil
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from .bridge import MemcachedBridge
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import memcachio
|
||||
|
||||
|
||||
class MemcachioBridge(MemcachedBridge):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
super().__init__(uri, dependency, **options)
|
||||
self._storage: memcachio.Client[bytes] | None = None
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (
|
||||
self.dependency.errors.NoAvailableNodes,
|
||||
self.dependency.errors.MemcachioConnectionError,
|
||||
)
|
||||
|
||||
async def get_storage(self) -> memcachio.Client[bytes]:
|
||||
if not self._storage:
|
||||
self._storage = self.dependency.Client(
|
||||
[(h, p) for h, p in self.hosts],
|
||||
**self.options,
|
||||
)
|
||||
assert self._storage
|
||||
return self._storage
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
|
||||
|
||||
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
"""
|
||||
results = await (await self.get_storage()).get(
|
||||
*[k.encode("utf-8") for k in keys]
|
||||
)
|
||||
return {k: int(v.value) for k, v in results.items()}
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
await (await self.get_storage()).delete(key.encode("utf-8"))
|
||||
|
||||
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
return await storage.decr(limit_key, amount, noreply=noreply) or 0
|
||||
|
||||
async def incr(
|
||||
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
||||
) -> int:
|
||||
storage = await self.get_storage()
|
||||
limit_key = key.encode("utf-8")
|
||||
expire_key = self._expiration_key(key).encode()
|
||||
if (value := (await storage.incr(limit_key, amount))) is None:
|
||||
storage = await self.get_storage()
|
||||
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
|
||||
if set_expiration_key:
|
||||
await storage.set(
|
||||
expire_key,
|
||||
str(expiry + time.time()).encode("utf-8"),
|
||||
expiry=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
return amount
|
||||
else:
|
||||
storage = await self.get_storage()
|
||||
return await storage.incr(limit_key, amount) or amount
|
||||
return value
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
storage = await self.get_storage()
|
||||
expiration_key = self._expiration_key(key).encode("utf-8")
|
||||
item = (await storage.get(expiration_key)).get(expiration_key, None)
|
||||
|
||||
return item and float(item.value) or time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
storage = await self.get_storage()
|
||||
await storage.get(b"limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
@@ -0,0 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import bisect
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
import limits.typing
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.storage.base import TimestampedSlidingWindow
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: int) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed & sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+memory"]
|
||||
"""
|
||||
The storage scheme for in process memory storage for use in an
|
||||
async context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, uri: str | None = None, wrap_exceptions: bool = False, **_: str
|
||||
) -> None:
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: asyncio.Task[None] | None = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.timer = None
|
||||
self.locks = defaultdict(asyncio.Lock)
|
||||
asyncio.ensure_future(self.__schedule_expiry())
|
||||
|
||||
async def __expire_events(self) -> None:
|
||||
try:
|
||||
now = time.time()
|
||||
for key in list(self.events.keys()):
|
||||
async with self.locks[key]:
|
||||
cutoff = await asyncio.to_thread(
|
||||
lambda evts: bisect.bisect_left(
|
||||
evts, -now, key=lambda event: -event.expiry
|
||||
),
|
||||
self.events[key],
|
||||
)
|
||||
if self.events.get(key, []):
|
||||
self.events[key] = self.events[key][:cutoff]
|
||||
if not self.events.get(key, None):
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
async def __schedule_expiry(self) -> None:
|
||||
if not self.timer or self.timer.done():
|
||||
self.timer = asyncio.create_task(self.__expire_events())
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key. 0 is the minimum allowed value.
|
||||
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.get(key)
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
await self.__schedule_expiry()
|
||||
async with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry: Entry | None = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = await self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
await self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return await self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now
|
||||
)
|
||||
|
||||
async def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
await self.clear(current_key)
|
||||
await self.clear(previous_key)
|
||||
|
||||
async def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = await self.get(previous_key)
|
||||
current_count = await self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
|
||||
return num_items
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self.timer and not self.timer.done():
|
||||
self.timer.cancel()
|
||||
except RuntimeError: # noqa
|
||||
pass
|
||||
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.aio.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
)
|
||||
from limits.typing import (
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`motor`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
|
||||
"""
|
||||
The storage scheme for MongoDB for use in an async context
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
||||
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
|
||||
not available
|
||||
"""
|
||||
|
||||
uri = uri.replace("async+mongodb", "mongodb", 1)
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
|
||||
self.dependency = self.dependencies["motor.motor_asyncio"]
|
||||
self.proxy_dependency = self.dependencies["pymongo"]
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
|
||||
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
|
||||
# TODO: Fix this hack. It was noticed when running a benchmark
|
||||
# with FastAPI - however - doesn't appear in unit tests or in an isolated
|
||||
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
|
||||
self.storage.get_io_loop = asyncio.get_running_loop
|
||||
|
||||
self.__database_name = database_name
|
||||
self.__collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.__indices_created = False
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
@property
|
||||
def database(self): # type: ignore
|
||||
return self.storage.get_database(self.__database_name)
|
||||
|
||||
async def create_indices(self) -> None:
|
||||
if not self.__indices_created:
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].create_index(
|
||||
"expireAt", expireAfterSeconds=0
|
||||
),
|
||||
)
|
||||
self.__indices_created = True
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = sum(
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].count_documents(
|
||||
{}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].count_documents({}),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].drop(),
|
||||
self.database[self.__collection_mapping["windows"]].drop(),
|
||||
)
|
||||
|
||||
return cast(int, num_keys)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
await asyncio.gather(
|
||||
self.database[self.__collection_mapping["counters"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
self.database[self.__collection_mapping["windows"]].find_one_and_delete(
|
||||
{"_id": key}
|
||||
),
|
||||
)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{"_id": key}
|
||||
)
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = await self.database[self.__collection_mapping["counters"]].find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
response = await self.database[
|
||||
self.__collection_mapping["counters"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
return int(response["count"])
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling
|
||||
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
|
||||
"""
|
||||
try:
|
||||
await self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param str key: rate limit key
|
||||
:param int expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
|
||||
timestamp = time.time()
|
||||
if (
|
||||
result := await self.database[self.__collection_mapping["windows"]]
|
||||
.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
.to_list(length=1)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
await self.create_indices()
|
||||
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
await self.database[self.__collection_mapping["windows"]].update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.proxy_dependency.module.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
await self.create_indices()
|
||||
expiry_ms = expiry * 1000
|
||||
result = await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := await self.database[
|
||||
self.__collection_mapping["windows"]
|
||||
].find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
return await self.clear(key)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.storage and self.storage.close()
|
||||
@@ -0,0 +1,423 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.aio.storage.redis.coredis import CoredisBridge
|
||||
from limits.aio.storage.redis.redispy import RedispyBridge
|
||||
from limits.aio.storage.redis.valkey import ValkeyBridge
|
||||
from limits.typing import Literal
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`redis`"
|
||||
" through :paramref:`implementation`"
|
||||
),
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis",
|
||||
"async+rediss",
|
||||
"async+redis+unix",
|
||||
"async+valkey",
|
||||
"async+valkeys",
|
||||
"async+valkey+unix",
|
||||
]
|
||||
"""
|
||||
The storage schemes for redis to be used in an async context
|
||||
"""
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
|
||||
PREFIX = "LIMITS"
|
||||
|
||||
bridge: RedisBridge
|
||||
storage_exceptions: tuple[Exception, ...]
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
key_prefix: str = PREFIX,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form:
|
||||
|
||||
- ``async+redis://[:password]@host:port``
|
||||
- ``async+redis://[:password]@host:port/db``
|
||||
- ``async+rediss://[:password]@host:port``
|
||||
- ``async+redis+unix:///path/to/sock?db=0`` etc...
|
||||
|
||||
This uri is passed directly to :meth:`coredis.Redis.from_url` or
|
||||
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
|
||||
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
|
||||
|
||||
If the uri scheme is ``async+valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.Redis`
|
||||
- ``redispy``: :class:`redis.asyncio.client.Redis`
|
||||
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
|
||||
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
"""
|
||||
uri = uri.removeprefix("async+")
|
||||
self.target_server = "redis" if uri.startswith("redis") else "valkey"
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
self.options = options
|
||||
if self.target_server == "valkey" or implementation == "valkey":
|
||||
self.bridge = ValkeyBridge(
|
||||
uri, self.dependencies["valkey"].module, key_prefix
|
||||
)
|
||||
else:
|
||||
if implementation == "redispy":
|
||||
self.bridge = RedispyBridge(
|
||||
uri, self.dependencies["redis"].module, key_prefix
|
||||
)
|
||||
else:
|
||||
self.bridge = CoredisBridge(
|
||||
uri, self.dependencies["coredis"].module, key_prefix
|
||||
)
|
||||
self.configure_bridge()
|
||||
self.bridge.register_scripts()
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_basic(**self.options)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.bridge.base_exceptions
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
|
||||
return await self.bridge.incr(key, expiry, amount)
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
return await self.bridge.get(key)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
|
||||
return await self.bridge.clear(key)
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
|
||||
return await self.bridge.acquire_entry(key, limit, expiry, amount)
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
return await self.bridge.get_moving_window(key, limit, expiry)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
current_key = self._current_window_key(key)
|
||||
previous_key = self._previous_window_key(key)
|
||||
return await self.bridge.acquire_sliding_window_entry(
|
||||
previous_key, current_key, limit, expiry, amount
|
||||
)
|
||||
|
||||
async def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
|
||||
|
||||
async def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
await asyncio.gather(self.clear(previous_key), self.clear(current_key))
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return await self.bridge.get_expiry(key)
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling ``PING``
|
||||
"""
|
||||
|
||||
return await self.bridge.check()
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
:paramref:`RedisStorage.key_prefix` in blocks of 5000.
|
||||
|
||||
.. warning:: This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
"""
|
||||
|
||||
return await self.bridge.lua_reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+cluster`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
|
||||
"""
|
||||
The storage schemes for redis cluster to be used in an async context
|
||||
"""
|
||||
|
||||
MODE = "CLUSTER"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
key_prefix: str = RedisStorage.PREFIX,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.RedisCluster`
|
||||
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
|
||||
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.RedisCluster` or
|
||||
:class:`redis.asyncio.RedisCluster`
|
||||
:raise ConfigurationError: when the redis library is not
|
||||
available or if the redis host cannot be pinged.
|
||||
"""
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
key_prefix=key_prefix,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_cluster(**self.options)
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with :paramref:`RedisClusterStorage.key_prefix`
|
||||
and calls delete on them one at a time.
|
||||
|
||||
.. warning:: This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets
|
||||
"""
|
||||
|
||||
return await self.bridge.reset()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="4.2",
|
||||
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the asyncio redis client from :pypi:`valkey`"
|
||||
" through :paramref:`implementation` or if :paramref:`uri` has the"
|
||||
" ``async+valkey+sentinel`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`coredis` or :pypi:`redis`
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"async+redis+sentinel",
|
||||
"async+valkey+sentinel",
|
||||
]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
MODE = "SENTINEL"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("5.2.0"),
|
||||
"coredis": Version("3.4.0"),
|
||||
"coredis.sentinel": Version("3.4.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
|
||||
key_prefix: str = RedisStorage.PREFIX,
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``async+redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param implementation: Whether to use the client implementation from
|
||||
|
||||
- ``coredis``: :class:`coredis.sentinel.Sentinel`
|
||||
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
|
||||
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param service_name: sentinel service name (if not provided in `uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: optional arguments to pass as
|
||||
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.Sentinel`
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
|
||||
:class:`redis.asyncio.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis primary host cannot be pinged.
|
||||
"""
|
||||
|
||||
self.service_name = service_name
|
||||
self.use_replicas = use_replicas
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
super().__init__(
|
||||
uri,
|
||||
wrap_exceptions=wrap_exceptions,
|
||||
implementation=implementation,
|
||||
key_prefix=key_prefix,
|
||||
**options,
|
||||
)
|
||||
|
||||
def configure_bridge(self) -> None:
|
||||
self.bridge.use_sentinel(
|
||||
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
|
||||
)
|
||||
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from types import ModuleType
|
||||
|
||||
from limits.util import get_package_data
|
||||
|
||||
|
||||
class RedisBridge(ABC):
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
dependency: ModuleType,
|
||||
key_prefix: str,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
||||
self.dependency = dependency
|
||||
self.parsed_auth = {}
|
||||
self.key_prefix = key_prefix
|
||||
if self.parsed_uri.username:
|
||||
self.parsed_auth["username"] = self.parsed_uri.username
|
||||
if self.parsed_uri.password:
|
||||
self.parsed_auth["password"] = self.parsed_uri.password
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.key_prefix}:{key}"
|
||||
|
||||
@abstractmethod
|
||||
def register_scripts(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_basic(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def use_cluster(self, **options: str | float | bool) -> None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, key: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_expiry(self, key: str) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
async def check(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> int | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def lua_reset(self) -> int | None: ...
|
||||
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncCoRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
|
||||
|
||||
class CoredisBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`coredis.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.exceptions.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.sentinel.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.primary_for(service_name)
|
||||
self.storage_replica = self.sentinel.replica_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts: list[dict[str, int | str]] = []
|
||||
cluster_hosts.extend(
|
||||
{"host": host, "port": int(port)}
|
||||
for loc in self.parsed_uri.netloc[sep:].split(",")
|
||||
if loc
|
||||
for host, port in [loc.split(":")]
|
||||
)
|
||||
self.storage = self.dependency.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_moving_window: coredis.commands.Script[bytes]
|
||||
lua_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_acquire_sliding_window: coredis.commands.Script[bytes]
|
||||
lua_clear_keys: coredis.commands.Script[bytes]
|
||||
lua_incr_expire: coredis.commands.Script[bytes]
|
||||
connection_getter: Callable[[bool], AsyncCoRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
if (value := await self.get_connection().incrby(key, amount)) == amount:
|
||||
await self.get_connection().expire(key, expiry)
|
||||
return value
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete([key])
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window.execute(
|
||||
[key], [timestamp - expiry, limit]
|
||||
)
|
||||
if window:
|
||||
return float(window[0]), window[1] # type: ignore
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
|
||||
if window := await self.lua_sliding_window.execute(
|
||||
[previous_key, current_key], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0), # type: ignore
|
||||
max(0, float(window[1] or 0)) / 1000, # type: ignore
|
||||
int(window[2] or 0), # type: ignore
|
||||
max(0, float(window[3] or 0)) / 1000, # type: ignore
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window.execute(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window.execute(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(prefix)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete([key])
|
||||
return count
|
||||
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from limits.aio.storage.redis.bridge import RedisBridge
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.typing import AsyncRedisClient, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis.commands
|
||||
|
||||
|
||||
class RedispyBridge(RedisBridge):
|
||||
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to :class:`redis.asyncio.RedisCluster`"
|
||||
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.RedisError,)
|
||||
|
||||
def use_sentinel(
|
||||
self,
|
||||
service_name: str | None,
|
||||
use_replicas: bool,
|
||||
sentinel_kwargs: dict[str, str | float | bool] | None,
|
||||
**options: str | float | bool,
|
||||
) -> None:
|
||||
sentinel_configuration = []
|
||||
|
||||
connection_options = options.copy()
|
||||
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
service_name = (
|
||||
self.parsed_uri.path.replace("/", "")
|
||||
if self.parsed_uri.path
|
||||
else service_name
|
||||
)
|
||||
|
||||
if service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.sentinel = self.dependency.asyncio.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
||||
**{**self.parsed_auth, **connection_options},
|
||||
)
|
||||
self.storage = self.sentinel.master_for(service_name)
|
||||
self.storage_replica = self.sentinel.slave_for(service_name)
|
||||
self.connection_getter = lambda readonly: (
|
||||
self.storage_replica if readonly and use_replicas else self.storage
|
||||
)
|
||||
|
||||
def use_basic(self, **options: str | float | bool) -> None:
|
||||
if connection_pool := options.pop("connection_pool", None):
|
||||
self.storage = self.dependency.asyncio.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
|
||||
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
def use_cluster(self, **options: str | float | bool) -> None:
|
||||
sep = self.parsed_uri.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
|
||||
for loc in self.parsed_uri.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append(
|
||||
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
|
||||
)
|
||||
|
||||
self.storage = self.dependency.asyncio.RedisCluster(
|
||||
startup_nodes=cluster_hosts,
|
||||
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
||||
)
|
||||
self.connection_getter = lambda _: self.storage
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
lua_clear_keys: redis.commands.core.Script
|
||||
lua_incr_expire: redis.commands.core.Script
|
||||
connection_getter: Callable[[bool], AsyncRedisClient]
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
|
||||
return self.connection_getter(readonly)
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
# Redis-py uses a slightly different script registration
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
async def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
async def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(await self.get_connection(readonly=True).get(key) or 0)
|
||||
|
||||
async def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
await self.get_connection().delete(key)
|
||||
|
||||
async def lua_reset(self) -> int | None:
|
||||
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
|
||||
|
||||
async def get_moving_window(
|
||||
self, key: str, limit: int, expiry: int
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (previous count, previous TTL, current count, current TTL)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
|
||||
if window:
|
||||
return float(window[0]), window[1]
|
||||
return timestamp, 0
|
||||
|
||||
async def get_sliding_window(
|
||||
self, previous_key: str, current_key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
if window := await self.lua_sliding_window(
|
||||
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
|
||||
):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
async def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = await self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
async def acquire_sliding_window_entry(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
previous_key = self.prefixed_key(previous_key)
|
||||
current_key = self.prefixed_key(current_key)
|
||||
acquired = await self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
async def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(await self.get_connection().ttl(key), 0) + time.time()
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
await self.get_connection().ping()
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
async def reset(self) -> int | None:
|
||||
prefix = self.prefixed_key("*")
|
||||
keys = await self.storage.keys(
|
||||
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
|
||||
)
|
||||
count = 0
|
||||
for key in keys:
|
||||
count += await self.storage.delete(key)
|
||||
return count
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .redispy import RedispyBridge
|
||||
|
||||
|
||||
class ValkeyBridge(RedispyBridge):
|
||||
@property
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
return (self.dependency.ValkeyError,)
|
||||
331
buffteks/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
331
buffteks/lib/python3.11/site-packages/limits/aio/strategies.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Asynchronous rate limiting strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from math import floor, inf
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from ..limits import RateLimitItem
|
||||
from ..storage import StorageTypes
|
||||
from ..typing import cast
|
||||
from ..util import WindowStats
|
||||
from .storage import MovingWindowSupport, Storage
|
||||
from .storage.base import SlidingWindowCounterSupport
|
||||
|
||||
|
||||
class RateLimiter(ABC):
|
||||
def __init__(self, storage: StorageTypes):
|
||||
assert isinstance(storage, Storage)
|
||||
self.storage: Storage = storage
|
||||
|
||||
@abstractmethod
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining))
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return await self.storage.clear(item.key_for(*identifiers))
|
||||
|
||||
|
||||
class MovingWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:moving window`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes) -> None:
|
||||
if not (
|
||||
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"MovingWindowRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
|
||||
return await cast(MovingWindowSupport, self.storage).acquire_entry(
|
||||
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
)
|
||||
amount = res[1]
|
||||
|
||||
return amount <= item.amount - cost
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
returns the number of requests remaining within this limit.
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
window_start, window_items = await cast(
|
||||
MovingWindowSupport, self.storage
|
||||
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
|
||||
reset = window_start + item.get_expiry()
|
||||
|
||||
return WindowStats(reset, item.amount - window_items)
|
||||
|
||||
|
||||
class FixedWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:fixed window`
|
||||
"""
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.incr(
|
||||
item.key_for(*identifiers),
|
||||
item.get_expiry(),
|
||||
amount=cost,
|
||||
)
|
||||
<= item.amount
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
|
||||
return (
|
||||
await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: the rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify the
|
||||
limit
|
||||
:return: reset time, remaining
|
||||
"""
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount - await self.storage.get(item.key_for(*identifiers)),
|
||||
)
|
||||
reset = await self.storage.get_expiry(item.key_for(*identifiers))
|
||||
|
||||
return WindowStats(reset, remaining)
|
||||
|
||||
|
||||
@versionadded(version="4.1")
|
||||
class SlidingWindowCounterRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:sliding window counter`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
||||
storage, "acquire_sliding_window_entry"
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def _weighted_count(
|
||||
self,
|
||||
item: RateLimitItem,
|
||||
previous_count: int,
|
||||
previous_expires_in: float,
|
||||
current_count: int,
|
||||
) -> float:
|
||||
"""
|
||||
Return the approximated by weighting the previous window count and adding the current window count.
|
||||
"""
|
||||
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
||||
|
||||
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
return await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).acquire_sliding_window_entry(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
cost,
|
||||
)
|
||||
|
||||
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
|
||||
previous_count, previous_expires_in, current_count, _ = await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
return (
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
< item.amount - cost + 1
|
||||
)
|
||||
|
||||
async def get_window_stats(
|
||||
self, item: RateLimitItem, *identifiers: str
|
||||
) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
|
||||
(
|
||||
previous_count,
|
||||
previous_expires_in,
|
||||
current_count,
|
||||
current_expires_in,
|
||||
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
|
||||
item.key_for(*identifiers), item.get_expiry()
|
||||
)
|
||||
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount
|
||||
- floor(
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if not (previous_count or current_count):
|
||||
return WindowStats(now, remaining)
|
||||
|
||||
expiry = item.get_expiry()
|
||||
|
||||
previous_reset_in, current_reset_in = inf, inf
|
||||
if previous_count:
|
||||
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
||||
if current_count:
|
||||
current_reset_in = current_expires_in % expiry
|
||||
|
||||
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
||||
|
||||
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return await cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).clear_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
|
||||
STRATEGIES = {
|
||||
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
||||
"fixed-window": FixedWindowRateLimiter,
|
||||
"moving-window": MovingWindowRateLimiter,
|
||||
}
|
||||
30
buffteks/lib/python3.11/site-packages/limits/errors.py
Normal file
30
buffteks/lib/python3.11/site-packages/limits/errors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
errors and exceptions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""
|
||||
Error raised when a configuration problem is encountered
|
||||
"""
|
||||
|
||||
|
||||
class ConcurrentUpdateError(Exception):
|
||||
"""
|
||||
Error raised when an update to limit fails due to concurrent
|
||||
updates
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, attempts: int) -> None:
|
||||
super().__init__(f"Unable to update {key} after {attempts} retries")
|
||||
|
||||
|
||||
class StorageError(Exception):
|
||||
"""
|
||||
Error raised when an error is encountered in a storage
|
||||
"""
|
||||
|
||||
def __init__(self, storage_error: Exception) -> None:
|
||||
self.storage_error = storage_error
|
||||
196
buffteks/lib/python3.11/site-packages/limits/limits.py
Normal file
196
buffteks/lib/python3.11/site-packages/limits/limits.py
Normal file
@@ -0,0 +1,196 @@
|
||||
""" """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import total_ordering
|
||||
|
||||
from limits.typing import ClassVar, NamedTuple, cast
|
||||
|
||||
|
||||
def safe_string(value: bytes | str | int | float) -> str:
|
||||
"""
|
||||
normalize a byte/str/int or float to a str
|
||||
"""
|
||||
|
||||
if isinstance(value, bytes):
|
||||
return value.decode()
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
class Granularity(NamedTuple):
|
||||
seconds: int
|
||||
name: str
|
||||
|
||||
|
||||
TIME_TYPES = dict(
|
||||
day=Granularity(60 * 60 * 24, "day"),
|
||||
month=Granularity(60 * 60 * 24 * 30, "month"),
|
||||
year=Granularity(60 * 60 * 24 * 30 * 12, "year"),
|
||||
hour=Granularity(60 * 60, "hour"),
|
||||
minute=Granularity(60, "minute"),
|
||||
second=Granularity(1, "second"),
|
||||
)
|
||||
|
||||
GRANULARITIES: dict[str, type[RateLimitItem]] = {}
|
||||
|
||||
|
||||
class RateLimitItemMeta(type):
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
parents: tuple[type, ...],
|
||||
dct: dict[str, Granularity | list[str]],
|
||||
) -> RateLimitItemMeta:
|
||||
if "__slots__" not in dct:
|
||||
dct["__slots__"] = []
|
||||
granularity = super().__new__(cls, name, parents, dct)
|
||||
|
||||
if "GRANULARITY" in dct:
|
||||
GRANULARITIES[dct["GRANULARITY"][1]] = cast(
|
||||
type[RateLimitItem], granularity
|
||||
)
|
||||
|
||||
return granularity
|
||||
|
||||
|
||||
# pylint: disable=no-member
|
||||
@total_ordering
|
||||
class RateLimitItem(metaclass=RateLimitItemMeta):
|
||||
"""
|
||||
defines a Rate limited resource which contains the characteristic
|
||||
namespace, amount and granularity multiples of the rate limiting window.
|
||||
|
||||
:param amount: the rate limit amount
|
||||
:param multiples: multiple of the 'per' :attr:`GRANULARITY`
|
||||
(e.g. 'n' per 'm' seconds)
|
||||
:param namespace: category for the specific rate limit
|
||||
"""
|
||||
|
||||
__slots__ = ["namespace", "amount", "multiples"]
|
||||
|
||||
GRANULARITY: ClassVar[Granularity]
|
||||
"""
|
||||
A tuple describing the granularity of this limit as
|
||||
(number of seconds, name)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, amount: int, multiples: int | None = 1, namespace: str = "LIMITER"
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.amount = int(amount)
|
||||
self.multiples = int(multiples or 1)
|
||||
|
||||
@classmethod
|
||||
def check_granularity_string(cls, granularity_string: str) -> bool:
|
||||
"""
|
||||
Checks if this instance matches a *granularity_string*
|
||||
of type ``n per hour``, ``n per minute`` etc,
|
||||
by comparing with :attr:`GRANULARITY`
|
||||
|
||||
"""
|
||||
|
||||
return granularity_string.lower() in {
|
||||
cls.GRANULARITY.name,
|
||||
f"{cls.GRANULARITY.name}s", # allow plurals like days, hours etc.
|
||||
}
|
||||
|
||||
def get_expiry(self) -> int:
|
||||
"""
|
||||
:return: the duration the limit is enforced for in seconds.
|
||||
"""
|
||||
|
||||
return self.GRANULARITY.seconds * self.multiples
|
||||
|
||||
def key_for(self, *identifiers: bytes | str | int | float) -> str:
|
||||
"""
|
||||
Constructs a key for the current limit and any additional
|
||||
identifiers provided.
|
||||
|
||||
:param identifiers: a list of strings to append to the key
|
||||
:return: a string key identifying this resource with
|
||||
each identifier separated with a '/' delimiter.
|
||||
"""
|
||||
remainder = "/".join(
|
||||
[safe_string(k) for k in identifiers]
|
||||
+ [
|
||||
safe_string(self.amount),
|
||||
safe_string(self.multiples),
|
||||
self.GRANULARITY.name,
|
||||
]
|
||||
)
|
||||
|
||||
return f"{self.namespace}/{remainder}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, RateLimitItem):
|
||||
return (
|
||||
self.amount == other.amount
|
||||
and self.GRANULARITY == other.GRANULARITY
|
||||
and self.multiples == other.multiples
|
||||
)
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.amount} per {self.multiples} {self.GRANULARITY.name}"
|
||||
|
||||
def __lt__(self, other: RateLimitItem) -> bool:
|
||||
return self.GRANULARITY.seconds < other.GRANULARITY.seconds
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.namespace, self.amount, self.multiples, self.GRANULARITY))
|
||||
|
||||
|
||||
class RateLimitItemPerYear(RateLimitItem):
|
||||
"""
|
||||
per year rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["year"]
|
||||
"""A year"""
|
||||
|
||||
|
||||
class RateLimitItemPerMonth(RateLimitItem):
|
||||
"""
|
||||
per month rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["month"]
|
||||
"""A month"""
|
||||
|
||||
|
||||
class RateLimitItemPerDay(RateLimitItem):
|
||||
"""
|
||||
per day rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["day"]
|
||||
"""A day"""
|
||||
|
||||
|
||||
class RateLimitItemPerHour(RateLimitItem):
|
||||
"""
|
||||
per hour rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["hour"]
|
||||
"""An hour"""
|
||||
|
||||
|
||||
class RateLimitItemPerMinute(RateLimitItem):
|
||||
"""
|
||||
per minute rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["minute"]
|
||||
"""A minute"""
|
||||
|
||||
|
||||
class RateLimitItemPerSecond(RateLimitItem):
|
||||
"""
|
||||
per second rate limited resource.
|
||||
"""
|
||||
|
||||
GRANULARITY = TIME_TYPES["second"]
|
||||
"""A second"""
|
||||
@@ -0,0 +1,26 @@
|
||||
local timestamp = tonumber(ARGV[1])
|
||||
local limit = tonumber(ARGV[2])
|
||||
local expiry = tonumber(ARGV[3])
|
||||
local amount = tonumber(ARGV[4])
|
||||
|
||||
if amount > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
local entry = redis.call('lindex', KEYS[1], limit - amount)
|
||||
|
||||
if entry and tonumber(entry) >= timestamp - expiry then
|
||||
return false
|
||||
end
|
||||
local entries = {}
|
||||
for i = 1, amount do
|
||||
entries[i] = timestamp
|
||||
end
|
||||
|
||||
for i=1,#entries,5000 do
|
||||
redis.call('lpush', KEYS[1], unpack(entries, i, math.min(i+4999, #entries)))
|
||||
end
|
||||
redis.call('ltrim', KEYS[1], 0, limit - 1)
|
||||
redis.call('expire', KEYS[1], expiry)
|
||||
|
||||
return true
|
||||
@@ -0,0 +1,45 @@
|
||||
-- Time is in milliseconds in this script: TTL, expiry...
|
||||
|
||||
local limit = tonumber(ARGV[1])
|
||||
local expiry = tonumber(ARGV[2]) * 1000
|
||||
local amount = tonumber(ARGV[3])
|
||||
|
||||
if amount > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
local current_ttl = tonumber(redis.call('pttl', KEYS[2]))
|
||||
|
||||
if current_ttl > 0 and current_ttl < expiry then
|
||||
-- Current window expired, shift it to the previous window
|
||||
redis.call('rename', KEYS[2], KEYS[1])
|
||||
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
||||
end
|
||||
|
||||
local previous_count = tonumber(redis.call('get', KEYS[1])) or 0
|
||||
local previous_ttl = tonumber(redis.call('pttl', KEYS[1])) or 0
|
||||
local current_count = tonumber(redis.call('get', KEYS[2])) or 0
|
||||
current_ttl = tonumber(redis.call('pttl', KEYS[2])) or 0
|
||||
|
||||
-- If the values don't exist yet, consider the TTL is 0
|
||||
if previous_ttl <= 0 then
|
||||
previous_ttl = 0
|
||||
end
|
||||
if current_ttl <= 0 then
|
||||
current_ttl = 0
|
||||
end
|
||||
local weighted_count = math.floor(previous_count * previous_ttl / expiry) + current_count
|
||||
|
||||
if (weighted_count + amount) > limit then
|
||||
return false
|
||||
end
|
||||
|
||||
-- If the current counter exists, increase its value
|
||||
if redis.call('exists', KEYS[2]) == 1 then
|
||||
redis.call('incrby', KEYS[2], amount)
|
||||
else
|
||||
-- Otherwise, set the value with twice the expiry time
|
||||
redis.call('set', KEYS[2], amount, 'PX', expiry * 2)
|
||||
end
|
||||
|
||||
return true
|
||||
@@ -0,0 +1,10 @@
|
||||
local keys = redis.call('keys', KEYS[1])
|
||||
local res = 0
|
||||
|
||||
for i=1,#keys,5000 do
|
||||
res = res + redis.call(
|
||||
'del', unpack(keys, i, math.min(i+4999, #keys))
|
||||
)
|
||||
end
|
||||
|
||||
return res
|
||||
@@ -0,0 +1,9 @@
|
||||
local current
|
||||
local amount = tonumber(ARGV[2])
|
||||
current = redis.call("incrby", KEYS[1], amount)
|
||||
|
||||
if tonumber(current) == amount then
|
||||
redis.call("expire", KEYS[1], ARGV[1])
|
||||
end
|
||||
|
||||
return current
|
||||
@@ -0,0 +1,30 @@
|
||||
local len = tonumber(ARGV[2])
|
||||
local expiry = tonumber(ARGV[1])
|
||||
|
||||
-- Binary search to find the oldest valid entry in the window
|
||||
local function oldest_entry(high, target)
|
||||
local low = 0
|
||||
local result = nil
|
||||
|
||||
while low <= high do
|
||||
local mid = math.floor((low + high) / 2)
|
||||
local val = tonumber(redis.call('lindex', KEYS[1], mid))
|
||||
|
||||
if val and val >= target then
|
||||
result = mid
|
||||
low = mid + 1
|
||||
else
|
||||
high = mid - 1
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
local index = oldest_entry(len - 1, expiry)
|
||||
|
||||
if index then
|
||||
local count = index + 1
|
||||
local oldest = tonumber(redis.call('lindex', KEYS[1], index))
|
||||
return {tostring(oldest), count}
|
||||
end
|
||||
@@ -0,0 +1,17 @@
|
||||
local expiry = tonumber(ARGV[1]) * 1000
|
||||
local previous_count = redis.call('get', KEYS[1])
|
||||
local previous_ttl = redis.call('pttl', KEYS[1])
|
||||
local current_count = redis.call('get', KEYS[2])
|
||||
local current_ttl = redis.call('pttl', KEYS[2])
|
||||
|
||||
if current_ttl > 0 and current_ttl < expiry then
|
||||
-- Current window expired, shift it to the previous window
|
||||
redis.call('rename', KEYS[2], KEYS[1])
|
||||
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
||||
previous_count = redis.call('get', KEYS[1])
|
||||
previous_ttl = redis.call('pttl', KEYS[1])
|
||||
current_count = redis.call('get', KEYS[2])
|
||||
current_ttl = redis.call('pttl', KEYS[2])
|
||||
end
|
||||
|
||||
return {previous_count, previous_ttl, current_count, current_ttl}
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Implementations of storage backends to be used with
|
||||
:class:`limits.strategies.RateLimiter` strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
import limits # noqa
|
||||
|
||||
from ..errors import ConfigurationError
|
||||
from ..typing import TypeAlias, cast
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
from .memcached import MemcachedStorage
|
||||
from .memory import MemoryStorage
|
||||
from .mongodb import MongoDBStorage, MongoDBStorageBase
|
||||
from .redis import RedisStorage
|
||||
from .redis_cluster import RedisClusterStorage
|
||||
from .redis_sentinel import RedisSentinelStorage
|
||||
from .registry import SCHEMES
|
||||
|
||||
StorageTypes: TypeAlias = "Storage | limits.aio.storage.Storage"
|
||||
|
||||
|
||||
def storage_from_string(
|
||||
storage_string: str, **options: float | str | bool
|
||||
) -> StorageTypes:
|
||||
"""
|
||||
Factory function to get an instance of the storage class based
|
||||
on the uri of the storage. In most cases using it should be sufficient
|
||||
instead of directly instantiating the storage classes. for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("memory://")
|
||||
memcached = storage_from_string("memcached://localhost:11211")
|
||||
redis = storage_from_string("redis://localhost:6379")
|
||||
|
||||
The same function can be used to construct the :ref:`storage:async storage`
|
||||
variants, for example::
|
||||
|
||||
from limits.storage import storage_from_string
|
||||
|
||||
memory = storage_from_string("async+memory://")
|
||||
memcached = storage_from_string("async+memcached://localhost:11211")
|
||||
redis = storage_from_string("async+redis://localhost:6379")
|
||||
|
||||
:param storage_string: a string of the form ``scheme://host:port``.
|
||||
More details about supported storage schemes can be found at
|
||||
:ref:`storage:storage scheme`
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor matched by :paramref:`storage_string`.
|
||||
:raises ConfigurationError: when the :attr:`storage_string` cannot be
|
||||
mapped to a registered :class:`limits.storage.Storage`
|
||||
or :class:`limits.aio.storage.Storage` instance.
|
||||
|
||||
|
||||
"""
|
||||
scheme = urllib.parse.urlparse(storage_string).scheme
|
||||
|
||||
if scheme not in SCHEMES:
|
||||
raise ConfigurationError(f"unknown storage scheme : {storage_string}")
|
||||
|
||||
return cast(StorageTypes, SCHEMES[scheme](storage_string, **options))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemcachedStorage",
|
||||
"MemoryStorage",
|
||||
"MongoDBStorage",
|
||||
"MongoDBStorageBase",
|
||||
"MovingWindowSupport",
|
||||
"RedisClusterStorage",
|
||||
"RedisSentinelStorage",
|
||||
"RedisStorage",
|
||||
"SlidingWindowCounterSupport",
|
||||
"Storage",
|
||||
"storage_from_string",
|
||||
]
|
||||
246
buffteks/lib/python3.11/site-packages/limits/storage/base.py
Normal file
246
buffteks/lib/python3.11/site-packages/limits/storage/base.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from limits import errors
|
||||
from limits.storage.registry import StorageRegistry
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import LazyDependency
|
||||
|
||||
|
||||
def _wrap_errors(
|
||||
fn: Callable[P, R],
|
||||
) -> Callable[P, R]:
|
||||
@functools.wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
instance = cast(Storage, args[0])
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except instance.base_exceptions as exc:
|
||||
if instance.wrap_exceptions:
|
||||
raise errors.StorageError(exc) from exc
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class Storage(LazyDependency, metaclass=StorageRegistry):
|
||||
"""
|
||||
Base class to extend when implementing a storage backend.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME: list[str] | None
|
||||
"""The storage schemes to register against this implementation"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"incr",
|
||||
"get",
|
||||
"get_expiry",
|
||||
"check",
|
||||
"reset",
|
||||
"clear",
|
||||
}:
|
||||
setattr(cls, method, _wrap_errors(getattr(cls, method)))
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
):
|
||||
"""
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.wrap_exceptions = wrap_exceptions
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
reset storage to clear limits
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
resets the rate limit key
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MovingWindowSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:moving window` strategy
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_entry",
|
||||
"get_moving_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlidingWindowCounterSupport(ABC):
|
||||
"""
|
||||
Abstract base class for storages that support
|
||||
the :ref:`strategies:sliding window counter` strategy.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
|
||||
for method in {
|
||||
"acquire_sliding_window_entry",
|
||||
"get_sliding_window",
|
||||
"clear_sliding_window",
|
||||
}:
|
||||
setattr(
|
||||
cls,
|
||||
method,
|
||||
_wrap_errors(getattr(cls, method)),
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry if the weighted count of the current and previous
|
||||
windows is less than or equal to the limit
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
"""
|
||||
Return the previous and current window information.
|
||||
|
||||
:param key: the rate limit key
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implementations
|
||||
:return: a tuple of (int, float, int, float) with the following information:
|
||||
- previous window counter
|
||||
- previous window TTL
|
||||
- current window counter
|
||||
- current window TTL
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
"""
|
||||
Resets the rate limit key(s) for the sliding window
|
||||
|
||||
:param key: the key to clear rate limits for
|
||||
:param expiry: the rate limit expiry, needed to compute the key in some implemenations
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TimestampedSlidingWindow:
|
||||
"""Helper class for storage that support the sliding window counter, with timestamp based keys."""
|
||||
|
||||
@classmethod
|
||||
def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
|
||||
"""
|
||||
returns the previous and the current window's keys.
|
||||
|
||||
:param key: the key to get the window's keys from
|
||||
:param expiry: the expiry of the limit item, in seconds
|
||||
:param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
|
||||
|
||||
Returns a tuple with the previous and the current key: (previous, current).
|
||||
|
||||
Example:
|
||||
- key = "mykey"
|
||||
- expiry = 60
|
||||
- at = 1738576292.6631825
|
||||
|
||||
The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
|
||||
"""
|
||||
return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"
|
||||
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable
|
||||
from math import ceil, floor
|
||||
from types import ModuleType
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.base import (
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
from limits.typing import (
|
||||
Any,
|
||||
Callable,
|
||||
MemcachedClientP,
|
||||
P,
|
||||
R,
|
||||
cast,
|
||||
)
|
||||
from limits.util import get_dependency
|
||||
|
||||
|
||||
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
||||
"""
|
||||
Rate limit storage with memcached as backend.
|
||||
|
||||
Depends on :pypi:`pymemcache`.
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memcached"]
|
||||
"""The storage scheme for memcached"""
|
||||
DEPENDENCIES = ["pymemcache"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: str | Callable[[], MemcachedClientP],
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: memcached location of the form
|
||||
``memcached://host:port,host:port``,
|
||||
``memcached:///var/tmp/path/to/sock``
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`pymemcache.client.base.PooledClient`
|
||||
or :class:`pymemcache.client.hash.HashClient` (if there are more than
|
||||
one hosts specified)
|
||||
:raise ConfigurationError: when :pypi:`pymemcache` is not available
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
self.hosts = []
|
||||
|
||||
for loc in parsed.netloc.strip().split(","):
|
||||
if not loc:
|
||||
continue
|
||||
host, port = loc.split(":")
|
||||
self.hosts.append((host, int(port)))
|
||||
else:
|
||||
# filesystem path to UDS
|
||||
|
||||
if parsed.path and not parsed.netloc and not parsed.port:
|
||||
self.hosts = [parsed.path] # type: ignore
|
||||
|
||||
self.dependency = self.dependencies["pymemcache"].module
|
||||
self.library = str(options.pop("library", "pymemcache.client"))
|
||||
self.cluster_library = str(
|
||||
options.pop("cluster_library", "pymemcache.client.hash")
|
||||
)
|
||||
self.client_getter = cast(
|
||||
Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
|
||||
options.pop("client_getter", self.get_client),
|
||||
)
|
||||
self.options = options
|
||||
|
||||
if not get_dependency(self.library):
|
||||
raise ConfigurationError(
|
||||
f"memcached prerequisite not available. please install {self.library}"
|
||||
) # pragma: no cover
|
||||
self.local_storage = threading.local()
|
||||
self.local_storage.storage = None
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.dependency.MemcacheError # type: ignore[no-any-return]
|
||||
|
||||
def get_client(
|
||||
self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
|
||||
) -> MemcachedClientP:
|
||||
"""
|
||||
returns a memcached client.
|
||||
|
||||
:param module: the memcached module
|
||||
:param hosts: list of memcached hosts
|
||||
"""
|
||||
|
||||
return cast(
|
||||
MemcachedClientP,
|
||||
(
|
||||
module.HashClient(hosts, **kwargs)
|
||||
if len(hosts) > 1
|
||||
else module.PooledClient(*hosts, **kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
def call_memcached_func(
|
||||
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> R:
|
||||
if "noreply" in kwargs:
|
||||
argspec = inspect.getfullargspec(func)
|
||||
|
||||
if not ("noreply" in argspec.args or argspec.varkw):
|
||||
kwargs.pop("noreply")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def storage(self) -> MemcachedClientP:
|
||||
"""
|
||||
lazily creates a memcached client instance using a thread local
|
||||
"""
|
||||
|
||||
if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
|
||||
dependency = get_dependency(
|
||||
self.cluster_library if len(self.hosts) > 1 else self.library
|
||||
)[0]
|
||||
|
||||
if not dependency:
|
||||
raise ConfigurationError(f"Unable to import {self.cluster_library}")
|
||||
self.local_storage.storage = self.client_getter(
|
||||
dependency, self.hosts, **self.options
|
||||
)
|
||||
|
||||
return cast(MemcachedClientP, self.local_storage.storage)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
return int(self.storage.get(key, "0"))
|
||||
|
||||
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
|
||||
"""
|
||||
Return multiple counters at once
|
||||
|
||||
:param keys: the keys to get the counter values for
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.storage.get_many(keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.delete(key)
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: float,
|
||||
amount: int = 1,
|
||||
set_expiration_key: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
window every hit.
|
||||
:param amount: the number to increment by
|
||||
:param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
|
||||
"""
|
||||
if (
|
||||
value := self.call_memcached_func(
|
||||
self.storage.incr, key, amount, noreply=False
|
||||
)
|
||||
) is not None:
|
||||
return value
|
||||
else:
|
||||
if not self.call_memcached_func(
|
||||
self.storage.add, key, amount, ceil(expiry), noreply=False
|
||||
):
|
||||
return self.storage.incr(key, amount) or amount
|
||||
else:
|
||||
if set_expiration_key:
|
||||
self.call_memcached_func(
|
||||
self.storage.set,
|
||||
self._expiration_key(key),
|
||||
expiry + time.time(),
|
||||
expire=ceil(expiry),
|
||||
noreply=False,
|
||||
)
|
||||
|
||||
return amount
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return float(self.storage.get(self._expiration_key(key)) or time.time())
|
||||
|
||||
def _expiration_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the expiration key for the given counter key.
|
||||
|
||||
Memcached doesn't natively return the expiration time or TTL for a given key,
|
||||
so we implement the expiration time on a separate key.
|
||||
"""
|
||||
return key + "/expires"
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling the ``get`` command
|
||||
on the key ``limiter-check``
|
||||
"""
|
||||
try:
|
||||
self.call_memcached_func(self.storage.get, "limiter-check")
|
||||
|
||||
return True
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
|
||||
previous_key, current_key, expiry, now=now
|
||||
)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
# We don't need the expiration key as it is estimated with the timestamps directly.
|
||||
current_count = self.incr(
|
||||
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
||||
)
|
||||
actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
|
||||
weighted_count = (
|
||||
previous_count * actualised_previous_ttl / expiry + current_count
|
||||
)
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.call_memcached_func(
|
||||
self.storage.decr,
|
||||
current_key,
|
||||
amount,
|
||||
noreply=True,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
self.clear(previous_key)
|
||||
self.clear(current_key)
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self, previous_key: str, current_key: str, expiry: int, now: float
|
||||
) -> tuple[int, float, int, float]:
|
||||
result = self.get_many([previous_key, current_key])
|
||||
previous_count, current_count = (
|
||||
int(result.get(previous_key, 0)),
|
||||
int(result.get(current_key, 0)),
|
||||
)
|
||||
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
259
buffteks/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
259
buffteks/lib/python3.11/site-packages/limits/storage/memory.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from math import floor
|
||||
|
||||
import limits.typing
|
||||
from limits.storage.base import (
|
||||
MovingWindowSupport,
|
||||
SlidingWindowCounterSupport,
|
||||
Storage,
|
||||
TimestampedSlidingWindow,
|
||||
)
|
||||
|
||||
|
||||
class Entry:
|
||||
def __init__(self, expiry: float) -> None:
|
||||
self.atime = time.time()
|
||||
self.expiry = self.atime + expiry
|
||||
|
||||
|
||||
class MemoryStorage(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
|
||||
):
|
||||
"""
|
||||
rate limit storage using :class:`collections.Counter`
|
||||
as an in memory storage for fixed and sliding window strategies,
|
||||
and a simple list to implement moving window strategy.
|
||||
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["memory"]
|
||||
|
||||
def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
|
||||
self.storage: limits.typing.Counter[str] = Counter()
|
||||
self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
|
||||
self.expirations: dict[str, float] = {}
|
||||
self.events: dict[str, list[Entry]] = {}
|
||||
self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
|
||||
|
||||
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
|
||||
state = self.__dict__.copy()
|
||||
del state["timer"]
|
||||
del state["locks"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
|
||||
self.__dict__.update(state)
|
||||
self.locks = defaultdict(threading.RLock)
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
def __expire_events(self) -> None:
|
||||
for key in list(self.events.keys()):
|
||||
with self.locks[key]:
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -time.time(), key=lambda event: -event.expiry
|
||||
)
|
||||
self.events[key] = self.events[key][:oldest]
|
||||
if not self.events.get(key, None):
|
||||
self.locks.pop(key, None)
|
||||
for key in list(self.expirations.keys()):
|
||||
if self.expirations[key] <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def __schedule_expiry(self) -> None:
|
||||
if not self.timer.is_alive():
|
||||
self.timer = threading.Timer(0.01, self.__expire_events)
|
||||
self.timer.start()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ValueError
|
||||
|
||||
def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] += amount
|
||||
if self.storage[key] == amount:
|
||||
self.expirations[key] = time.time() + expiry
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
decrements the counter for a given rate limit key
|
||||
|
||||
:param key: the key to decrement
|
||||
:param amount: the number to decrement by
|
||||
"""
|
||||
self.get(key)
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.storage[key] = max(self.storage[key] - amount, 0)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
if self.expirations.get(key, 0) <= time.time():
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
return self.storage.get(key, 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.storage.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
self.events.pop(key, None)
|
||||
self.locks.pop(key, None)
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
self.__schedule_expiry()
|
||||
with self.locks[key]:
|
||||
self.events.setdefault(key, [])
|
||||
timestamp = time.time()
|
||||
try:
|
||||
entry = self.events[key][limit - amount]
|
||||
except IndexError:
|
||||
entry = None
|
||||
|
||||
if entry and entry.atime >= timestamp - expiry:
|
||||
return False
|
||||
else:
|
||||
self.events[key][:0] = [Entry(expiry)] * amount
|
||||
return True
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
|
||||
return self.expirations.get(key, time.time())
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if events := self.events.get(key, []):
|
||||
oldest = bisect.bisect_left(
|
||||
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
||||
)
|
||||
return events[oldest - 1].atime, oldest
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if amount > limit:
|
||||
return False
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
(
|
||||
previous_count,
|
||||
previous_ttl,
|
||||
current_count,
|
||||
_,
|
||||
) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) + amount > limit:
|
||||
return False
|
||||
else:
|
||||
# Hit, increase the current counter.
|
||||
# If the counter doesn't exist yet, set twice the theorical expiry.
|
||||
current_count = self.incr(current_key, 2 * expiry, amount=amount)
|
||||
weighted_count = previous_count * previous_ttl / expiry + current_count
|
||||
if floor(weighted_count) > limit:
|
||||
# Another hit won the race condition: revert the incrementation and refuse this hit
|
||||
# Limitation: during high concurrency at the end of the window,
|
||||
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
||||
self.decr(current_key, amount)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sliding_window_info(
|
||||
self,
|
||||
previous_key: str,
|
||||
current_key: str,
|
||||
expiry: int,
|
||||
now: float,
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_count = self.get(previous_key)
|
||||
current_count = self.get(current_key)
|
||||
if previous_count == 0:
|
||||
previous_ttl = float(0)
|
||||
else:
|
||||
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
||||
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
||||
return previous_count, previous_ttl, current_count, current_ttl
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
return self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
||||
|
||||
def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
now = time.time()
|
||||
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
||||
self.clear(previous_key)
|
||||
self.clear(current_key)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
def reset(self) -> int | None:
|
||||
num_items = max(len(self.storage), len(self.events))
|
||||
self.storage.clear()
|
||||
self.expirations.clear()
|
||||
self.events.clear()
|
||||
self.locks.clear()
|
||||
return num_items
|
||||
492
buffteks/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
492
buffteks/lib/python3.11/site-packages/limits/storage/mongodb.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deprecated.sphinx import versionadded, versionchanged
|
||||
|
||||
from limits.typing import (
|
||||
MongoClient,
|
||||
MongoCollection,
|
||||
MongoDatabase,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..util import get_dependency
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
|
||||
class MongoDBStorageBase(
|
||||
Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
|
||||
):
|
||||
"""
|
||||
Rate limit storage with MongoDB as backend.
|
||||
|
||||
Depends on :pypi:`pymongo`.
|
||||
"""
|
||||
|
||||
DEPENDENCIES = ["pymongo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str = "limits",
|
||||
counter_collection_name: str = "counters",
|
||||
window_collection_name: str = "windows",
|
||||
wrap_exceptions: bool = False,
|
||||
**options: int | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
|
||||
This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
|
||||
:param database_name: The database to use for storing the rate limit
|
||||
collections.
|
||||
:param counter_collection_name: The collection name to use for individual counters
|
||||
used in fixed window strategies
|
||||
:param window_collection_name: The collection name to use for sliding & moving window
|
||||
storage
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed to the
|
||||
constructor of :class:`~pymongo.mongo_client.MongoClient`
|
||||
:raise ConfigurationError: when the :pypi:`pymongo` library is not available
|
||||
"""
|
||||
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self._database_name = database_name
|
||||
self._collection_mapping = {
|
||||
"counters": counter_collection_name,
|
||||
"windows": window_collection_name,
|
||||
}
|
||||
self.lib = self.dependencies["pymongo"].module
|
||||
self.lib_errors, _ = get_dependency("pymongo.errors")
|
||||
self._storage_uri = uri
|
||||
self._storage_options = options
|
||||
self._storage: MongoClient | None = None
|
||||
|
||||
@property
|
||||
def storage(self) -> MongoClient:
|
||||
if self._storage is None:
|
||||
self._storage = self._init_mongo_client(
|
||||
self._storage_uri, **self._storage_options
|
||||
)
|
||||
self.__initialize_database()
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def _database(self) -> MongoDatabase:
|
||||
return self.storage[self._database_name]
|
||||
|
||||
@property
|
||||
def counters(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["counters"]]
|
||||
|
||||
@property
|
||||
def windows(self) -> MongoCollection:
|
||||
return self._database[self._collection_mapping["windows"]]
|
||||
|
||||
@abstractmethod
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return self.lib_errors.PyMongoError # type: ignore
|
||||
|
||||
def __initialize_database(self) -> None:
|
||||
self.counters.create_index("expireAt", expireAfterSeconds=0)
|
||||
self.windows.create_index("expireAt", expireAfterSeconds=0)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Delete all rate limit keys in the rate limit collections (counters, windows)
|
||||
"""
|
||||
num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
|
||||
self.counters.drop()
|
||||
self.windows.drop()
|
||||
|
||||
return int(num_keys)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
self.counters.find_one_and_delete({"_id": key})
|
||||
self.windows.find_one_and_delete({"_id": key})
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
"""
|
||||
counter = self.counters.find_one({"_id": key})
|
||||
return (
|
||||
(counter["expireAt"] if counter else datetime.datetime.now())
|
||||
.replace(tzinfo=datetime.timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
counter = self.counters.find_one(
|
||||
{
|
||||
"_id": key,
|
||||
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
|
||||
},
|
||||
projection=["count"],
|
||||
)
|
||||
|
||||
return counter and counter["count"] or 0
|
||||
|
||||
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
|
||||
seconds=expiry
|
||||
)
|
||||
|
||||
return int(
|
||||
self.counters.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"count": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": amount,
|
||||
"else": {"$add": ["$count", amount]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
||||
"then": expiration,
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
upsert=True,
|
||||
projection=["count"],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
)["count"]
|
||||
)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
|
||||
"""
|
||||
try:
|
||||
self.storage.server_info()
|
||||
|
||||
return True
|
||||
except: # noqa: E722
|
||||
return False
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
timestamp = time.time()
|
||||
if result := list(
|
||||
self.windows.aggregate(
|
||||
[
|
||||
{"$match": {"_id": key}},
|
||||
{
|
||||
"$project": {
|
||||
"filteredEntries": {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"min": {"$min": "$filteredEntries"},
|
||||
"count": {"$size": "$filteredEntries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
):
|
||||
return result[0]["min"], result[0]["count"]
|
||||
return timestamp, 0
|
||||
|
||||
def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
if amount > limit:
|
||||
return False
|
||||
|
||||
timestamp = time.time()
|
||||
try:
|
||||
updates: dict[
|
||||
str,
|
||||
dict[str, datetime.datetime | dict[str, list[float] | int]],
|
||||
] = {
|
||||
"$push": {
|
||||
"entries": {
|
||||
"$each": [timestamp] * amount,
|
||||
"$position": 0,
|
||||
"$slice": limit,
|
||||
}
|
||||
},
|
||||
"$set": {
|
||||
"expireAt": (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expiry)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
self.windows.update_one(
|
||||
{
|
||||
"_id": key,
|
||||
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
|
||||
},
|
||||
updates,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
return True
|
||||
except self.lib.errors.DuplicateKeyError:
|
||||
return False
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
expiry_ms = expiry * 1000
|
||||
if result := self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$add": ["$expireAt", expiry_ms],
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
projection=["currentCount", "previousCount", "expireAt"],
|
||||
):
|
||||
expires_at = (
|
||||
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
if result.get("expireAt")
|
||||
else time.time()
|
||||
)
|
||||
current_ttl = max(0, expires_at - time.time())
|
||||
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
|
||||
|
||||
return (
|
||||
result["previousCount"],
|
||||
prev_ttl,
|
||||
result["currentCount"],
|
||||
current_ttl,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self, key: str, limit: int, expiry: int, amount: int = 1
|
||||
) -> bool:
|
||||
expiry_ms = expiry * 1000
|
||||
result = self.windows.find_one_and_update(
|
||||
{"_id": key},
|
||||
[
|
||||
{
|
||||
"$set": {
|
||||
"previousCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {"$ifNull": ["$currentCount", 0]},
|
||||
"else": {"$ifNull": ["$previousCount", 0]},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": 0,
|
||||
"else": {"$ifNull": ["$currentCount", 0]},
|
||||
}
|
||||
},
|
||||
"expireAt": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$subtract": ["$expireAt", "$$NOW"]},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$cond": {
|
||||
"if": {"$gt": ["$expireAt", 0]},
|
||||
"then": {"$add": ["$expireAt", expiry_ms]},
|
||||
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
|
||||
}
|
||||
},
|
||||
"else": "$expireAt",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"curWeightedCount": {
|
||||
"$floor": {
|
||||
"$add": [
|
||||
{
|
||||
"$multiply": [
|
||||
"$previousCount",
|
||||
{
|
||||
"$divide": [
|
||||
{
|
||||
"$max": [
|
||||
0,
|
||||
{
|
||||
"$subtract": [
|
||||
"$expireAt",
|
||||
{
|
||||
"$add": [
|
||||
"$$NOW",
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
expiry_ms,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
"$currentCount",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"currentCount": {
|
||||
"$cond": {
|
||||
"if": {
|
||||
"$lte": [
|
||||
{"$add": ["$curWeightedCount", amount]},
|
||||
limit,
|
||||
]
|
||||
},
|
||||
"then": {"$add": ["$currentCount", amount]},
|
||||
"else": "$currentCount",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"_acquired": {
|
||||
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["curWeightedCount"]},
|
||||
],
|
||||
return_document=self.lib.ReturnDocument.AFTER,
|
||||
upsert=True,
|
||||
)
|
||||
return cast(bool, result["_acquired"])
|
||||
|
||||
def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
return self.clear(key)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.close()
|
||||
|
||||
|
||||
@versionadded(version="2.1")
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="Added option to select custom collection names for windows & counters",
|
||||
)
|
||||
class MongoDBStorage(MongoDBStorageBase):
|
||||
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
|
||||
|
||||
def _init_mongo_client(
|
||||
self, uri: str | None, **options: int | str | bool
|
||||
) -> MongoClient:
|
||||
return cast(MongoClient, self.lib.MongoClient(uri, **options))
|
||||
317
buffteks/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
317
buffteks/lib/python3.11/site-packages/limits/storage/redis.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.typing import Literal, RedisClient
|
||||
|
||||
from ..util import get_package_data
|
||||
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
|
||||
"""
|
||||
Rate limit storage with redis as backend.
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = [
|
||||
"redis",
|
||||
"rediss",
|
||||
"redis+unix",
|
||||
"valkey",
|
||||
"valkeys",
|
||||
"valkey+unix",
|
||||
]
|
||||
"""The storage scheme for redis"""
|
||||
|
||||
DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}
|
||||
|
||||
RES_DIR = "resources/redis/lua_scripts"
|
||||
|
||||
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
|
||||
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_moving_window.lua"
|
||||
)
|
||||
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
|
||||
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
|
||||
|
||||
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
|
||||
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
|
||||
f"{RES_DIR}/acquire_sliding_window.lua"
|
||||
)
|
||||
|
||||
lua_moving_window: redis.commands.core.Script
|
||||
lua_acquire_moving_window: redis.commands.core.Script
|
||||
lua_sliding_window: redis.commands.core.Script
|
||||
lua_acquire_sliding_window: redis.commands.core.Script
|
||||
|
||||
PREFIX = "LIMITS"
|
||||
target_server: Literal["redis", "valkey"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
connection_pool: redis.connection.ConnectionPool | None = None,
|
||||
key_prefix: str = PREFIX,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: uri of the form ``redis://[:password]@host:port``,
|
||||
``redis://[:password]@host:port/db``,
|
||||
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
|
||||
This uri is passed directly to :func:`redis.from_url` except for the
|
||||
case of ``redis+unix://`` where it is replaced with ``unix://``.
|
||||
|
||||
If the uri scheme is ``valkey`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param connection_pool: if provided, the redis client is initialized with
|
||||
the connection pool and any other params passed as :paramref:`options`
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.Redis`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not available
|
||||
"""
|
||||
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
||||
self.key_prefix = key_prefix
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
|
||||
uri = uri.replace(f"{self.target_server}+unix", "unix")
|
||||
|
||||
if not connection_pool:
|
||||
self.storage = self.dependency.from_url(uri, **options)
|
||||
else:
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.Redis(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.Valkey(
|
||||
connection_pool=connection_pool, **options
|
||||
)
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependency.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependency.ValkeyError
|
||||
)
|
||||
|
||||
def initialize_storage(self, _uri: str) -> None:
|
||||
self.lua_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_MOVING_WINDOW
|
||||
)
|
||||
self.lua_acquire_moving_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
||||
)
|
||||
self.lua_clear_keys = self.get_connection().register_script(
|
||||
self.SCRIPT_CLEAR_KEYS
|
||||
)
|
||||
self.lua_incr_expire = self.get_connection().register_script(
|
||||
self.SCRIPT_INCR_EXPIRE
|
||||
)
|
||||
self.lua_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_SLIDING_WINDOW
|
||||
)
|
||||
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
||||
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return cast(RedisClient, self.storage)
|
||||
|
||||
def _current_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the current window's storage key (Sliding window strategy)
|
||||
|
||||
Contrary to other strategies that have one key per rate limit item,
|
||||
this strategy has two keys per rate limit item than must be on the same machine.
|
||||
To keep the current key and the previous key on the same Redis cluster node,
|
||||
curly braces are added.
|
||||
|
||||
Eg: "{constructed_key}"
|
||||
"""
|
||||
return f"{{{key}}}"
|
||||
|
||||
def _previous_window_key(self, key: str) -> str:
|
||||
"""
|
||||
Return the previous window's storage key (Sliding window strategy).
|
||||
|
||||
Curvy braces are added on the common pattern with the current window's key,
|
||||
so the current and the previous key are stored on the same Redis cluster node.
|
||||
|
||||
Eg: "{constructed_key}/-1"
|
||||
"""
|
||||
return f"{self._current_window_key(key)}/-1"
|
||||
|
||||
def prefixed_key(self, key: str) -> str:
|
||||
return f"{self.key_prefix}:{key}"
|
||||
|
||||
def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
|
||||
"""
|
||||
returns the starting point and the number of entries in the moving
|
||||
window
|
||||
|
||||
:param key: rate limit key
|
||||
:param expiry: expiry of entry
|
||||
:return: (start of window, number of acquired entries)
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
|
||||
return float(window[0]), window[1]
|
||||
|
||||
return timestamp, 0
|
||||
|
||||
def get_sliding_window(
|
||||
self, key: str, expiry: int
|
||||
) -> tuple[int, float, int, float]:
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
|
||||
return (
|
||||
int(window[0] or 0),
|
||||
max(0, float(window[1] or 0)) / 1000,
|
||||
int(window[2] or 0),
|
||||
max(0, float(window[3] or 0)) / 1000,
|
||||
)
|
||||
return 0, 0.0, 0, 0.0
|
||||
|
||||
def clear_sliding_window(self, key: str, expiry: int) -> None:
|
||||
previous_key = self._previous_window_key(key)
|
||||
current_key = self._current_window_key(key)
|
||||
self.clear(previous_key)
|
||||
self.clear(current_key)
|
||||
|
||||
def incr(
|
||||
self,
|
||||
key: str,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
increments the counter for a given rate limit key
|
||||
|
||||
|
||||
:param key: the key to increment
|
||||
:param expiry: amount in seconds for the key to expire in
|
||||
:param amount: the number to increment by
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.lua_incr_expire([key], [expiry, amount]))
|
||||
|
||||
def get(self, key: str) -> int:
|
||||
"""
|
||||
|
||||
:param key: the key to get the counter value for
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return int(self.get_connection(True).get(key) or 0)
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
"""
|
||||
:param key: the key to clear rate limits for
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
self.get_connection().delete(key)
|
||||
|
||||
def acquire_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
key = self.prefixed_key(key)
|
||||
timestamp = time.time()
|
||||
acquired = self.lua_acquire_moving_window(
|
||||
[key], [timestamp, limit, expiry, amount]
|
||||
)
|
||||
|
||||
return bool(acquired)
|
||||
|
||||
def acquire_sliding_window_entry(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
expiry: int,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire an entry. Shift the current window to the previous window if it expired.
|
||||
|
||||
:param key: rate limit key to acquire an entry in
|
||||
:param limit: amount of entries allowed
|
||||
:param expiry: expiry of the entry
|
||||
:param amount: the number of entries to acquire
|
||||
"""
|
||||
previous_key = self.prefixed_key(self._previous_window_key(key))
|
||||
current_key = self.prefixed_key(self._current_window_key(key))
|
||||
acquired = self.lua_acquire_sliding_window(
|
||||
[previous_key, current_key], [limit, expiry, amount]
|
||||
)
|
||||
return bool(acquired)
|
||||
|
||||
def get_expiry(self, key: str) -> float:
|
||||
"""
|
||||
:param key: the key to get the expiry for
|
||||
|
||||
"""
|
||||
|
||||
key = self.prefixed_key(key)
|
||||
return max(self.get_connection(True).ttl(key), 0) + time.time()
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
check if storage is healthy
|
||||
"""
|
||||
try:
|
||||
return self.get_connection().ping()
|
||||
except: # noqa
|
||||
return False
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
This function calls a Lua Script to delete keys prefixed with
|
||||
:paramref:`RedisStorage.key_prefix` in blocks of 5000.
|
||||
|
||||
.. warning::
|
||||
This operation was designed to be fast, but was not tested
|
||||
on a large production based system. Be careful with its usage as it
|
||||
could be slow on very large data sets.
|
||||
|
||||
"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
return int(self.lua_clear_keys([prefix]))
|
||||
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.storage.redis import RedisStorage
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="3.14.0",
|
||||
reason="""
|
||||
Dropped support for the :pypi:`redis-py-cluster` library
|
||||
which has been abandoned/deprecated.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="2.5.0",
|
||||
reason="""
|
||||
Cluster support was provided by the :pypi:`redis-py-cluster` library
|
||||
which has been absorbed into the official :pypi:`redis` client. By
|
||||
default the :class:`redis.cluster.RedisCluster` client will be used
|
||||
however if the version of the package is lower than ``4.2.0`` the implementation
|
||||
will fallback to trying to use :class:`rediscluster.RedisCluster`.
|
||||
""",
|
||||
)
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+cluster://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisClusterStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis cluster as backend
|
||||
|
||||
Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri`
|
||||
starts with ``valkey+cluster://``).
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+cluster", "valkey+cluster"]
|
||||
"""The storage scheme for redis cluster"""
|
||||
|
||||
DEFAULT_OPTIONS: dict[str, float | str | bool] = {
|
||||
"max_connections": 1000,
|
||||
}
|
||||
"Default options passed to the :class:`~redis.cluster.RedisCluster`"
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("4.2.0"),
|
||||
"valkey": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
key_prefix: str = RedisStorage.PREFIX,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+cluster://[:password]@host:port,host:port``
|
||||
|
||||
If the uri scheme is ``valkey+cluster`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.cluster.RedisCluster`
|
||||
:raise ConfigurationError: when the :pypi:`redis` library is not
|
||||
available or if the redis cluster cannot be reached.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
cluster_hosts = []
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
cluster_hosts.append((host, int(port)))
|
||||
|
||||
self.key_prefix = key_prefix
|
||||
self.storage = None
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
merged_options = {**self.DEFAULT_OPTIONS, **parsed_auth, **options}
|
||||
self.dependency = self.dependencies[self.target_server].module
|
||||
startup_nodes = [self.dependency.cluster.ClusterNode(*c) for c in cluster_hosts]
|
||||
if self.target_server == "redis":
|
||||
self.storage = self.dependency.cluster.RedisCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
else:
|
||||
self.storage = self.dependency.cluster.ValkeyCluster(
|
||||
startup_nodes=startup_nodes, **merged_options
|
||||
)
|
||||
|
||||
assert self.storage
|
||||
self.initialize_storage(uri)
|
||||
super(RedisStorage, self).__init__(uri, wrap_exceptions, **options)
|
||||
|
||||
def reset(self) -> int | None:
|
||||
"""
|
||||
Redis Clusters are sharded and deleting across shards
|
||||
can't be done atomically. Because of this, this reset loops over all
|
||||
keys that are prefixed with :paramref:`RedisClusterStorage.prefix` and
|
||||
calls delete on them one at a time.
|
||||
|
||||
.. warning::
|
||||
This operation was not tested with extremely large data sets.
|
||||
On a large production based system, care should be taken with its
|
||||
usage as it could be slow on very large data sets"""
|
||||
|
||||
prefix = self.prefixed_key("*")
|
||||
count = 0
|
||||
for primary in self.storage.get_primaries():
|
||||
node = self.storage.get_redis_connection(primary)
|
||||
keys = node.keys(prefix)
|
||||
count += sum([node.delete(k.decode("utf-8")) for k in keys])
|
||||
return count
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deprecated.sphinx import versionchanged
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.errors import ConfigurationError
|
||||
from limits.storage.redis import RedisStorage
|
||||
from limits.typing import RedisClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@versionchanged(
|
||||
version="4.3",
|
||||
reason=(
|
||||
"Added support for using the redis client from :pypi:`valkey`"
|
||||
" if :paramref:`uri` has the ``valkey+sentinel://`` schema"
|
||||
),
|
||||
)
|
||||
class RedisSentinelStorage(RedisStorage):
|
||||
"""
|
||||
Rate limit storage with redis sentinel as backend
|
||||
|
||||
Depends on :pypi:`redis` package (or :pypi:`valkey` if :paramref:`uri` starts with
|
||||
``valkey+sentinel://``)
|
||||
"""
|
||||
|
||||
STORAGE_SCHEME = ["redis+sentinel", "valkey+sentinel"]
|
||||
"""The storage scheme for redis accessed via a redis sentinel installation"""
|
||||
|
||||
DEPENDENCIES = {
|
||||
"redis": Version("3.0"),
|
||||
"redis.sentinel": Version("3.0"),
|
||||
"valkey": Version("6.0"),
|
||||
"valkey.sentinel": Version("6.0"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
service_name: str | None = None,
|
||||
use_replicas: bool = True,
|
||||
sentinel_kwargs: dict[str, float | str | bool] | None = None,
|
||||
key_prefix: str = RedisStorage.PREFIX,
|
||||
wrap_exceptions: bool = False,
|
||||
**options: float | str | bool,
|
||||
) -> None:
|
||||
"""
|
||||
:param uri: url of the form
|
||||
``redis+sentinel://host:port,host:port/service_name``
|
||||
|
||||
If the uri scheme is ``valkey+sentinel`` the implementation used will be from
|
||||
:pypi:`valkey`.
|
||||
:param service_name: sentinel service name
|
||||
(if not provided in :attr:`uri`)
|
||||
:param use_replicas: Whether to use replicas for read only operations
|
||||
:param sentinel_kwargs: kwargs to pass as
|
||||
:attr:`sentinel_kwargs` to :class:`redis.sentinel.Sentinel`
|
||||
:param key_prefix: the prefix for each key created in redis
|
||||
:param wrap_exceptions: Whether to wrap storage exceptions in
|
||||
:exc:`limits.errors.StorageError` before raising it.
|
||||
:param options: all remaining keyword arguments are passed
|
||||
directly to the constructor of :class:`redis.sentinel.Sentinel`
|
||||
:raise ConfigurationError: when the redis library is not available
|
||||
or if the redis master host cannot be pinged.
|
||||
"""
|
||||
|
||||
super(RedisStorage, self).__init__(
|
||||
uri, wrap_exceptions=wrap_exceptions, **options
|
||||
)
|
||||
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
sentinel_configuration = []
|
||||
sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
|
||||
|
||||
parsed_auth: dict[str, float | str | bool] = {}
|
||||
|
||||
if parsed.username:
|
||||
parsed_auth["username"] = parsed.username
|
||||
if parsed.password:
|
||||
parsed_auth["password"] = parsed.password
|
||||
|
||||
sep = parsed.netloc.find("@") + 1
|
||||
|
||||
for loc in parsed.netloc[sep:].split(","):
|
||||
host, port = loc.split(":")
|
||||
sentinel_configuration.append((host, int(port)))
|
||||
self.key_prefix = key_prefix
|
||||
self.service_name = (
|
||||
parsed.path.replace("/", "") if parsed.path else service_name
|
||||
)
|
||||
|
||||
if self.service_name is None:
|
||||
raise ConfigurationError("'service_name' not provided")
|
||||
|
||||
self.target_server = "valkey" if uri.startswith("valkey") else "redis"
|
||||
sentinel_dep = self.dependencies[f"{self.target_server}.sentinel"].module
|
||||
self.sentinel = sentinel_dep.Sentinel(
|
||||
sentinel_configuration,
|
||||
sentinel_kwargs={**parsed_auth, **sentinel_options},
|
||||
**{**parsed_auth, **options},
|
||||
)
|
||||
self.storage: RedisClient = self.sentinel.master_for(self.service_name)
|
||||
self.storage_slave: RedisClient = self.sentinel.slave_for(self.service_name)
|
||||
self.use_replicas = use_replicas
|
||||
self.initialize_storage(uri)
|
||||
|
||||
@property
|
||||
def base_exceptions(
|
||||
self,
|
||||
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
||||
return ( # type: ignore[no-any-return]
|
||||
self.dependencies["redis"].module.RedisError
|
||||
if self.target_server == "redis"
|
||||
else self.dependencies["valkey"].module.ValkeyError
|
||||
)
|
||||
|
||||
def get_connection(self, readonly: bool = False) -> RedisClient:
|
||||
return self.storage_slave if (readonly and self.use_replicas) else self.storage
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
SCHEMES: dict[str, StorageRegistry] = {}
|
||||
|
||||
|
||||
class StorageRegistry(ABCMeta):
|
||||
def __new__(
|
||||
mcs, name: str, bases: tuple[type, ...], dct: dict[str, str | list[str]]
|
||||
) -> StorageRegistry:
|
||||
storage_scheme = dct.get("STORAGE_SCHEME", None)
|
||||
cls = super().__new__(mcs, name, bases, dct)
|
||||
|
||||
if storage_scheme:
|
||||
if isinstance(storage_scheme, str): # noqa
|
||||
schemes = [storage_scheme]
|
||||
else:
|
||||
schemes = storage_scheme
|
||||
|
||||
for scheme in schemes:
|
||||
SCHEMES[scheme] = cls
|
||||
|
||||
return cls
|
||||
318
buffteks/lib/python3.11/site-packages/limits/strategies.py
Normal file
318
buffteks/lib/python3.11/site-packages/limits/strategies.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Rate limiting strategies
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from math import floor, inf
|
||||
|
||||
from deprecated.sphinx import versionadded
|
||||
|
||||
from limits.storage.base import SlidingWindowCounterSupport
|
||||
|
||||
from .limits import RateLimitItem
|
||||
from .storage import MovingWindowSupport, Storage, StorageTypes
|
||||
from .typing import cast
|
||||
from .util import WindowStats
|
||||
|
||||
|
||||
class RateLimiter(metaclass=ABCMeta):
|
||||
def __init__(self, storage: StorageTypes):
|
||||
assert isinstance(storage, Storage)
|
||||
self.storage: Storage = storage
|
||||
|
||||
@abstractmethod
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check the rate limit without consuming from it.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return self.storage.clear(item.key_for(*identifiers))
|
||||
|
||||
|
||||
class MovingWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:moving window`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not (
|
||||
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"MovingWindowRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
|
||||
return cast(MovingWindowSupport, self.storage).acquire_entry(
|
||||
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
|
||||
return (
|
||||
cast(MovingWindowSupport, self.storage).get_moving_window(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
)[1]
|
||||
<= item.amount - cost
|
||||
)
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
returns the number of requests remaining within this limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: tuple (reset time, remaining)
|
||||
"""
|
||||
window_start, window_items = cast(
|
||||
MovingWindowSupport, self.storage
|
||||
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
|
||||
reset = window_start + item.get_expiry()
|
||||
|
||||
return WindowStats(reset, item.amount - window_items)
|
||||
|
||||
|
||||
class FixedWindowRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:fixed window`
|
||||
"""
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
|
||||
return (
|
||||
self.storage.incr(
|
||||
item.key_for(*identifiers),
|
||||
item.get_expiry(),
|
||||
amount=cost,
|
||||
)
|
||||
<= item.amount
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
|
||||
return self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: (reset time, remaining)
|
||||
"""
|
||||
remaining = max(0, item.amount - self.storage.get(item.key_for(*identifiers)))
|
||||
reset = self.storage.get_expiry(item.key_for(*identifiers))
|
||||
|
||||
return WindowStats(reset, remaining)
|
||||
|
||||
|
||||
@versionadded(version="4.1")
|
||||
class SlidingWindowCounterRateLimiter(RateLimiter):
|
||||
"""
|
||||
Reference: :ref:`strategies:sliding window counter`
|
||||
"""
|
||||
|
||||
def __init__(self, storage: StorageTypes):
|
||||
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
||||
storage, "acquire_sliding_window_entry"
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
||||
f"of type {storage.__class__}"
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def _weighted_count(
|
||||
self,
|
||||
item: RateLimitItem,
|
||||
previous_count: int,
|
||||
previous_expires_in: float,
|
||||
current_count: int,
|
||||
) -> float:
|
||||
"""
|
||||
Return the approximated by weighting the previous window count and adding the current window count.
|
||||
"""
|
||||
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
||||
|
||||
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Consume the rate limit
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The cost of this hit, default 1
|
||||
|
||||
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
|
||||
"""
|
||||
return cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).acquire_sliding_window_entry(
|
||||
item.key_for(*identifiers),
|
||||
item.amount,
|
||||
item.get_expiry(),
|
||||
cost,
|
||||
)
|
||||
|
||||
def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
||||
"""
|
||||
Check if the rate limit can be consumed
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:param cost: The expected cost to be consumed, default 1
|
||||
|
||||
:return: True if the rate limit is not depleted
|
||||
"""
|
||||
previous_count, previous_expires_in, current_count, _ = cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
return (
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
< item.amount - cost + 1
|
||||
)
|
||||
|
||||
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
|
||||
"""
|
||||
Query the reset time and remaining amount for the limit.
|
||||
|
||||
:param item: The rate limit item
|
||||
:param identifiers: variable list of strings to uniquely identify this
|
||||
instance of the limit
|
||||
:return: WindowStats(reset time, remaining)
|
||||
"""
|
||||
previous_count, previous_expires_in, current_count, current_expires_in = cast(
|
||||
SlidingWindowCounterSupport, self.storage
|
||||
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
||||
|
||||
remaining = max(
|
||||
0,
|
||||
item.amount
|
||||
- floor(
|
||||
self._weighted_count(
|
||||
item, previous_count, previous_expires_in, current_count
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if not (previous_count or current_count):
|
||||
return WindowStats(now, remaining)
|
||||
|
||||
expiry = item.get_expiry()
|
||||
|
||||
previous_reset_in, current_reset_in = inf, inf
|
||||
if previous_count:
|
||||
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
||||
if current_count:
|
||||
current_reset_in = current_expires_in % expiry
|
||||
|
||||
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
||||
|
||||
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
|
||||
return cast(SlidingWindowCounterSupport, self.storage).clear_sliding_window(
|
||||
item.key_for(*identifiers), item.get_expiry()
|
||||
)
|
||||
|
||||
|
||||
KnownStrategy = (
|
||||
type[SlidingWindowCounterRateLimiter]
|
||||
| type[FixedWindowRateLimiter]
|
||||
| type[MovingWindowRateLimiter]
|
||||
)
|
||||
|
||||
STRATEGIES: dict[str, KnownStrategy] = {
|
||||
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
||||
"fixed-window": FixedWindowRateLimiter,
|
||||
"moving-window": MovingWindowRateLimiter,
|
||||
}
|
||||
127
buffteks/lib/python3.11/site-packages/limits/typing.py
Normal file
127
buffteks/lib/python3.11/site-packages/limits/typing.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
Serializable = int | str | float
|
||||
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import coredis
|
||||
import pymongo.collection
|
||||
import pymongo.database
|
||||
import pymongo.mongo_client
|
||||
import redis
|
||||
|
||||
|
||||
class MemcachedClientP(Protocol):
|
||||
def add(
|
||||
self,
|
||||
key: str,
|
||||
value: Serializable,
|
||||
expire: int | None = 0,
|
||||
noreply: bool | None = None,
|
||||
flags: int | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
def get(self, key: str, default: str | None = None) -> bytes: ...
|
||||
|
||||
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: ... # type:ignore[explicit-any]
|
||||
|
||||
def incr(
|
||||
self, key: str, value: int, noreply: bool | None = False
|
||||
) -> int | None: ...
|
||||
|
||||
def decr(
|
||||
self,
|
||||
key: str,
|
||||
value: int,
|
||||
noreply: bool | None = False,
|
||||
) -> int | None: ...
|
||||
|
||||
def delete(self, key: str, noreply: bool | None = None) -> bool | None: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Serializable,
|
||||
expire: int = 0,
|
||||
noreply: bool | None = None,
|
||||
flags: int | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
def touch(
|
||||
self, key: str, expire: int | None = 0, noreply: bool | None = None
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class RedisClientP(Protocol):
|
||||
def incrby(self, key: str, amount: int) -> int: ...
|
||||
def get(self, key: str) -> bytes | None: ...
|
||||
def delete(self, key: str) -> int: ...
|
||||
def ttl(self, key: str) -> int: ...
|
||||
def expire(self, key: str, seconds: int) -> bool: ...
|
||||
def ping(self) -> bool: ...
|
||||
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
|
||||
|
||||
|
||||
class AsyncRedisClientP(Protocol):
|
||||
async def incrby(self, key: str, amount: int) -> int: ...
|
||||
async def get(self, key: str) -> bytes | None: ...
|
||||
async def delete(self, key: str) -> int: ...
|
||||
async def ttl(self, key: str) -> int: ...
|
||||
async def expire(self, key: str, seconds: int) -> bool: ...
|
||||
async def ping(self) -> bool: ...
|
||||
def register_script(self, script: bytes) -> redis.commands.core.Script: ...
|
||||
|
||||
|
||||
RedisClient: TypeAlias = RedisClientP
|
||||
AsyncRedisClient: TypeAlias = AsyncRedisClientP
|
||||
AsyncCoRedisClient: TypeAlias = "coredis.Redis[bytes] | coredis.RedisCluster[bytes]"
|
||||
|
||||
MongoClient: TypeAlias = "pymongo.mongo_client.MongoClient[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
MongoDatabase: TypeAlias = "pymongo.database.Database[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
MongoCollection: TypeAlias = "pymongo.collection.Collection[dict[str, Any]]" # type:ignore[explicit-any]
|
||||
|
||||
__all__ = [
|
||||
"TYPE_CHECKING",
|
||||
"Any",
|
||||
"AsyncRedisClient",
|
||||
"Awaitable",
|
||||
"Callable",
|
||||
"ClassVar",
|
||||
"Counter",
|
||||
"Iterable",
|
||||
"Literal",
|
||||
"MemcachedClientP",
|
||||
"MongoClient",
|
||||
"MongoCollection",
|
||||
"MongoDatabase",
|
||||
"NamedTuple",
|
||||
"P",
|
||||
"ParamSpec",
|
||||
"Protocol",
|
||||
"R",
|
||||
"R_co",
|
||||
"RedisClient",
|
||||
"Serializable",
|
||||
"TypeAlias",
|
||||
"TypeVar",
|
||||
"cast",
|
||||
]
|
||||
209
buffteks/lib/python3.11/site-packages/limits/util.py
Normal file
209
buffteks/lib/python3.11/site-packages/limits/util.py
Normal file
@@ -0,0 +1,209 @@
|
||||
""" """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import re
|
||||
import sys
|
||||
from collections import UserDict
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from limits.typing import NamedTuple
|
||||
|
||||
from .errors import ConfigurationError
|
||||
from .limits import GRANULARITIES, RateLimitItem
|
||||
|
||||
SEPARATORS = re.compile(r"[,;|]{1}")
|
||||
SINGLE_EXPR = re.compile(
|
||||
r"""
|
||||
\s*([0-9]+)
|
||||
\s*(/|\s*per\s*)
|
||||
\s*([0-9]+)?
|
||||
\s*([a-z]+)
|
||||
\s*
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
EXPR = re.compile(
|
||||
rf"^{SINGLE_EXPR.pattern}(:?{SEPARATORS.pattern}{SINGLE_EXPR.pattern})*$",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
class WindowStats(NamedTuple):
|
||||
"""
|
||||
tuple to describe a rate limited window
|
||||
"""
|
||||
|
||||
#: Time as seconds since the Epoch when this window will be reset
|
||||
reset_time: float
|
||||
#: Quantity remaining in this window
|
||||
remaining: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Dependency:
|
||||
name: str
|
||||
version_required: Version | None
|
||||
version_found: Version | None
|
||||
module: ModuleType
|
||||
|
||||
|
||||
MissingModule = ModuleType("Missing")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_UserDict = UserDict[str, Dependency]
|
||||
else:
|
||||
_UserDict = UserDict
|
||||
|
||||
|
||||
class DependencyDict(_UserDict):
|
||||
def __getitem__(self, key: str) -> Dependency:
|
||||
dependency = super().__getitem__(key)
|
||||
|
||||
if dependency.module is MissingModule:
|
||||
message = f"'{dependency.name}' prerequisite not available."
|
||||
if dependency.version_required:
|
||||
message += (
|
||||
f" A minimum version of {dependency.version_required} is required."
|
||||
if dependency.version_required
|
||||
else ""
|
||||
)
|
||||
message += (
|
||||
" See https://limits.readthedocs.io/en/stable/storage.html#supported-versions"
|
||||
" for more details."
|
||||
)
|
||||
raise ConfigurationError(message)
|
||||
elif dependency.version_required and (
|
||||
not dependency.version_found
|
||||
or dependency.version_found < dependency.version_required
|
||||
):
|
||||
raise ConfigurationError(
|
||||
f"The minimum version of {dependency.version_required}"
|
||||
f" for '{dependency.name}' could not be found. Found version: {dependency.version_found}"
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
class LazyDependency:
|
||||
"""
|
||||
Simple utility that provides an :attr:`dependency`
|
||||
to the child class to fetch any dependencies
|
||||
without having to import them explicitly.
|
||||
"""
|
||||
|
||||
DEPENDENCIES: dict[str, Version | None] | list[str] = []
|
||||
"""
|
||||
The python modules this class has a dependency on.
|
||||
Used to lazily populate the :attr:`dependencies`
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dependencies: DependencyDict = DependencyDict()
|
||||
|
||||
@property
|
||||
def dependencies(self) -> DependencyDict:
|
||||
"""
|
||||
Cached mapping of the modules this storage depends on.
|
||||
This is done so that the module is only imported lazily
|
||||
when the storage is instantiated.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
if not getattr(self, "_dependencies", None):
|
||||
dependencies = DependencyDict()
|
||||
mapping: dict[str, Version | None]
|
||||
|
||||
if isinstance(self.DEPENDENCIES, list):
|
||||
mapping = {dependency: None for dependency in self.DEPENDENCIES}
|
||||
else:
|
||||
mapping = self.DEPENDENCIES
|
||||
|
||||
for name, minimum_version in mapping.items():
|
||||
dependency, version = get_dependency(name)
|
||||
|
||||
dependencies[name] = Dependency(
|
||||
name, minimum_version, version, dependency
|
||||
)
|
||||
self._dependencies = dependencies
|
||||
|
||||
return self._dependencies
|
||||
|
||||
|
||||
def get_dependency(module_path: str) -> tuple[ModuleType, Version | None]:
|
||||
"""
|
||||
safe function to import a module at runtime
|
||||
"""
|
||||
try:
|
||||
if module_path not in sys.modules:
|
||||
__import__(module_path)
|
||||
root = module_path.split(".")[0]
|
||||
version = getattr(sys.modules[root], "__version__", "0.0.0")
|
||||
|
||||
return sys.modules[module_path], Version(version)
|
||||
except ImportError: # pragma: no cover
|
||||
return MissingModule, None
|
||||
|
||||
|
||||
def get_package_data(path: str) -> bytes:
|
||||
return importlib.resources.files("limits").joinpath(path).read_bytes()
|
||||
|
||||
|
||||
def parse_many(limit_string: str) -> list[RateLimitItem]:
|
||||
"""
|
||||
parses rate limits in string notation containing multiple rate limits
|
||||
(e.g. ``1/second; 5/minute``)
|
||||
|
||||
:param limit_string: rate limit string using :ref:`ratelimit-string`
|
||||
:raise ValueError: if the string notation is invalid.
|
||||
|
||||
"""
|
||||
|
||||
if not (isinstance(limit_string, str) and EXPR.match(limit_string)):
|
||||
raise ValueError(f"couldn't parse rate limit string '{limit_string}'")
|
||||
limits = []
|
||||
|
||||
for limit in SEPARATORS.split(limit_string):
|
||||
match = SINGLE_EXPR.match(limit)
|
||||
|
||||
if match:
|
||||
amount, _, multiples, granularity_string = match.groups()
|
||||
granularity = granularity_from_string(granularity_string)
|
||||
limits.append(
|
||||
granularity(int(amount), multiples and int(multiples) or None)
|
||||
)
|
||||
|
||||
return limits
|
||||
|
||||
|
||||
def parse(limit_string: str) -> RateLimitItem:
|
||||
"""
|
||||
parses a single rate limit in string notation
|
||||
(e.g. ``1/second`` or ``1 per second``)
|
||||
|
||||
:param limit_string: rate limit string using :ref:`ratelimit-string`
|
||||
:raise ValueError: if the string notation is invalid.
|
||||
|
||||
"""
|
||||
|
||||
return list(parse_many(limit_string))[0]
|
||||
|
||||
|
||||
def granularity_from_string(granularity_string: str) -> type[RateLimitItem]:
|
||||
"""
|
||||
|
||||
:param granularity_string:
|
||||
:raise ValueError:
|
||||
"""
|
||||
|
||||
for granularity in GRANULARITIES.values():
|
||||
if granularity.check_granularity_string(granularity_string):
|
||||
return granularity
|
||||
raise ValueError(f"no granularity matched for {granularity_string}")
|
||||
Reference in New Issue
Block a user