211 lines
6.1 KiB
Python
211 lines
6.1 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
from contextlib import contextmanager
|
|
from typing import Generator
|
|
|
|
from sqlalchemy import create_engine, schema, text
|
|
from sqlalchemy.engine import Engine, URL
|
|
from sqlalchemy.exc import InterfaceError as SAInterfaceError
|
|
from sqlalchemy.exc import OperationalError as SAOperationalError
|
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
|
from psycopg2 import OperationalError as PGOperationalError
|
|
from psycopg2 import InterfaceError as PGInterfaceError
|
|
|
|
Base = declarative_base()
|
|
|
|
_ENGINE: Engine | None = None
|
|
_ENGINE_LOCK = threading.Lock()
|
|
|
|
|
|
class _ConnectionProxy:
|
|
def __init__(self, conn):
|
|
self._conn = conn
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._conn, name)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
if exc_type is None:
|
|
try:
|
|
self._conn.commit()
|
|
except Exception:
|
|
self._conn.rollback()
|
|
raise
|
|
else:
|
|
try:
|
|
self._conn.rollback()
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
|
|
def _db_config() -> dict[str, str | int]:
|
|
url = os.getenv("DATABASE_URL")
|
|
if url:
|
|
return {"url": url}
|
|
|
|
return {
|
|
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
|
|
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
|
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
|
|
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
|
|
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
|
|
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
|
"schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app",
|
|
}
|
|
|
|
|
|
def get_database_url(cfg: dict[str, str | int] | None = None) -> str:
|
|
cfg = cfg or _db_config()
|
|
if "url" in cfg:
|
|
return str(cfg["url"])
|
|
schema_name = cfg.get("schema")
|
|
query = {"connect_timeout": str(cfg["connect_timeout"])}
|
|
if schema_name:
|
|
query["options"] = f"-csearch_path={schema_name},public"
|
|
url = URL.create(
|
|
"postgresql+psycopg2",
|
|
username=str(cfg["user"]),
|
|
password=str(cfg["password"]),
|
|
host=str(cfg["host"]),
|
|
port=int(cfg["port"]),
|
|
database=str(cfg["dbname"]),
|
|
query=query,
|
|
)
|
|
return url.render_as_string(hide_password=False)
|
|
|
|
|
|
def _create_engine() -> Engine:
|
|
cfg = _db_config()
|
|
pool_size = int(os.getenv("DB_POOL_SIZE", os.getenv("DB_POOL_MIN", "5")))
|
|
max_overflow = int(os.getenv("DB_POOL_MAX", "10"))
|
|
pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
|
engine = create_engine(
|
|
get_database_url(cfg),
|
|
pool_size=pool_size,
|
|
max_overflow=max_overflow,
|
|
pool_timeout=pool_timeout,
|
|
pool_pre_ping=True,
|
|
future=True,
|
|
)
|
|
schema_name = cfg.get("schema")
|
|
if schema_name:
|
|
try:
|
|
with engine.begin() as conn:
|
|
conn.execute(schema.CreateSchema(schema_name, if_not_exists=True))
|
|
except Exception:
|
|
# Schema creation is best-effort; permissions might be limited in some environments.
|
|
pass
|
|
return engine
|
|
|
|
|
|
def get_engine() -> Engine:
|
|
global _ENGINE
|
|
if _ENGINE is None:
|
|
with _ENGINE_LOCK:
|
|
if _ENGINE is None:
|
|
_ENGINE = _create_engine()
|
|
return _ENGINE
|
|
|
|
|
|
SessionLocal = sessionmaker(
|
|
autocommit=False,
|
|
autoflush=False,
|
|
expire_on_commit=False,
|
|
bind=get_engine(),
|
|
)
|
|
|
|
|
|
def _get_connection():
|
|
return get_engine().raw_connection()
|
|
|
|
|
|
def _put_connection(conn, close=False):
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
if not close:
|
|
raise
|
|
|
|
|
|
@contextmanager
|
|
def db_connection(retries: int | None = None, delay: float | None = None):
|
|
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
|
|
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
|
|
last_error = None
|
|
for attempt in range(attempts):
|
|
conn = None
|
|
try:
|
|
conn = _get_connection()
|
|
conn.autocommit = False
|
|
yield _ConnectionProxy(conn)
|
|
return
|
|
except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc:
|
|
last_error = exc
|
|
if conn is not None:
|
|
_put_connection(conn)
|
|
conn = None
|
|
time.sleep(backoff * (2 ** attempt))
|
|
continue
|
|
finally:
|
|
if conn is not None:
|
|
_put_connection(conn, close=conn.closed != 0)
|
|
if last_error:
|
|
raise last_error
|
|
|
|
|
|
def run_with_retry(operation, retries: int | None = None, delay: float | None = None):
|
|
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
|
|
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
|
|
last_error = None
|
|
for attempt in range(attempts):
|
|
with db_connection(retries=1) as conn:
|
|
try:
|
|
with conn.cursor() as cur:
|
|
result = operation(cur, conn)
|
|
conn.commit()
|
|
return result
|
|
except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc:
|
|
conn.rollback()
|
|
last_error = exc
|
|
time.sleep(backoff * (2 ** attempt))
|
|
continue
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
if last_error:
|
|
raise last_error
|
|
|
|
|
|
@contextmanager
|
|
def db_transaction():
|
|
with db_connection() as conn:
|
|
try:
|
|
with conn.cursor() as cur:
|
|
yield cur
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
|
|
|
|
def get_db() -> Generator:
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def health_check() -> bool:
|
|
try:
|
|
with get_engine().connect() as conn:
|
|
conn.execute(text("SELECT 1"))
|
|
return True
|
|
except Exception:
|
|
return False
|