from __future__ import annotations import hashlib import logging import os import threading from collections import deque from datetime import datetime, timedelta, timezone from typing import Deque from fastapi import Request from app.services.db import db_connection logger = logging.getLogger(__name__) _MEMORY_LOCK = threading.Lock() _MEMORY_EVENTS: list[dict] = [] class SupportGuardRejected(Exception): def __init__(self, status_code: int, detail: str): super().__init__(detail) self.status_code = status_code self.detail = detail def _now_utc() -> datetime: return datetime.now(timezone.utc) def _sha256(value: str | None) -> str | None: if not value: return None return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest() def _backend_mode() -> str: return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower() def _window() -> timedelta: return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900"))) def _create_limit() -> int: return int(os.getenv("SUPPORT_CREATE_LIMIT", "5")) def _status_limit() -> int: return int(os.getenv("SUPPORT_STATUS_LIMIT", "15")) def _ticket_probe_limit() -> int: return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10")) def _captcha_secret() -> str | None: return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None def _request_ip(request: Request) -> str: forwarded = request.headers.get("x-forwarded-for") if forwarded: first = forwarded.split(",")[0].strip() if first: return first return request.client.host if request.client else "unknown" def _validate_captcha(captcha_token: str | None) -> None: secret = _captcha_secret() if secret and captcha_token != secret: raise SupportGuardRejected(403, "Support verification failed") def _record_memory_event(record: dict) -> None: cutoff = _now_utc() - _window() with _MEMORY_LOCK: _MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff] _MEMORY_EVENTS.append(record) def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: with _MEMORY_LOCK: ip_count = sum( 1 for entry in _MEMORY_EVENTS if entry["endpoint"] == endpoint and entry["created_at"] >= cutoff and entry["ip_hash"] == ip_hash ) ticket_count = sum( 1 for entry in _MEMORY_EVENTS if entry["endpoint"] == endpoint and entry["created_at"] >= cutoff and entry["ticket_hash"] == ticket_hash ) if ticket_hash else 0 return ip_count, ticket_count def _record_db_event(record: dict) -> None: with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO support_request_audit ( endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s) """, ( record["endpoint"], record["ip_hash"], record["email_hash"], record["ticket_hash"], record["blocked"], record["reason"], record["created_at"], ), ) def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT COUNT(*) FROM support_request_audit WHERE endpoint = %s AND ip_hash IS NOT DISTINCT FROM %s AND created_at >= %s """, (endpoint, ip_hash, cutoff), ) ip_count = cur.fetchone()[0] or 0 ticket_count = 0 if ticket_hash: cur.execute( """ SELECT COUNT(*) FROM support_request_audit WHERE endpoint = %s AND ticket_hash = %s AND created_at >= %s """, (endpoint, ticket_hash, cutoff), ) ticket_count = cur.fetchone()[0] or 0 return ip_count, ticket_count def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None: if endpoint == "ticket_create" and ip_count >= _create_limit(): return "create_rate_limited" if endpoint == "ticket_status" and ip_count >= _status_limit(): return "status_rate_limited" if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit(): return "ticket_probe_limited" return None def _audit_attempt(record: dict) -> None: if _backend_mode() == "memory": _record_memory_event(record) return _record_db_event(record) def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: if _backend_mode() == "memory": return _memory_count(endpoint, ip_hash, ticket_hash, cutoff) return _db_count(endpoint, ip_hash, ticket_hash, cutoff) def enforce_support_guard( *, request: Request, endpoint: str, email: str | None = None, ticket_id: str | None = None, captcha_token: str | None = None, ) -> None: _validate_captcha(captcha_token) now = _now_utc() cutoff = now - _window() ip_hash = _sha256(_request_ip(request)) email_hash = _sha256(email) ticket_hash = _sha256(ticket_id) ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff) reason = _determine_limits(endpoint, ip_count, ticket_count) record = { "endpoint": endpoint, "ip_hash": ip_hash, "email_hash": email_hash, "ticket_hash": ticket_hash, "blocked": reason is not None, "reason": reason, "created_at": now, } _audit_attempt(record) if reason is not None: logger.warning( "Support request blocked", extra={ "endpoint": endpoint, "reason": reason, "ip_hash": ip_hash, "ticket_hash": ticket_hash, }, ) raise SupportGuardRejected(429, "Too many support requests. Please try again later.") def reset_memory_support_guard_state() -> None: with _MEMORY_LOCK: _MEMORY_EVENTS.clear()