from __future__ import annotations import itertools import time from functools import partial from typing import Any from urllib.parse import urlparse import click from flask import Flask, current_app from flask.cli import with_appcontext from limits.strategies import RateLimiter from rich.console import Console, group from rich.live import Live from rich.pretty import Pretty from rich.prompt import Confirm from rich.table import Table from rich.theme import Theme from rich.tree import Tree from typing_extensions import TypedDict from werkzeug.exceptions import MethodNotAllowed, NotFound from werkzeug.routing import Rule from ._extension import Limiter from ._limits import RuntimeLimit from ._typing import Callable, Generator, cast from .constants import ConfigVars, ExemptionScope, HeaderNames from .util import get_qualified_name limiter_theme = Theme( { "success": "bold green", "danger": "bold red", "error": "bold red", "blueprint": "bold red", "default": "magenta", "callable": "cyan", "entity": "magenta", "exempt": "bold red", "route": "yellow", "http": "bold green", "option": "bold yellow", } ) def render_func(func: Any) -> str | Pretty: if callable(func): if func.__name__ == "": return f"[callable]({func.__module__})[/callable]" return f"[callable]{func.__module__}.{func.__name__}()[/callable]" return Pretty(func) def render_storage(ext: Limiter) -> Tree: render = Tree(ext._storage_uri or "N/A") if ext.storage: render.add(f"[entity]{ext.storage.__class__.__name__}[/entity]") render.add(f"[entity]{ext.storage.storage}[/entity]") # type: ignore render.add(Pretty(ext._storage_options or {})) health = ext.storage.check() if health: render.add("[success]OK[/success]") else: render.add("[error]Error[/error]") return render def render_strategy(strategy: RateLimiter) -> str: return f"[entity]{strategy.__class__.__name__}[/entity]" def render_limit_state( limiter: Limiter, endpoint: str, limit: RuntimeLimit, key: str, method: str ) -> str: args = [key, limit.scope_for(endpoint, method)] if not limiter.storage or (limiter.storage and not limiter.storage.check()): return ": [error]Storage not available[/error]" test = limiter.limiter.test(limit.limit, *args) stats = limiter.limiter.get_window_stats(limit.limit, *args) if not test: return f": [error]Fail[/error] ({stats[1]} out of {limit.limit.amount} remaining)" else: return f": [success]Pass[/success] ({stats[1]} out of {limit.limit.amount} remaining)" def render_limit(limit: RuntimeLimit, simple: bool = True) -> str: render = str(limit.limit) if simple: return render options = [] if limit.deduct_when: options.append(f"deduct_when: {render_func(limit.deduct_when)}") if limit.exempt_when: options.append(f"exempt_when: {render_func(limit.exempt_when)}") if options: render = f"{render} [option]{{{', '.join(options)}}}[/option]" return render def render_limits( app: Flask, limiter: Limiter, limits: tuple[list[RuntimeLimit], ...], endpoint: str | None = None, blueprint: str | None = None, rule: Rule | None = None, exemption_scope: ExemptionScope = ExemptionScope.NONE, test: str | None = None, method: str = "GET", label: str | None = "", ) -> Tree: _label = None if rule and endpoint: _label = f"{endpoint}: {rule}" label = _label or label or "" renderable = Tree(label) entries = [] for limit in limits[0] + limits[1]: if endpoint: view_func = app.view_functions.get(endpoint, None) source = ( "blueprint" if blueprint and limit in limiter.limit_manager.blueprint_limits(app, blueprint) else ( "route" if limit in limiter.limit_manager.decorated_limits( get_qualified_name(view_func) if view_func else "" ) else "default" ) ) else: source = "default" if limit.per_method and rule and rule.methods: for method in rule.methods: rendered = render_limit(limit, False) entry = f"[{source}]{rendered} [http]({method})[/http][/{source}]" if test: entry += render_limit_state(limiter, endpoint or "", limit, test, method) entries.append(entry) else: rendered = render_limit(limit, False) entry = f"[{source}]{rendered}[/{source}]" if test: entry += render_limit_state(limiter, endpoint or "", limit, test, method) entries.append(entry) if not entries and exemption_scope: renderable.add("[exempt]Exempt[/exempt]") else: [renderable.add(entry) for entry in entries] return renderable def get_filtered_endpoint( app: Flask, console: Console, endpoint: str | None, path: str | None, method: str | None = None, ) -> str | None: if not (endpoint or path): return None if endpoint: if endpoint in current_app.view_functions: return endpoint else: console.print(f"[red]Error: {endpoint} not found") elif path: adapter = app.url_map.bind("dev.null") parsed = urlparse(path) try: filter_endpoint, _ = adapter.match(parsed.path, method=method, query_args=parsed.query) return cast(str, filter_endpoint) except NotFound: console.print(f"[error]Error: {path} could not be matched to an endpoint[/error]") except MethodNotAllowed: assert method console.print( f"[error]Error: {method.upper()}: {path}" " could not be matched to an endpoint[/error]" ) raise SystemExit @click.group(help="Flask-Limiter maintenance & utility commmands") def cli() -> None: pass @cli.command(help="View the extension configuration") @with_appcontext def config() -> None: with current_app.test_request_context(): console = Console(theme=limiter_theme) limiters = list(current_app.extensions.get("limiter", set())) limiter = limiters and list(limiters)[0] if limiter: extension_details = Table(title="Flask-Limiter Config") extension_details.add_column("Notes") extension_details.add_column("Configuration") extension_details.add_column("Value") extension_details.add_row("Enabled", ConfigVars.ENABLED, Pretty(limiter.enabled)) extension_details.add_row( "Key Function", ConfigVars.KEY_FUNC, render_func(limiter._key_func) ) extension_details.add_row( "Key Prefix", ConfigVars.KEY_PREFIX, Pretty(limiter._key_prefix) ) limiter_config = Tree(ConfigVars.STRATEGY) limiter_config_values = Tree(render_strategy(limiter.limiter)) node = limiter_config.add(ConfigVars.STORAGE_URI) node.add("Instance") node.add("Backend") limiter_config.add(ConfigVars.STORAGE_OPTIONS) limiter_config.add("Status") limiter_config_values.add(render_storage(limiter)) extension_details.add_row("Rate Limiting Config", limiter_config, limiter_config_values) if limiter.limit_manager.application_limits: extension_details.add_row( "Application Limits", ConfigVars.APPLICATION_LIMITS, Pretty( [render_limit(limit) for limit in limiter.limit_manager.application_limits] ), ) extension_details.add_row( None, ConfigVars.APPLICATION_LIMITS_PER_METHOD, Pretty(limiter._application_limits_per_method), ) extension_details.add_row( None, ConfigVars.APPLICATION_LIMITS_EXEMPT_WHEN, render_func(limiter._application_limits_exempt_when), ) extension_details.add_row( None, ConfigVars.APPLICATION_LIMITS_DEDUCT_WHEN, render_func(limiter._application_limits_deduct_when), ) extension_details.add_row( None, ConfigVars.APPLICATION_LIMITS_COST, Pretty(limiter._application_limits_cost), ) else: extension_details.add_row( "ApplicationLimits Limits", ConfigVars.APPLICATION_LIMITS, Pretty([]), ) if limiter.limit_manager.default_limits: extension_details.add_row( "Default Limits", ConfigVars.DEFAULT_LIMITS, Pretty([render_limit(limit) for limit in limiter.limit_manager.default_limits]), ) extension_details.add_row( None, ConfigVars.DEFAULT_LIMITS_PER_METHOD, Pretty(limiter._default_limits_per_method), ) extension_details.add_row( None, ConfigVars.DEFAULT_LIMITS_EXEMPT_WHEN, render_func(limiter._default_limits_exempt_when), ) extension_details.add_row( None, ConfigVars.DEFAULT_LIMITS_DEDUCT_WHEN, render_func(limiter._default_limits_deduct_when), ) extension_details.add_row( None, ConfigVars.DEFAULT_LIMITS_COST, render_func(limiter._default_limits_cost), ) else: extension_details.add_row("Default Limits", ConfigVars.DEFAULT_LIMITS, Pretty([])) if limiter._meta_limits: extension_details.add_row( "Meta Limits", ConfigVars.META_LIMITS, Pretty( [render_limit(limit) for limit in itertools.chain(*limiter._meta_limits)] ), ) if limiter._headers_enabled: header_configs = Tree(ConfigVars.HEADERS_ENABLED) header_configs.add(ConfigVars.HEADER_RESET) header_configs.add(ConfigVars.HEADER_REMAINING) header_configs.add(ConfigVars.HEADER_RETRY_AFTER) header_configs.add(ConfigVars.HEADER_RETRY_AFTER_VALUE) header_values = Tree(Pretty(limiter._headers_enabled)) header_values.add(Pretty(limiter._header_mapping[HeaderNames.RESET])) header_values.add(Pretty(limiter._header_mapping[HeaderNames.REMAINING])) header_values.add(Pretty(limiter._header_mapping[HeaderNames.RETRY_AFTER])) header_values.add(Pretty(limiter._retry_after)) extension_details.add_row( "Header configuration", header_configs, header_values, ) else: extension_details.add_row( "Header configuration", ConfigVars.HEADERS_ENABLED, Pretty(False) ) extension_details.add_row( "Fail on first breach", ConfigVars.FAIL_ON_FIRST_BREACH, Pretty(limiter._fail_on_first_breach), ) extension_details.add_row( "On breach callback", ConfigVars.ON_BREACH, render_func(limiter._on_breach), ) console.print(extension_details) else: console.print( f"No Flask-Limiter extension installed on {current_app}", style="bold red", ) @cli.command(help="Enumerate details about all routes with rate limits") @click.option("--endpoint", default=None, help="Endpoint to filter by") @click.option("--path", default=None, help="Path to filter by") @click.option("--method", default=None, help="HTTP Method to filter by") @click.option("--key", default=None, help="Test the limit") @click.option("--watch/--no-watch", default=False, help="Create a live dashboard") @with_appcontext def limits( endpoint: str | None = None, path: str | None = None, method: str = "GET", key: str | None = None, watch: bool = False, ) -> None: with current_app.test_request_context(): limiters: set[Limiter] = current_app.extensions.get("limiter", set()) limiter: Limiter | None = list(limiters)[0] if limiters else None console = Console(theme=limiter_theme) if limiter: manager = limiter.limit_manager groups: dict[str, list[Callable[..., Tree]]] = {} filter_endpoint = get_filtered_endpoint(current_app, console, endpoint, path, method) for rule in sorted( current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r) ): rule_endpoint = rule.endpoint if rule_endpoint == "static": continue if len(rule_endpoint.split(".")) > 1: bp_fullname = ".".join(rule_endpoint.split(".")[:-1]) groups.setdefault(bp_fullname, []).append( partial( render_limits, current_app, limiter, manager.resolve_limits(current_app, rule_endpoint, bp_fullname), rule_endpoint, bp_fullname, rule, exemption_scope=manager.exemption_scope( current_app, rule_endpoint, bp_fullname ), method=method, test=key, ) ) else: groups.setdefault("root", []).append( partial( render_limits, current_app, limiter, manager.resolve_limits(current_app, rule_endpoint, ""), rule_endpoint, None, rule, exemption_scope=manager.exemption_scope( current_app, rule_endpoint, None ), method=method, test=key, ) ) @group() def console_renderable() -> Generator: # type: ignore if limiter and limiter.limit_manager.application_limits and not (endpoint or path): yield render_limits( current_app, limiter, (list(itertools.chain(*limiter._meta_limits)), []), test=key, method=method, label="[gold3]Meta Limits[/gold3]", ) yield render_limits( current_app, limiter, (limiter.limit_manager.application_limits, []), test=key, method=method, label="[gold3]Application Limits[/gold3]", ) for name in groups: if name == "root": group_tree = Tree(f"[gold3]{current_app.name}[/gold3]") else: group_tree = Tree(f"[blue]{name}[/blue]") [group_tree.add(renderable()) for renderable in groups[name]] yield group_tree if not watch: console.print(console_renderable()) else: # noqa with Live( console_renderable(), console=console, refresh_per_second=0.4, screen=True, ) as live: while True: try: live.update(console_renderable()) time.sleep(0.4) except KeyboardInterrupt: break else: console.print( f"No Flask-Limiter extension installed on {current_app}", style="bold red", ) @cli.command(help="Clear limits for a specific key") @click.option("--endpoint", default=None, help="Endpoint to filter by") @click.option("--path", default=None, help="Path to filter by") @click.option("--method", default=None, help="HTTP Method to filter by") @click.option("--key", default=None, required=True, help="Key to reset the limits for") @click.option("-y", is_flag=True, help="Skip prompt for confirmation") @with_appcontext def clear( key: str, endpoint: str | None = None, path: str | None = None, method: str = "GET", y: bool = False, ) -> None: with current_app.test_request_context(): limiters = list(current_app.extensions.get("limiter", set())) limiter: Limiter | None = limiters[0] if limiters else None console = Console(theme=limiter_theme) if limiter: manager = limiter.limit_manager filter_endpoint = get_filtered_endpoint(current_app, console, endpoint, path, method) class Details(TypedDict): rule: Rule limits: tuple[list[RuntimeLimit], ...] rule_limits: dict[str, Details] = {} for rule in sorted( current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r) ): rule_endpoint = rule.endpoint if rule_endpoint == "static": continue if len(rule_endpoint.split(".")) > 1: bp_fullname = ".".join(rule_endpoint.split(".")[:-1]) rule_limits[rule_endpoint] = Details( rule=rule, limits=manager.resolve_limits(current_app, rule_endpoint, bp_fullname), ) else: rule_limits[rule_endpoint] = Details( rule=rule, limits=manager.resolve_limits(current_app, rule_endpoint, ""), ) application_limits = None if not filter_endpoint: application_limits = limiter.limit_manager.application_limits if not y: # noqa if application_limits: console.print( render_limits( current_app, limiter, (application_limits, []), label="Application Limits", test=key, ) ) for endpoint, details in rule_limits.items(): if details["limits"]: console.print( render_limits( current_app, limiter, details["limits"], endpoint, rule=details["rule"], test=key, ) ) if y or Confirm.ask(f"Proceed with resetting limits for key: [danger]{key}[/danger]?"): if application_limits: node = Tree("Application Limits") for limit in application_limits: limiter.limiter.clear( limit.limit, key, limit.scope_for("", method), ) node.add(f"{render_limit(limit)}: [success]Cleared[/success]") console.print(node) for endpoint, details in rule_limits.items(): if details["limits"]: node = Tree(endpoint) default, decorated = details["limits"] for limit in default + decorated: if ( limit.per_method and details["rule"] and details["rule"].methods and not method ): for rule_method in details["rule"].methods: limiter.limiter.clear( limit.limit, key, limit.scope_for(endpoint, rule_method), ) else: limiter.limiter.clear( limit.limit, key, limit.scope_for(endpoint, method), ) node.add(f"{render_limit(limit)}: [success]Cleared[/success]") console.print(node) else: console.print( f"No Flask-Limiter extension installed on {current_app}", style="bold red", ) if __name__ == "__main__": # noqa cli()