import os import threading import time from contextlib import contextmanager from typing import Generator from sqlalchemy import create_engine, schema, text from sqlalchemy.engine import Engine, URL from sqlalchemy.exc import InterfaceError as SAInterfaceError from sqlalchemy.exc import OperationalError as SAOperationalError from sqlalchemy.orm import declarative_base, sessionmaker from psycopg2 import OperationalError as PGOperationalError from psycopg2 import InterfaceError as PGInterfaceError Base = declarative_base() _ENGINE: Engine | None = None _ENGINE_LOCK = threading.Lock() class _ConnectionProxy: def __init__(self, conn): self._conn = conn def __getattr__(self, name): return getattr(self._conn, name) def __enter__(self): return self def __exit__(self, exc_type, exc, tb): if exc_type is None: try: self._conn.commit() except Exception: self._conn.rollback() raise else: try: self._conn.rollback() except Exception: pass return False def _db_config() -> dict[str, str | int]: url = os.getenv("DATABASE_URL") if url: return {"url": url} return { "host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost", "port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"), "dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db", "user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader", "password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass", "connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")), "schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app", } def get_database_url(cfg: dict[str, str | int] | None = None) -> str: cfg = cfg or _db_config() if "url" in cfg: return str(cfg["url"]) schema_name = cfg.get("schema") query = {"connect_timeout": str(cfg["connect_timeout"])} if schema_name: query["options"] = f"-csearch_path={schema_name},public" url = URL.create( "postgresql+psycopg2", username=str(cfg["user"]), password=str(cfg["password"]), host=str(cfg["host"]), port=int(cfg["port"]), database=str(cfg["dbname"]), query=query, ) return url.render_as_string(hide_password=False) def _create_engine() -> Engine: cfg = _db_config() pool_size = int(os.getenv("DB_POOL_SIZE", os.getenv("DB_POOL_MIN", "5"))) max_overflow = int(os.getenv("DB_POOL_MAX", "10")) pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30")) engine = create_engine( get_database_url(cfg), pool_size=pool_size, max_overflow=max_overflow, pool_timeout=pool_timeout, pool_pre_ping=True, future=True, ) schema_name = cfg.get("schema") if schema_name: try: with engine.begin() as conn: conn.execute(schema.CreateSchema(schema_name, if_not_exists=True)) except Exception: # Schema creation is best-effort; permissions might be limited in some environments. pass return engine def get_engine() -> Engine: global _ENGINE if _ENGINE is None: with _ENGINE_LOCK: if _ENGINE is None: _ENGINE = _create_engine() return _ENGINE SessionLocal = sessionmaker( autocommit=False, autoflush=False, expire_on_commit=False, bind=get_engine(), ) def _get_connection(): return get_engine().raw_connection() def _put_connection(conn, close=False): try: conn.close() except Exception: if not close: raise @contextmanager def db_connection(retries: int | None = None, delay: float | None = None): attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3")) backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2")) last_error = None for attempt in range(attempts): conn = None try: conn = _get_connection() conn.autocommit = False yield _ConnectionProxy(conn) return except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc: last_error = exc if conn is not None: _put_connection(conn) conn = None time.sleep(backoff * (2 ** attempt)) continue finally: if conn is not None: _put_connection(conn, close=conn.closed != 0) if last_error: raise last_error def run_with_retry(operation, retries: int | None = None, delay: float | None = None): attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3")) backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2")) last_error = None for attempt in range(attempts): with db_connection(retries=1) as conn: try: with conn.cursor() as cur: result = operation(cur, conn) conn.commit() return result except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc: conn.rollback() last_error = exc time.sleep(backoff * (2 ** attempt)) continue except Exception: conn.rollback() raise if last_error: raise last_error @contextmanager def db_transaction(): with db_connection() as conn: try: with conn.cursor() as cur: yield cur conn.commit() except Exception: conn.rollback() raise def get_db() -> Generator: db = SessionLocal() try: yield db finally: db.close() def health_check() -> bool: try: with get_engine().connect() as conn: conn.execute(text("SELECT 1")) return True except Exception: return False