from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timezone import hashlib from psycopg2.extras import execute_values from indian_paper_trading_strategy.engine.data import fetch_live_price from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context class Broker(ABC): @abstractmethod def place_order( self, symbol: str, side: str, quantity: float, price: float | None = None, logical_time: datetime | None = None, ): raise NotImplementedError @abstractmethod def get_positions(self): raise NotImplementedError @abstractmethod def get_orders(self): raise NotImplementedError @abstractmethod def get_funds(self): raise NotImplementedError def _local_tz(): return datetime.now().astimezone().tzinfo def _format_utc_ts(value: datetime | None): if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=_local_tz()) return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") def _format_local_ts(value: datetime | None): if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=_local_tz()) return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() def _parse_ts(value, assume_local: bool = True): if value is None: return None if isinstance(value, datetime): if value.tzinfo is None: return value.replace(tzinfo=_local_tz() if assume_local else timezone.utc) return value if isinstance(value, str): text = value.strip() if not text: return None if text.endswith("Z"): try: return datetime.fromisoformat(text.replace("Z", "+00:00")) except ValueError: return None try: parsed = datetime.fromisoformat(text) except ValueError: return None if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=_local_tz() if assume_local else timezone.utc) return parsed return None def _stable_num(value: float) -> str: return f"{float(value):.12f}" def _normalize_ts_for_id(ts: datetime) -> str: if ts.tzinfo is None: ts = ts.replace(tzinfo=timezone.utc) return ts.astimezone(timezone.utc).replace(microsecond=0).isoformat() def _deterministic_id(prefix: str, parts: list[str]) -> str: payload = "|".join(parts) digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16] return f"{prefix}_{digest}" def _resolve_scope(user_id: str | None, run_id: str | None): return get_context(user_id, run_id) @dataclass class PaperBroker(Broker): initial_cash: float store_path: str | None = None def _default_store(self): return { "cash": float(self.initial_cash), "positions": {}, "orders": [], "trades": [], "equity_curve": [], } def _load_store(self, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None): scope_user, scope_run = _resolve_scope(user_id, run_id) if cur is None: with db_connection() as conn: with conn.cursor() as cur: return self._load_store( cur=cur, for_update=for_update, user_id=scope_user, run_id=scope_run, ) store = self._default_store() lock_clause = " FOR UPDATE" if for_update else "" cur.execute( f"SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1", (scope_user, scope_run), ) row = cur.fetchone() if row and row[0] is not None: store["cash"] = float(row[0]) cur.execute( f""" SELECT symbol, qty, avg_price, last_price FROM paper_position WHERE user_id = %s AND run_id = %s{lock_clause} """ , (scope_user, scope_run), ) positions = {} for symbol, qty, avg_price, last_price in cur.fetchall(): positions[symbol] = { "qty": float(qty) if qty is not None else 0.0, "avg_price": float(avg_price) if avg_price is not None else 0.0, "last_price": float(last_price) if last_price is not None else 0.0, } store["positions"] = positions cur.execute( """ SELECT id, symbol, side, qty, price, status, timestamp, logical_time FROM paper_order WHERE user_id = %s AND run_id = %s ORDER BY timestamp, id """ , (scope_user, scope_run), ) orders = [] for order_id, symbol, side, qty, price, status, ts, logical_ts in cur.fetchall(): orders.append( { "id": order_id, "symbol": symbol, "side": side, "qty": float(qty) if qty is not None else 0.0, "price": float(price) if price is not None else 0.0, "status": status, "timestamp": _format_utc_ts(ts), "_logical_time": _format_utc_ts(logical_ts), } ) store["orders"] = orders cur.execute( """ SELECT id, order_id, symbol, side, qty, price, timestamp, logical_time FROM paper_trade WHERE user_id = %s AND run_id = %s ORDER BY timestamp, id """ , (scope_user, scope_run), ) trades = [] for trade_id, order_id, symbol, side, qty, price, ts, logical_ts in cur.fetchall(): trades.append( { "id": trade_id, "order_id": order_id, "symbol": symbol, "side": side, "qty": float(qty) if qty is not None else 0.0, "price": float(price) if price is not None else 0.0, "timestamp": _format_utc_ts(ts), "_logical_time": _format_utc_ts(logical_ts), } ) store["trades"] = trades cur.execute( """ SELECT timestamp, logical_time, equity, pnl FROM paper_equity_curve WHERE user_id = %s AND run_id = %s ORDER BY timestamp """ , (scope_user, scope_run), ) equity_curve = [] for ts, logical_ts, equity, pnl in cur.fetchall(): equity_curve.append( { "timestamp": _format_local_ts(ts), "_logical_time": _format_local_ts(logical_ts), "equity": float(equity) if equity is not None else 0.0, "pnl": float(pnl) if pnl is not None else 0.0, } ) store["equity_curve"] = equity_curve return store def _save_store(self, store, cur=None, user_id: str | None = None, run_id: str | None = None): scope_user, scope_run = _resolve_scope(user_id, run_id) if cur is None: def _persist(cur, _conn): self._save_store(store, cur=cur, user_id=scope_user, run_id=scope_run) return run_with_retry(_persist) cash = store.get("cash") if cash is not None: cur.execute( """ INSERT INTO paper_broker_account (user_id, run_id, cash) VALUES (%s, %s, %s) ON CONFLICT (user_id, run_id) DO UPDATE SET cash = EXCLUDED.cash """, (scope_user, scope_run, float(cash)), ) positions = store.get("positions") if isinstance(positions, dict): symbols = [s for s in positions.keys() if s] if symbols: cur.execute( "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s AND symbol NOT IN %s", (scope_user, scope_run, tuple(symbols)), ) else: cur.execute( "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s", (scope_user, scope_run), ) if symbols: rows = [] updated_at = datetime.now(timezone.utc) for symbol, data in positions.items(): if not symbol or not isinstance(data, dict): continue rows.append( ( scope_user, scope_run, symbol, float(data.get("qty", 0.0)), float(data.get("avg_price", 0.0)), float(data.get("last_price", 0.0)), updated_at, ) ) if rows: execute_values( cur, """ INSERT INTO paper_position ( user_id, run_id, symbol, qty, avg_price, last_price, updated_at ) VALUES %s ON CONFLICT (user_id, run_id, symbol) DO UPDATE SET qty = EXCLUDED.qty, avg_price = EXCLUDED.avg_price, last_price = EXCLUDED.last_price, updated_at = EXCLUDED.updated_at """, rows, ) orders = store.get("orders") if isinstance(orders, list) and orders: rows = [] for order in orders: if not isinstance(order, dict): continue order_id = order.get("id") if not order_id: continue ts = _parse_ts(order.get("timestamp"), assume_local=False) logical_ts = _parse_ts(order.get("_logical_time"), assume_local=False) or ts rows.append( ( scope_user, scope_run, order_id, order.get("symbol"), order.get("side"), float(order.get("qty", 0.0)), float(order.get("price", 0.0)), order.get("status"), ts, logical_ts, ) ) if rows: execute_values( cur, """ INSERT INTO paper_order ( user_id, run_id, id, symbol, side, qty, price, status, timestamp, logical_time ) VALUES %s ON CONFLICT DO NOTHING """, rows, ) trades = store.get("trades") if isinstance(trades, list) and trades: rows = [] for trade in trades: if not isinstance(trade, dict): continue trade_id = trade.get("id") if not trade_id: continue ts = _parse_ts(trade.get("timestamp"), assume_local=False) logical_ts = _parse_ts(trade.get("_logical_time"), assume_local=False) or ts rows.append( ( scope_user, scope_run, trade_id, trade.get("order_id"), trade.get("symbol"), trade.get("side"), float(trade.get("qty", 0.0)), float(trade.get("price", 0.0)), ts, logical_ts, ) ) if rows: execute_values( cur, """ INSERT INTO paper_trade ( user_id, run_id, id, order_id, symbol, side, qty, price, timestamp, logical_time ) VALUES %s ON CONFLICT DO NOTHING """, rows, ) equity_curve = store.get("equity_curve") if isinstance(equity_curve, list) and equity_curve: rows = [] for point in equity_curve: if not isinstance(point, dict): continue ts = _parse_ts(point.get("timestamp"), assume_local=True) logical_ts = _parse_ts(point.get("_logical_time"), assume_local=True) or ts if ts is None: continue rows.append( ( scope_user, scope_run, ts, logical_ts, float(point.get("equity", 0.0)), float(point.get("pnl", 0.0)), ) ) if rows: execute_values( cur, """ INSERT INTO paper_equity_curve (user_id, run_id, timestamp, logical_time, equity, pnl) VALUES %s ON CONFLICT DO NOTHING """, rows, ) def get_funds(self, cur=None): store = self._load_store(cur=cur) cash = float(store.get("cash", 0)) positions = store.get("positions", {}) positions_value = 0.0 for position in positions.values(): qty = float(position.get("qty", 0)) last_price = float(position.get("last_price", position.get("avg_price", 0))) positions_value += qty * last_price total_equity = cash + positions_value return { "cash_available": cash, "invested_value": positions_value, "cash": cash, "used_margin": 0.0, "available": cash, "net": total_equity, "total_equity": total_equity, } def get_positions(self, cur=None): store = self._load_store(cur=cur) positions = store.get("positions", {}) return [ { "symbol": symbol, "qty": float(data.get("qty", 0)), "avg_price": float(data.get("avg_price", 0)), "last_price": float(data.get("last_price", data.get("avg_price", 0))), } for symbol, data in positions.items() ] def get_orders(self, cur=None): store = self._load_store(cur=cur) orders = [] for order in store.get("orders", []): if isinstance(order, dict): order = {k: v for k, v in order.items() if k != "_logical_time"} orders.append(order) return orders def get_trades(self, cur=None): store = self._load_store(cur=cur) trades = [] for trade in store.get("trades", []): if isinstance(trade, dict): trade = {k: v for k, v in trade.items() if k != "_logical_time"} trades.append(trade) return trades def get_equity_curve(self, cur=None): store = self._load_store(cur=cur) points = [] for point in store.get("equity_curve", []): if isinstance(point, dict): point = {k: v for k, v in point.items() if k != "_logical_time"} points.append(point) return points def _update_equity_in_tx( self, cur, prices: dict[str, float], now: datetime, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): store = self._load_store(cur=cur, for_update=True, user_id=user_id, run_id=run_id) positions = store.get("positions", {}) for symbol, price in prices.items(): if symbol in positions: positions[symbol]["last_price"] = float(price) cash = float(store.get("cash", 0)) positions_value = 0.0 for symbol, position in positions.items(): qty = float(position.get("qty", 0)) price = float(position.get("last_price", position.get("avg_price", 0))) positions_value += qty * price equity = cash + positions_value pnl = equity - float(self.initial_cash) ts_for_equity = logical_time or now store.setdefault("equity_curve", []).append( { "timestamp": _format_local_ts(ts_for_equity), "_logical_time": _format_local_ts(ts_for_equity), "equity": equity, "pnl": pnl, } ) store["positions"] = positions self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event( cur, "EQUITY_UPDATED", data={ "timestamp": _format_utc_ts(ts_for_equity), "equity": equity, "pnl": pnl, }, ) return equity def update_equity( self, prices: dict[str, float], now: datetime, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): if cur is not None: return self._update_equity_in_tx( cur, prices, now, logical_time=logical_time, user_id=user_id, run_id=run_id, ) def _op(cur, _conn): return self._update_equity_in_tx( cur, prices, now, logical_time=logical_time, user_id=user_id, run_id=run_id, ) return run_with_retry(_op) def _place_order_in_tx( self, cur, symbol: str, side: str, quantity: float, price: float | None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): scope_user, scope_run = _resolve_scope(user_id, run_id) store = self._load_store(cur=cur, for_update=True, user_id=scope_user, run_id=scope_run) side = side.upper().strip() qty = float(quantity) if price is None: price = fetch_live_price(symbol) price = float(price) logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) timestamp = logical_ts timestamp_str = _format_utc_ts(timestamp) logical_ts_str = _format_utc_ts(logical_ts) order_id = _deterministic_id( "ord", [ scope_user, scope_run, _normalize_ts_for_id(logical_ts), symbol, side, _stable_num(qty), _stable_num(price), ], ) order = { "id": order_id, "symbol": symbol, "side": side, "qty": qty, "price": price, "status": "REJECTED", "timestamp": timestamp_str, "_logical_time": logical_ts_str, } if qty <= 0 or price <= 0: store.setdefault("orders", []).append(order) self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event(cur, "ORDER_PLACED", data=order) return order positions = store.get("positions", {}) cash = float(store.get("cash", 0)) trade = None if side == "BUY": cost = qty * price if cash >= cost: cash -= cost existing = positions.get(symbol, {"qty": 0.0, "avg_price": 0.0, "last_price": price}) new_qty = float(existing.get("qty", 0)) + qty prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0)) avg_price = (prev_cost + cost) / new_qty if new_qty else price positions[symbol] = { "qty": new_qty, "avg_price": avg_price, "last_price": price, } order["status"] = "FILLED" trade = { "id": _deterministic_id("trd", [order_id]), "order_id": order_id, "symbol": symbol, "side": side, "qty": qty, "price": price, "timestamp": timestamp_str, "_logical_time": logical_ts_str, } store.setdefault("trades", []).append(trade) elif side == "SELL": existing = positions.get(symbol) if existing and float(existing.get("qty", 0)) >= qty: cash += qty * price remaining = float(existing.get("qty", 0)) - qty if remaining > 0: existing["qty"] = remaining existing["last_price"] = price positions[symbol] = existing else: positions.pop(symbol, None) order["status"] = "FILLED" trade = { "id": _deterministic_id("trd", [order_id]), "order_id": order_id, "symbol": symbol, "side": side, "qty": qty, "price": price, "timestamp": timestamp_str, "_logical_time": logical_ts_str, } store.setdefault("trades", []).append(trade) store["cash"] = cash store["positions"] = positions store.setdefault("orders", []).append(order) self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event(cur, "ORDER_PLACED", data=order) if trade is not None: insert_engine_event(cur, "TRADE_EXECUTED", data=trade) insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order_id}) return order def place_order( self, symbol: str, side: str, quantity: float, price: float | None = None, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): if cur is not None: return self._place_order_in_tx( cur, symbol, side, quantity, price, logical_time=logical_time, user_id=user_id, run_id=run_id, ) def _op(cur, _conn): return self._place_order_in_tx( cur, symbol, side, quantity, price, logical_time=logical_time, user_id=user_id, run_id=run_id, ) return run_with_retry(_op)