2026-02-01 20:34:57 +00:00

325 lines
9.0 KiB
Python

import os
import threading
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from contextvars import ContextVar
import psycopg2
from psycopg2 import pool
from psycopg2 import OperationalError, InterfaceError
from psycopg2.extras import Json
_POOL = None
_POOL_LOCK = threading.Lock()
_DEFAULT_USER_ID = None
_DEFAULT_LOCK = threading.Lock()
_USER_ID = ContextVar("engine_user_id", default=None)
_RUN_ID = ContextVar("engine_run_id", default=None)
def _db_config():
url = os.getenv("DATABASE_URL")
if url:
return {"dsn": url}
schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app"
return {
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
"options": f"-csearch_path={schema},public" if schema else None,
}
def _init_pool():
config = _db_config()
config = {k: v for k, v in config.items() if v is not None}
minconn = int(os.getenv("DB_POOL_MIN", "1"))
maxconn = int(os.getenv("DB_POOL_MAX", "10"))
if "dsn" in config:
return pool.ThreadedConnectionPool(minconn, maxconn, dsn=config["dsn"])
return pool.ThreadedConnectionPool(minconn, maxconn, **config)
def get_pool():
global _POOL
if _POOL is None:
with _POOL_LOCK:
if _POOL is None:
_POOL = _init_pool()
return _POOL
def _get_connection():
return get_pool().getconn()
def _put_connection(conn, close=False):
try:
get_pool().putconn(conn, close=close)
except Exception:
try:
conn.close()
except Exception:
pass
@contextmanager
def db_connection(retries: int | None = None, delay: float | None = None):
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
last_error = None
for attempt in range(attempts):
conn = None
try:
conn = _get_connection()
conn.autocommit = False
yield conn
return
except (OperationalError, InterfaceError) as exc:
last_error = exc
if conn is not None:
_put_connection(conn, close=True)
conn = None
time.sleep(backoff * (2 ** attempt))
continue
finally:
if conn is not None:
_put_connection(conn, close=conn.closed != 0)
if last_error:
raise last_error
def run_with_retry(operation, retries: int | None = None, delay: float | None = None):
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
last_error = None
for attempt in range(attempts):
with db_connection(retries=1) as conn:
try:
with conn.cursor() as cur:
result = operation(cur, conn)
conn.commit()
return result
except (OperationalError, InterfaceError) as exc:
conn.rollback()
last_error = exc
time.sleep(backoff * (2 ** attempt))
continue
except Exception:
conn.rollback()
raise
if last_error:
raise last_error
@contextmanager
def db_transaction():
with db_connection() as conn:
try:
with conn.cursor() as cur:
yield cur
conn.commit()
except Exception:
conn.rollback()
raise
def _utc_now():
return datetime.utcnow().replace(tzinfo=timezone.utc)
def set_context(user_id: str | None, run_id: str | None):
token_user = _USER_ID.set(user_id)
token_run = _RUN_ID.set(run_id)
return token_user, token_run
def reset_context(token_user, token_run):
_USER_ID.reset(token_user)
_RUN_ID.reset(token_run)
@contextmanager
def engine_context(user_id: str, run_id: str):
token_user, token_run = set_context(user_id, run_id)
try:
yield
finally:
reset_context(token_user, token_run)
def _resolve_context(user_id: str | None = None, run_id: str | None = None):
ctx_user = user_id or _USER_ID.get()
ctx_run = run_id or _RUN_ID.get()
if ctx_user and ctx_run:
return ctx_user, ctx_run
env_user = os.getenv("ENGINE_USER_ID")
env_run = os.getenv("ENGINE_RUN_ID")
if not ctx_user and env_user:
ctx_user = env_user
if not ctx_run and env_run:
ctx_run = env_run
if ctx_user and ctx_run:
return ctx_user, ctx_run
if not ctx_user:
ctx_user = get_default_user_id()
if ctx_user and not ctx_run:
ctx_run = get_active_run_id(ctx_user)
if not ctx_user or not ctx_run:
raise ValueError("engine context missing user_id/run_id")
return ctx_user, ctx_run
def get_context(user_id: str | None = None, run_id: str | None = None):
return _resolve_context(user_id, run_id)
def get_default_user_id():
global _DEFAULT_USER_ID
if _DEFAULT_USER_ID:
return _DEFAULT_USER_ID
def _op(cur, _conn):
cur.execute("SELECT id FROM app_user ORDER BY username LIMIT 1")
row = cur.fetchone()
return row[0] if row else None
user_id = run_with_retry(_op)
if user_id:
with _DEFAULT_LOCK:
_DEFAULT_USER_ID = user_id
return user_id
def _default_run_id(user_id: str) -> str:
return f"default_{user_id}"
def ensure_default_run(user_id: str):
run_id = _default_run_id(user_id)
def _op(cur, _conn):
now = _utc_now()
cur.execute(
"""
INSERT INTO strategy_run (
run_id, user_id, created_at, started_at, stopped_at, status, strategy, mode, broker, meta
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (run_id) DO NOTHING
""",
(
run_id,
user_id,
now,
None,
None,
"STOPPED",
None,
None,
None,
Json({}),
),
)
return run_id
return run_with_retry(_op)
def get_active_run_id(user_id: str):
def _op(cur, _conn):
cur.execute(
"""
SELECT run_id
FROM strategy_run
WHERE user_id = %s AND status = 'RUNNING'
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
)
row = cur.fetchone()
if row:
return row[0]
cur.execute(
"""
SELECT run_id
FROM strategy_run
WHERE user_id = %s
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
)
row = cur.fetchone()
if row:
return row[0]
return None
run_id = run_with_retry(_op)
if run_id:
return run_id
return ensure_default_run(user_id)
def get_running_runs(user_id: str | None = None):
def _op(cur, _conn):
if user_id:
cur.execute(
"""
SELECT user_id, run_id
FROM strategy_run
WHERE user_id = %s AND status = 'RUNNING'
ORDER BY created_at DESC
""",
(user_id,),
)
else:
cur.execute(
"""
SELECT user_id, run_id
FROM strategy_run
WHERE status = 'RUNNING'
ORDER BY created_at DESC
"""
)
return cur.fetchall()
return run_with_retry(_op)
def insert_engine_event(
cur,
event: str,
data=None,
message: str | None = None,
meta=None,
ts=None,
user_id: str | None = None,
run_id: str | None = None,
):
when = ts or _utc_now()
scope_user, scope_run = _resolve_context(user_id, run_id)
cur.execute(
"""
INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
scope_user,
scope_run,
when,
event,
Json(data) if data is not None else None,
message,
Json(meta) if meta is not None else None,
),
)