From c17222ad9c2006e06d396ef4d5003bde547310c1 Mon Sep 17 00:00:00 2001 From: Thigazhezhilan J Date: Tue, 24 Mar 2026 21:59:17 +0530 Subject: [PATCH] Refine live strategy execution flow --- backend/app/models.py | 2 +- backend/app/routers/broker.py | 26 +- backend/app/routers/strategy.py | 6 + backend/app/services/strategy_service.py | 252 +++++++- backend/app/services/zerodha_service.py | 120 +++- .../engine/broker.py | 443 +++++++++++-- indian_paper_trading_strategy/engine/data.py | 165 +++-- .../engine/execution.py | 593 ++++++++++++++---- .../engine/history.py | 68 +- .../engine/runner.py | 234 +++++-- 10 files changed, 1520 insertions(+), 389 deletions(-) diff --git a/backend/app/models.py b/backend/app/models.py index 20308dc..a43c64a 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -11,7 +11,7 @@ class StrategyStartRequest(BaseModel): initial_cash: Optional[float] = None sip_amount: float sip_frequency: SipFrequency - mode: Literal["PAPER"] + mode: Literal["PAPER", "LIVE"] @validator("initial_cash") def validate_cash(cls, v): diff --git a/backend/app/routers/broker.py b/backend/app/routers/broker.py index f16443b..188b453 100644 --- a/backend/app/routers/broker.py +++ b/backend/app/routers/broker.py @@ -30,6 +30,17 @@ def _require_user(request: Request): return user +def _build_saved_broker_login_url(request: Request, user_id: str) -> str: + creds = get_broker_credentials(user_id) + if not creds: + raise HTTPException(status_code=400, detail="Broker credentials not configured") + redirect_url = (os.getenv("ZERODHA_REDIRECT_URL") or "").strip() + if not redirect_url: + base = str(request.base_url).rstrip("/") + redirect_url = f"{base}/api/broker/callback" + return build_login_url(creds["api_key"], redirect_url=redirect_url) + + @router.post("/connect") async def connect_broker(payload: dict, request: Request): user = _require_user(request) @@ -153,17 +164,16 @@ async def zerodha_callback(request: Request, request_token: str = ""): @router.get("/login") async def broker_login(request: Request): user = _require_user(request) - creds = get_broker_credentials(user["id"]) - if not creds: - raise HTTPException(status_code=400, detail="Broker credentials not configured") - redirect_url = (os.getenv("ZERODHA_REDIRECT_URL") or "").strip() - if not redirect_url: - base = str(request.base_url).rstrip("/") - redirect_url = f"{base}/api/broker/callback" - login_url = build_login_url(creds["api_key"], redirect_url=redirect_url) + login_url = _build_saved_broker_login_url(request, user["id"]) return RedirectResponse(login_url) +@router.get("/login-url") +async def broker_login_url(request: Request): + user = _require_user(request) + return {"loginUrl": _build_saved_broker_login_url(request, user["id"])} + + @router.get("/callback") async def broker_callback(request: Request, request_token: str = ""): user = _require_user(request) diff --git a/backend/app/routers/strategy.py b/backend/app/routers/strategy.py index 5df9b50..862d765 100644 --- a/backend/app/routers/strategy.py +++ b/backend/app/routers/strategy.py @@ -4,6 +4,7 @@ from app.services.strategy_service import ( start_strategy, stop_strategy, get_strategy_status, + get_strategy_summary, get_engine_status, get_market_status, get_strategy_logs as fetch_strategy_logs, @@ -27,6 +28,11 @@ def status(request: Request): user_id = get_request_user_id(request) return get_strategy_status(user_id) +@router.get("/strategy/summary") +def summary(request: Request): + user_id = get_request_user_id(request) + return get_strategy_summary(user_id) + @router.get("/engine/status") def engine_status(request: Request): user_id = get_request_user_id(request) diff --git a/backend/app/services/strategy_service.py b/backend/app/services/strategy_service.py index 3c83e9a..1447e89 100644 --- a/backend/app/services/strategy_service.py +++ b/backend/app/services/strategy_service.py @@ -11,11 +11,12 @@ if str(ENGINE_ROOT) not in sys.path: from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine -from indian_paper_trading_strategy.engine.state import init_paper_state, load_state +from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state from indian_paper_trading_strategy.engine.broker import PaperBroker from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta from indian_paper_trading_strategy.engine.db import engine_context +from app.broker_store import get_user_broker, set_broker_auth_state from app.services.db import db_connection from app.services.run_service import ( create_strategy_run, @@ -25,6 +26,11 @@ from app.services.run_service import ( ) from app.services.auth_service import get_user_by_id from app.services.email_service import send_email_async +from app.services.zerodha_service import ( + KiteTokenError, + fetch_funds, +) +from app.services.zerodha_storage import get_session from psycopg2.extras import Json from psycopg2 import errors @@ -298,6 +304,27 @@ def validate_frequency(freq: dict, mode: str): if unit == "days" and value < 1: raise ValueError("Minimum frequency is 1 day") + +def _validate_live_broker_session(user_id: str): + broker_state = get_user_broker(user_id) or {} + broker_name = (broker_state.get("broker") or "").strip().upper() + if not broker_state.get("connected") or broker_name != "ZERODHA": + return False, broker_state, "broker_not_connected" + + session = get_session(user_id) + if not session: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + try: + fetch_funds(session["api_key"], session["access_token"]) + except KiteTokenError: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + set_broker_auth_state(user_id, "VALID") + return True, broker_state, "ok" + def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): if not last_run or not sip_frequency: return None @@ -313,10 +340,27 @@ def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): next_dt = align_to_market_open(next_dt) return next_dt.isoformat() + +def _last_execution_ts(state: dict, mode: str) -> str | None: + mode_key = (mode or "LIVE").strip().upper() + if mode_key == "LIVE": + return state.get("last_sip_ts") + return state.get("last_run") or state.get("last_sip_ts") + def start_strategy(req, user_id: str): engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} running_run_id = get_running_run_id(user_id) if running_run_id: + running_cfg = _load_config(user_id, running_run_id) + running_mode = (running_cfg.get("mode") or req.mode or "PAPER").strip().upper() + if running_mode == "LIVE": + is_valid, broker_state, failure_reason = _validate_live_broker_session(user_id) + if not is_valid: + return { + "status": "broker_auth_required", + "redirect_url": "/api/broker/login", + "broker": broker_state.get("broker"), + } if engine_external: return {"status": "already_running", "run_id": running_run_id} engine_config = _build_engine_config(user_id, running_run_id, req) @@ -327,48 +371,77 @@ def start_strategy(req, user_id: str): return {"status": "restarted", "run_id": running_run_id} return {"status": "already_running", "run_id": running_run_id} mode = (req.mode or "PAPER").strip().upper() - if mode != "PAPER": - return {"status": "unsupported_mode"} frequency_payload = req.sip_frequency.dict() if hasattr(req.sip_frequency, "dict") else dict(req.sip_frequency) validate_frequency(frequency_payload, mode) - initial_cash = float(req.initial_cash) if req.initial_cash is not None else 1_000_000.0 + if mode == "PAPER": + initial_cash = float(req.initial_cash) if req.initial_cash is not None else 1_000_000.0 + broker_name = "paper" + elif mode == "LIVE": + is_valid, broker_state, failure_reason = _validate_live_broker_session(user_id) + if not is_valid: + return { + "status": "broker_auth_required", + "redirect_url": "/api/broker/login", + "broker": broker_state.get("broker"), + } + initial_cash = None + broker_name = ((broker_state.get("broker") or "ZERODHA").strip().lower()) + else: + return {"status": "unsupported_mode"} + + meta = { + "sip_amount": req.sip_amount, + "sip_frequency": frequency_payload, + } + if initial_cash is not None: + meta["initial_cash"] = initial_cash try: run_id = create_strategy_run( user_id, strategy=req.strategy_name, mode=mode, - broker="paper", - meta={ - "sip_amount": req.sip_amount, - "sip_frequency": frequency_payload, - "initial_cash": initial_cash, - }, + broker=broker_name, + meta=meta, ) except errors.UniqueViolation: return {"status": "already_running"} with engine_context(user_id, run_id): - init_paper_state(initial_cash, frequency_payload) - with db_connection() as conn: - with conn: - with conn.cursor() as cur: - cur.execute( - """ - INSERT INTO paper_broker_account (user_id, run_id, cash) - VALUES (%s, %s, %s) - ON CONFLICT (user_id, run_id) DO UPDATE - SET cash = EXCLUDED.cash - """, - (user_id, run_id, initial_cash), - ) - PaperBroker(initial_cash=initial_cash) + if mode == "PAPER": + init_paper_state(initial_cash, frequency_payload) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO paper_broker_account (user_id, run_id, cash) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET cash = EXCLUDED.cash + """, + (user_id, run_id, initial_cash), + ) + PaperBroker(initial_cash=initial_cash) + else: + save_state( + { + "total_invested": 0.0, + "nifty_units": 0.0, + "gold_units": 0.0, + "last_sip_ts": None, + "last_run": None, + }, + mode="LIVE", + emit_event=True, + event_meta={"source": "live_start"}, + ) config = { "strategy": req.strategy_name, "sip_amount": req.sip_amount, "sip_frequency": frequency_payload, "mode": mode, - "broker": "paper", + "broker": broker_name, "active": True, } save_strategy_config(config, user_id, run_id) @@ -387,7 +460,8 @@ def start_strategy(req, user_id: str): ) engine_config = dict(config) - engine_config["initial_cash"] = initial_cash + if initial_cash is not None: + engine_config["initial_cash"] = initial_cash engine_config["run_id"] = run_id engine_config["user_id"] = user_id engine_config["emit_event"] = emit_event_cb @@ -518,7 +592,7 @@ def get_strategy_status(user_id: str): mode = (cfg.get("mode") or "LIVE").strip().upper() with engine_context(user_id, run_id): state = load_state(mode=mode) - last_execution_ts = state.get("last_run") or state.get("last_sip_ts") + last_execution_ts = _last_execution_ts(state, mode) sip_frequency = cfg.get("sip_frequency") if not isinstance(sip_frequency, dict): frequency = cfg.get("frequency") @@ -578,7 +652,7 @@ def get_engine_status(user_id: str): mode = (cfg.get("mode") or "LIVE").strip().upper() with engine_context(user_id, run_id): state = load_state(mode=mode) - last_execution_ts = state.get("last_run") or state.get("last_sip_ts") + last_execution_ts = _last_execution_ts(state, mode) sip_frequency = cfg.get("sip_frequency") if isinstance(sip_frequency, dict): sip_frequency = { @@ -642,6 +716,128 @@ def get_strategy_logs(user_id: str, since_seq: int): latest_seq = cur.fetchone()[0] return {"events": events, "latest_seq": latest_seq} + +def _humanize_reason(reason: str | None): + if not reason: + return None + return reason.replace("_", " ").strip().capitalize() + + +def _issue_message(event: str, message: str | None, data: dict | None, meta: dict | None): + payload = data if isinstance(data, dict) else {} + extra = meta if isinstance(meta, dict) else {} + reason = payload.get("reason") or extra.get("reason") + reason_key = str(reason or "").strip().lower() + + if event == "SIP_NO_FILL": + if reason_key == "insufficient_funds": + return "Insufficient funds for this SIP." + if reason_key == "broker_auth_expired": + return "Broker session expired. Reconnect broker." + if reason_key == "no_fill": + return "Order was not filled." + return f"SIP not executed: {_humanize_reason(reason) or 'Unknown reason'}." + + if event == "BROKER_AUTH_EXPIRED": + return "Broker session expired. Reconnect broker." + if event == "PRICE_FETCH_ERROR": + return "Could not fetch prices. Retrying." + if event == "HISTORY_LOAD_ERROR": + return "Could not load price history. Retrying." + if event == "ENGINE_ERROR": + return message or "Strategy engine hit an error." + if event == "EXECUTION_BLOCKED": + if reason_key == "market_closed": + return "Market is closed. Execution will resume next session." + return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}." + if event == "ORDER_REJECTED": + return message or payload.get("status_message") or "Broker rejected the order." + if event == "ORDER_CANCELLED": + return message or "Order was cancelled." + + return message or _humanize_reason(reason) or "Strategy update available." + + +def get_strategy_summary(user_id: str): + run_id = get_active_run_id(user_id) + status = get_strategy_status(user_id) + next_eligible_ts = status.get("next_eligible_ts") + + summary = { + "run_id": run_id, + "status": status.get("status"), + "tone": "neutral", + "message": "No active strategy.", + "event": None, + "ts": None, + } + + issue_row = None + if run_id: + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT event, message, data, meta, ts + FROM engine_event + WHERE user_id = %s + AND run_id = %s + AND event IN ( + 'SIP_NO_FILL', + 'BROKER_AUTH_EXPIRED', + 'PRICE_FETCH_ERROR', + 'HISTORY_LOAD_ERROR', + 'ENGINE_ERROR', + 'EXECUTION_BLOCKED', + 'ORDER_REJECTED', + 'ORDER_CANCELLED' + ) + ORDER BY ts DESC + LIMIT 1 + """, + (user_id, run_id), + ) + issue_row = cur.fetchone() + + if issue_row: + event, message, data, meta, ts = issue_row + summary.update( + { + "tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning", + "message": _issue_message(event, message, data, meta), + "event": event, + "ts": _format_local_ts(ts), + } + ) + return summary + + status_key = (status.get("status") or "IDLE").upper() + if status_key == "WAITING" and next_eligible_ts: + summary.update( + { + "tone": "warning", + "message": f"Waiting until {next_eligible_ts}.", + } + ) + return summary + if status_key == "RUNNING": + summary.update( + { + "tone": "success", + "message": "Strategy is running.", + } + ) + return summary + if status_key == "STOPPED": + summary.update( + { + "tone": "neutral", + "message": "Strategy is stopped.", + } + ) + return summary + return summary + def get_market_status(): now = datetime.now() return { diff --git a/backend/app/services/zerodha_service.py b/backend/app/services/zerodha_service.py index a1d8214..1e71698 100644 --- a/backend/app/services/zerodha_service.py +++ b/backend/app/services/zerodha_service.py @@ -23,6 +23,10 @@ class KiteTokenError(KiteApiError): pass +class KitePermissionError(KiteApiError): + pass + + def build_login_url(api_key: str, redirect_url: str | None = None) -> str: params = {"api_key": api_key, "v": KITE_VERSION} redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip() @@ -48,7 +52,12 @@ def _request(method: str, url: str, data: dict | None = None, headers: dict | No payload = {} error_type = payload.get("error_type") or payload.get("status") or "unknown_error" message = payload.get("message") or error_body or err.reason - exc_cls = KiteTokenError if error_type == "TokenException" else KiteApiError + if error_type == "TokenException": + exc_cls = KiteTokenError + elif error_type == "PermissionException": + exc_cls = KitePermissionError + else: + exc_cls = KiteApiError raise exc_cls(err.code, error_type, message) from err return json.loads(body) @@ -87,3 +96,112 @@ def fetch_funds(api_key: str, access_token: str) -> dict: url = f"{KITE_API_BASE}/user/margins" response = _request("GET", url, headers=_auth_headers(api_key, access_token)) return response.get("data", {}) + + +def fetch_ltp_quotes(api_key: str, access_token: str, instruments: list[str]) -> dict: + symbols = [str(item).strip() for item in instruments if str(item).strip()] + if not symbols: + return {} + query = urllib.parse.urlencode([("i", symbol) for symbol in symbols]) + url = f"{KITE_API_BASE}/quote/ltp?{query}" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", {}) + + +def fetch_ohlc_quotes(api_key: str, access_token: str, instruments: list[str]) -> dict: + symbols = [str(item).strip() for item in instruments if str(item).strip()] + if not symbols: + return {} + query = urllib.parse.urlencode([("i", symbol) for symbol in symbols]) + url = f"{KITE_API_BASE}/quote/ohlc?{query}" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", {}) + + +def fetch_historical_candles( + api_key: str, + access_token: str, + instrument_token: int | str, + interval: str, + *, + from_dt, + to_dt, + continuous: bool = False, + oi: bool = False, +) -> list: + params = { + "from": from_dt.strftime("%Y-%m-%d %H:%M:%S"), + "to": to_dt.strftime("%Y-%m-%d %H:%M:%S"), + "continuous": 1 if continuous else 0, + "oi": 1 if oi else 0, + } + query = urllib.parse.urlencode(params) + url = f"{KITE_API_BASE}/instruments/historical/{instrument_token}/{interval}?{query}" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", {}).get("candles", []) + + +def place_order( + api_key: str, + access_token: str, + *, + tradingsymbol: str, + exchange: str, + transaction_type: str, + order_type: str, + quantity: int, + product: str, + price: float | None = None, + validity: str = "DAY", + variety: str = "regular", + market_protection: int | None = None, + tag: str | None = None, +) -> dict: + payload = { + "tradingsymbol": tradingsymbol, + "exchange": exchange, + "transaction_type": transaction_type, + "order_type": order_type, + "quantity": int(quantity), + "product": product, + "validity": validity, + } + if price is not None: + payload["price"] = price + if market_protection is not None: + payload["market_protection"] = market_protection + if tag: + payload["tag"] = tag + + url = f"{KITE_API_BASE}/orders/{variety}" + response = _request( + "POST", + url, + data=payload, + headers=_auth_headers(api_key, access_token), + ) + return response.get("data", {}) + + +def fetch_orders(api_key: str, access_token: str) -> list: + url = f"{KITE_API_BASE}/orders" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", []) + + +def fetch_order_history(api_key: str, access_token: str, order_id: str) -> list: + url = f"{KITE_API_BASE}/orders/{order_id}" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", []) + + +def cancel_order( + api_key: str, + access_token: str, + *, + order_id: str, + variety: str = "regular", +) -> dict: + url = f"{KITE_API_BASE}/orders/{variety}/{order_id}" + response = _request("DELETE", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", {}) diff --git a/indian_paper_trading_strategy/engine/broker.py b/indian_paper_trading_strategy/engine/broker.py index b443eba..374e924 100644 --- a/indian_paper_trading_strategy/engine/broker.py +++ b/indian_paper_trading_strategy/engine/broker.py @@ -1,22 +1,27 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime, timezone -import hashlib - -from psycopg2.extras import execute_values - -from indian_paper_trading_strategy.engine.data import fetch_live_price -from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib +import math +import os +import time + +from psycopg2.extras import execute_values + +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context -class Broker(ABC): - @abstractmethod - def place_order( - self, - symbol: str, - side: str, +class Broker(ABC): + external_orders = False + + @abstractmethod + def place_order( + self, + symbol: str, + side: str, quantity: float, price: float | None = None, logical_time: datetime | None = None, @@ -32,8 +37,16 @@ class Broker(ABC): raise NotImplementedError @abstractmethod - def get_funds(self): - raise NotImplementedError + def get_funds(self): + raise NotImplementedError + + +class BrokerError(Exception): + pass + + +class BrokerAuthExpired(BrokerError): + pass def _local_tz(): @@ -98,12 +111,332 @@ def _deterministic_id(prefix: str, parts: list[str]) -> str: return f"{prefix}_{digest}" -def _resolve_scope(user_id: str | None, run_id: str | None): - return get_context(user_id, run_id) - - -@dataclass -class PaperBroker(Broker): +def _resolve_scope(user_id: str | None, run_id: str | None): + return get_context(user_id, run_id) + + +class LiveZerodhaBroker(Broker): + external_orders = True + + TERMINAL_STATUSES = {"COMPLETE", "REJECTED", "CANCELLED"} + POLL_TIMEOUT_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_TIMEOUT", "12")) + POLL_INTERVAL_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_INTERVAL", "1")) + + def __init__(self, user_id: str | None = None, run_id: str | None = None): + self.user_id = user_id + self.run_id = run_id + + def _scope(self): + return _resolve_scope(self.user_id, self.run_id) + + def _session(self): + from app.services.zerodha_storage import get_session + + user_id, _run_id = self._scope() + session = get_session(user_id) + if not session or not session.get("api_key") or not session.get("access_token"): + raise BrokerAuthExpired("Zerodha session missing. Please reconnect broker.") + return session + + def _raise_auth_expired(self, exc: Exception): + from app.broker_store import set_broker_auth_state + + user_id, _run_id = self._scope() + set_broker_auth_state(user_id, "EXPIRED") + raise BrokerAuthExpired(str(exc)) from exc + + def _normalize_symbol(self, symbol: str) -> tuple[str, str]: + cleaned = (symbol or "").strip().upper() + if cleaned.endswith(".NS"): + return cleaned[:-3], "NSE" + if cleaned.endswith(".BO"): + return cleaned[:-3], "BSE" + return cleaned, "NSE" + + def _make_tag(self, logical_time: datetime | None, symbol: str, side: str) -> str: + user_id, run_id = self._scope() + logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) + digest = hashlib.sha1( + f"{user_id}|{run_id}|{_normalize_ts_for_id(logical_ts)}|{symbol}|{side}".encode("utf-8") + ).hexdigest()[:18] + return f"qf{digest}" + + def _normalize_order_payload( + self, + *, + order_id: str, + symbol: str, + side: str, + requested_qty: int, + requested_price: float | None, + history_entry: dict | None, + logical_time: datetime | None, + ) -> dict: + entry = history_entry or {} + raw_status = (entry.get("status") or "").upper() + status = raw_status + if raw_status == "COMPLETE": + status = "FILLED" + elif raw_status in {"REJECTED", "CANCELLED"}: + status = raw_status + elif raw_status: + status = "PENDING" + else: + status = "PENDING" + + quantity = int(entry.get("quantity") or requested_qty or 0) + filled_qty = int(entry.get("filled_quantity") or 0) + average_price = float(entry.get("average_price") or requested_price or 0.0) + price = float(entry.get("price") or requested_price or average_price or 0.0) + timestamp = ( + entry.get("exchange_timestamp") + or entry.get("order_timestamp") + or _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)) + ) + if timestamp and " " in str(timestamp): + timestamp = str(timestamp).replace(" ", "T") + + return { + "id": order_id, + "symbol": symbol, + "side": side.upper().strip(), + "qty": quantity, + "requested_qty": quantity, + "filled_qty": filled_qty, + "price": price, + "requested_price": float(requested_price or price or 0.0), + "average_price": average_price, + "status": status, + "timestamp": timestamp, + "broker_order_id": order_id, + "exchange": entry.get("exchange"), + "tradingsymbol": entry.get("tradingsymbol"), + "status_message": entry.get("status_message") or entry.get("status_message_raw"), + } + + def _wait_for_terminal_order( + self, + session: dict, + order_id: str, + *, + symbol: str, + side: str, + requested_qty: int, + requested_price: float | None, + logical_time: datetime | None, + ) -> dict: + from app.services.zerodha_service import ( + KiteTokenError, + cancel_order, + fetch_order_history, + ) + + started = time.monotonic() + last_payload = self._normalize_order_payload( + order_id=order_id, + symbol=symbol, + side=side, + requested_qty=requested_qty, + requested_price=requested_price, + history_entry=None, + logical_time=logical_time, + ) + + while True: + try: + history = fetch_order_history( + session["api_key"], + session["access_token"], + order_id, + ) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + + if history: + entry = history[-1] + last_payload = self._normalize_order_payload( + order_id=order_id, + symbol=symbol, + side=side, + requested_qty=requested_qty, + requested_price=requested_price, + history_entry=entry, + logical_time=logical_time, + ) + raw_status = (entry.get("status") or "").upper() + if raw_status in self.TERMINAL_STATUSES: + return last_payload + + if time.monotonic() - started >= self.POLL_TIMEOUT_SECONDS: + try: + cancel_order( + session["api_key"], + session["access_token"], + order_id=order_id, + ) + history = fetch_order_history( + session["api_key"], + session["access_token"], + order_id, + ) + if history: + return self._normalize_order_payload( + order_id=order_id, + symbol=symbol, + side=side, + requested_qty=requested_qty, + requested_price=requested_price, + history_entry=history[-1], + logical_time=logical_time, + ) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + return last_payload + + time.sleep(self.POLL_INTERVAL_SECONDS) + + def get_funds(self, cur=None): + from app.services.zerodha_service import KiteTokenError, fetch_funds + + session = self._session() + try: + data = fetch_funds(session["api_key"], session["access_token"]) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + + equity = data.get("equity", {}) if isinstance(data, dict) else {} + available = equity.get("available", {}) if isinstance(equity, dict) else {} + cash = ( + available.get("live_balance") + or available.get("cash") + or available.get("opening_balance") + or equity.get("net") + or 0.0 + ) + return {"cash": float(cash), "raw": data} + + def get_positions(self): + from app.services.zerodha_service import KiteTokenError, fetch_holdings + + session = self._session() + try: + holdings = fetch_holdings(session["api_key"], session["access_token"]) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + + normalized = [] + for item in holdings: + qty = float(item.get("quantity") or item.get("qty") or 0) + avg_price = float(item.get("average_price") or 0) + last_price = float(item.get("last_price") or avg_price or 0) + exchange = item.get("exchange") + suffix = ".NS" if exchange == "NSE" else ".BO" if exchange == "BSE" else "" + normalized.append( + { + "symbol": f"{item.get('tradingsymbol')}{suffix}", + "qty": qty, + "avg_price": avg_price, + "last_price": last_price, + } + ) + return normalized + + def get_orders(self): + from app.services.zerodha_service import KiteTokenError, fetch_orders + + session = self._session() + try: + return fetch_orders(session["api_key"], session["access_token"]) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + + def update_equity( + self, + prices: dict[str, float], + now: datetime, + cur=None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + return None + + def place_order( + self, + symbol: str, + side: str, + quantity: float, + price: float | None = None, + cur=None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + from app.services.zerodha_service import KiteTokenError, place_order + + if user_id is not None: + self.user_id = user_id + if run_id is not None: + self.run_id = run_id + + qty = int(math.floor(float(quantity))) + side = side.upper().strip() + requested_price = float(price) if price is not None else None + if qty <= 0: + return { + "id": _deterministic_id("live_rej", [symbol, side, _stable_num(quantity)]), + "symbol": symbol, + "side": side, + "qty": qty, + "requested_qty": qty, + "filled_qty": 0, + "price": float(price or 0.0), + "requested_price": float(price or 0.0), + "average_price": 0.0, + "status": "REJECTED", + "timestamp": _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)), + "status_message": "Computed quantity is less than 1 share", + } + + session = self._session() + tradingsymbol, exchange = self._normalize_symbol(symbol) + tag = self._make_tag(logical_time, symbol, side) + + try: + placed = place_order( + session["api_key"], + session["access_token"], + tradingsymbol=tradingsymbol, + exchange=exchange, + transaction_type=side, + order_type="MARKET", + quantity=qty, + product="CNC", + validity="DAY", + variety="regular", + market_protection=-1, + tag=tag, + ) + except KiteTokenError as exc: + self._raise_auth_expired(exc) + + order_id = placed.get("order_id") + if not order_id: + raise BrokerError("Zerodha order placement did not return an order_id") + + return self._wait_for_terminal_order( + session, + order_id, + symbol=symbol, + side=side, + requested_qty=qty, + requested_price=requested_price, + logical_time=logical_time, + ) + + +@dataclass +class PaperBroker(Broker): initial_cash: float store_path: str | None = None @@ -578,16 +911,20 @@ class PaperBroker(Broker): ], ) - order = { - "id": order_id, - "symbol": symbol, - "side": side, - "qty": qty, - "price": price, - "status": "REJECTED", - "timestamp": timestamp_str, - "_logical_time": logical_ts_str, - } + order = { + "id": order_id, + "symbol": symbol, + "side": side, + "qty": qty, + "requested_qty": qty, + "filled_qty": 0.0, + "price": price, + "requested_price": price, + "average_price": 0.0, + "status": "REJECTED", + "timestamp": timestamp_str, + "_logical_time": logical_ts_str, + } if qty <= 0 or price <= 0: store.setdefault("orders", []).append(order) @@ -607,16 +944,18 @@ class PaperBroker(Broker): new_qty = float(existing.get("qty", 0)) + qty prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0)) avg_price = (prev_cost + cost) / new_qty if new_qty else price - positions[symbol] = { - "qty": new_qty, - "avg_price": avg_price, - "last_price": price, - } - order["status"] = "FILLED" - trade = { - "id": _deterministic_id("trd", [order_id]), - "order_id": order_id, - "symbol": symbol, + positions[symbol] = { + "qty": new_qty, + "avg_price": avg_price, + "last_price": price, + } + order["status"] = "FILLED" + order["filled_qty"] = qty + order["average_price"] = price + trade = { + "id": _deterministic_id("trd", [order_id]), + "order_id": order_id, + "symbol": symbol, "side": side, "qty": qty, "price": price, @@ -634,12 +973,14 @@ class PaperBroker(Broker): existing["last_price"] = price positions[symbol] = existing else: - positions.pop(symbol, None) - order["status"] = "FILLED" - trade = { - "id": _deterministic_id("trd", [order_id]), - "order_id": order_id, - "symbol": symbol, + positions.pop(symbol, None) + order["status"] = "FILLED" + order["filled_qty"] = qty + order["average_price"] = price + trade = { + "id": _deterministic_id("trd", [order_id]), + "order_id": order_id, + "symbol": symbol, "side": side, "qty": qty, "price": price, diff --git a/indian_paper_trading_strategy/engine/data.py b/indian_paper_trading_strategy/engine/data.py index d26917c..ba0793d 100644 --- a/indian_paper_trading_strategy/engine/data.py +++ b/indian_paper_trading_strategy/engine/data.py @@ -1,81 +1,110 @@ # engine/data.py -from datetime import datetime, timezone -from pathlib import Path -import os -import threading - -import pandas as pd -import yfinance as yf +from datetime import datetime, timezone +from pathlib import Path +import os +import threading + +import pandas as pd +import yfinance as yf ENGINE_ROOT = Path(__file__).resolve().parents[1] HISTORY_DIR = ENGINE_ROOT / "storage" / "history" ALLOW_PRICE_CACHE = os.getenv("ALLOW_PRICE_CACHE", "0").strip().lower() in {"1", "true", "yes"} -_LAST_PRICE: dict[str, dict[str, object]] = {} -_LAST_PRICE_LOCK = threading.Lock() - - -def _set_last_price(ticker: str, price: float, source: str): - now = datetime.now(timezone.utc) - with _LAST_PRICE_LOCK: - _LAST_PRICE[ticker] = {"price": float(price), "source": source, "ts": now} - - -def get_price_snapshot(ticker: str) -> dict[str, object] | None: +_LAST_PRICE: dict[str, dict[str, object]] = {} +_LAST_PRICE_LOCK = threading.Lock() + + +def _history_cache_file(ticker: str, provider: str = "yfinance") -> Path: + safe_ticker = (ticker or "").replace(":", "_").replace("/", "_") + return HISTORY_DIR / f"{safe_ticker}.csv" + + +def _set_last_price( + ticker: str, + price: float, + source: str, + *, + provider: str | None = None, + instrument_token: int | None = None, +): + now = datetime.now(timezone.utc) + with _LAST_PRICE_LOCK: + payload = {"price": float(price), "source": source, "ts": now} + if provider: + payload["provider"] = provider + if instrument_token is not None: + payload["instrument_token"] = int(instrument_token) + _LAST_PRICE[ticker] = payload + + +def get_price_snapshot(ticker: str) -> dict[str, object] | None: with _LAST_PRICE_LOCK: data = _LAST_PRICE.get(ticker) if not data: return None - return dict(data) - - -def _get_last_live_price(ticker: str) -> float | None: - with _LAST_PRICE_LOCK: - data = _LAST_PRICE.get(ticker) - if not data: - return None - if data.get("source") == "live": - return float(data.get("price", 0)) - return None - - -def _cached_last_close(ticker: str) -> float | None: - file = HISTORY_DIR / f"{ticker}.csv" - if not file.exists(): - return None - df = pd.read_csv(file) - if df.empty or "Close" not in df.columns: - return None - return float(df["Close"].iloc[-1]) - - -def fetch_live_price(ticker, allow_cache: bool | None = None): - if allow_cache is None: - allow_cache = ALLOW_PRICE_CACHE - try: - df = yf.download( - ticker, - period="1d", - interval="1m", + return dict(data) + + +def _get_last_live_price(ticker: str, provider: str | None = None) -> float | None: + with _LAST_PRICE_LOCK: + data = _LAST_PRICE.get(ticker) + if not data: + return None + if data.get("source") == "live": + if provider and data.get("provider") not in {None, provider}: + return None + return float(data.get("price", 0)) + return None + + +def _cached_last_close(ticker: str, provider: str = "yfinance") -> float | None: + file = _history_cache_file(ticker, provider=provider) + if not file.exists(): + return None + df = pd.read_csv(file) + if df.empty or "Close" not in df.columns: + return None + return float(df["Close"].iloc[-1]) + + +def fetch_live_price( + ticker, + allow_cache: bool | None = None, + *, + provider: str = "yfinance", + user_id: str | None = None, + run_id: str | None = None, +): + if allow_cache is None: + allow_cache = ALLOW_PRICE_CACHE + try: + df = yf.download( + ticker, + period="1d", + interval="1m", auto_adjust=True, progress=False, timeout=5, - ) - if df is not None and not df.empty: - price = float(df["Close"].iloc[-1]) - _set_last_price(ticker, price, "live") - return price - except Exception: - pass - - if allow_cache: - last_live = _get_last_live_price(ticker) - if last_live is not None: - return last_live - - cached = _cached_last_close(ticker) - if cached is not None: - _set_last_price(ticker, cached, "cache") - return cached - - raise RuntimeError(f"No live data for {ticker}") + ) + if df is not None and not df.empty: + close_value = df["Close"].iloc[-1] + if hasattr(close_value, "iloc"): + close_value = close_value.iloc[-1] + price = float(close_value) + _set_last_price(ticker, price, "live", provider="yfinance") + return price + except Exception: + pass + + if allow_cache: + last_live = _get_last_live_price(ticker, provider="yfinance") + if last_live is not None: + return last_live + + cached = _cached_last_close(ticker, provider="yfinance") + if cached is not None: + _set_last_price(ticker, cached, "cache", provider="yfinance") + return cached + + raise RuntimeError(f"No live data for {ticker}") diff --git a/indian_paper_trading_strategy/engine/execution.py b/indian_paper_trading_strategy/engine/execution.py index 9463e09..e2d5e47 100644 --- a/indian_paper_trading_strategy/engine/execution.py +++ b/indian_paper_trading_strategy/engine/execution.py @@ -1,11 +1,11 @@ -# engine/execution.py -from datetime import datetime, timezone -from indian_paper_trading_strategy.engine.state import load_state, save_state - -from indian_paper_trading_strategy.engine.broker import Broker -from indian_paper_trading_strategy.engine.ledger import log_event, event_exists -from indian_paper_trading_strategy.engine.db import run_with_retry -from indian_paper_trading_strategy.engine.time_utils import compute_logical_time +# engine/execution.py +from datetime import datetime, timezone +from indian_paper_trading_strategy.engine.state import load_state, save_state + +from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired +from indian_paper_trading_strategy.engine.ledger import log_event, event_exists +from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry +from indian_paper_trading_strategy.engine.time_utils import compute_logical_time def _as_float(value): if hasattr(value, "item"): @@ -20,138 +20,457 @@ def _as_float(value): pass return float(value) -def _local_tz(): - return datetime.now().astimezone().tzinfo - -def try_execute_sip( - now, - market_open, +def _local_tz(): + return datetime.now().astimezone().tzinfo + + +def _normalize_now(now): + if now.tzinfo is None: + return now.replace(tzinfo=_local_tz()) + return now + + +def _resolve_timing(state, now_ts, sip_interval): + force_execute = state.get("last_sip_ts") is None + last = state.get("last_sip_ts") or state.get("last_run") + if last and not force_execute: + try: + last_dt = datetime.fromisoformat(last) + except ValueError: + last_dt = None + if last_dt: + if last_dt.tzinfo is None: + last_dt = last_dt.replace(tzinfo=_local_tz()) + if now_ts.tzinfo and last_dt.tzinfo and last_dt.tzinfo != now_ts.tzinfo: + last_dt = last_dt.astimezone(now_ts.tzinfo) + if last_dt and (now_ts - last_dt).total_seconds() < sip_interval: + return False, last, None + logical_time = compute_logical_time(now_ts, last, sip_interval) + return True, last, logical_time + + +def _order_fill(order): + if not isinstance(order, dict): + return 0.0, 0.0 + filled_qty = float(order.get("filled_qty") or 0.0) + average_price = float(order.get("average_price") or order.get("price") or 0.0) + return filled_qty, average_price + + +def _apply_filled_orders_to_state(state, orders): + nifty_filled = 0.0 + gold_filled = 0.0 + total_spent = 0.0 + + for order in orders: + filled_qty, average_price = _order_fill(order) + if filled_qty <= 0: + continue + symbol = (order.get("symbol") or "").upper() + if symbol.startswith("NIFTYBEES"): + nifty_filled += filled_qty + elif symbol.startswith("GOLDBEES"): + gold_filled += filled_qty + total_spent += filled_qty * average_price + + if nifty_filled: + state["nifty_units"] += nifty_filled + if gold_filled: + state["gold_units"] += gold_filled + if total_spent: + state["total_invested"] += total_spent + + return { + "nifty_units": nifty_filled, + "gold_units": gold_filled, + "amount": total_spent, + } + + +def _record_live_order_events(cur, orders, event_ts): + for order in orders: + insert_engine_event(cur, "ORDER_PLACED", data=order, ts=event_ts) + filled_qty, _average_price = _order_fill(order) + status = (order.get("status") or "").upper() + if filled_qty > 0: + insert_engine_event( + cur, + "TRADE_EXECUTED", + data={ + "order_id": order.get("id"), + "symbol": order.get("symbol"), + "side": order.get("side"), + "qty": filled_qty, + "price": order.get("average_price") or order.get("price"), + }, + ts=event_ts, + ) + insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order.get("id")}, ts=event_ts) + elif status == "REJECTED": + insert_engine_event(cur, "ORDER_REJECTED", data=order, ts=event_ts) + elif status == "CANCELLED": + insert_engine_event(cur, "ORDER_CANCELLED", data=order, ts=event_ts) + elif status == "PENDING": + insert_engine_event(cur, "ORDER_PENDING", data=order, ts=event_ts) + + +def _try_execute_sip_paper( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, + broker: Broker | None, + mode: str | None, +): + def _op(cur, _conn): + now_ts = _normalize_now(now) + event_ts = now_ts + log_event("DEBUG_ENTER_TRY_EXECUTE", {"now": now_ts.isoformat()}, cur=cur, ts=event_ts) + + state = load_state(mode=mode, cur=cur, for_update=True) + + if not market_open: + return state, False + + should_run, _last, logical_time = _resolve_timing(state, now_ts, sip_interval) + if not should_run: + return state, False + if event_exists("SIP_EXECUTED", logical_time, cur=cur): + return state, False + + sp_price_val = _as_float(sp_price) + gd_price_val = _as_float(gd_price) + eq_w_val = _as_float(eq_w) + gd_w_val = _as_float(gd_w) + sip_amount_val = _as_float(sip_amount) + + nifty_qty = (sip_amount_val * eq_w_val) / sp_price_val + gold_qty = (sip_amount_val * gd_w_val) / gd_price_val + + if broker is None: + return state, False + + funds = broker.get_funds(cur=cur) + cash = funds.get("cash") + if cash is not None and float(cash) < sip_amount_val: + return state, False + + log_event( + "DEBUG_EXECUTION_DECISION", + { + "last_sip_ts": state.get("last_sip_ts"), + "now": now_ts.isoformat(), + }, + cur=cur, + ts=event_ts, + ) + + orders = [ + broker.place_order( + "NIFTYBEES.NS", + "BUY", + nifty_qty, + sp_price_val, + cur=cur, + logical_time=logical_time, + ), + broker.place_order( + "GOLDBEES.NS", + "BUY", + gold_qty, + gd_price_val, + cur=cur, + logical_time=logical_time, + ), + ] + + applied = _apply_filled_orders_to_state(state, orders) + executed = applied["amount"] > 0 + if not executed: + return state, False + + funds_after = broker.get_funds(cur=cur) + cash_after = funds_after.get("cash") + if cash_after is not None: + state["cash"] = float(cash_after) + + state["last_sip_ts"] = now_ts.isoformat() + state["last_run"] = now_ts.isoformat() + + save_state( + state, + mode=mode, + cur=cur, + emit_event=True, + event_meta={"source": "sip"}, + ) + + log_event( + "SIP_EXECUTED", + { + "nifty_units": applied["nifty_units"], + "gold_units": applied["gold_units"], + "nifty_price": sp_price_val, + "gold_price": gd_price_val, + "amount": applied["amount"], + }, + cur=cur, + ts=event_ts, + logical_time=logical_time, + ) + + return state, True + + return run_with_retry(_op) + + +def _prepare_live_execution(now_ts, sip_interval, sip_amount_val, sp_price_val, gd_price_val, nifty_qty, gold_qty, mode): + def _op(cur, _conn): + state = load_state(mode=mode, cur=cur, for_update=True) + should_run, _last, logical_time = _resolve_timing(state, now_ts, sip_interval) + if not should_run: + return {"ready": False, "state": state} + if event_exists("SIP_EXECUTED", logical_time, cur=cur): + return {"ready": False, "state": state} + if event_exists("SIP_ORDER_ATTEMPTED", logical_time, cur=cur): + return {"ready": False, "state": state} + + log_event( + "SIP_ORDER_ATTEMPTED", + { + "nifty_units": nifty_qty, + "gold_units": gold_qty, + "nifty_price": sp_price_val, + "gold_price": gd_price_val, + "amount": sip_amount_val, + }, + cur=cur, + ts=now_ts, + logical_time=logical_time, + ) + return {"ready": True, "state": state, "logical_time": logical_time} + + return run_with_retry(_op) + + +def _finalize_live_execution( + *, + now_ts, + mode, + logical_time, + orders, + funds_after, + sp_price_val, + gd_price_val, + auth_failed: bool = False, + failure_reason: str | None = None, +): + def _op(cur, _conn): + state = load_state(mode=mode, cur=cur, for_update=True) + _record_live_order_events(cur, orders, now_ts) + + applied = _apply_filled_orders_to_state(state, orders) + executed = applied["amount"] > 0 + + if funds_after is not None: + cash_after = funds_after.get("cash") + if cash_after is not None: + state["cash"] = float(cash_after) + + if executed: + state["last_run"] = now_ts.isoformat() + state["last_sip_ts"] = now_ts.isoformat() + + save_state( + state, + mode=mode, + cur=cur, + emit_event=True, + event_meta={"source": "sip_live"}, + ) + + if executed: + log_event( + "SIP_EXECUTED", + { + "nifty_units": applied["nifty_units"], + "gold_units": applied["gold_units"], + "nifty_price": sp_price_val, + "gold_price": gd_price_val, + "amount": applied["amount"], + }, + cur=cur, + ts=now_ts, + logical_time=logical_time, + ) + else: + insert_engine_event( + cur, + "SIP_NO_FILL", + data={ + "reason": failure_reason or ("broker_auth_expired" if auth_failed else "no_fill"), + "orders": orders, + }, + ts=now_ts, + ) + + return state, executed + + return run_with_retry(_op) + + +def _try_execute_sip_live( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, + broker: Broker | None, + mode: str | None, +): + now_ts = _normalize_now(now) + if not market_open or broker is None: + return load_state(mode=mode), False + + sp_price_val = _as_float(sp_price) + gd_price_val = _as_float(gd_price) + eq_w_val = _as_float(eq_w) + gd_w_val = _as_float(gd_w) + sip_amount_val = _as_float(sip_amount) + + nifty_qty = (sip_amount_val * eq_w_val) / sp_price_val + gold_qty = (sip_amount_val * gd_w_val) / gd_price_val + + prepared = _prepare_live_execution( + now_ts, + sip_interval, + sip_amount_val, + sp_price_val, + gd_price_val, + nifty_qty, + gold_qty, + mode, + ) + if not prepared.get("ready"): + return prepared.get("state") or load_state(mode=mode), False + + logical_time = prepared["logical_time"] + orders = [] + funds_after = None + failure_reason = None + auth_failed = False + + try: + funds_before = broker.get_funds() + cash = funds_before.get("cash") + if cash is not None and float(cash) < sip_amount_val: + failure_reason = "insufficient_funds" + else: + if nifty_qty > 0: + orders.append( + broker.place_order( + "NIFTYBEES.NS", + "BUY", + nifty_qty, + sp_price_val, + logical_time=logical_time, + ) + ) + if gold_qty > 0: + orders.append( + broker.place_order( + "GOLDBEES.NS", + "BUY", + gold_qty, + gd_price_val, + logical_time=logical_time, + ) + ) + funds_after = broker.get_funds() + except BrokerAuthExpired: + auth_failed = True + try: + funds_after = broker.get_funds() + except Exception: + funds_after = None + except Exception as exc: + failure_reason = str(exc) + try: + funds_after = broker.get_funds() + except Exception: + funds_after = None + state, _executed = _finalize_live_execution( + now_ts=now_ts, + mode=mode, + logical_time=logical_time, + orders=orders, + funds_after=funds_after, + sp_price_val=sp_price_val, + gd_price_val=gd_price_val, + auth_failed=False, + failure_reason=failure_reason, + ) + raise + + state, executed = _finalize_live_execution( + now_ts=now_ts, + mode=mode, + logical_time=logical_time, + orders=orders, + funds_after=funds_after, + sp_price_val=sp_price_val, + gd_price_val=gd_price_val, + auth_failed=auth_failed, + failure_reason=failure_reason, + ) + if auth_failed: + raise BrokerAuthExpired("Broker session expired during live order execution") + return state, executed + +def try_execute_sip( + now, + market_open, sip_interval, sip_amount, sp_price, gd_price, eq_w, - gd_w, - broker: Broker | None = None, - mode: str | None = "LIVE", -): - def _op(cur, _conn): - if now.tzinfo is None: - now_ts = now.replace(tzinfo=_local_tz()) - else: - now_ts = now - event_ts = now_ts - log_event("DEBUG_ENTER_TRY_EXECUTE", { - "now": now_ts.isoformat(), - }, cur=cur, ts=event_ts) - - state = load_state(mode=mode, cur=cur, for_update=True) - - force_execute = state.get("last_sip_ts") is None - - if not market_open: - return state, False - - last = state.get("last_sip_ts") or state.get("last_run") - if last and not force_execute: - try: - last_dt = datetime.fromisoformat(last) - except ValueError: - last_dt = None - if last_dt: - if last_dt.tzinfo is None: - last_dt = last_dt.replace(tzinfo=_local_tz()) - if now_ts.tzinfo and last_dt.tzinfo and last_dt.tzinfo != now_ts.tzinfo: - last_dt = last_dt.astimezone(now_ts.tzinfo) - if last_dt and (now_ts - last_dt).total_seconds() < sip_interval: - return state, False - - logical_time = compute_logical_time(now_ts, last, sip_interval) - if event_exists("SIP_EXECUTED", logical_time, cur=cur): - return state, False - - sp_price_val = _as_float(sp_price) - gd_price_val = _as_float(gd_price) - eq_w_val = _as_float(eq_w) - gd_w_val = _as_float(gd_w) - sip_amount_val = _as_float(sip_amount) - - nifty_qty = (sip_amount_val * eq_w_val) / sp_price_val - gold_qty = (sip_amount_val * gd_w_val) / gd_price_val - - if broker is None: - return state, False - - funds = broker.get_funds(cur=cur) - cash = funds.get("cash") - if cash is not None and float(cash) < sip_amount_val: - return state, False - - log_event("DEBUG_EXECUTION_DECISION", { - "force_execute": force_execute, - "last_sip_ts": state.get("last_sip_ts"), - "now": now_ts.isoformat(), - }, cur=cur, ts=event_ts) - - nifty_order = broker.place_order( - "NIFTYBEES.NS", - "BUY", - nifty_qty, - sp_price_val, - cur=cur, - logical_time=logical_time, - ) - gold_order = broker.place_order( - "GOLDBEES.NS", - "BUY", - gold_qty, - gd_price_val, - cur=cur, - logical_time=logical_time, - ) - orders = [nifty_order, gold_order] - executed = all( - isinstance(order, dict) and order.get("status") == "FILLED" - for order in orders - ) - if not executed: - return state, False - assert len(orders) > 0, "executed=True but no broker orders placed" - - funds_after = broker.get_funds(cur=cur) - cash_after = funds_after.get("cash") - if cash_after is not None: - state["cash"] = float(cash_after) - - state["nifty_units"] += nifty_qty - state["gold_units"] += gold_qty - state["total_invested"] += sip_amount_val - state["last_sip_ts"] = now_ts.isoformat() - state["last_run"] = now_ts.isoformat() - - save_state( - state, - mode=mode, - cur=cur, - emit_event=True, - event_meta={"source": "sip"}, - ) - - log_event( - "SIP_EXECUTED", - { - "nifty_units": nifty_qty, - "gold_units": gold_qty, - "nifty_price": sp_price_val, - "gold_price": gd_price_val, - "amount": sip_amount_val, - }, - cur=cur, - ts=event_ts, - logical_time=logical_time, - ) - - return state, True - - return run_with_retry(_op) + gd_w, + broker: Broker | None = None, + mode: str | None = "LIVE", +): + if broker is None: + return load_state(mode=mode), False + if getattr(broker, "external_orders", False): + return _try_execute_sip_live( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, + broker, + mode, + ) + return _try_execute_sip_paper( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, + broker, + mode, + ) diff --git a/indian_paper_trading_strategy/engine/history.py b/indian_paper_trading_strategy/engine/history.py index cf6ecf8..503a1b0 100644 --- a/indian_paper_trading_strategy/engine/history.py +++ b/indian_paper_trading_strategy/engine/history.py @@ -1,34 +1,48 @@ # engine/history.py -import yfinance as yf -import pandas as pd -from pathlib import Path - -ENGINE_ROOT = Path(__file__).resolve().parents[1] -STORAGE_DIR = ENGINE_ROOT / "storage" -STORAGE_DIR.mkdir(exist_ok=True) - -CACHE_DIR = STORAGE_DIR / "history" -CACHE_DIR.mkdir(exist_ok=True) - -def load_monthly_close(ticker, years=10): - file = CACHE_DIR / f"{ticker}.csv" - - if file.exists(): - df = pd.read_csv(file, parse_dates=["Date"], index_col="Date") - return df["Close"] - - df = yf.download( - ticker, - period=f"{years}y", - auto_adjust=True, +import yfinance as yf +import pandas as pd +from pathlib import Path + +ENGINE_ROOT = Path(__file__).resolve().parents[1] +STORAGE_DIR = ENGINE_ROOT / "storage" +STORAGE_DIR.mkdir(exist_ok=True) + +CACHE_DIR = STORAGE_DIR / "history" +CACHE_DIR.mkdir(exist_ok=True) + +def _cache_file(ticker: str, provider: str = "yfinance") -> Path: + safe_ticker = (ticker or "").replace(":", "_").replace("/", "_") + return CACHE_DIR / f"{safe_ticker}.csv" + + +def _read_monthly_close(file: Path): + df = pd.read_csv(file, parse_dates=["Date"], index_col="Date") + return df["Close"] +def load_monthly_close( + ticker, + years=10, + *, + provider: str = "yfinance", + user_id: str | None = None, + run_id: str | None = None, +): + file = _cache_file(ticker, provider="yfinance") + + if file.exists(): + return _read_monthly_close(file) + + df = yf.download( + ticker, + period=f"{years}y", + auto_adjust=True, progress=False, timeout=5, ) - if df.empty: - raise RuntimeError(f"No history for {ticker}") - - series = df["Close"].resample("M").last() - series.to_csv(file, header=["Close"]) + if df.empty: + raise RuntimeError(f"No history for {ticker}") + + series = df["Close"].resample("ME").last() + series.to_csv(file, header=["Close"]) return series diff --git a/indian_paper_trading_strategy/engine/runner.py b/indian_paper_trading_strategy/engine/runner.py index bd89f0c..7155703 100644 --- a/indian_paper_trading_strategy/engine/runner.py +++ b/indian_paper_trading_strategy/engine/runner.py @@ -1,17 +1,20 @@ -import os -import threading -import time -from datetime import datetime, timedelta, timezone - -from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open -from indian_paper_trading_strategy.engine.execution import try_execute_sip -from indian_paper_trading_strategy.engine.broker import PaperBroker -from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm -from indian_paper_trading_strategy.engine.state import load_state -from indian_paper_trading_strategy.engine.data import fetch_live_price -from indian_paper_trading_strategy.engine.history import load_monthly_close -from indian_paper_trading_strategy.engine.strategy import allocation -from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time +import os +import threading +import time +from datetime import datetime, timedelta, timezone + +from psycopg2.extras import Json + +from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open +from indian_paper_trading_strategy.engine.execution import try_execute_sip +from indian_paper_trading_strategy.engine.broker import PaperBroker, LiveZerodhaBroker, BrokerAuthExpired +from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm +from indian_paper_trading_strategy.engine.state import load_state +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.history import load_monthly_close +from indian_paper_trading_strategy.engine.strategy import allocation +from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time +from app.services.zerodha_service import KiteTokenError from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context @@ -134,12 +137,75 @@ def _clear_runner(user_id: str, run_id: str): with _RUNNERS_LOCK: _RUNNERS.pop(key, None) -def can_execute(now: datetime) -> tuple[bool, str]: - if not is_market_open(now): - return False, "MARKET_CLOSED" - return True, "OK" - -def _engine_loop(config, stop_event: threading.Event): +def can_execute(now: datetime) -> tuple[bool, str]: + if not is_market_open(now): + return False, "MARKET_CLOSED" + return True, "OK" + + +def _last_execution_anchor(state: dict, mode: str) -> str | None: + mode_key = (mode or "LIVE").strip().upper() + if mode_key == "LIVE": + return state.get("last_sip_ts") + return state.get("last_run") or state.get("last_sip_ts") + + +def _pause_for_auth_expiry( + user_id: str, + run_id: str, + reason: str, + emit_event_cb=None, +): + def _op(cur, _conn): + now = datetime.utcnow().replace(tzinfo=timezone.utc) + cur.execute( + """ + UPDATE strategy_run + SET status = 'STOPPED', + stopped_at = %s, + meta = COALESCE(meta, '{}'::jsonb) || %s + WHERE user_id = %s AND run_id = %s + """, + ( + now, + Json({"reason": "auth_expired", "lifecycle": "auth_expired"}), + user_id, + run_id, + ), + ) + + run_with_retry(_op) + _set_state( + user_id, + run_id, + state="STOPPED", + last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", + ) + _update_engine_status(user_id, run_id, "STOPPED") + log_event( + event="BROKER_AUTH_EXPIRED", + message="Broker authentication expired", + meta={"reason": reason}, + ) + log_event( + event="ENGINE_PAUSED", + message="Engine paused until broker reconnect", + meta={"reason": "auth_expired"}, + ) + if callable(emit_event_cb): + emit_event_cb( + event="BROKER_AUTH_EXPIRED", + message="Broker authentication expired", + meta={"reason": reason}, + ) + emit_event_cb( + event="STRATEGY_STOPPED", + message="Strategy stopped", + meta={"reason": "broker_auth_expired"}, + ) + + +def _engine_loop(config, stop_event: threading.Event): print("Strategy engine started with config:", config) user_id = config.get("user_id") @@ -170,35 +236,38 @@ def _engine_loop(config, stop_event: threading.Event): if emit_event_cb: emit_event_cb(event=event, message=message, meta=meta or {}) print(f"[ENGINE] {event} {message} {meta or {}}", flush=True) - mode = (config.get("mode") or "LIVE").strip().upper() - if mode not in {"PAPER", "LIVE"}: - mode = "LIVE" - broker_type = config.get("broker") or "paper" - if broker_type != "paper": - broker_type = "paper" - if broker_type == "paper": - mode = "PAPER" - initial_cash = float(config.get("initial_cash", 0)) - broker = PaperBroker(initial_cash=initial_cash) - log_event( - event="DEBUG_PAPER_STORE_PATH", - message="Paper broker store path", - meta={ - "cwd": os.getcwd(), - "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", - "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", - }, - ) - if emit_event_cb: - emit_event_cb( - event="DEBUG_PAPER_STORE_PATH", - message="Paper broker store path", - meta={ - "cwd": os.getcwd(), - "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", - "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", - }, - ) + mode = (config.get("mode") or "LIVE").strip().upper() + if mode not in {"PAPER", "LIVE"}: + mode = "LIVE" + broker_type = (config.get("broker") or ("paper" if mode == "PAPER" else "zerodha")).strip().lower() + initial_cash = float(config.get("initial_cash", 0)) + if broker_type == "paper": + mode = "PAPER" + broker = PaperBroker(initial_cash=initial_cash) + log_event( + event="DEBUG_PAPER_STORE_PATH", + message="Paper broker store path", + meta={ + "cwd": os.getcwd(), + "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", + "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", + }, + ) + if emit_event_cb: + emit_event_cb( + event="DEBUG_PAPER_STORE_PATH", + message="Paper broker store path", + meta={ + "cwd": os.getcwd(), + "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", + "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", + }, + ) + elif broker_type == "zerodha": + broker = LiveZerodhaBroker(user_id=scope_user, run_id=scope_run) + else: + raise ValueError(f"Unsupported broker: {broker_type}") + market_data_provider = "yfinance" log_event("ENGINE_START", { "strategy": strategy_name, @@ -243,8 +312,8 @@ def _engine_loop(config, stop_event: threading.Event): delta = timedelta(days=freq) # Gate 2: time to SIP - last_run = state.get("last_run") or state.get("last_sip_ts") - is_first_run = last_run is None + last_run = _last_execution_anchor(state, mode) + is_first_run = last_run is None now = datetime.now() debug_event( "ENGINE_LOOP_TICK", @@ -280,25 +349,51 @@ def _engine_loop(config, stop_event: threading.Event): sleep_with_heartbeat(60, stop_event, scope_user, scope_run) continue - try: - debug_event("PRICE_FETCH_START", "fetching live prices", {"tickers": [NIFTY, GOLD]}) - nifty_price = fetch_live_price(NIFTY) - gold_price = fetch_live_price(GOLD) - debug_event( - "PRICE_FETCHED", - "fetched live prices", - {"nifty_price": float(nifty_price), "gold_price": float(gold_price)}, - ) - except Exception as exc: - debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)}) + try: + debug_event("PRICE_FETCH_START", "fetching live prices", {"tickers": [NIFTY, GOLD]}) + nifty_price = fetch_live_price( + NIFTY, + provider=market_data_provider, + user_id=scope_user, + run_id=scope_run, + ) + gold_price = fetch_live_price( + GOLD, + provider=market_data_provider, + user_id=scope_user, + run_id=scope_run, + ) + debug_event( + "PRICE_FETCHED", + "fetched live prices", + {"nifty_price": float(nifty_price), "gold_price": float(gold_price)}, + ) + except KiteTokenError as exc: + _pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb) + break + except Exception as exc: + debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)}) sleep_with_heartbeat(30, stop_event, scope_user, scope_run) continue - - try: - nifty_hist = load_monthly_close(NIFTY) - gold_hist = load_monthly_close(GOLD) - except Exception as exc: - debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)}) + + try: + nifty_hist = load_monthly_close( + NIFTY, + provider=market_data_provider, + user_id=scope_user, + run_id=scope_run, + ) + gold_hist = load_monthly_close( + GOLD, + provider=market_data_provider, + user_id=scope_user, + run_id=scope_run, + ) + except KiteTokenError as exc: + _pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb) + break + except Exception as exc: + debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)}) sleep_with_heartbeat(30, stop_event, scope_user, scope_run) continue @@ -449,6 +544,9 @@ def _engine_loop(config, stop_event: threading.Event): ) sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + except BrokerAuthExpired as exc: + _pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb) + print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True) except Exception as e: _set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") _update_engine_status(scope_user, scope_run, "ERROR")