from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timezone import hashlib import math import os import time from psycopg2.extras import execute_values from indian_paper_trading_strategy.engine.data import fetch_live_price from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context class Broker(ABC): external_orders = False @abstractmethod def place_order( self, symbol: str, side: str, quantity: float, price: float | None = None, logical_time: datetime | None = None, ): raise NotImplementedError @abstractmethod def get_positions(self): raise NotImplementedError @abstractmethod def get_orders(self): raise NotImplementedError @abstractmethod def get_funds(self): raise NotImplementedError class BrokerError(Exception): pass class BrokerAuthExpired(BrokerError): pass def _local_tz(): return datetime.now().astimezone().tzinfo def _format_utc_ts(value: datetime | None): if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=_local_tz()) return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") def _format_local_ts(value: datetime | None): if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=_local_tz()) return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() def _parse_ts(value, assume_local: bool = True): if value is None: return None if isinstance(value, datetime): if value.tzinfo is None: return value.replace(tzinfo=_local_tz() if assume_local else timezone.utc) return value if isinstance(value, str): text = value.strip() if not text: return None if text.endswith("Z"): try: return datetime.fromisoformat(text.replace("Z", "+00:00")) except ValueError: return None try: parsed = datetime.fromisoformat(text) except ValueError: return None if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=_local_tz() if assume_local else timezone.utc) return parsed return None def _stable_num(value: float) -> str: return f"{float(value):.12f}" def _normalize_ts_for_id(ts: datetime) -> str: if ts.tzinfo is None: ts = ts.replace(tzinfo=timezone.utc) return ts.astimezone(timezone.utc).replace(microsecond=0).isoformat() def _deterministic_id(prefix: str, parts: list[str]) -> str: payload = "|".join(parts) digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16] return f"{prefix}_{digest}" def _resolve_scope(user_id: str | None, run_id: str | None): return get_context(user_id, run_id) class LiveZerodhaBroker(Broker): external_orders = True TERMINAL_STATUSES = {"COMPLETE", "REJECTED", "CANCELLED"} POLL_TIMEOUT_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_TIMEOUT", "12")) POLL_INTERVAL_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_INTERVAL", "1")) def __init__(self, user_id: str | None = None, run_id: str | None = None): self.user_id = user_id self.run_id = run_id def _scope(self): return _resolve_scope(self.user_id, self.run_id) def _session(self): from app.services.zerodha_storage import get_session user_id, _run_id = self._scope() session = get_session(user_id) if not session or not session.get("api_key") or not session.get("access_token"): raise BrokerAuthExpired("Zerodha session missing. Please reconnect broker.") return session def _raise_auth_expired(self, exc: Exception): from app.broker_store import set_broker_auth_state user_id, _run_id = self._scope() set_broker_auth_state(user_id, "EXPIRED") raise BrokerAuthExpired(str(exc)) from exc def _normalize_symbol(self, symbol: str) -> tuple[str, str]: cleaned = (symbol or "").strip().upper() if cleaned.endswith(".NS"): return cleaned[:-3], "NSE" if cleaned.endswith(".BO"): return cleaned[:-3], "BSE" return cleaned, "NSE" def _make_tag(self, logical_time: datetime | None, symbol: str, side: str) -> str: user_id, run_id = self._scope() logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) digest = hashlib.sha1( f"{user_id}|{run_id}|{_normalize_ts_for_id(logical_ts)}|{symbol}|{side}".encode("utf-8") ).hexdigest()[:18] return f"qf{digest}" def _normalize_order_payload( self, *, order_id: str, symbol: str, side: str, requested_qty: int, requested_price: float | None, history_entry: dict | None, logical_time: datetime | None, ) -> dict: entry = history_entry or {} raw_status = (entry.get("status") or "").upper() status = raw_status if raw_status == "COMPLETE": status = "FILLED" elif raw_status in {"REJECTED", "CANCELLED"}: status = raw_status elif raw_status: status = "PENDING" else: status = "PENDING" quantity = int(entry.get("quantity") or requested_qty or 0) filled_qty = int(entry.get("filled_quantity") or 0) average_price = float(entry.get("average_price") or requested_price or 0.0) price = float(entry.get("price") or requested_price or average_price or 0.0) timestamp = ( entry.get("exchange_timestamp") or entry.get("order_timestamp") or _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)) ) if timestamp and " " in str(timestamp): timestamp = str(timestamp).replace(" ", "T") return { "id": order_id, "symbol": symbol, "side": side.upper().strip(), "qty": quantity, "requested_qty": quantity, "filled_qty": filled_qty, "price": price, "requested_price": float(requested_price or price or 0.0), "average_price": average_price, "status": status, "timestamp": timestamp, "broker_order_id": order_id, "exchange": entry.get("exchange"), "tradingsymbol": entry.get("tradingsymbol"), "status_message": entry.get("status_message") or entry.get("status_message_raw"), } def _wait_for_terminal_order( self, session: dict, order_id: str, *, symbol: str, side: str, requested_qty: int, requested_price: float | None, logical_time: datetime | None, ) -> dict: from app.services.zerodha_service import ( KiteTokenError, cancel_order, fetch_order_history, ) started = time.monotonic() last_payload = self._normalize_order_payload( order_id=order_id, symbol=symbol, side=side, requested_qty=requested_qty, requested_price=requested_price, history_entry=None, logical_time=logical_time, ) while True: try: history = fetch_order_history( session["api_key"], session["access_token"], order_id, ) except KiteTokenError as exc: self._raise_auth_expired(exc) if history: entry = history[-1] last_payload = self._normalize_order_payload( order_id=order_id, symbol=symbol, side=side, requested_qty=requested_qty, requested_price=requested_price, history_entry=entry, logical_time=logical_time, ) raw_status = (entry.get("status") or "").upper() if raw_status in self.TERMINAL_STATUSES: return last_payload if time.monotonic() - started >= self.POLL_TIMEOUT_SECONDS: try: cancel_order( session["api_key"], session["access_token"], order_id=order_id, ) history = fetch_order_history( session["api_key"], session["access_token"], order_id, ) if history: return self._normalize_order_payload( order_id=order_id, symbol=symbol, side=side, requested_qty=requested_qty, requested_price=requested_price, history_entry=history[-1], logical_time=logical_time, ) except KiteTokenError as exc: self._raise_auth_expired(exc) return last_payload time.sleep(self.POLL_INTERVAL_SECONDS) def get_funds(self, cur=None): from app.services.zerodha_service import KiteTokenError, fetch_funds session = self._session() try: data = fetch_funds(session["api_key"], session["access_token"]) except KiteTokenError as exc: self._raise_auth_expired(exc) equity = data.get("equity", {}) if isinstance(data, dict) else {} available = equity.get("available", {}) if isinstance(equity, dict) else {} cash = ( available.get("live_balance") or available.get("cash") or available.get("opening_balance") or equity.get("net") or 0.0 ) return {"cash": float(cash), "raw": data} def get_positions(self): from app.services.zerodha_service import KiteTokenError, fetch_holdings session = self._session() try: holdings = fetch_holdings(session["api_key"], session["access_token"]) except KiteTokenError as exc: self._raise_auth_expired(exc) normalized = [] for item in holdings: qty = float(item.get("quantity") or item.get("qty") or 0) avg_price = float(item.get("average_price") or 0) last_price = float(item.get("last_price") or avg_price or 0) exchange = item.get("exchange") suffix = ".NS" if exchange == "NSE" else ".BO" if exchange == "BSE" else "" normalized.append( { "symbol": f"{item.get('tradingsymbol')}{suffix}", "qty": qty, "avg_price": avg_price, "last_price": last_price, } ) return normalized def get_orders(self): from app.services.zerodha_service import KiteTokenError, fetch_orders session = self._session() try: return fetch_orders(session["api_key"], session["access_token"]) except KiteTokenError as exc: self._raise_auth_expired(exc) def update_equity( self, prices: dict[str, float], now: datetime, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): return None def place_order( self, symbol: str, side: str, quantity: float, price: float | None = None, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): from app.services.zerodha_service import KiteTokenError, place_order if user_id is not None: self.user_id = user_id if run_id is not None: self.run_id = run_id qty = int(math.floor(float(quantity))) side = side.upper().strip() requested_price = float(price) if price is not None else None if qty <= 0: return { "id": _deterministic_id("live_rej", [symbol, side, _stable_num(quantity)]), "symbol": symbol, "side": side, "qty": qty, "requested_qty": qty, "filled_qty": 0, "price": float(price or 0.0), "requested_price": float(price or 0.0), "average_price": 0.0, "status": "REJECTED", "timestamp": _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)), "status_message": "Computed quantity is less than 1 share", } session = self._session() tradingsymbol, exchange = self._normalize_symbol(symbol) tag = self._make_tag(logical_time, symbol, side) try: placed = place_order( session["api_key"], session["access_token"], tradingsymbol=tradingsymbol, exchange=exchange, transaction_type=side, order_type="MARKET", quantity=qty, product="CNC", validity="DAY", variety="regular", market_protection=-1, tag=tag, ) except KiteTokenError as exc: self._raise_auth_expired(exc) order_id = placed.get("order_id") if not order_id: raise BrokerError("Zerodha order placement did not return an order_id") return self._wait_for_terminal_order( session, order_id, symbol=symbol, side=side, requested_qty=qty, requested_price=requested_price, logical_time=logical_time, ) @dataclass class PaperBroker(Broker): initial_cash: float store_path: str | None = None def _default_store(self): return { "cash": float(self.initial_cash), "positions": {}, "orders": [], "trades": [], "equity_curve": [], } def _load_store(self, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None): scope_user, scope_run = _resolve_scope(user_id, run_id) if cur is None: with db_connection() as conn: with conn.cursor() as cur: return self._load_store( cur=cur, for_update=for_update, user_id=scope_user, run_id=scope_run, ) store = self._default_store() lock_clause = " FOR UPDATE" if for_update else "" cur.execute( f"SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1", (scope_user, scope_run), ) row = cur.fetchone() if row and row[0] is not None: store["cash"] = float(row[0]) cur.execute( f""" SELECT symbol, qty, avg_price, last_price FROM paper_position WHERE user_id = %s AND run_id = %s{lock_clause} """ , (scope_user, scope_run), ) positions = {} for symbol, qty, avg_price, last_price in cur.fetchall(): positions[symbol] = { "qty": float(qty) if qty is not None else 0.0, "avg_price": float(avg_price) if avg_price is not None else 0.0, "last_price": float(last_price) if last_price is not None else 0.0, } store["positions"] = positions cur.execute( """ SELECT id, symbol, side, qty, price, status, timestamp, logical_time FROM paper_order WHERE user_id = %s AND run_id = %s ORDER BY timestamp, id """ , (scope_user, scope_run), ) orders = [] for order_id, symbol, side, qty, price, status, ts, logical_ts in cur.fetchall(): orders.append( { "id": order_id, "symbol": symbol, "side": side, "qty": float(qty) if qty is not None else 0.0, "price": float(price) if price is not None else 0.0, "status": status, "timestamp": _format_utc_ts(ts), "_logical_time": _format_utc_ts(logical_ts), } ) store["orders"] = orders cur.execute( """ SELECT id, order_id, symbol, side, qty, price, timestamp, logical_time FROM paper_trade WHERE user_id = %s AND run_id = %s ORDER BY timestamp, id """ , (scope_user, scope_run), ) trades = [] for trade_id, order_id, symbol, side, qty, price, ts, logical_ts in cur.fetchall(): trades.append( { "id": trade_id, "order_id": order_id, "symbol": symbol, "side": side, "qty": float(qty) if qty is not None else 0.0, "price": float(price) if price is not None else 0.0, "timestamp": _format_utc_ts(ts), "_logical_time": _format_utc_ts(logical_ts), } ) store["trades"] = trades cur.execute( """ SELECT timestamp, logical_time, equity, pnl FROM paper_equity_curve WHERE user_id = %s AND run_id = %s ORDER BY timestamp """ , (scope_user, scope_run), ) equity_curve = [] for ts, logical_ts, equity, pnl in cur.fetchall(): equity_curve.append( { "timestamp": _format_local_ts(ts), "_logical_time": _format_local_ts(logical_ts), "equity": float(equity) if equity is not None else 0.0, "pnl": float(pnl) if pnl is not None else 0.0, } ) store["equity_curve"] = equity_curve return store def _save_store(self, store, cur=None, user_id: str | None = None, run_id: str | None = None): scope_user, scope_run = _resolve_scope(user_id, run_id) if cur is None: def _persist(cur, _conn): self._save_store(store, cur=cur, user_id=scope_user, run_id=scope_run) return run_with_retry(_persist) cash = store.get("cash") if cash is not None: cur.execute( """ INSERT INTO paper_broker_account (user_id, run_id, cash) VALUES (%s, %s, %s) ON CONFLICT (user_id, run_id) DO UPDATE SET cash = EXCLUDED.cash """, (scope_user, scope_run, float(cash)), ) positions = store.get("positions") if isinstance(positions, dict): symbols = [s for s in positions.keys() if s] if symbols: cur.execute( "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s AND symbol NOT IN %s", (scope_user, scope_run, tuple(symbols)), ) else: cur.execute( "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s", (scope_user, scope_run), ) if symbols: rows = [] updated_at = datetime.now(timezone.utc) for symbol, data in positions.items(): if not symbol or not isinstance(data, dict): continue rows.append( ( scope_user, scope_run, symbol, float(data.get("qty", 0.0)), float(data.get("avg_price", 0.0)), float(data.get("last_price", 0.0)), updated_at, ) ) if rows: execute_values( cur, """ INSERT INTO paper_position ( user_id, run_id, symbol, qty, avg_price, last_price, updated_at ) VALUES %s ON CONFLICT (user_id, run_id, symbol) DO UPDATE SET qty = EXCLUDED.qty, avg_price = EXCLUDED.avg_price, last_price = EXCLUDED.last_price, updated_at = EXCLUDED.updated_at """, rows, ) orders = store.get("orders") if isinstance(orders, list) and orders: rows = [] for order in orders: if not isinstance(order, dict): continue order_id = order.get("id") if not order_id: continue ts = _parse_ts(order.get("timestamp"), assume_local=False) logical_ts = _parse_ts(order.get("_logical_time"), assume_local=False) or ts rows.append( ( scope_user, scope_run, order_id, order.get("symbol"), order.get("side"), float(order.get("qty", 0.0)), float(order.get("price", 0.0)), order.get("status"), ts, logical_ts, ) ) if rows: execute_values( cur, """ INSERT INTO paper_order ( user_id, run_id, id, symbol, side, qty, price, status, timestamp, logical_time ) VALUES %s ON CONFLICT DO NOTHING """, rows, ) trades = store.get("trades") if isinstance(trades, list) and trades: rows = [] for trade in trades: if not isinstance(trade, dict): continue trade_id = trade.get("id") if not trade_id: continue ts = _parse_ts(trade.get("timestamp"), assume_local=False) logical_ts = _parse_ts(trade.get("_logical_time"), assume_local=False) or ts rows.append( ( scope_user, scope_run, trade_id, trade.get("order_id"), trade.get("symbol"), trade.get("side"), float(trade.get("qty", 0.0)), float(trade.get("price", 0.0)), ts, logical_ts, ) ) if rows: execute_values( cur, """ INSERT INTO paper_trade ( user_id, run_id, id, order_id, symbol, side, qty, price, timestamp, logical_time ) VALUES %s ON CONFLICT DO NOTHING """, rows, ) equity_curve = store.get("equity_curve") if isinstance(equity_curve, list) and equity_curve: rows = [] for point in equity_curve: if not isinstance(point, dict): continue ts = _parse_ts(point.get("timestamp"), assume_local=True) logical_ts = _parse_ts(point.get("_logical_time"), assume_local=True) or ts if ts is None: continue rows.append( ( scope_user, scope_run, ts, logical_ts, float(point.get("equity", 0.0)), float(point.get("pnl", 0.0)), ) ) if rows: execute_values( cur, """ INSERT INTO paper_equity_curve (user_id, run_id, timestamp, logical_time, equity, pnl) VALUES %s ON CONFLICT DO NOTHING """, rows, ) def get_funds(self, cur=None): store = self._load_store(cur=cur) cash = float(store.get("cash", 0)) positions = store.get("positions", {}) positions_value = 0.0 for position in positions.values(): qty = float(position.get("qty", 0)) last_price = float(position.get("last_price", position.get("avg_price", 0))) positions_value += qty * last_price total_equity = cash + positions_value return { "cash_available": cash, "invested_value": positions_value, "cash": cash, "used_margin": 0.0, "available": cash, "net": total_equity, "total_equity": total_equity, } def get_positions(self, cur=None): store = self._load_store(cur=cur) positions = store.get("positions", {}) return [ { "symbol": symbol, "qty": float(data.get("qty", 0)), "avg_price": float(data.get("avg_price", 0)), "last_price": float(data.get("last_price", data.get("avg_price", 0))), } for symbol, data in positions.items() ] def get_orders(self, cur=None): store = self._load_store(cur=cur) orders = [] for order in store.get("orders", []): if isinstance(order, dict): order = {k: v for k, v in order.items() if k != "_logical_time"} orders.append(order) return orders def get_trades(self, cur=None): store = self._load_store(cur=cur) trades = [] for trade in store.get("trades", []): if isinstance(trade, dict): trade = {k: v for k, v in trade.items() if k != "_logical_time"} trades.append(trade) return trades def get_equity_curve(self, cur=None): store = self._load_store(cur=cur) points = [] for point in store.get("equity_curve", []): if isinstance(point, dict): point = {k: v for k, v in point.items() if k != "_logical_time"} points.append(point) return points def _update_equity_in_tx( self, cur, prices: dict[str, float], now: datetime, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): store = self._load_store(cur=cur, for_update=True, user_id=user_id, run_id=run_id) positions = store.get("positions", {}) for symbol, price in prices.items(): if symbol in positions: positions[symbol]["last_price"] = float(price) cash = float(store.get("cash", 0)) positions_value = 0.0 for symbol, position in positions.items(): qty = float(position.get("qty", 0)) price = float(position.get("last_price", position.get("avg_price", 0))) positions_value += qty * price equity = cash + positions_value pnl = equity - float(self.initial_cash) ts_for_equity = logical_time or now store.setdefault("equity_curve", []).append( { "timestamp": _format_local_ts(ts_for_equity), "_logical_time": _format_local_ts(ts_for_equity), "equity": equity, "pnl": pnl, } ) store["positions"] = positions self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event( cur, "EQUITY_UPDATED", data={ "timestamp": _format_utc_ts(ts_for_equity), "equity": equity, "pnl": pnl, }, ) return equity def update_equity( self, prices: dict[str, float], now: datetime, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): if cur is not None: return self._update_equity_in_tx( cur, prices, now, logical_time=logical_time, user_id=user_id, run_id=run_id, ) def _op(cur, _conn): return self._update_equity_in_tx( cur, prices, now, logical_time=logical_time, user_id=user_id, run_id=run_id, ) return run_with_retry(_op) def _place_order_in_tx( self, cur, symbol: str, side: str, quantity: float, price: float | None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): scope_user, scope_run = _resolve_scope(user_id, run_id) store = self._load_store(cur=cur, for_update=True, user_id=scope_user, run_id=scope_run) side = side.upper().strip() qty = float(quantity) if price is None: price = fetch_live_price(symbol) price = float(price) logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) timestamp = logical_ts timestamp_str = _format_utc_ts(timestamp) logical_ts_str = _format_utc_ts(logical_ts) order_id = _deterministic_id( "ord", [ scope_user, scope_run, _normalize_ts_for_id(logical_ts), symbol, side, _stable_num(qty), _stable_num(price), ], ) order = { "id": order_id, "symbol": symbol, "side": side, "qty": qty, "requested_qty": qty, "filled_qty": 0.0, "price": price, "requested_price": price, "average_price": 0.0, "status": "REJECTED", "timestamp": timestamp_str, "_logical_time": logical_ts_str, } if qty <= 0 or price <= 0: store.setdefault("orders", []).append(order) self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event(cur, "ORDER_PLACED", data=order) return order positions = store.get("positions", {}) cash = float(store.get("cash", 0)) trade = None if side == "BUY": cost = qty * price if cash >= cost: cash -= cost existing = positions.get(symbol, {"qty": 0.0, "avg_price": 0.0, "last_price": price}) new_qty = float(existing.get("qty", 0)) + qty prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0)) avg_price = (prev_cost + cost) / new_qty if new_qty else price positions[symbol] = { "qty": new_qty, "avg_price": avg_price, "last_price": price, } order["status"] = "FILLED" order["filled_qty"] = qty order["average_price"] = price trade = { "id": _deterministic_id("trd", [order_id]), "order_id": order_id, "symbol": symbol, "side": side, "qty": qty, "price": price, "timestamp": timestamp_str, "_logical_time": logical_ts_str, } store.setdefault("trades", []).append(trade) elif side == "SELL": existing = positions.get(symbol) if existing and float(existing.get("qty", 0)) >= qty: cash += qty * price remaining = float(existing.get("qty", 0)) - qty if remaining > 0: existing["qty"] = remaining existing["last_price"] = price positions[symbol] = existing else: positions.pop(symbol, None) order["status"] = "FILLED" order["filled_qty"] = qty order["average_price"] = price trade = { "id": _deterministic_id("trd", [order_id]), "order_id": order_id, "symbol": symbol, "side": side, "qty": qty, "price": price, "timestamp": timestamp_str, "_logical_time": logical_ts_str, } store.setdefault("trades", []).append(trade) store["cash"] = cash store["positions"] = positions store.setdefault("orders", []).append(order) self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) insert_engine_event(cur, "ORDER_PLACED", data=order) if trade is not None: insert_engine_event(cur, "TRADE_EXECUTED", data=trade) insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order_id}) return order def place_order( self, symbol: str, side: str, quantity: float, price: float | None = None, cur=None, logical_time: datetime | None = None, user_id: str | None = None, run_id: str | None = None, ): if cur is not None: return self._place_order_in_tx( cur, symbol, side, quantity, price, logical_time=logical_time, user_id=user_id, run_id=run_id, ) def _op(cur, _conn): return self._place_order_in_tx( cur, symbol, side, quantity, price, logical_time=logical_time, user_id=user_id, run_id=run_id, ) return run_with_retry(_op)