2026-02-01 20:34:57 +00:00

698 lines
24 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
from psycopg2.extras import execute_values
from indian_paper_trading_strategy.engine.data import fetch_live_price
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
class Broker(ABC):
@abstractmethod
def place_order(
self,
symbol: str,
side: str,
quantity: float,
price: float | None = None,
logical_time: datetime | None = None,
):
raise NotImplementedError
@abstractmethod
def get_positions(self):
raise NotImplementedError
@abstractmethod
def get_orders(self):
raise NotImplementedError
@abstractmethod
def get_funds(self):
raise NotImplementedError
def _local_tz():
return datetime.now().astimezone().tzinfo
def _format_utc_ts(value: datetime | None):
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=_local_tz())
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
def _format_local_ts(value: datetime | None):
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=_local_tz())
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
def _parse_ts(value, assume_local: bool = True):
if value is None:
return None
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=_local_tz() if assume_local else timezone.utc)
return value
if isinstance(value, str):
text = value.strip()
if not text:
return None
if text.endswith("Z"):
try:
return datetime.fromisoformat(text.replace("Z", "+00:00"))
except ValueError:
return None
try:
parsed = datetime.fromisoformat(text)
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=_local_tz() if assume_local else timezone.utc)
return parsed
return None
def _stable_num(value: float) -> str:
return f"{float(value):.12f}"
def _normalize_ts_for_id(ts: datetime) -> str:
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
return ts.astimezone(timezone.utc).replace(microsecond=0).isoformat()
def _deterministic_id(prefix: str, parts: list[str]) -> str:
payload = "|".join(parts)
digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]
return f"{prefix}_{digest}"
def _resolve_scope(user_id: str | None, run_id: str | None):
return get_context(user_id, run_id)
@dataclass
class PaperBroker(Broker):
initial_cash: float
store_path: str | None = None
def _default_store(self):
return {
"cash": float(self.initial_cash),
"positions": {},
"orders": [],
"trades": [],
"equity_curve": [],
}
def _load_store(self, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None):
scope_user, scope_run = _resolve_scope(user_id, run_id)
if cur is None:
with db_connection() as conn:
with conn.cursor() as cur:
return self._load_store(
cur=cur,
for_update=for_update,
user_id=scope_user,
run_id=scope_run,
)
store = self._default_store()
lock_clause = " FOR UPDATE" if for_update else ""
cur.execute(
f"SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1",
(scope_user, scope_run),
)
row = cur.fetchone()
if row and row[0] is not None:
store["cash"] = float(row[0])
cur.execute(
f"""
SELECT symbol, qty, avg_price, last_price
FROM paper_position
WHERE user_id = %s AND run_id = %s{lock_clause}
"""
,
(scope_user, scope_run),
)
positions = {}
for symbol, qty, avg_price, last_price in cur.fetchall():
positions[symbol] = {
"qty": float(qty) if qty is not None else 0.0,
"avg_price": float(avg_price) if avg_price is not None else 0.0,
"last_price": float(last_price) if last_price is not None else 0.0,
}
store["positions"] = positions
cur.execute(
"""
SELECT id, symbol, side, qty, price, status, timestamp, logical_time
FROM paper_order
WHERE user_id = %s AND run_id = %s
ORDER BY timestamp, id
"""
,
(scope_user, scope_run),
)
orders = []
for order_id, symbol, side, qty, price, status, ts, logical_ts in cur.fetchall():
orders.append(
{
"id": order_id,
"symbol": symbol,
"side": side,
"qty": float(qty) if qty is not None else 0.0,
"price": float(price) if price is not None else 0.0,
"status": status,
"timestamp": _format_utc_ts(ts),
"_logical_time": _format_utc_ts(logical_ts),
}
)
store["orders"] = orders
cur.execute(
"""
SELECT id, order_id, symbol, side, qty, price, timestamp, logical_time
FROM paper_trade
WHERE user_id = %s AND run_id = %s
ORDER BY timestamp, id
"""
,
(scope_user, scope_run),
)
trades = []
for trade_id, order_id, symbol, side, qty, price, ts, logical_ts in cur.fetchall():
trades.append(
{
"id": trade_id,
"order_id": order_id,
"symbol": symbol,
"side": side,
"qty": float(qty) if qty is not None else 0.0,
"price": float(price) if price is not None else 0.0,
"timestamp": _format_utc_ts(ts),
"_logical_time": _format_utc_ts(logical_ts),
}
)
store["trades"] = trades
cur.execute(
"""
SELECT timestamp, logical_time, equity, pnl
FROM paper_equity_curve
WHERE user_id = %s AND run_id = %s
ORDER BY timestamp
"""
,
(scope_user, scope_run),
)
equity_curve = []
for ts, logical_ts, equity, pnl in cur.fetchall():
equity_curve.append(
{
"timestamp": _format_local_ts(ts),
"_logical_time": _format_local_ts(logical_ts),
"equity": float(equity) if equity is not None else 0.0,
"pnl": float(pnl) if pnl is not None else 0.0,
}
)
store["equity_curve"] = equity_curve
return store
def _save_store(self, store, cur=None, user_id: str | None = None, run_id: str | None = None):
scope_user, scope_run = _resolve_scope(user_id, run_id)
if cur is None:
def _persist(cur, _conn):
self._save_store(store, cur=cur, user_id=scope_user, run_id=scope_run)
return run_with_retry(_persist)
cash = store.get("cash")
if cash is not None:
cur.execute(
"""
INSERT INTO paper_broker_account (user_id, run_id, cash)
VALUES (%s, %s, %s)
ON CONFLICT (user_id, run_id) DO UPDATE
SET cash = EXCLUDED.cash
""",
(scope_user, scope_run, float(cash)),
)
positions = store.get("positions")
if isinstance(positions, dict):
symbols = [s for s in positions.keys() if s]
if symbols:
cur.execute(
"DELETE FROM paper_position WHERE user_id = %s AND run_id = %s AND symbol NOT IN %s",
(scope_user, scope_run, tuple(symbols)),
)
else:
cur.execute(
"DELETE FROM paper_position WHERE user_id = %s AND run_id = %s",
(scope_user, scope_run),
)
if symbols:
rows = []
updated_at = datetime.now(timezone.utc)
for symbol, data in positions.items():
if not symbol or not isinstance(data, dict):
continue
rows.append(
(
scope_user,
scope_run,
symbol,
float(data.get("qty", 0.0)),
float(data.get("avg_price", 0.0)),
float(data.get("last_price", 0.0)),
updated_at,
)
)
if rows:
execute_values(
cur,
"""
INSERT INTO paper_position (
user_id, run_id, symbol, qty, avg_price, last_price, updated_at
)
VALUES %s
ON CONFLICT (user_id, run_id, symbol) DO UPDATE
SET qty = EXCLUDED.qty,
avg_price = EXCLUDED.avg_price,
last_price = EXCLUDED.last_price,
updated_at = EXCLUDED.updated_at
""",
rows,
)
orders = store.get("orders")
if isinstance(orders, list) and orders:
rows = []
for order in orders:
if not isinstance(order, dict):
continue
order_id = order.get("id")
if not order_id:
continue
ts = _parse_ts(order.get("timestamp"), assume_local=False)
logical_ts = _parse_ts(order.get("_logical_time"), assume_local=False) or ts
rows.append(
(
scope_user,
scope_run,
order_id,
order.get("symbol"),
order.get("side"),
float(order.get("qty", 0.0)),
float(order.get("price", 0.0)),
order.get("status"),
ts,
logical_ts,
)
)
if rows:
execute_values(
cur,
"""
INSERT INTO paper_order (
user_id, run_id, id, symbol, side, qty, price, status, timestamp, logical_time
)
VALUES %s
ON CONFLICT DO NOTHING
""",
rows,
)
trades = store.get("trades")
if isinstance(trades, list) and trades:
rows = []
for trade in trades:
if not isinstance(trade, dict):
continue
trade_id = trade.get("id")
if not trade_id:
continue
ts = _parse_ts(trade.get("timestamp"), assume_local=False)
logical_ts = _parse_ts(trade.get("_logical_time"), assume_local=False) or ts
rows.append(
(
scope_user,
scope_run,
trade_id,
trade.get("order_id"),
trade.get("symbol"),
trade.get("side"),
float(trade.get("qty", 0.0)),
float(trade.get("price", 0.0)),
ts,
logical_ts,
)
)
if rows:
execute_values(
cur,
"""
INSERT INTO paper_trade (
user_id, run_id, id, order_id, symbol, side, qty, price, timestamp, logical_time
)
VALUES %s
ON CONFLICT DO NOTHING
""",
rows,
)
equity_curve = store.get("equity_curve")
if isinstance(equity_curve, list) and equity_curve:
rows = []
for point in equity_curve:
if not isinstance(point, dict):
continue
ts = _parse_ts(point.get("timestamp"), assume_local=True)
logical_ts = _parse_ts(point.get("_logical_time"), assume_local=True) or ts
if ts is None:
continue
rows.append(
(
scope_user,
scope_run,
ts,
logical_ts,
float(point.get("equity", 0.0)),
float(point.get("pnl", 0.0)),
)
)
if rows:
execute_values(
cur,
"""
INSERT INTO paper_equity_curve (user_id, run_id, timestamp, logical_time, equity, pnl)
VALUES %s
ON CONFLICT DO NOTHING
""",
rows,
)
def get_funds(self, cur=None):
store = self._load_store(cur=cur)
cash = float(store.get("cash", 0))
positions = store.get("positions", {})
positions_value = 0.0
for position in positions.values():
qty = float(position.get("qty", 0))
last_price = float(position.get("last_price", position.get("avg_price", 0)))
positions_value += qty * last_price
total_equity = cash + positions_value
return {
"cash_available": cash,
"invested_value": positions_value,
"cash": cash,
"used_margin": 0.0,
"available": cash,
"net": total_equity,
"total_equity": total_equity,
}
def get_positions(self, cur=None):
store = self._load_store(cur=cur)
positions = store.get("positions", {})
return [
{
"symbol": symbol,
"qty": float(data.get("qty", 0)),
"avg_price": float(data.get("avg_price", 0)),
"last_price": float(data.get("last_price", data.get("avg_price", 0))),
}
for symbol, data in positions.items()
]
def get_orders(self, cur=None):
store = self._load_store(cur=cur)
orders = []
for order in store.get("orders", []):
if isinstance(order, dict):
order = {k: v for k, v in order.items() if k != "_logical_time"}
orders.append(order)
return orders
def get_trades(self, cur=None):
store = self._load_store(cur=cur)
trades = []
for trade in store.get("trades", []):
if isinstance(trade, dict):
trade = {k: v for k, v in trade.items() if k != "_logical_time"}
trades.append(trade)
return trades
def get_equity_curve(self, cur=None):
store = self._load_store(cur=cur)
points = []
for point in store.get("equity_curve", []):
if isinstance(point, dict):
point = {k: v for k, v in point.items() if k != "_logical_time"}
points.append(point)
return points
def _update_equity_in_tx(
self,
cur,
prices: dict[str, float],
now: datetime,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
store = self._load_store(cur=cur, for_update=True, user_id=user_id, run_id=run_id)
positions = store.get("positions", {})
for symbol, price in prices.items():
if symbol in positions:
positions[symbol]["last_price"] = float(price)
cash = float(store.get("cash", 0))
positions_value = 0.0
for symbol, position in positions.items():
qty = float(position.get("qty", 0))
price = float(position.get("last_price", position.get("avg_price", 0)))
positions_value += qty * price
equity = cash + positions_value
pnl = equity - float(self.initial_cash)
ts_for_equity = logical_time or now
store.setdefault("equity_curve", []).append(
{
"timestamp": _format_local_ts(ts_for_equity),
"_logical_time": _format_local_ts(ts_for_equity),
"equity": equity,
"pnl": pnl,
}
)
store["positions"] = positions
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
insert_engine_event(
cur,
"EQUITY_UPDATED",
data={
"timestamp": _format_utc_ts(ts_for_equity),
"equity": equity,
"pnl": pnl,
},
)
return equity
def update_equity(
self,
prices: dict[str, float],
now: datetime,
cur=None,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
if cur is not None:
return self._update_equity_in_tx(
cur,
prices,
now,
logical_time=logical_time,
user_id=user_id,
run_id=run_id,
)
def _op(cur, _conn):
return self._update_equity_in_tx(
cur,
prices,
now,
logical_time=logical_time,
user_id=user_id,
run_id=run_id,
)
return run_with_retry(_op)
def _place_order_in_tx(
self,
cur,
symbol: str,
side: str,
quantity: float,
price: float | None,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
scope_user, scope_run = _resolve_scope(user_id, run_id)
store = self._load_store(cur=cur, for_update=True, user_id=scope_user, run_id=scope_run)
side = side.upper().strip()
qty = float(quantity)
if price is None:
price = fetch_live_price(symbol)
price = float(price)
logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)
timestamp = logical_ts
timestamp_str = _format_utc_ts(timestamp)
logical_ts_str = _format_utc_ts(logical_ts)
order_id = _deterministic_id(
"ord",
[
scope_user,
scope_run,
_normalize_ts_for_id(logical_ts),
symbol,
side,
_stable_num(qty),
_stable_num(price),
],
)
order = {
"id": order_id,
"symbol": symbol,
"side": side,
"qty": qty,
"price": price,
"status": "REJECTED",
"timestamp": timestamp_str,
"_logical_time": logical_ts_str,
}
if qty <= 0 or price <= 0:
store.setdefault("orders", []).append(order)
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
insert_engine_event(cur, "ORDER_PLACED", data=order)
return order
positions = store.get("positions", {})
cash = float(store.get("cash", 0))
trade = None
if side == "BUY":
cost = qty * price
if cash >= cost:
cash -= cost
existing = positions.get(symbol, {"qty": 0.0, "avg_price": 0.0, "last_price": price})
new_qty = float(existing.get("qty", 0)) + qty
prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0))
avg_price = (prev_cost + cost) / new_qty if new_qty else price
positions[symbol] = {
"qty": new_qty,
"avg_price": avg_price,
"last_price": price,
}
order["status"] = "FILLED"
trade = {
"id": _deterministic_id("trd", [order_id]),
"order_id": order_id,
"symbol": symbol,
"side": side,
"qty": qty,
"price": price,
"timestamp": timestamp_str,
"_logical_time": logical_ts_str,
}
store.setdefault("trades", []).append(trade)
elif side == "SELL":
existing = positions.get(symbol)
if existing and float(existing.get("qty", 0)) >= qty:
cash += qty * price
remaining = float(existing.get("qty", 0)) - qty
if remaining > 0:
existing["qty"] = remaining
existing["last_price"] = price
positions[symbol] = existing
else:
positions.pop(symbol, None)
order["status"] = "FILLED"
trade = {
"id": _deterministic_id("trd", [order_id]),
"order_id": order_id,
"symbol": symbol,
"side": side,
"qty": qty,
"price": price,
"timestamp": timestamp_str,
"_logical_time": logical_ts_str,
}
store.setdefault("trades", []).append(trade)
store["cash"] = cash
store["positions"] = positions
store.setdefault("orders", []).append(order)
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
insert_engine_event(cur, "ORDER_PLACED", data=order)
if trade is not None:
insert_engine_event(cur, "TRADE_EXECUTED", data=trade)
insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order_id})
return order
def place_order(
self,
symbol: str,
side: str,
quantity: float,
price: float | None = None,
cur=None,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
if cur is not None:
return self._place_order_in_tx(
cur,
symbol,
side,
quantity,
price,
logical_time=logical_time,
user_id=user_id,
run_id=run_id,
)
def _op(cur, _conn):
return self._place_order_in_tx(
cur,
symbol,
side,
quantity,
price,
logical_time=logical_time,
user_id=user_id,
run_id=run_id,
)
return run_with_retry(_op)