325 lines
9.0 KiB
Python
325 lines
9.0 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from contextvars import ContextVar
|
|
|
|
import psycopg2
|
|
from psycopg2 import pool
|
|
from psycopg2 import OperationalError, InterfaceError
|
|
from psycopg2.extras import Json
|
|
|
|
_POOL = None
|
|
_POOL_LOCK = threading.Lock()
|
|
_DEFAULT_USER_ID = None
|
|
_DEFAULT_LOCK = threading.Lock()
|
|
|
|
_USER_ID = ContextVar("engine_user_id", default=None)
|
|
_RUN_ID = ContextVar("engine_run_id", default=None)
|
|
|
|
|
|
def _db_config():
|
|
url = os.getenv("DATABASE_URL")
|
|
if url:
|
|
return {"dsn": url}
|
|
|
|
schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app"
|
|
|
|
return {
|
|
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
|
|
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
|
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
|
|
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
|
|
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
|
|
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
|
"options": f"-csearch_path={schema},public" if schema else None,
|
|
}
|
|
|
|
|
|
def _init_pool():
|
|
config = _db_config()
|
|
config = {k: v for k, v in config.items() if v is not None}
|
|
minconn = int(os.getenv("DB_POOL_MIN", "1"))
|
|
maxconn = int(os.getenv("DB_POOL_MAX", "10"))
|
|
if "dsn" in config:
|
|
return pool.ThreadedConnectionPool(minconn, maxconn, dsn=config["dsn"])
|
|
return pool.ThreadedConnectionPool(minconn, maxconn, **config)
|
|
|
|
|
|
def get_pool():
|
|
global _POOL
|
|
if _POOL is None:
|
|
with _POOL_LOCK:
|
|
if _POOL is None:
|
|
_POOL = _init_pool()
|
|
return _POOL
|
|
|
|
|
|
def _get_connection():
|
|
return get_pool().getconn()
|
|
|
|
|
|
def _put_connection(conn, close=False):
|
|
try:
|
|
get_pool().putconn(conn, close=close)
|
|
except Exception:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@contextmanager
|
|
def db_connection(retries: int | None = None, delay: float | None = None):
|
|
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
|
|
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
|
|
last_error = None
|
|
for attempt in range(attempts):
|
|
conn = None
|
|
try:
|
|
conn = _get_connection()
|
|
conn.autocommit = False
|
|
yield conn
|
|
return
|
|
except (OperationalError, InterfaceError) as exc:
|
|
last_error = exc
|
|
if conn is not None:
|
|
_put_connection(conn, close=True)
|
|
conn = None
|
|
time.sleep(backoff * (2 ** attempt))
|
|
continue
|
|
finally:
|
|
if conn is not None:
|
|
_put_connection(conn, close=conn.closed != 0)
|
|
if last_error:
|
|
raise last_error
|
|
|
|
|
|
def run_with_retry(operation, retries: int | None = None, delay: float | None = None):
|
|
attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3"))
|
|
backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2"))
|
|
last_error = None
|
|
for attempt in range(attempts):
|
|
with db_connection(retries=1) as conn:
|
|
try:
|
|
with conn.cursor() as cur:
|
|
result = operation(cur, conn)
|
|
conn.commit()
|
|
return result
|
|
except (OperationalError, InterfaceError) as exc:
|
|
conn.rollback()
|
|
last_error = exc
|
|
time.sleep(backoff * (2 ** attempt))
|
|
continue
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
if last_error:
|
|
raise last_error
|
|
|
|
|
|
@contextmanager
|
|
def db_transaction():
|
|
with db_connection() as conn:
|
|
try:
|
|
with conn.cursor() as cur:
|
|
yield cur
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
|
|
|
|
def _utc_now():
|
|
return datetime.utcnow().replace(tzinfo=timezone.utc)
|
|
|
|
|
|
def set_context(user_id: str | None, run_id: str | None):
|
|
token_user = _USER_ID.set(user_id)
|
|
token_run = _RUN_ID.set(run_id)
|
|
return token_user, token_run
|
|
|
|
|
|
def reset_context(token_user, token_run):
|
|
_USER_ID.reset(token_user)
|
|
_RUN_ID.reset(token_run)
|
|
|
|
|
|
@contextmanager
|
|
def engine_context(user_id: str, run_id: str):
|
|
token_user, token_run = set_context(user_id, run_id)
|
|
try:
|
|
yield
|
|
finally:
|
|
reset_context(token_user, token_run)
|
|
|
|
|
|
def _resolve_context(user_id: str | None = None, run_id: str | None = None):
|
|
ctx_user = user_id or _USER_ID.get()
|
|
ctx_run = run_id or _RUN_ID.get()
|
|
if ctx_user and ctx_run:
|
|
return ctx_user, ctx_run
|
|
env_user = os.getenv("ENGINE_USER_ID")
|
|
env_run = os.getenv("ENGINE_RUN_ID")
|
|
if not ctx_user and env_user:
|
|
ctx_user = env_user
|
|
if not ctx_run and env_run:
|
|
ctx_run = env_run
|
|
if ctx_user and ctx_run:
|
|
return ctx_user, ctx_run
|
|
if not ctx_user:
|
|
ctx_user = get_default_user_id()
|
|
if ctx_user and not ctx_run:
|
|
ctx_run = get_active_run_id(ctx_user)
|
|
if not ctx_user or not ctx_run:
|
|
raise ValueError("engine context missing user_id/run_id")
|
|
return ctx_user, ctx_run
|
|
|
|
|
|
def get_context(user_id: str | None = None, run_id: str | None = None):
|
|
return _resolve_context(user_id, run_id)
|
|
|
|
|
|
def get_default_user_id():
|
|
global _DEFAULT_USER_ID
|
|
if _DEFAULT_USER_ID:
|
|
return _DEFAULT_USER_ID
|
|
|
|
def _op(cur, _conn):
|
|
cur.execute("SELECT id FROM app_user ORDER BY username LIMIT 1")
|
|
row = cur.fetchone()
|
|
return row[0] if row else None
|
|
|
|
user_id = run_with_retry(_op)
|
|
if user_id:
|
|
with _DEFAULT_LOCK:
|
|
_DEFAULT_USER_ID = user_id
|
|
return user_id
|
|
|
|
|
|
def _default_run_id(user_id: str) -> str:
|
|
return f"default_{user_id}"
|
|
|
|
|
|
def ensure_default_run(user_id: str):
|
|
run_id = _default_run_id(user_id)
|
|
|
|
def _op(cur, _conn):
|
|
now = _utc_now()
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO strategy_run (
|
|
run_id, user_id, created_at, started_at, stopped_at, status, strategy, mode, broker, meta
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (run_id) DO NOTHING
|
|
""",
|
|
(
|
|
run_id,
|
|
user_id,
|
|
now,
|
|
None,
|
|
None,
|
|
"STOPPED",
|
|
None,
|
|
None,
|
|
None,
|
|
Json({}),
|
|
),
|
|
)
|
|
return run_id
|
|
|
|
return run_with_retry(_op)
|
|
|
|
|
|
def get_active_run_id(user_id: str):
|
|
def _op(cur, _conn):
|
|
cur.execute(
|
|
"""
|
|
SELECT run_id
|
|
FROM strategy_run
|
|
WHERE user_id = %s AND status = 'RUNNING'
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row:
|
|
return row[0]
|
|
cur.execute(
|
|
"""
|
|
SELECT run_id
|
|
FROM strategy_run
|
|
WHERE user_id = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row:
|
|
return row[0]
|
|
return None
|
|
|
|
run_id = run_with_retry(_op)
|
|
if run_id:
|
|
return run_id
|
|
return ensure_default_run(user_id)
|
|
|
|
|
|
def get_running_runs(user_id: str | None = None):
|
|
def _op(cur, _conn):
|
|
if user_id:
|
|
cur.execute(
|
|
"""
|
|
SELECT user_id, run_id
|
|
FROM strategy_run
|
|
WHERE user_id = %s AND status = 'RUNNING'
|
|
ORDER BY created_at DESC
|
|
""",
|
|
(user_id,),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT user_id, run_id
|
|
FROM strategy_run
|
|
WHERE status = 'RUNNING'
|
|
ORDER BY created_at DESC
|
|
"""
|
|
)
|
|
return cur.fetchall()
|
|
|
|
return run_with_retry(_op)
|
|
|
|
|
|
def insert_engine_event(
|
|
cur,
|
|
event: str,
|
|
data=None,
|
|
message: str | None = None,
|
|
meta=None,
|
|
ts=None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
when = ts or _utc_now()
|
|
scope_user, scope_run = _resolve_context(user_id, run_id)
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
when,
|
|
event,
|
|
Json(data) if data is not None else None,
|
|
message,
|
|
Json(meta) if meta is not None else None,
|
|
),
|
|
)
|