2026-02-01 13:57:30 +00:00

211 lines
6.1 KiB
Python

import os
import threading
import time
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, schema, text
from sqlalchemy.engine import Engine, URL
from sqlalchemy.exc import InterfaceError as SAInterfaceError
from sqlalchemy.exc import OperationalError as SAOperationalError
from sqlalchemy.orm import declarative_base, sessionmaker
from psycopg2 import OperationalError as PGOperationalError
from psycopg2 import InterfaceError as PGInterfaceError
Base = declarative_base()
_ENGINE: Engine | None = None
_ENGINE_LOCK = threading.Lock()
class _ConnectionProxy:
def __init__(self, conn):
self._conn = conn
def __getattr__(self, name):
return getattr(self._conn, name)
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type is None:
try:
self._conn.commit()
except Exception:
self._conn.rollback()
raise
else:
try:
self._conn.rollback()
except Exception:
pass
return False
def _db_config() -> dict[str, str | int]:
url = os.getenv("DATABASE_URL")
if url:
return {"url": url}
return {
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
"schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app",
}
def get_database_url(cfg: dict[str, str | int] | None = None) -> str:
cfg = cfg or _db_config()
if "url" in cfg:
return str(cfg["url"])
schema_name = cfg.get("schema")
query = {"connect_timeout": str(cfg["connect_timeout"])}
if schema_name:
query["options"] = f"-csearch_path={schema_name},public"
url = URL.create(
"postgresql+psycopg2",
username=str(cfg["user"]),
password=str(cfg["password"]),
host=str(cfg["host"]),
port=int(cfg["port"]),
database=str(cfg["dbname"]),
query=query,
)
return url.render_as_string(hide_password=False)
def _create_engine() -> Engine:
cfg = _db_config()
pool_size = int(os.getenv("DB_POOL_SIZE", os.getenv("DB_POOL_MIN", "5")))
max_overflow = int(os.getenv("DB_POOL_MAX", "10"))
pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30"))
engine = create_engine(
get_database_url(cfg),
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
pool_pre_ping=True,
future=True,
)
schema_name = cfg.get("schema")
if schema_name:
try:
with engine.begin() as conn:
conn.execute(schema.CreateSchema(schema_name, if_not_exists=True))
except Exception:
# Schema creation is best-effort; permissions might be limited in some environments.
pass
return engine
def get_engine() -> Engine:
global _ENGINE
if _ENGINE is None:
with _ENGINE_LOCK:
if _ENGINE is None:
_ENGINE = _create_engine()
return _ENGINE
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
expire_on_commit=False,
bind=get_engine(),
)
def _get_connection():
return get_engine().raw_connection()
def _put_connection(conn, close=False):
try:
conn.close()
except Exception:
if not close:
raise
@contextmanager
def db_connection(retries: int | None = None, delay: float | None = None):
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
last_error = None
for attempt in range(attempts):
conn = None
try:
conn = _get_connection()
conn.autocommit = False
yield _ConnectionProxy(conn)
return
except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc:
last_error = exc
if conn is not None:
_put_connection(conn)
conn = None
time.sleep(backoff * (2 ** attempt))
continue
finally:
if conn is not None:
_put_connection(conn, close=conn.closed != 0)
if last_error:
raise last_error
def run_with_retry(operation, retries: int | None = None, delay: float | None = None):
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
last_error = None
for attempt in range(attempts):
with db_connection(retries=1) as conn:
try:
with conn.cursor() as cur:
result = operation(cur, conn)
conn.commit()
return result
except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc:
conn.rollback()
last_error = exc
time.sleep(backoff * (2 ** attempt))
continue
except Exception:
conn.rollback()
raise
if last_error:
raise last_error
@contextmanager
def db_transaction():
with db_connection() as conn:
try:
with conn.cursor() as cur:
yield cur
conn.commit()
except Exception:
conn.rollback()
raise
def get_db() -> Generator:
db = SessionLocal()
try:
yield db
finally:
db.close()
def health_check() -> bool:
try:
with get_engine().connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception:
return False