225 lines
6.7 KiB
Python

from __future__ import annotations
import hashlib
import logging
import os
import threading
from collections import deque
from datetime import datetime, timedelta, timezone
from typing import Deque
from fastapi import Request
from app.services.db import db_connection
logger = logging.getLogger(__name__)
_MEMORY_LOCK = threading.Lock()
_MEMORY_EVENTS: list[dict] = []
class SupportGuardRejected(Exception):
def __init__(self, status_code: int, detail: str):
super().__init__(detail)
self.status_code = status_code
self.detail = detail
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _sha256(value: str | None) -> str | None:
if not value:
return None
return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest()
def _backend_mode() -> str:
return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower()
def _window() -> timedelta:
return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")))
def _create_limit() -> int:
return int(os.getenv("SUPPORT_CREATE_LIMIT", "5"))
def _status_limit() -> int:
return int(os.getenv("SUPPORT_STATUS_LIMIT", "15"))
def _ticket_probe_limit() -> int:
return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10"))
def _captcha_secret() -> str | None:
return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None
def _request_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
first = forwarded.split(",")[0].strip()
if first:
return first
return request.client.host if request.client else "unknown"
def _validate_captcha(captcha_token: str | None) -> None:
secret = _captcha_secret()
if secret and captcha_token != secret:
raise SupportGuardRejected(403, "Support verification failed")
def _record_memory_event(record: dict) -> None:
cutoff = _now_utc() - _window()
with _MEMORY_LOCK:
_MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff]
_MEMORY_EVENTS.append(record)
def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
with _MEMORY_LOCK:
ip_count = sum(
1
for entry in _MEMORY_EVENTS
if entry["endpoint"] == endpoint
and entry["created_at"] >= cutoff
and entry["ip_hash"] == ip_hash
)
ticket_count = sum(
1
for entry in _MEMORY_EVENTS
if entry["endpoint"] == endpoint
and entry["created_at"] >= cutoff
and entry["ticket_hash"] == ticket_hash
) if ticket_hash else 0
return ip_count, ticket_count
def _record_db_event(record: dict) -> None:
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO support_request_audit (
endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at
)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
record["endpoint"],
record["ip_hash"],
record["email_hash"],
record["ticket_hash"],
record["blocked"],
record["reason"],
record["created_at"],
),
)
def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT COUNT(*)
FROM support_request_audit
WHERE endpoint = %s
AND ip_hash IS NOT DISTINCT FROM %s
AND created_at >= %s
""",
(endpoint, ip_hash, cutoff),
)
ip_count = cur.fetchone()[0] or 0
ticket_count = 0
if ticket_hash:
cur.execute(
"""
SELECT COUNT(*)
FROM support_request_audit
WHERE endpoint = %s
AND ticket_hash = %s
AND created_at >= %s
""",
(endpoint, ticket_hash, cutoff),
)
ticket_count = cur.fetchone()[0] or 0
return ip_count, ticket_count
def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None:
if endpoint == "ticket_create" and ip_count >= _create_limit():
return "create_rate_limited"
if endpoint == "ticket_status" and ip_count >= _status_limit():
return "status_rate_limited"
if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit():
return "ticket_probe_limited"
return None
def _audit_attempt(record: dict) -> None:
if _backend_mode() == "memory":
_record_memory_event(record)
return
_record_db_event(record)
def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
if _backend_mode() == "memory":
return _memory_count(endpoint, ip_hash, ticket_hash, cutoff)
return _db_count(endpoint, ip_hash, ticket_hash, cutoff)
def enforce_support_guard(
*,
request: Request,
endpoint: str,
email: str | None = None,
ticket_id: str | None = None,
captcha_token: str | None = None,
) -> None:
_validate_captcha(captcha_token)
now = _now_utc()
cutoff = now - _window()
ip_hash = _sha256(_request_ip(request))
email_hash = _sha256(email)
ticket_hash = _sha256(ticket_id)
ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff)
reason = _determine_limits(endpoint, ip_count, ticket_count)
record = {
"endpoint": endpoint,
"ip_hash": ip_hash,
"email_hash": email_hash,
"ticket_hash": ticket_hash,
"blocked": reason is not None,
"reason": reason,
"created_at": now,
}
_audit_attempt(record)
if reason is not None:
logger.warning(
"Support request blocked",
extra={
"endpoint": endpoint,
"reason": reason,
"ip_hash": ip_hash,
"ticket_hash": ticket_hash,
},
)
raise SupportGuardRejected(429, "Too many support requests. Please try again later.")
def reset_memory_support_guard_state() -> None:
with _MEMORY_LOCK:
_MEMORY_EVENTS.clear()