225 lines
6.7 KiB
Python
225 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import threading
|
|
from collections import deque
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Deque
|
|
|
|
from fastapi import Request
|
|
|
|
from app.services.db import db_connection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MEMORY_LOCK = threading.Lock()
|
|
_MEMORY_EVENTS: list[dict] = []
|
|
|
|
|
|
class SupportGuardRejected(Exception):
|
|
def __init__(self, status_code: int, detail: str):
|
|
super().__init__(detail)
|
|
self.status_code = status_code
|
|
self.detail = detail
|
|
|
|
|
|
def _now_utc() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _sha256(value: str | None) -> str | None:
|
|
if not value:
|
|
return None
|
|
return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _backend_mode() -> str:
|
|
return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower()
|
|
|
|
|
|
def _window() -> timedelta:
|
|
return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")))
|
|
|
|
|
|
def _create_limit() -> int:
|
|
return int(os.getenv("SUPPORT_CREATE_LIMIT", "5"))
|
|
|
|
|
|
def _status_limit() -> int:
|
|
return int(os.getenv("SUPPORT_STATUS_LIMIT", "15"))
|
|
|
|
|
|
def _ticket_probe_limit() -> int:
|
|
return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10"))
|
|
|
|
|
|
def _captcha_secret() -> str | None:
|
|
return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None
|
|
|
|
|
|
def _request_ip(request: Request) -> str:
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
first = forwarded.split(",")[0].strip()
|
|
if first:
|
|
return first
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
def _validate_captcha(captcha_token: str | None) -> None:
|
|
secret = _captcha_secret()
|
|
if secret and captcha_token != secret:
|
|
raise SupportGuardRejected(403, "Support verification failed")
|
|
|
|
|
|
def _record_memory_event(record: dict) -> None:
|
|
cutoff = _now_utc() - _window()
|
|
with _MEMORY_LOCK:
|
|
_MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff]
|
|
_MEMORY_EVENTS.append(record)
|
|
|
|
|
|
def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
|
with _MEMORY_LOCK:
|
|
ip_count = sum(
|
|
1
|
|
for entry in _MEMORY_EVENTS
|
|
if entry["endpoint"] == endpoint
|
|
and entry["created_at"] >= cutoff
|
|
and entry["ip_hash"] == ip_hash
|
|
)
|
|
ticket_count = sum(
|
|
1
|
|
for entry in _MEMORY_EVENTS
|
|
if entry["endpoint"] == endpoint
|
|
and entry["created_at"] >= cutoff
|
|
and entry["ticket_hash"] == ticket_hash
|
|
) if ticket_hash else 0
|
|
return ip_count, ticket_count
|
|
|
|
|
|
def _record_db_event(record: dict) -> None:
|
|
with db_connection() as conn:
|
|
with conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO support_request_audit (
|
|
endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(
|
|
record["endpoint"],
|
|
record["ip_hash"],
|
|
record["email_hash"],
|
|
record["ticket_hash"],
|
|
record["blocked"],
|
|
record["reason"],
|
|
record["created_at"],
|
|
),
|
|
)
|
|
|
|
|
|
def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
|
with db_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT COUNT(*)
|
|
FROM support_request_audit
|
|
WHERE endpoint = %s
|
|
AND ip_hash IS NOT DISTINCT FROM %s
|
|
AND created_at >= %s
|
|
""",
|
|
(endpoint, ip_hash, cutoff),
|
|
)
|
|
ip_count = cur.fetchone()[0] or 0
|
|
ticket_count = 0
|
|
if ticket_hash:
|
|
cur.execute(
|
|
"""
|
|
SELECT COUNT(*)
|
|
FROM support_request_audit
|
|
WHERE endpoint = %s
|
|
AND ticket_hash = %s
|
|
AND created_at >= %s
|
|
""",
|
|
(endpoint, ticket_hash, cutoff),
|
|
)
|
|
ticket_count = cur.fetchone()[0] or 0
|
|
return ip_count, ticket_count
|
|
|
|
|
|
def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None:
|
|
if endpoint == "ticket_create" and ip_count >= _create_limit():
|
|
return "create_rate_limited"
|
|
if endpoint == "ticket_status" and ip_count >= _status_limit():
|
|
return "status_rate_limited"
|
|
if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit():
|
|
return "ticket_probe_limited"
|
|
return None
|
|
|
|
|
|
def _audit_attempt(record: dict) -> None:
|
|
if _backend_mode() == "memory":
|
|
_record_memory_event(record)
|
|
return
|
|
_record_db_event(record)
|
|
|
|
|
|
def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
|
if _backend_mode() == "memory":
|
|
return _memory_count(endpoint, ip_hash, ticket_hash, cutoff)
|
|
return _db_count(endpoint, ip_hash, ticket_hash, cutoff)
|
|
|
|
|
|
def enforce_support_guard(
|
|
*,
|
|
request: Request,
|
|
endpoint: str,
|
|
email: str | None = None,
|
|
ticket_id: str | None = None,
|
|
captcha_token: str | None = None,
|
|
) -> None:
|
|
_validate_captcha(captcha_token)
|
|
|
|
now = _now_utc()
|
|
cutoff = now - _window()
|
|
ip_hash = _sha256(_request_ip(request))
|
|
email_hash = _sha256(email)
|
|
ticket_hash = _sha256(ticket_id)
|
|
|
|
ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff)
|
|
reason = _determine_limits(endpoint, ip_count, ticket_count)
|
|
|
|
record = {
|
|
"endpoint": endpoint,
|
|
"ip_hash": ip_hash,
|
|
"email_hash": email_hash,
|
|
"ticket_hash": ticket_hash,
|
|
"blocked": reason is not None,
|
|
"reason": reason,
|
|
"created_at": now,
|
|
}
|
|
_audit_attempt(record)
|
|
|
|
if reason is not None:
|
|
logger.warning(
|
|
"Support request blocked",
|
|
extra={
|
|
"endpoint": endpoint,
|
|
"reason": reason,
|
|
"ip_hash": ip_hash,
|
|
"ticket_hash": ticket_hash,
|
|
},
|
|
)
|
|
raise SupportGuardRejected(429, "Too many support requests. Please try again later.")
|
|
|
|
|
|
def reset_memory_support_guard_state() -> None:
|
|
with _MEMORY_LOCK:
|
|
_MEMORY_EVENTS.clear()
|