1039 lines
36 KiB
Python
1039 lines
36 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
import hashlib
|
|
import math
|
|
import os
|
|
import time
|
|
|
|
from psycopg2.extras import execute_values
|
|
|
|
from indian_paper_trading_strategy.engine.data import fetch_live_price
|
|
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
|
|
|
|
|
|
class Broker(ABC):
|
|
external_orders = False
|
|
|
|
@abstractmethod
|
|
def place_order(
|
|
self,
|
|
symbol: str,
|
|
side: str,
|
|
quantity: float,
|
|
price: float | None = None,
|
|
logical_time: datetime | None = None,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_positions(self):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_orders(self):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_funds(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class BrokerError(Exception):
|
|
pass
|
|
|
|
|
|
class BrokerAuthExpired(BrokerError):
|
|
pass
|
|
|
|
|
|
def _local_tz():
|
|
return datetime.now().astimezone().tzinfo
|
|
|
|
|
|
def _format_utc_ts(value: datetime | None):
|
|
if value is None:
|
|
return None
|
|
if value.tzinfo is None:
|
|
value = value.replace(tzinfo=_local_tz())
|
|
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def _format_local_ts(value: datetime | None):
|
|
if value is None:
|
|
return None
|
|
if value.tzinfo is None:
|
|
value = value.replace(tzinfo=_local_tz())
|
|
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
|
|
|
|
|
|
def _parse_ts(value, assume_local: bool = True):
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, datetime):
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=_local_tz() if assume_local else timezone.utc)
|
|
return value
|
|
if isinstance(value, str):
|
|
text = value.strip()
|
|
if not text:
|
|
return None
|
|
if text.endswith("Z"):
|
|
try:
|
|
return datetime.fromisoformat(text.replace("Z", "+00:00"))
|
|
except ValueError:
|
|
return None
|
|
try:
|
|
parsed = datetime.fromisoformat(text)
|
|
except ValueError:
|
|
return None
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=_local_tz() if assume_local else timezone.utc)
|
|
return parsed
|
|
return None
|
|
|
|
|
|
def _stable_num(value: float) -> str:
|
|
return f"{float(value):.12f}"
|
|
|
|
|
|
def _normalize_ts_for_id(ts: datetime) -> str:
|
|
if ts.tzinfo is None:
|
|
ts = ts.replace(tzinfo=timezone.utc)
|
|
return ts.astimezone(timezone.utc).replace(microsecond=0).isoformat()
|
|
|
|
|
|
def _deterministic_id(prefix: str, parts: list[str]) -> str:
|
|
payload = "|".join(parts)
|
|
digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]
|
|
return f"{prefix}_{digest}"
|
|
|
|
|
|
def _resolve_scope(user_id: str | None, run_id: str | None):
|
|
return get_context(user_id, run_id)
|
|
|
|
|
|
class LiveZerodhaBroker(Broker):
|
|
external_orders = True
|
|
|
|
TERMINAL_STATUSES = {"COMPLETE", "REJECTED", "CANCELLED"}
|
|
POLL_TIMEOUT_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_TIMEOUT", "12"))
|
|
POLL_INTERVAL_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_INTERVAL", "1"))
|
|
|
|
def __init__(self, user_id: str | None = None, run_id: str | None = None):
|
|
self.user_id = user_id
|
|
self.run_id = run_id
|
|
|
|
def _scope(self):
|
|
return _resolve_scope(self.user_id, self.run_id)
|
|
|
|
def _session(self):
|
|
from app.services.zerodha_storage import get_session
|
|
|
|
user_id, _run_id = self._scope()
|
|
session = get_session(user_id)
|
|
if not session or not session.get("api_key") or not session.get("access_token"):
|
|
raise BrokerAuthExpired("Zerodha session missing. Please reconnect broker.")
|
|
return session
|
|
|
|
def _raise_auth_expired(self, exc: Exception):
|
|
from app.broker_store import set_broker_auth_state
|
|
|
|
user_id, _run_id = self._scope()
|
|
set_broker_auth_state(user_id, "EXPIRED")
|
|
raise BrokerAuthExpired(str(exc)) from exc
|
|
|
|
def _normalize_symbol(self, symbol: str) -> tuple[str, str]:
|
|
cleaned = (symbol or "").strip().upper()
|
|
if cleaned.endswith(".NS"):
|
|
return cleaned[:-3], "NSE"
|
|
if cleaned.endswith(".BO"):
|
|
return cleaned[:-3], "BSE"
|
|
return cleaned, "NSE"
|
|
|
|
def _make_tag(self, logical_time: datetime | None, symbol: str, side: str) -> str:
|
|
user_id, run_id = self._scope()
|
|
logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)
|
|
digest = hashlib.sha1(
|
|
f"{user_id}|{run_id}|{_normalize_ts_for_id(logical_ts)}|{symbol}|{side}".encode("utf-8")
|
|
).hexdigest()[:18]
|
|
return f"qf{digest}"
|
|
|
|
def _normalize_order_payload(
|
|
self,
|
|
*,
|
|
order_id: str,
|
|
symbol: str,
|
|
side: str,
|
|
requested_qty: int,
|
|
requested_price: float | None,
|
|
history_entry: dict | None,
|
|
logical_time: datetime | None,
|
|
) -> dict:
|
|
entry = history_entry or {}
|
|
raw_status = (entry.get("status") or "").upper()
|
|
status = raw_status
|
|
if raw_status == "COMPLETE":
|
|
status = "FILLED"
|
|
elif raw_status in {"REJECTED", "CANCELLED"}:
|
|
status = raw_status
|
|
elif raw_status:
|
|
status = "PENDING"
|
|
else:
|
|
status = "PENDING"
|
|
|
|
quantity = int(entry.get("quantity") or requested_qty or 0)
|
|
filled_qty = int(entry.get("filled_quantity") or 0)
|
|
average_price = float(entry.get("average_price") or requested_price or 0.0)
|
|
price = float(entry.get("price") or requested_price or average_price or 0.0)
|
|
timestamp = (
|
|
entry.get("exchange_timestamp")
|
|
or entry.get("order_timestamp")
|
|
or _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc))
|
|
)
|
|
if timestamp and " " in str(timestamp):
|
|
timestamp = str(timestamp).replace(" ", "T")
|
|
|
|
return {
|
|
"id": order_id,
|
|
"symbol": symbol,
|
|
"side": side.upper().strip(),
|
|
"qty": quantity,
|
|
"requested_qty": quantity,
|
|
"filled_qty": filled_qty,
|
|
"price": price,
|
|
"requested_price": float(requested_price or price or 0.0),
|
|
"average_price": average_price,
|
|
"status": status,
|
|
"timestamp": timestamp,
|
|
"broker_order_id": order_id,
|
|
"exchange": entry.get("exchange"),
|
|
"tradingsymbol": entry.get("tradingsymbol"),
|
|
"status_message": entry.get("status_message") or entry.get("status_message_raw"),
|
|
}
|
|
|
|
def _wait_for_terminal_order(
|
|
self,
|
|
session: dict,
|
|
order_id: str,
|
|
*,
|
|
symbol: str,
|
|
side: str,
|
|
requested_qty: int,
|
|
requested_price: float | None,
|
|
logical_time: datetime | None,
|
|
) -> dict:
|
|
from app.services.zerodha_service import (
|
|
KiteTokenError,
|
|
cancel_order,
|
|
fetch_order_history,
|
|
)
|
|
|
|
started = time.monotonic()
|
|
last_payload = self._normalize_order_payload(
|
|
order_id=order_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
requested_qty=requested_qty,
|
|
requested_price=requested_price,
|
|
history_entry=None,
|
|
logical_time=logical_time,
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
history = fetch_order_history(
|
|
session["api_key"],
|
|
session["access_token"],
|
|
order_id,
|
|
)
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
|
|
if history:
|
|
entry = history[-1]
|
|
last_payload = self._normalize_order_payload(
|
|
order_id=order_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
requested_qty=requested_qty,
|
|
requested_price=requested_price,
|
|
history_entry=entry,
|
|
logical_time=logical_time,
|
|
)
|
|
raw_status = (entry.get("status") or "").upper()
|
|
if raw_status in self.TERMINAL_STATUSES:
|
|
return last_payload
|
|
|
|
if time.monotonic() - started >= self.POLL_TIMEOUT_SECONDS:
|
|
try:
|
|
cancel_order(
|
|
session["api_key"],
|
|
session["access_token"],
|
|
order_id=order_id,
|
|
)
|
|
history = fetch_order_history(
|
|
session["api_key"],
|
|
session["access_token"],
|
|
order_id,
|
|
)
|
|
if history:
|
|
return self._normalize_order_payload(
|
|
order_id=order_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
requested_qty=requested_qty,
|
|
requested_price=requested_price,
|
|
history_entry=history[-1],
|
|
logical_time=logical_time,
|
|
)
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
return last_payload
|
|
|
|
time.sleep(self.POLL_INTERVAL_SECONDS)
|
|
|
|
def get_funds(self, cur=None):
|
|
from app.services.zerodha_service import KiteTokenError, fetch_funds
|
|
|
|
session = self._session()
|
|
try:
|
|
data = fetch_funds(session["api_key"], session["access_token"])
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
|
|
equity = data.get("equity", {}) if isinstance(data, dict) else {}
|
|
available = equity.get("available", {}) if isinstance(equity, dict) else {}
|
|
cash = (
|
|
available.get("live_balance")
|
|
or available.get("cash")
|
|
or available.get("opening_balance")
|
|
or equity.get("net")
|
|
or 0.0
|
|
)
|
|
return {"cash": float(cash), "raw": data}
|
|
|
|
def get_positions(self):
|
|
from app.services.zerodha_service import KiteTokenError, fetch_holdings
|
|
|
|
session = self._session()
|
|
try:
|
|
holdings = fetch_holdings(session["api_key"], session["access_token"])
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
|
|
normalized = []
|
|
for item in holdings:
|
|
qty = float(item.get("quantity") or item.get("qty") or 0)
|
|
avg_price = float(item.get("average_price") or 0)
|
|
last_price = float(item.get("last_price") or avg_price or 0)
|
|
exchange = item.get("exchange")
|
|
suffix = ".NS" if exchange == "NSE" else ".BO" if exchange == "BSE" else ""
|
|
normalized.append(
|
|
{
|
|
"symbol": f"{item.get('tradingsymbol')}{suffix}",
|
|
"qty": qty,
|
|
"avg_price": avg_price,
|
|
"last_price": last_price,
|
|
}
|
|
)
|
|
return normalized
|
|
|
|
def get_orders(self):
|
|
from app.services.zerodha_service import KiteTokenError, fetch_orders
|
|
|
|
session = self._session()
|
|
try:
|
|
return fetch_orders(session["api_key"], session["access_token"])
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
|
|
def update_equity(
|
|
self,
|
|
prices: dict[str, float],
|
|
now: datetime,
|
|
cur=None,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
return None
|
|
|
|
def place_order(
|
|
self,
|
|
symbol: str,
|
|
side: str,
|
|
quantity: float,
|
|
price: float | None = None,
|
|
cur=None,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
from app.services.zerodha_service import KiteTokenError, place_order
|
|
|
|
if user_id is not None:
|
|
self.user_id = user_id
|
|
if run_id is not None:
|
|
self.run_id = run_id
|
|
|
|
qty = int(math.floor(float(quantity)))
|
|
side = side.upper().strip()
|
|
requested_price = float(price) if price is not None else None
|
|
if qty <= 0:
|
|
return {
|
|
"id": _deterministic_id("live_rej", [symbol, side, _stable_num(quantity)]),
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": qty,
|
|
"requested_qty": qty,
|
|
"filled_qty": 0,
|
|
"price": float(price or 0.0),
|
|
"requested_price": float(price or 0.0),
|
|
"average_price": 0.0,
|
|
"status": "REJECTED",
|
|
"timestamp": _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)),
|
|
"status_message": "Computed quantity is less than 1 share",
|
|
}
|
|
|
|
session = self._session()
|
|
tradingsymbol, exchange = self._normalize_symbol(symbol)
|
|
tag = self._make_tag(logical_time, symbol, side)
|
|
|
|
try:
|
|
placed = place_order(
|
|
session["api_key"],
|
|
session["access_token"],
|
|
tradingsymbol=tradingsymbol,
|
|
exchange=exchange,
|
|
transaction_type=side,
|
|
order_type="MARKET",
|
|
quantity=qty,
|
|
product="CNC",
|
|
validity="DAY",
|
|
variety="regular",
|
|
market_protection=-1,
|
|
tag=tag,
|
|
)
|
|
except KiteTokenError as exc:
|
|
self._raise_auth_expired(exc)
|
|
|
|
order_id = placed.get("order_id")
|
|
if not order_id:
|
|
raise BrokerError("Zerodha order placement did not return an order_id")
|
|
|
|
return self._wait_for_terminal_order(
|
|
session,
|
|
order_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
requested_qty=qty,
|
|
requested_price=requested_price,
|
|
logical_time=logical_time,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PaperBroker(Broker):
|
|
initial_cash: float
|
|
store_path: str | None = None
|
|
|
|
def _default_store(self):
|
|
return {
|
|
"cash": float(self.initial_cash),
|
|
"positions": {},
|
|
"orders": [],
|
|
"trades": [],
|
|
"equity_curve": [],
|
|
}
|
|
|
|
def _load_store(self, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None):
|
|
scope_user, scope_run = _resolve_scope(user_id, run_id)
|
|
if cur is None:
|
|
with db_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
return self._load_store(
|
|
cur=cur,
|
|
for_update=for_update,
|
|
user_id=scope_user,
|
|
run_id=scope_run,
|
|
)
|
|
|
|
store = self._default_store()
|
|
lock_clause = " FOR UPDATE" if for_update else ""
|
|
cur.execute(
|
|
f"SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1",
|
|
(scope_user, scope_run),
|
|
)
|
|
row = cur.fetchone()
|
|
if row and row[0] is not None:
|
|
store["cash"] = float(row[0])
|
|
|
|
cur.execute(
|
|
f"""
|
|
SELECT symbol, qty, avg_price, last_price
|
|
FROM paper_position
|
|
WHERE user_id = %s AND run_id = %s{lock_clause}
|
|
"""
|
|
,
|
|
(scope_user, scope_run),
|
|
)
|
|
positions = {}
|
|
for symbol, qty, avg_price, last_price in cur.fetchall():
|
|
positions[symbol] = {
|
|
"qty": float(qty) if qty is not None else 0.0,
|
|
"avg_price": float(avg_price) if avg_price is not None else 0.0,
|
|
"last_price": float(last_price) if last_price is not None else 0.0,
|
|
}
|
|
store["positions"] = positions
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT id, symbol, side, qty, price, status, timestamp, logical_time
|
|
FROM paper_order
|
|
WHERE user_id = %s AND run_id = %s
|
|
ORDER BY timestamp, id
|
|
"""
|
|
,
|
|
(scope_user, scope_run),
|
|
)
|
|
orders = []
|
|
for order_id, symbol, side, qty, price, status, ts, logical_ts in cur.fetchall():
|
|
orders.append(
|
|
{
|
|
"id": order_id,
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": float(qty) if qty is not None else 0.0,
|
|
"price": float(price) if price is not None else 0.0,
|
|
"status": status,
|
|
"timestamp": _format_utc_ts(ts),
|
|
"_logical_time": _format_utc_ts(logical_ts),
|
|
}
|
|
)
|
|
store["orders"] = orders
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT id, order_id, symbol, side, qty, price, timestamp, logical_time
|
|
FROM paper_trade
|
|
WHERE user_id = %s AND run_id = %s
|
|
ORDER BY timestamp, id
|
|
"""
|
|
,
|
|
(scope_user, scope_run),
|
|
)
|
|
trades = []
|
|
for trade_id, order_id, symbol, side, qty, price, ts, logical_ts in cur.fetchall():
|
|
trades.append(
|
|
{
|
|
"id": trade_id,
|
|
"order_id": order_id,
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": float(qty) if qty is not None else 0.0,
|
|
"price": float(price) if price is not None else 0.0,
|
|
"timestamp": _format_utc_ts(ts),
|
|
"_logical_time": _format_utc_ts(logical_ts),
|
|
}
|
|
)
|
|
store["trades"] = trades
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT timestamp, logical_time, equity, pnl
|
|
FROM paper_equity_curve
|
|
WHERE user_id = %s AND run_id = %s
|
|
ORDER BY timestamp
|
|
"""
|
|
,
|
|
(scope_user, scope_run),
|
|
)
|
|
equity_curve = []
|
|
for ts, logical_ts, equity, pnl in cur.fetchall():
|
|
equity_curve.append(
|
|
{
|
|
"timestamp": _format_local_ts(ts),
|
|
"_logical_time": _format_local_ts(logical_ts),
|
|
"equity": float(equity) if equity is not None else 0.0,
|
|
"pnl": float(pnl) if pnl is not None else 0.0,
|
|
}
|
|
)
|
|
store["equity_curve"] = equity_curve
|
|
return store
|
|
|
|
def _save_store(self, store, cur=None, user_id: str | None = None, run_id: str | None = None):
|
|
scope_user, scope_run = _resolve_scope(user_id, run_id)
|
|
if cur is None:
|
|
def _persist(cur, _conn):
|
|
self._save_store(store, cur=cur, user_id=scope_user, run_id=scope_run)
|
|
return run_with_retry(_persist)
|
|
|
|
cash = store.get("cash")
|
|
if cash is not None:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO paper_broker_account (user_id, run_id, cash)
|
|
VALUES (%s, %s, %s)
|
|
ON CONFLICT (user_id, run_id) DO UPDATE
|
|
SET cash = EXCLUDED.cash
|
|
""",
|
|
(scope_user, scope_run, float(cash)),
|
|
)
|
|
|
|
positions = store.get("positions")
|
|
if isinstance(positions, dict):
|
|
symbols = [s for s in positions.keys() if s]
|
|
if symbols:
|
|
cur.execute(
|
|
"DELETE FROM paper_position WHERE user_id = %s AND run_id = %s AND symbol NOT IN %s",
|
|
(scope_user, scope_run, tuple(symbols)),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"DELETE FROM paper_position WHERE user_id = %s AND run_id = %s",
|
|
(scope_user, scope_run),
|
|
)
|
|
|
|
if symbols:
|
|
rows = []
|
|
updated_at = datetime.now(timezone.utc)
|
|
for symbol, data in positions.items():
|
|
if not symbol or not isinstance(data, dict):
|
|
continue
|
|
rows.append(
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
symbol,
|
|
float(data.get("qty", 0.0)),
|
|
float(data.get("avg_price", 0.0)),
|
|
float(data.get("last_price", 0.0)),
|
|
updated_at,
|
|
)
|
|
)
|
|
if rows:
|
|
execute_values(
|
|
cur,
|
|
"""
|
|
INSERT INTO paper_position (
|
|
user_id, run_id, symbol, qty, avg_price, last_price, updated_at
|
|
)
|
|
VALUES %s
|
|
ON CONFLICT (user_id, run_id, symbol) DO UPDATE
|
|
SET qty = EXCLUDED.qty,
|
|
avg_price = EXCLUDED.avg_price,
|
|
last_price = EXCLUDED.last_price,
|
|
updated_at = EXCLUDED.updated_at
|
|
""",
|
|
rows,
|
|
)
|
|
|
|
orders = store.get("orders")
|
|
if isinstance(orders, list) and orders:
|
|
rows = []
|
|
for order in orders:
|
|
if not isinstance(order, dict):
|
|
continue
|
|
order_id = order.get("id")
|
|
if not order_id:
|
|
continue
|
|
ts = _parse_ts(order.get("timestamp"), assume_local=False)
|
|
logical_ts = _parse_ts(order.get("_logical_time"), assume_local=False) or ts
|
|
rows.append(
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
order_id,
|
|
order.get("symbol"),
|
|
order.get("side"),
|
|
float(order.get("qty", 0.0)),
|
|
float(order.get("price", 0.0)),
|
|
order.get("status"),
|
|
ts,
|
|
logical_ts,
|
|
)
|
|
)
|
|
if rows:
|
|
execute_values(
|
|
cur,
|
|
"""
|
|
INSERT INTO paper_order (
|
|
user_id, run_id, id, symbol, side, qty, price, status, timestamp, logical_time
|
|
)
|
|
VALUES %s
|
|
ON CONFLICT DO NOTHING
|
|
""",
|
|
rows,
|
|
)
|
|
|
|
trades = store.get("trades")
|
|
if isinstance(trades, list) and trades:
|
|
rows = []
|
|
for trade in trades:
|
|
if not isinstance(trade, dict):
|
|
continue
|
|
trade_id = trade.get("id")
|
|
if not trade_id:
|
|
continue
|
|
ts = _parse_ts(trade.get("timestamp"), assume_local=False)
|
|
logical_ts = _parse_ts(trade.get("_logical_time"), assume_local=False) or ts
|
|
rows.append(
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
trade_id,
|
|
trade.get("order_id"),
|
|
trade.get("symbol"),
|
|
trade.get("side"),
|
|
float(trade.get("qty", 0.0)),
|
|
float(trade.get("price", 0.0)),
|
|
ts,
|
|
logical_ts,
|
|
)
|
|
)
|
|
if rows:
|
|
execute_values(
|
|
cur,
|
|
"""
|
|
INSERT INTO paper_trade (
|
|
user_id, run_id, id, order_id, symbol, side, qty, price, timestamp, logical_time
|
|
)
|
|
VALUES %s
|
|
ON CONFLICT DO NOTHING
|
|
""",
|
|
rows,
|
|
)
|
|
|
|
equity_curve = store.get("equity_curve")
|
|
if isinstance(equity_curve, list) and equity_curve:
|
|
rows = []
|
|
for point in equity_curve:
|
|
if not isinstance(point, dict):
|
|
continue
|
|
ts = _parse_ts(point.get("timestamp"), assume_local=True)
|
|
logical_ts = _parse_ts(point.get("_logical_time"), assume_local=True) or ts
|
|
if ts is None:
|
|
continue
|
|
rows.append(
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
ts,
|
|
logical_ts,
|
|
float(point.get("equity", 0.0)),
|
|
float(point.get("pnl", 0.0)),
|
|
)
|
|
)
|
|
if rows:
|
|
execute_values(
|
|
cur,
|
|
"""
|
|
INSERT INTO paper_equity_curve (user_id, run_id, timestamp, logical_time, equity, pnl)
|
|
VALUES %s
|
|
ON CONFLICT DO NOTHING
|
|
""",
|
|
rows,
|
|
)
|
|
|
|
def get_funds(self, cur=None):
|
|
store = self._load_store(cur=cur)
|
|
cash = float(store.get("cash", 0))
|
|
positions = store.get("positions", {})
|
|
positions_value = 0.0
|
|
for position in positions.values():
|
|
qty = float(position.get("qty", 0))
|
|
last_price = float(position.get("last_price", position.get("avg_price", 0)))
|
|
positions_value += qty * last_price
|
|
total_equity = cash + positions_value
|
|
return {
|
|
"cash_available": cash,
|
|
"invested_value": positions_value,
|
|
"cash": cash,
|
|
"used_margin": 0.0,
|
|
"available": cash,
|
|
"net": total_equity,
|
|
"total_equity": total_equity,
|
|
}
|
|
|
|
def get_positions(self, cur=None):
|
|
store = self._load_store(cur=cur)
|
|
positions = store.get("positions", {})
|
|
return [
|
|
{
|
|
"symbol": symbol,
|
|
"qty": float(data.get("qty", 0)),
|
|
"avg_price": float(data.get("avg_price", 0)),
|
|
"last_price": float(data.get("last_price", data.get("avg_price", 0))),
|
|
}
|
|
for symbol, data in positions.items()
|
|
]
|
|
|
|
def get_orders(self, cur=None):
|
|
store = self._load_store(cur=cur)
|
|
orders = []
|
|
for order in store.get("orders", []):
|
|
if isinstance(order, dict):
|
|
order = {k: v for k, v in order.items() if k != "_logical_time"}
|
|
orders.append(order)
|
|
return orders
|
|
|
|
def get_trades(self, cur=None):
|
|
store = self._load_store(cur=cur)
|
|
trades = []
|
|
for trade in store.get("trades", []):
|
|
if isinstance(trade, dict):
|
|
trade = {k: v for k, v in trade.items() if k != "_logical_time"}
|
|
trades.append(trade)
|
|
return trades
|
|
|
|
def get_equity_curve(self, cur=None):
|
|
store = self._load_store(cur=cur)
|
|
points = []
|
|
for point in store.get("equity_curve", []):
|
|
if isinstance(point, dict):
|
|
point = {k: v for k, v in point.items() if k != "_logical_time"}
|
|
points.append(point)
|
|
return points
|
|
|
|
def _update_equity_in_tx(
|
|
self,
|
|
cur,
|
|
prices: dict[str, float],
|
|
now: datetime,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
store = self._load_store(cur=cur, for_update=True, user_id=user_id, run_id=run_id)
|
|
positions = store.get("positions", {})
|
|
for symbol, price in prices.items():
|
|
if symbol in positions:
|
|
positions[symbol]["last_price"] = float(price)
|
|
|
|
cash = float(store.get("cash", 0))
|
|
positions_value = 0.0
|
|
for symbol, position in positions.items():
|
|
qty = float(position.get("qty", 0))
|
|
price = float(position.get("last_price", position.get("avg_price", 0)))
|
|
positions_value += qty * price
|
|
|
|
equity = cash + positions_value
|
|
pnl = equity - float(self.initial_cash)
|
|
ts_for_equity = logical_time or now
|
|
store.setdefault("equity_curve", []).append(
|
|
{
|
|
"timestamp": _format_local_ts(ts_for_equity),
|
|
"_logical_time": _format_local_ts(ts_for_equity),
|
|
"equity": equity,
|
|
"pnl": pnl,
|
|
}
|
|
)
|
|
store["positions"] = positions
|
|
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
|
|
insert_engine_event(
|
|
cur,
|
|
"EQUITY_UPDATED",
|
|
data={
|
|
"timestamp": _format_utc_ts(ts_for_equity),
|
|
"equity": equity,
|
|
"pnl": pnl,
|
|
},
|
|
)
|
|
return equity
|
|
|
|
def update_equity(
|
|
self,
|
|
prices: dict[str, float],
|
|
now: datetime,
|
|
cur=None,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
if cur is not None:
|
|
return self._update_equity_in_tx(
|
|
cur,
|
|
prices,
|
|
now,
|
|
logical_time=logical_time,
|
|
user_id=user_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
def _op(cur, _conn):
|
|
return self._update_equity_in_tx(
|
|
cur,
|
|
prices,
|
|
now,
|
|
logical_time=logical_time,
|
|
user_id=user_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
return run_with_retry(_op)
|
|
|
|
def _place_order_in_tx(
|
|
self,
|
|
cur,
|
|
symbol: str,
|
|
side: str,
|
|
quantity: float,
|
|
price: float | None,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
scope_user, scope_run = _resolve_scope(user_id, run_id)
|
|
store = self._load_store(cur=cur, for_update=True, user_id=scope_user, run_id=scope_run)
|
|
side = side.upper().strip()
|
|
qty = float(quantity)
|
|
if price is None:
|
|
price = fetch_live_price(symbol)
|
|
price = float(price)
|
|
|
|
logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)
|
|
timestamp = logical_ts
|
|
timestamp_str = _format_utc_ts(timestamp)
|
|
logical_ts_str = _format_utc_ts(logical_ts)
|
|
order_id = _deterministic_id(
|
|
"ord",
|
|
[
|
|
scope_user,
|
|
scope_run,
|
|
_normalize_ts_for_id(logical_ts),
|
|
symbol,
|
|
side,
|
|
_stable_num(qty),
|
|
_stable_num(price),
|
|
],
|
|
)
|
|
|
|
order = {
|
|
"id": order_id,
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": qty,
|
|
"requested_qty": qty,
|
|
"filled_qty": 0.0,
|
|
"price": price,
|
|
"requested_price": price,
|
|
"average_price": 0.0,
|
|
"status": "REJECTED",
|
|
"timestamp": timestamp_str,
|
|
"_logical_time": logical_ts_str,
|
|
}
|
|
|
|
if qty <= 0 or price <= 0:
|
|
store.setdefault("orders", []).append(order)
|
|
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
|
|
insert_engine_event(cur, "ORDER_PLACED", data=order)
|
|
return order
|
|
|
|
positions = store.get("positions", {})
|
|
cash = float(store.get("cash", 0))
|
|
trade = None
|
|
|
|
if side == "BUY":
|
|
cost = qty * price
|
|
if cash >= cost:
|
|
cash -= cost
|
|
existing = positions.get(symbol, {"qty": 0.0, "avg_price": 0.0, "last_price": price})
|
|
new_qty = float(existing.get("qty", 0)) + qty
|
|
prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0))
|
|
avg_price = (prev_cost + cost) / new_qty if new_qty else price
|
|
positions[symbol] = {
|
|
"qty": new_qty,
|
|
"avg_price": avg_price,
|
|
"last_price": price,
|
|
}
|
|
order["status"] = "FILLED"
|
|
order["filled_qty"] = qty
|
|
order["average_price"] = price
|
|
trade = {
|
|
"id": _deterministic_id("trd", [order_id]),
|
|
"order_id": order_id,
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": qty,
|
|
"price": price,
|
|
"timestamp": timestamp_str,
|
|
"_logical_time": logical_ts_str,
|
|
}
|
|
store.setdefault("trades", []).append(trade)
|
|
elif side == "SELL":
|
|
existing = positions.get(symbol)
|
|
if existing and float(existing.get("qty", 0)) >= qty:
|
|
cash += qty * price
|
|
remaining = float(existing.get("qty", 0)) - qty
|
|
if remaining > 0:
|
|
existing["qty"] = remaining
|
|
existing["last_price"] = price
|
|
positions[symbol] = existing
|
|
else:
|
|
positions.pop(symbol, None)
|
|
order["status"] = "FILLED"
|
|
order["filled_qty"] = qty
|
|
order["average_price"] = price
|
|
trade = {
|
|
"id": _deterministic_id("trd", [order_id]),
|
|
"order_id": order_id,
|
|
"symbol": symbol,
|
|
"side": side,
|
|
"qty": qty,
|
|
"price": price,
|
|
"timestamp": timestamp_str,
|
|
"_logical_time": logical_ts_str,
|
|
}
|
|
store.setdefault("trades", []).append(trade)
|
|
|
|
store["cash"] = cash
|
|
store["positions"] = positions
|
|
store.setdefault("orders", []).append(order)
|
|
self._save_store(store, cur=cur, user_id=user_id, run_id=run_id)
|
|
insert_engine_event(cur, "ORDER_PLACED", data=order)
|
|
if trade is not None:
|
|
insert_engine_event(cur, "TRADE_EXECUTED", data=trade)
|
|
insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order_id})
|
|
return order
|
|
|
|
def place_order(
|
|
self,
|
|
symbol: str,
|
|
side: str,
|
|
quantity: float,
|
|
price: float | None = None,
|
|
cur=None,
|
|
logical_time: datetime | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
if cur is not None:
|
|
return self._place_order_in_tx(
|
|
cur,
|
|
symbol,
|
|
side,
|
|
quantity,
|
|
price,
|
|
logical_time=logical_time,
|
|
user_id=user_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
def _op(cur, _conn):
|
|
return self._place_order_in_tx(
|
|
cur,
|
|
symbol,
|
|
side,
|
|
quantity,
|
|
price,
|
|
logical_time=logical_time,
|
|
user_id=user_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
return run_with_retry(_op)
|
|
|