Refine live strategy execution flow

This commit is contained in:
Thigazhezhilan J 2026-03-24 21:59:17 +05:30
parent 7677895b05
commit c17222ad9c
10 changed files with 1520 additions and 389 deletions

View File

@ -11,7 +11,7 @@ class StrategyStartRequest(BaseModel):
initial_cash: Optional[float] = None initial_cash: Optional[float] = None
sip_amount: float sip_amount: float
sip_frequency: SipFrequency sip_frequency: SipFrequency
mode: Literal["PAPER"] mode: Literal["PAPER", "LIVE"]
@validator("initial_cash") @validator("initial_cash")
def validate_cash(cls, v): def validate_cash(cls, v):

View File

@ -30,6 +30,17 @@ def _require_user(request: Request):
return user return user
def _build_saved_broker_login_url(request: Request, user_id: str) -> str:
creds = get_broker_credentials(user_id)
if not creds:
raise HTTPException(status_code=400, detail="Broker credentials not configured")
redirect_url = (os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
if not redirect_url:
base = str(request.base_url).rstrip("/")
redirect_url = f"{base}/api/broker/callback"
return build_login_url(creds["api_key"], redirect_url=redirect_url)
@router.post("/connect") @router.post("/connect")
async def connect_broker(payload: dict, request: Request): async def connect_broker(payload: dict, request: Request):
user = _require_user(request) user = _require_user(request)
@ -153,17 +164,16 @@ async def zerodha_callback(request: Request, request_token: str = ""):
@router.get("/login") @router.get("/login")
async def broker_login(request: Request): async def broker_login(request: Request):
user = _require_user(request) user = _require_user(request)
creds = get_broker_credentials(user["id"]) login_url = _build_saved_broker_login_url(request, user["id"])
if not creds:
raise HTTPException(status_code=400, detail="Broker credentials not configured")
redirect_url = (os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
if not redirect_url:
base = str(request.base_url).rstrip("/")
redirect_url = f"{base}/api/broker/callback"
login_url = build_login_url(creds["api_key"], redirect_url=redirect_url)
return RedirectResponse(login_url) return RedirectResponse(login_url)
@router.get("/login-url")
async def broker_login_url(request: Request):
user = _require_user(request)
return {"loginUrl": _build_saved_broker_login_url(request, user["id"])}
@router.get("/callback") @router.get("/callback")
async def broker_callback(request: Request, request_token: str = ""): async def broker_callback(request: Request, request_token: str = ""):
user = _require_user(request) user = _require_user(request)

View File

@ -4,6 +4,7 @@ from app.services.strategy_service import (
start_strategy, start_strategy,
stop_strategy, stop_strategy,
get_strategy_status, get_strategy_status,
get_strategy_summary,
get_engine_status, get_engine_status,
get_market_status, get_market_status,
get_strategy_logs as fetch_strategy_logs, get_strategy_logs as fetch_strategy_logs,
@ -27,6 +28,11 @@ def status(request: Request):
user_id = get_request_user_id(request) user_id = get_request_user_id(request)
return get_strategy_status(user_id) return get_strategy_status(user_id)
@router.get("/strategy/summary")
def summary(request: Request):
user_id = get_request_user_id(request)
return get_strategy_summary(user_id)
@router.get("/engine/status") @router.get("/engine/status")
def engine_status(request: Request): def engine_status(request: Request):
user_id = get_request_user_id(request) user_id = get_request_user_id(request)

View File

@ -11,11 +11,12 @@ if str(ENGINE_ROOT) not in sys.path:
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open
from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine
from indian_paper_trading_strategy.engine.state import init_paper_state, load_state from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state
from indian_paper_trading_strategy.engine.broker import PaperBroker from indian_paper_trading_strategy.engine.broker import PaperBroker
from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta
from indian_paper_trading_strategy.engine.db import engine_context from indian_paper_trading_strategy.engine.db import engine_context
from app.broker_store import get_user_broker, set_broker_auth_state
from app.services.db import db_connection from app.services.db import db_connection
from app.services.run_service import ( from app.services.run_service import (
create_strategy_run, create_strategy_run,
@ -25,6 +26,11 @@ from app.services.run_service import (
) )
from app.services.auth_service import get_user_by_id from app.services.auth_service import get_user_by_id
from app.services.email_service import send_email_async from app.services.email_service import send_email_async
from app.services.zerodha_service import (
KiteTokenError,
fetch_funds,
)
from app.services.zerodha_storage import get_session
from psycopg2.extras import Json from psycopg2.extras import Json
from psycopg2 import errors from psycopg2 import errors
@ -298,6 +304,27 @@ def validate_frequency(freq: dict, mode: str):
if unit == "days" and value < 1: if unit == "days" and value < 1:
raise ValueError("Minimum frequency is 1 day") raise ValueError("Minimum frequency is 1 day")
def _validate_live_broker_session(user_id: str):
broker_state = get_user_broker(user_id) or {}
broker_name = (broker_state.get("broker") or "").strip().upper()
if not broker_state.get("connected") or broker_name != "ZERODHA":
return False, broker_state, "broker_not_connected"
session = get_session(user_id)
if not session:
set_broker_auth_state(user_id, "EXPIRED")
return False, broker_state, "broker_auth_required"
try:
fetch_funds(session["api_key"], session["access_token"])
except KiteTokenError:
set_broker_auth_state(user_id, "EXPIRED")
return False, broker_state, "broker_auth_required"
set_broker_auth_state(user_id, "VALID")
return True, broker_state, "ok"
def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
if not last_run or not sip_frequency: if not last_run or not sip_frequency:
return None return None
@ -313,10 +340,27 @@ def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
next_dt = align_to_market_open(next_dt) next_dt = align_to_market_open(next_dt)
return next_dt.isoformat() return next_dt.isoformat()
def _last_execution_ts(state: dict, mode: str) -> str | None:
mode_key = (mode or "LIVE").strip().upper()
if mode_key == "LIVE":
return state.get("last_sip_ts")
return state.get("last_run") or state.get("last_sip_ts")
def start_strategy(req, user_id: str): def start_strategy(req, user_id: str):
engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"}
running_run_id = get_running_run_id(user_id) running_run_id = get_running_run_id(user_id)
if running_run_id: if running_run_id:
running_cfg = _load_config(user_id, running_run_id)
running_mode = (running_cfg.get("mode") or req.mode or "PAPER").strip().upper()
if running_mode == "LIVE":
is_valid, broker_state, failure_reason = _validate_live_broker_session(user_id)
if not is_valid:
return {
"status": "broker_auth_required",
"redirect_url": "/api/broker/login",
"broker": broker_state.get("broker"),
}
if engine_external: if engine_external:
return {"status": "already_running", "run_id": running_run_id} return {"status": "already_running", "run_id": running_run_id}
engine_config = _build_engine_config(user_id, running_run_id, req) engine_config = _build_engine_config(user_id, running_run_id, req)
@ -327,48 +371,77 @@ def start_strategy(req, user_id: str):
return {"status": "restarted", "run_id": running_run_id} return {"status": "restarted", "run_id": running_run_id}
return {"status": "already_running", "run_id": running_run_id} return {"status": "already_running", "run_id": running_run_id}
mode = (req.mode or "PAPER").strip().upper() mode = (req.mode or "PAPER").strip().upper()
if mode != "PAPER":
return {"status": "unsupported_mode"}
frequency_payload = req.sip_frequency.dict() if hasattr(req.sip_frequency, "dict") else dict(req.sip_frequency) frequency_payload = req.sip_frequency.dict() if hasattr(req.sip_frequency, "dict") else dict(req.sip_frequency)
validate_frequency(frequency_payload, mode) validate_frequency(frequency_payload, mode)
initial_cash = float(req.initial_cash) if req.initial_cash is not None else 1_000_000.0 if mode == "PAPER":
initial_cash = float(req.initial_cash) if req.initial_cash is not None else 1_000_000.0
broker_name = "paper"
elif mode == "LIVE":
is_valid, broker_state, failure_reason = _validate_live_broker_session(user_id)
if not is_valid:
return {
"status": "broker_auth_required",
"redirect_url": "/api/broker/login",
"broker": broker_state.get("broker"),
}
initial_cash = None
broker_name = ((broker_state.get("broker") or "ZERODHA").strip().lower())
else:
return {"status": "unsupported_mode"}
meta = {
"sip_amount": req.sip_amount,
"sip_frequency": frequency_payload,
}
if initial_cash is not None:
meta["initial_cash"] = initial_cash
try: try:
run_id = create_strategy_run( run_id = create_strategy_run(
user_id, user_id,
strategy=req.strategy_name, strategy=req.strategy_name,
mode=mode, mode=mode,
broker="paper", broker=broker_name,
meta={ meta=meta,
"sip_amount": req.sip_amount,
"sip_frequency": frequency_payload,
"initial_cash": initial_cash,
},
) )
except errors.UniqueViolation: except errors.UniqueViolation:
return {"status": "already_running"} return {"status": "already_running"}
with engine_context(user_id, run_id): with engine_context(user_id, run_id):
init_paper_state(initial_cash, frequency_payload) if mode == "PAPER":
with db_connection() as conn: init_paper_state(initial_cash, frequency_payload)
with conn: with db_connection() as conn:
with conn.cursor() as cur: with conn:
cur.execute( with conn.cursor() as cur:
""" cur.execute(
INSERT INTO paper_broker_account (user_id, run_id, cash) """
VALUES (%s, %s, %s) INSERT INTO paper_broker_account (user_id, run_id, cash)
ON CONFLICT (user_id, run_id) DO UPDATE VALUES (%s, %s, %s)
SET cash = EXCLUDED.cash ON CONFLICT (user_id, run_id) DO UPDATE
""", SET cash = EXCLUDED.cash
(user_id, run_id, initial_cash), """,
) (user_id, run_id, initial_cash),
PaperBroker(initial_cash=initial_cash) )
PaperBroker(initial_cash=initial_cash)
else:
save_state(
{
"total_invested": 0.0,
"nifty_units": 0.0,
"gold_units": 0.0,
"last_sip_ts": None,
"last_run": None,
},
mode="LIVE",
emit_event=True,
event_meta={"source": "live_start"},
)
config = { config = {
"strategy": req.strategy_name, "strategy": req.strategy_name,
"sip_amount": req.sip_amount, "sip_amount": req.sip_amount,
"sip_frequency": frequency_payload, "sip_frequency": frequency_payload,
"mode": mode, "mode": mode,
"broker": "paper", "broker": broker_name,
"active": True, "active": True,
} }
save_strategy_config(config, user_id, run_id) save_strategy_config(config, user_id, run_id)
@ -387,7 +460,8 @@ def start_strategy(req, user_id: str):
) )
engine_config = dict(config) engine_config = dict(config)
engine_config["initial_cash"] = initial_cash if initial_cash is not None:
engine_config["initial_cash"] = initial_cash
engine_config["run_id"] = run_id engine_config["run_id"] = run_id
engine_config["user_id"] = user_id engine_config["user_id"] = user_id
engine_config["emit_event"] = emit_event_cb engine_config["emit_event"] = emit_event_cb
@ -518,7 +592,7 @@ def get_strategy_status(user_id: str):
mode = (cfg.get("mode") or "LIVE").strip().upper() mode = (cfg.get("mode") or "LIVE").strip().upper()
with engine_context(user_id, run_id): with engine_context(user_id, run_id):
state = load_state(mode=mode) state = load_state(mode=mode)
last_execution_ts = state.get("last_run") or state.get("last_sip_ts") last_execution_ts = _last_execution_ts(state, mode)
sip_frequency = cfg.get("sip_frequency") sip_frequency = cfg.get("sip_frequency")
if not isinstance(sip_frequency, dict): if not isinstance(sip_frequency, dict):
frequency = cfg.get("frequency") frequency = cfg.get("frequency")
@ -578,7 +652,7 @@ def get_engine_status(user_id: str):
mode = (cfg.get("mode") or "LIVE").strip().upper() mode = (cfg.get("mode") or "LIVE").strip().upper()
with engine_context(user_id, run_id): with engine_context(user_id, run_id):
state = load_state(mode=mode) state = load_state(mode=mode)
last_execution_ts = state.get("last_run") or state.get("last_sip_ts") last_execution_ts = _last_execution_ts(state, mode)
sip_frequency = cfg.get("sip_frequency") sip_frequency = cfg.get("sip_frequency")
if isinstance(sip_frequency, dict): if isinstance(sip_frequency, dict):
sip_frequency = { sip_frequency = {
@ -642,6 +716,128 @@ def get_strategy_logs(user_id: str, since_seq: int):
latest_seq = cur.fetchone()[0] latest_seq = cur.fetchone()[0]
return {"events": events, "latest_seq": latest_seq} return {"events": events, "latest_seq": latest_seq}
def _humanize_reason(reason: str | None):
if not reason:
return None
return reason.replace("_", " ").strip().capitalize()
def _issue_message(event: str, message: str | None, data: dict | None, meta: dict | None):
payload = data if isinstance(data, dict) else {}
extra = meta if isinstance(meta, dict) else {}
reason = payload.get("reason") or extra.get("reason")
reason_key = str(reason or "").strip().lower()
if event == "SIP_NO_FILL":
if reason_key == "insufficient_funds":
return "Insufficient funds for this SIP."
if reason_key == "broker_auth_expired":
return "Broker session expired. Reconnect broker."
if reason_key == "no_fill":
return "Order was not filled."
return f"SIP not executed: {_humanize_reason(reason) or 'Unknown reason'}."
if event == "BROKER_AUTH_EXPIRED":
return "Broker session expired. Reconnect broker."
if event == "PRICE_FETCH_ERROR":
return "Could not fetch prices. Retrying."
if event == "HISTORY_LOAD_ERROR":
return "Could not load price history. Retrying."
if event == "ENGINE_ERROR":
return message or "Strategy engine hit an error."
if event == "EXECUTION_BLOCKED":
if reason_key == "market_closed":
return "Market is closed. Execution will resume next session."
return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}."
if event == "ORDER_REJECTED":
return message or payload.get("status_message") or "Broker rejected the order."
if event == "ORDER_CANCELLED":
return message or "Order was cancelled."
return message or _humanize_reason(reason) or "Strategy update available."
def get_strategy_summary(user_id: str):
run_id = get_active_run_id(user_id)
status = get_strategy_status(user_id)
next_eligible_ts = status.get("next_eligible_ts")
summary = {
"run_id": run_id,
"status": status.get("status"),
"tone": "neutral",
"message": "No active strategy.",
"event": None,
"ts": None,
}
issue_row = None
if run_id:
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT event, message, data, meta, ts
FROM engine_event
WHERE user_id = %s
AND run_id = %s
AND event IN (
'SIP_NO_FILL',
'BROKER_AUTH_EXPIRED',
'PRICE_FETCH_ERROR',
'HISTORY_LOAD_ERROR',
'ENGINE_ERROR',
'EXECUTION_BLOCKED',
'ORDER_REJECTED',
'ORDER_CANCELLED'
)
ORDER BY ts DESC
LIMIT 1
""",
(user_id, run_id),
)
issue_row = cur.fetchone()
if issue_row:
event, message, data, meta, ts = issue_row
summary.update(
{
"tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning",
"message": _issue_message(event, message, data, meta),
"event": event,
"ts": _format_local_ts(ts),
}
)
return summary
status_key = (status.get("status") or "IDLE").upper()
if status_key == "WAITING" and next_eligible_ts:
summary.update(
{
"tone": "warning",
"message": f"Waiting until {next_eligible_ts}.",
}
)
return summary
if status_key == "RUNNING":
summary.update(
{
"tone": "success",
"message": "Strategy is running.",
}
)
return summary
if status_key == "STOPPED":
summary.update(
{
"tone": "neutral",
"message": "Strategy is stopped.",
}
)
return summary
return summary
def get_market_status(): def get_market_status():
now = datetime.now() now = datetime.now()
return { return {

View File

@ -23,6 +23,10 @@ class KiteTokenError(KiteApiError):
pass pass
class KitePermissionError(KiteApiError):
pass
def build_login_url(api_key: str, redirect_url: str | None = None) -> str: def build_login_url(api_key: str, redirect_url: str | None = None) -> str:
params = {"api_key": api_key, "v": KITE_VERSION} params = {"api_key": api_key, "v": KITE_VERSION}
redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip() redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
@ -48,7 +52,12 @@ def _request(method: str, url: str, data: dict | None = None, headers: dict | No
payload = {} payload = {}
error_type = payload.get("error_type") or payload.get("status") or "unknown_error" error_type = payload.get("error_type") or payload.get("status") or "unknown_error"
message = payload.get("message") or error_body or err.reason message = payload.get("message") or error_body or err.reason
exc_cls = KiteTokenError if error_type == "TokenException" else KiteApiError if error_type == "TokenException":
exc_cls = KiteTokenError
elif error_type == "PermissionException":
exc_cls = KitePermissionError
else:
exc_cls = KiteApiError
raise exc_cls(err.code, error_type, message) from err raise exc_cls(err.code, error_type, message) from err
return json.loads(body) return json.loads(body)
@ -87,3 +96,112 @@ def fetch_funds(api_key: str, access_token: str) -> dict:
url = f"{KITE_API_BASE}/user/margins" url = f"{KITE_API_BASE}/user/margins"
response = _request("GET", url, headers=_auth_headers(api_key, access_token)) response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", {}) return response.get("data", {})
def fetch_ltp_quotes(api_key: str, access_token: str, instruments: list[str]) -> dict:
symbols = [str(item).strip() for item in instruments if str(item).strip()]
if not symbols:
return {}
query = urllib.parse.urlencode([("i", symbol) for symbol in symbols])
url = f"{KITE_API_BASE}/quote/ltp?{query}"
response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", {})
def fetch_ohlc_quotes(api_key: str, access_token: str, instruments: list[str]) -> dict:
symbols = [str(item).strip() for item in instruments if str(item).strip()]
if not symbols:
return {}
query = urllib.parse.urlencode([("i", symbol) for symbol in symbols])
url = f"{KITE_API_BASE}/quote/ohlc?{query}"
response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", {})
def fetch_historical_candles(
api_key: str,
access_token: str,
instrument_token: int | str,
interval: str,
*,
from_dt,
to_dt,
continuous: bool = False,
oi: bool = False,
) -> list:
params = {
"from": from_dt.strftime("%Y-%m-%d %H:%M:%S"),
"to": to_dt.strftime("%Y-%m-%d %H:%M:%S"),
"continuous": 1 if continuous else 0,
"oi": 1 if oi else 0,
}
query = urllib.parse.urlencode(params)
url = f"{KITE_API_BASE}/instruments/historical/{instrument_token}/{interval}?{query}"
response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", {}).get("candles", [])
def place_order(
api_key: str,
access_token: str,
*,
tradingsymbol: str,
exchange: str,
transaction_type: str,
order_type: str,
quantity: int,
product: str,
price: float | None = None,
validity: str = "DAY",
variety: str = "regular",
market_protection: int | None = None,
tag: str | None = None,
) -> dict:
payload = {
"tradingsymbol": tradingsymbol,
"exchange": exchange,
"transaction_type": transaction_type,
"order_type": order_type,
"quantity": int(quantity),
"product": product,
"validity": validity,
}
if price is not None:
payload["price"] = price
if market_protection is not None:
payload["market_protection"] = market_protection
if tag:
payload["tag"] = tag
url = f"{KITE_API_BASE}/orders/{variety}"
response = _request(
"POST",
url,
data=payload,
headers=_auth_headers(api_key, access_token),
)
return response.get("data", {})
def fetch_orders(api_key: str, access_token: str) -> list:
url = f"{KITE_API_BASE}/orders"
response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", [])
def fetch_order_history(api_key: str, access_token: str, order_id: str) -> list:
url = f"{KITE_API_BASE}/orders/{order_id}"
response = _request("GET", url, headers=_auth_headers(api_key, access_token))
return response.get("data", [])
def cancel_order(
api_key: str,
access_token: str,
*,
order_id: str,
variety: str = "regular",
) -> dict:
url = f"{KITE_API_BASE}/orders/{variety}/{order_id}"
response = _request("DELETE", url, headers=_auth_headers(api_key, access_token))
return response.get("data", {})

View File

@ -4,6 +4,9 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
import hashlib import hashlib
import math
import os
import time
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
@ -12,6 +15,8 @@ from indian_paper_trading_strategy.engine.db import db_connection, insert_engine
class Broker(ABC): class Broker(ABC):
external_orders = False
@abstractmethod @abstractmethod
def place_order( def place_order(
self, self,
@ -36,6 +41,14 @@ class Broker(ABC):
raise NotImplementedError raise NotImplementedError
class BrokerError(Exception):
pass
class BrokerAuthExpired(BrokerError):
pass
def _local_tz(): def _local_tz():
return datetime.now().astimezone().tzinfo return datetime.now().astimezone().tzinfo
@ -102,6 +115,326 @@ def _resolve_scope(user_id: str | None, run_id: str | None):
return get_context(user_id, run_id) return get_context(user_id, run_id)
class LiveZerodhaBroker(Broker):
external_orders = True
TERMINAL_STATUSES = {"COMPLETE", "REJECTED", "CANCELLED"}
POLL_TIMEOUT_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_TIMEOUT", "12"))
POLL_INTERVAL_SECONDS = float(os.getenv("ZERODHA_ORDER_POLL_INTERVAL", "1"))
def __init__(self, user_id: str | None = None, run_id: str | None = None):
self.user_id = user_id
self.run_id = run_id
def _scope(self):
return _resolve_scope(self.user_id, self.run_id)
def _session(self):
from app.services.zerodha_storage import get_session
user_id, _run_id = self._scope()
session = get_session(user_id)
if not session or not session.get("api_key") or not session.get("access_token"):
raise BrokerAuthExpired("Zerodha session missing. Please reconnect broker.")
return session
def _raise_auth_expired(self, exc: Exception):
from app.broker_store import set_broker_auth_state
user_id, _run_id = self._scope()
set_broker_auth_state(user_id, "EXPIRED")
raise BrokerAuthExpired(str(exc)) from exc
def _normalize_symbol(self, symbol: str) -> tuple[str, str]:
cleaned = (symbol or "").strip().upper()
if cleaned.endswith(".NS"):
return cleaned[:-3], "NSE"
if cleaned.endswith(".BO"):
return cleaned[:-3], "BSE"
return cleaned, "NSE"
def _make_tag(self, logical_time: datetime | None, symbol: str, side: str) -> str:
user_id, run_id = self._scope()
logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)
digest = hashlib.sha1(
f"{user_id}|{run_id}|{_normalize_ts_for_id(logical_ts)}|{symbol}|{side}".encode("utf-8")
).hexdigest()[:18]
return f"qf{digest}"
def _normalize_order_payload(
self,
*,
order_id: str,
symbol: str,
side: str,
requested_qty: int,
requested_price: float | None,
history_entry: dict | None,
logical_time: datetime | None,
) -> dict:
entry = history_entry or {}
raw_status = (entry.get("status") or "").upper()
status = raw_status
if raw_status == "COMPLETE":
status = "FILLED"
elif raw_status in {"REJECTED", "CANCELLED"}:
status = raw_status
elif raw_status:
status = "PENDING"
else:
status = "PENDING"
quantity = int(entry.get("quantity") or requested_qty or 0)
filled_qty = int(entry.get("filled_quantity") or 0)
average_price = float(entry.get("average_price") or requested_price or 0.0)
price = float(entry.get("price") or requested_price or average_price or 0.0)
timestamp = (
entry.get("exchange_timestamp")
or entry.get("order_timestamp")
or _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc))
)
if timestamp and " " in str(timestamp):
timestamp = str(timestamp).replace(" ", "T")
return {
"id": order_id,
"symbol": symbol,
"side": side.upper().strip(),
"qty": quantity,
"requested_qty": quantity,
"filled_qty": filled_qty,
"price": price,
"requested_price": float(requested_price or price or 0.0),
"average_price": average_price,
"status": status,
"timestamp": timestamp,
"broker_order_id": order_id,
"exchange": entry.get("exchange"),
"tradingsymbol": entry.get("tradingsymbol"),
"status_message": entry.get("status_message") or entry.get("status_message_raw"),
}
def _wait_for_terminal_order(
self,
session: dict,
order_id: str,
*,
symbol: str,
side: str,
requested_qty: int,
requested_price: float | None,
logical_time: datetime | None,
) -> dict:
from app.services.zerodha_service import (
KiteTokenError,
cancel_order,
fetch_order_history,
)
started = time.monotonic()
last_payload = self._normalize_order_payload(
order_id=order_id,
symbol=symbol,
side=side,
requested_qty=requested_qty,
requested_price=requested_price,
history_entry=None,
logical_time=logical_time,
)
while True:
try:
history = fetch_order_history(
session["api_key"],
session["access_token"],
order_id,
)
except KiteTokenError as exc:
self._raise_auth_expired(exc)
if history:
entry = history[-1]
last_payload = self._normalize_order_payload(
order_id=order_id,
symbol=symbol,
side=side,
requested_qty=requested_qty,
requested_price=requested_price,
history_entry=entry,
logical_time=logical_time,
)
raw_status = (entry.get("status") or "").upper()
if raw_status in self.TERMINAL_STATUSES:
return last_payload
if time.monotonic() - started >= self.POLL_TIMEOUT_SECONDS:
try:
cancel_order(
session["api_key"],
session["access_token"],
order_id=order_id,
)
history = fetch_order_history(
session["api_key"],
session["access_token"],
order_id,
)
if history:
return self._normalize_order_payload(
order_id=order_id,
symbol=symbol,
side=side,
requested_qty=requested_qty,
requested_price=requested_price,
history_entry=history[-1],
logical_time=logical_time,
)
except KiteTokenError as exc:
self._raise_auth_expired(exc)
return last_payload
time.sleep(self.POLL_INTERVAL_SECONDS)
def get_funds(self, cur=None):
from app.services.zerodha_service import KiteTokenError, fetch_funds
session = self._session()
try:
data = fetch_funds(session["api_key"], session["access_token"])
except KiteTokenError as exc:
self._raise_auth_expired(exc)
equity = data.get("equity", {}) if isinstance(data, dict) else {}
available = equity.get("available", {}) if isinstance(equity, dict) else {}
cash = (
available.get("live_balance")
or available.get("cash")
or available.get("opening_balance")
or equity.get("net")
or 0.0
)
return {"cash": float(cash), "raw": data}
def get_positions(self):
from app.services.zerodha_service import KiteTokenError, fetch_holdings
session = self._session()
try:
holdings = fetch_holdings(session["api_key"], session["access_token"])
except KiteTokenError as exc:
self._raise_auth_expired(exc)
normalized = []
for item in holdings:
qty = float(item.get("quantity") or item.get("qty") or 0)
avg_price = float(item.get("average_price") or 0)
last_price = float(item.get("last_price") or avg_price or 0)
exchange = item.get("exchange")
suffix = ".NS" if exchange == "NSE" else ".BO" if exchange == "BSE" else ""
normalized.append(
{
"symbol": f"{item.get('tradingsymbol')}{suffix}",
"qty": qty,
"avg_price": avg_price,
"last_price": last_price,
}
)
return normalized
def get_orders(self):
from app.services.zerodha_service import KiteTokenError, fetch_orders
session = self._session()
try:
return fetch_orders(session["api_key"], session["access_token"])
except KiteTokenError as exc:
self._raise_auth_expired(exc)
def update_equity(
self,
prices: dict[str, float],
now: datetime,
cur=None,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
return None
def place_order(
self,
symbol: str,
side: str,
quantity: float,
price: float | None = None,
cur=None,
logical_time: datetime | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
from app.services.zerodha_service import KiteTokenError, place_order
if user_id is not None:
self.user_id = user_id
if run_id is not None:
self.run_id = run_id
qty = int(math.floor(float(quantity)))
side = side.upper().strip()
requested_price = float(price) if price is not None else None
if qty <= 0:
return {
"id": _deterministic_id("live_rej", [symbol, side, _stable_num(quantity)]),
"symbol": symbol,
"side": side,
"qty": qty,
"requested_qty": qty,
"filled_qty": 0,
"price": float(price or 0.0),
"requested_price": float(price or 0.0),
"average_price": 0.0,
"status": "REJECTED",
"timestamp": _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)),
"status_message": "Computed quantity is less than 1 share",
}
session = self._session()
tradingsymbol, exchange = self._normalize_symbol(symbol)
tag = self._make_tag(logical_time, symbol, side)
try:
placed = place_order(
session["api_key"],
session["access_token"],
tradingsymbol=tradingsymbol,
exchange=exchange,
transaction_type=side,
order_type="MARKET",
quantity=qty,
product="CNC",
validity="DAY",
variety="regular",
market_protection=-1,
tag=tag,
)
except KiteTokenError as exc:
self._raise_auth_expired(exc)
order_id = placed.get("order_id")
if not order_id:
raise BrokerError("Zerodha order placement did not return an order_id")
return self._wait_for_terminal_order(
session,
order_id,
symbol=symbol,
side=side,
requested_qty=qty,
requested_price=requested_price,
logical_time=logical_time,
)
@dataclass @dataclass
class PaperBroker(Broker): class PaperBroker(Broker):
initial_cash: float initial_cash: float
@ -583,7 +916,11 @@ class PaperBroker(Broker):
"symbol": symbol, "symbol": symbol,
"side": side, "side": side,
"qty": qty, "qty": qty,
"requested_qty": qty,
"filled_qty": 0.0,
"price": price, "price": price,
"requested_price": price,
"average_price": 0.0,
"status": "REJECTED", "status": "REJECTED",
"timestamp": timestamp_str, "timestamp": timestamp_str,
"_logical_time": logical_ts_str, "_logical_time": logical_ts_str,
@ -613,6 +950,8 @@ class PaperBroker(Broker):
"last_price": price, "last_price": price,
} }
order["status"] = "FILLED" order["status"] = "FILLED"
order["filled_qty"] = qty
order["average_price"] = price
trade = { trade = {
"id": _deterministic_id("trd", [order_id]), "id": _deterministic_id("trd", [order_id]),
"order_id": order_id, "order_id": order_id,
@ -636,6 +975,8 @@ class PaperBroker(Broker):
else: else:
positions.pop(symbol, None) positions.pop(symbol, None)
order["status"] = "FILLED" order["status"] = "FILLED"
order["filled_qty"] = qty
order["average_price"] = price
trade = { trade = {
"id": _deterministic_id("trd", [order_id]), "id": _deterministic_id("trd", [order_id]),
"order_id": order_id, "order_id": order_id,

View File

@ -15,10 +15,27 @@ _LAST_PRICE: dict[str, dict[str, object]] = {}
_LAST_PRICE_LOCK = threading.Lock() _LAST_PRICE_LOCK = threading.Lock()
def _set_last_price(ticker: str, price: float, source: str): def _history_cache_file(ticker: str, provider: str = "yfinance") -> Path:
safe_ticker = (ticker or "").replace(":", "_").replace("/", "_")
return HISTORY_DIR / f"{safe_ticker}.csv"
def _set_last_price(
ticker: str,
price: float,
source: str,
*,
provider: str | None = None,
instrument_token: int | None = None,
):
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
with _LAST_PRICE_LOCK: with _LAST_PRICE_LOCK:
_LAST_PRICE[ticker] = {"price": float(price), "source": source, "ts": now} payload = {"price": float(price), "source": source, "ts": now}
if provider:
payload["provider"] = provider
if instrument_token is not None:
payload["instrument_token"] = int(instrument_token)
_LAST_PRICE[ticker] = payload
def get_price_snapshot(ticker: str) -> dict[str, object] | None: def get_price_snapshot(ticker: str) -> dict[str, object] | None:
@ -29,18 +46,20 @@ def get_price_snapshot(ticker: str) -> dict[str, object] | None:
return dict(data) return dict(data)
def _get_last_live_price(ticker: str) -> float | None: def _get_last_live_price(ticker: str, provider: str | None = None) -> float | None:
with _LAST_PRICE_LOCK: with _LAST_PRICE_LOCK:
data = _LAST_PRICE.get(ticker) data = _LAST_PRICE.get(ticker)
if not data: if not data:
return None return None
if data.get("source") == "live": if data.get("source") == "live":
if provider and data.get("provider") not in {None, provider}:
return None
return float(data.get("price", 0)) return float(data.get("price", 0))
return None return None
def _cached_last_close(ticker: str) -> float | None: def _cached_last_close(ticker: str, provider: str = "yfinance") -> float | None:
file = HISTORY_DIR / f"{ticker}.csv" file = _history_cache_file(ticker, provider=provider)
if not file.exists(): if not file.exists():
return None return None
df = pd.read_csv(file) df = pd.read_csv(file)
@ -49,7 +68,14 @@ def _cached_last_close(ticker: str) -> float | None:
return float(df["Close"].iloc[-1]) return float(df["Close"].iloc[-1])
def fetch_live_price(ticker, allow_cache: bool | None = None): def fetch_live_price(
ticker,
allow_cache: bool | None = None,
*,
provider: str = "yfinance",
user_id: str | None = None,
run_id: str | None = None,
):
if allow_cache is None: if allow_cache is None:
allow_cache = ALLOW_PRICE_CACHE allow_cache = ALLOW_PRICE_CACHE
try: try:
@ -62,20 +88,23 @@ def fetch_live_price(ticker, allow_cache: bool | None = None):
timeout=5, timeout=5,
) )
if df is not None and not df.empty: if df is not None and not df.empty:
price = float(df["Close"].iloc[-1]) close_value = df["Close"].iloc[-1]
_set_last_price(ticker, price, "live") if hasattr(close_value, "iloc"):
close_value = close_value.iloc[-1]
price = float(close_value)
_set_last_price(ticker, price, "live", provider="yfinance")
return price return price
except Exception: except Exception:
pass pass
if allow_cache: if allow_cache:
last_live = _get_last_live_price(ticker) last_live = _get_last_live_price(ticker, provider="yfinance")
if last_live is not None: if last_live is not None:
return last_live return last_live
cached = _cached_last_close(ticker) cached = _cached_last_close(ticker, provider="yfinance")
if cached is not None: if cached is not None:
_set_last_price(ticker, cached, "cache") _set_last_price(ticker, cached, "cache", provider="yfinance")
return cached return cached
raise RuntimeError(f"No live data for {ticker}") raise RuntimeError(f"No live data for {ticker}")

View File

@ -2,9 +2,9 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.state import load_state, save_state from indian_paper_trading_strategy.engine.state import load_state, save_state
from indian_paper_trading_strategy.engine.broker import Broker from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired
from indian_paper_trading_strategy.engine.ledger import log_event, event_exists from indian_paper_trading_strategy.engine.ledger import log_event, event_exists
from indian_paper_trading_strategy.engine.db import run_with_retry from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry
from indian_paper_trading_strategy.engine.time_utils import compute_logical_time from indian_paper_trading_strategy.engine.time_utils import compute_logical_time
def _as_float(value): def _as_float(value):
@ -23,7 +23,98 @@ def _as_float(value):
def _local_tz(): def _local_tz():
return datetime.now().astimezone().tzinfo return datetime.now().astimezone().tzinfo
def try_execute_sip(
def _normalize_now(now):
if now.tzinfo is None:
return now.replace(tzinfo=_local_tz())
return now
def _resolve_timing(state, now_ts, sip_interval):
force_execute = state.get("last_sip_ts") is None
last = state.get("last_sip_ts") or state.get("last_run")
if last and not force_execute:
try:
last_dt = datetime.fromisoformat(last)
except ValueError:
last_dt = None
if last_dt:
if last_dt.tzinfo is None:
last_dt = last_dt.replace(tzinfo=_local_tz())
if now_ts.tzinfo and last_dt.tzinfo and last_dt.tzinfo != now_ts.tzinfo:
last_dt = last_dt.astimezone(now_ts.tzinfo)
if last_dt and (now_ts - last_dt).total_seconds() < sip_interval:
return False, last, None
logical_time = compute_logical_time(now_ts, last, sip_interval)
return True, last, logical_time
def _order_fill(order):
if not isinstance(order, dict):
return 0.0, 0.0
filled_qty = float(order.get("filled_qty") or 0.0)
average_price = float(order.get("average_price") or order.get("price") or 0.0)
return filled_qty, average_price
def _apply_filled_orders_to_state(state, orders):
nifty_filled = 0.0
gold_filled = 0.0
total_spent = 0.0
for order in orders:
filled_qty, average_price = _order_fill(order)
if filled_qty <= 0:
continue
symbol = (order.get("symbol") or "").upper()
if symbol.startswith("NIFTYBEES"):
nifty_filled += filled_qty
elif symbol.startswith("GOLDBEES"):
gold_filled += filled_qty
total_spent += filled_qty * average_price
if nifty_filled:
state["nifty_units"] += nifty_filled
if gold_filled:
state["gold_units"] += gold_filled
if total_spent:
state["total_invested"] += total_spent
return {
"nifty_units": nifty_filled,
"gold_units": gold_filled,
"amount": total_spent,
}
def _record_live_order_events(cur, orders, event_ts):
for order in orders:
insert_engine_event(cur, "ORDER_PLACED", data=order, ts=event_ts)
filled_qty, _average_price = _order_fill(order)
status = (order.get("status") or "").upper()
if filled_qty > 0:
insert_engine_event(
cur,
"TRADE_EXECUTED",
data={
"order_id": order.get("id"),
"symbol": order.get("symbol"),
"side": order.get("side"),
"qty": filled_qty,
"price": order.get("average_price") or order.get("price"),
},
ts=event_ts,
)
insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order.get("id")}, ts=event_ts)
elif status == "REJECTED":
insert_engine_event(cur, "ORDER_REJECTED", data=order, ts=event_ts)
elif status == "CANCELLED":
insert_engine_event(cur, "ORDER_CANCELLED", data=order, ts=event_ts)
elif status == "PENDING":
insert_engine_event(cur, "ORDER_PENDING", data=order, ts=event_ts)
def _try_execute_sip_paper(
now, now,
market_open, market_open,
sip_interval, sip_interval,
@ -32,41 +123,22 @@ def try_execute_sip(
gd_price, gd_price,
eq_w, eq_w,
gd_w, gd_w,
broker: Broker | None = None, broker: Broker | None,
mode: str | None = "LIVE", mode: str | None,
): ):
def _op(cur, _conn): def _op(cur, _conn):
if now.tzinfo is None: now_ts = _normalize_now(now)
now_ts = now.replace(tzinfo=_local_tz())
else:
now_ts = now
event_ts = now_ts event_ts = now_ts
log_event("DEBUG_ENTER_TRY_EXECUTE", { log_event("DEBUG_ENTER_TRY_EXECUTE", {"now": now_ts.isoformat()}, cur=cur, ts=event_ts)
"now": now_ts.isoformat(),
}, cur=cur, ts=event_ts)
state = load_state(mode=mode, cur=cur, for_update=True) state = load_state(mode=mode, cur=cur, for_update=True)
force_execute = state.get("last_sip_ts") is None
if not market_open: if not market_open:
return state, False return state, False
last = state.get("last_sip_ts") or state.get("last_run") should_run, _last, logical_time = _resolve_timing(state, now_ts, sip_interval)
if last and not force_execute: if not should_run:
try: return state, False
last_dt = datetime.fromisoformat(last)
except ValueError:
last_dt = None
if last_dt:
if last_dt.tzinfo is None:
last_dt = last_dt.replace(tzinfo=_local_tz())
if now_ts.tzinfo and last_dt.tzinfo and last_dt.tzinfo != now_ts.tzinfo:
last_dt = last_dt.astimezone(now_ts.tzinfo)
if last_dt and (now_ts - last_dt).total_seconds() < sip_interval:
return state, False
logical_time = compute_logical_time(now_ts, last, sip_interval)
if event_exists("SIP_EXECUTED", logical_time, cur=cur): if event_exists("SIP_EXECUTED", logical_time, cur=cur):
return state, False return state, False
@ -87,45 +159,45 @@ def try_execute_sip(
if cash is not None and float(cash) < sip_amount_val: if cash is not None and float(cash) < sip_amount_val:
return state, False return state, False
log_event("DEBUG_EXECUTION_DECISION", { log_event(
"force_execute": force_execute, "DEBUG_EXECUTION_DECISION",
"last_sip_ts": state.get("last_sip_ts"), {
"now": now_ts.isoformat(), "last_sip_ts": state.get("last_sip_ts"),
}, cur=cur, ts=event_ts) "now": now_ts.isoformat(),
},
cur=cur,
ts=event_ts,
)
nifty_order = broker.place_order( orders = [
"NIFTYBEES.NS", broker.place_order(
"BUY", "NIFTYBEES.NS",
nifty_qty, "BUY",
sp_price_val, nifty_qty,
cur=cur, sp_price_val,
logical_time=logical_time, cur=cur,
) logical_time=logical_time,
gold_order = broker.place_order( ),
"GOLDBEES.NS", broker.place_order(
"BUY", "GOLDBEES.NS",
gold_qty, "BUY",
gd_price_val, gold_qty,
cur=cur, gd_price_val,
logical_time=logical_time, cur=cur,
) logical_time=logical_time,
orders = [nifty_order, gold_order] ),
executed = all( ]
isinstance(order, dict) and order.get("status") == "FILLED"
for order in orders applied = _apply_filled_orders_to_state(state, orders)
) executed = applied["amount"] > 0
if not executed: if not executed:
return state, False return state, False
assert len(orders) > 0, "executed=True but no broker orders placed"
funds_after = broker.get_funds(cur=cur) funds_after = broker.get_funds(cur=cur)
cash_after = funds_after.get("cash") cash_after = funds_after.get("cash")
if cash_after is not None: if cash_after is not None:
state["cash"] = float(cash_after) state["cash"] = float(cash_after)
state["nifty_units"] += nifty_qty
state["gold_units"] += gold_qty
state["total_invested"] += sip_amount_val
state["last_sip_ts"] = now_ts.isoformat() state["last_sip_ts"] = now_ts.isoformat()
state["last_run"] = now_ts.isoformat() state["last_run"] = now_ts.isoformat()
@ -140,11 +212,11 @@ def try_execute_sip(
log_event( log_event(
"SIP_EXECUTED", "SIP_EXECUTED",
{ {
"nifty_units": nifty_qty, "nifty_units": applied["nifty_units"],
"gold_units": gold_qty, "gold_units": applied["gold_units"],
"nifty_price": sp_price_val, "nifty_price": sp_price_val,
"gold_price": gd_price_val, "gold_price": gd_price_val,
"amount": sip_amount_val, "amount": applied["amount"],
}, },
cur=cur, cur=cur,
ts=event_ts, ts=event_ts,
@ -155,3 +227,250 @@ def try_execute_sip(
return run_with_retry(_op) return run_with_retry(_op)
def _prepare_live_execution(now_ts, sip_interval, sip_amount_val, sp_price_val, gd_price_val, nifty_qty, gold_qty, mode):
def _op(cur, _conn):
state = load_state(mode=mode, cur=cur, for_update=True)
should_run, _last, logical_time = _resolve_timing(state, now_ts, sip_interval)
if not should_run:
return {"ready": False, "state": state}
if event_exists("SIP_EXECUTED", logical_time, cur=cur):
return {"ready": False, "state": state}
if event_exists("SIP_ORDER_ATTEMPTED", logical_time, cur=cur):
return {"ready": False, "state": state}
log_event(
"SIP_ORDER_ATTEMPTED",
{
"nifty_units": nifty_qty,
"gold_units": gold_qty,
"nifty_price": sp_price_val,
"gold_price": gd_price_val,
"amount": sip_amount_val,
},
cur=cur,
ts=now_ts,
logical_time=logical_time,
)
return {"ready": True, "state": state, "logical_time": logical_time}
return run_with_retry(_op)
def _finalize_live_execution(
*,
now_ts,
mode,
logical_time,
orders,
funds_after,
sp_price_val,
gd_price_val,
auth_failed: bool = False,
failure_reason: str | None = None,
):
def _op(cur, _conn):
state = load_state(mode=mode, cur=cur, for_update=True)
_record_live_order_events(cur, orders, now_ts)
applied = _apply_filled_orders_to_state(state, orders)
executed = applied["amount"] > 0
if funds_after is not None:
cash_after = funds_after.get("cash")
if cash_after is not None:
state["cash"] = float(cash_after)
if executed:
state["last_run"] = now_ts.isoformat()
state["last_sip_ts"] = now_ts.isoformat()
save_state(
state,
mode=mode,
cur=cur,
emit_event=True,
event_meta={"source": "sip_live"},
)
if executed:
log_event(
"SIP_EXECUTED",
{
"nifty_units": applied["nifty_units"],
"gold_units": applied["gold_units"],
"nifty_price": sp_price_val,
"gold_price": gd_price_val,
"amount": applied["amount"],
},
cur=cur,
ts=now_ts,
logical_time=logical_time,
)
else:
insert_engine_event(
cur,
"SIP_NO_FILL",
data={
"reason": failure_reason or ("broker_auth_expired" if auth_failed else "no_fill"),
"orders": orders,
},
ts=now_ts,
)
return state, executed
return run_with_retry(_op)
def _try_execute_sip_live(
now,
market_open,
sip_interval,
sip_amount,
sp_price,
gd_price,
eq_w,
gd_w,
broker: Broker | None,
mode: str | None,
):
now_ts = _normalize_now(now)
if not market_open or broker is None:
return load_state(mode=mode), False
sp_price_val = _as_float(sp_price)
gd_price_val = _as_float(gd_price)
eq_w_val = _as_float(eq_w)
gd_w_val = _as_float(gd_w)
sip_amount_val = _as_float(sip_amount)
nifty_qty = (sip_amount_val * eq_w_val) / sp_price_val
gold_qty = (sip_amount_val * gd_w_val) / gd_price_val
prepared = _prepare_live_execution(
now_ts,
sip_interval,
sip_amount_val,
sp_price_val,
gd_price_val,
nifty_qty,
gold_qty,
mode,
)
if not prepared.get("ready"):
return prepared.get("state") or load_state(mode=mode), False
logical_time = prepared["logical_time"]
orders = []
funds_after = None
failure_reason = None
auth_failed = False
try:
funds_before = broker.get_funds()
cash = funds_before.get("cash")
if cash is not None and float(cash) < sip_amount_val:
failure_reason = "insufficient_funds"
else:
if nifty_qty > 0:
orders.append(
broker.place_order(
"NIFTYBEES.NS",
"BUY",
nifty_qty,
sp_price_val,
logical_time=logical_time,
)
)
if gold_qty > 0:
orders.append(
broker.place_order(
"GOLDBEES.NS",
"BUY",
gold_qty,
gd_price_val,
logical_time=logical_time,
)
)
funds_after = broker.get_funds()
except BrokerAuthExpired:
auth_failed = True
try:
funds_after = broker.get_funds()
except Exception:
funds_after = None
except Exception as exc:
failure_reason = str(exc)
try:
funds_after = broker.get_funds()
except Exception:
funds_after = None
state, _executed = _finalize_live_execution(
now_ts=now_ts,
mode=mode,
logical_time=logical_time,
orders=orders,
funds_after=funds_after,
sp_price_val=sp_price_val,
gd_price_val=gd_price_val,
auth_failed=False,
failure_reason=failure_reason,
)
raise
state, executed = _finalize_live_execution(
now_ts=now_ts,
mode=mode,
logical_time=logical_time,
orders=orders,
funds_after=funds_after,
sp_price_val=sp_price_val,
gd_price_val=gd_price_val,
auth_failed=auth_failed,
failure_reason=failure_reason,
)
if auth_failed:
raise BrokerAuthExpired("Broker session expired during live order execution")
return state, executed
def try_execute_sip(
now,
market_open,
sip_interval,
sip_amount,
sp_price,
gd_price,
eq_w,
gd_w,
broker: Broker | None = None,
mode: str | None = "LIVE",
):
if broker is None:
return load_state(mode=mode), False
if getattr(broker, "external_orders", False):
return _try_execute_sip_live(
now,
market_open,
sip_interval,
sip_amount,
sp_price,
gd_price,
eq_w,
gd_w,
broker,
mode,
)
return _try_execute_sip_paper(
now,
market_open,
sip_interval,
sip_amount,
sp_price,
gd_price,
eq_w,
gd_w,
broker,
mode,
)

View File

@ -10,12 +10,26 @@ STORAGE_DIR.mkdir(exist_ok=True)
CACHE_DIR = STORAGE_DIR / "history" CACHE_DIR = STORAGE_DIR / "history"
CACHE_DIR.mkdir(exist_ok=True) CACHE_DIR.mkdir(exist_ok=True)
def load_monthly_close(ticker, years=10): def _cache_file(ticker: str, provider: str = "yfinance") -> Path:
file = CACHE_DIR / f"{ticker}.csv" safe_ticker = (ticker or "").replace(":", "_").replace("/", "_")
return CACHE_DIR / f"{safe_ticker}.csv"
def _read_monthly_close(file: Path):
df = pd.read_csv(file, parse_dates=["Date"], index_col="Date")
return df["Close"]
def load_monthly_close(
ticker,
years=10,
*,
provider: str = "yfinance",
user_id: str | None = None,
run_id: str | None = None,
):
file = _cache_file(ticker, provider="yfinance")
if file.exists(): if file.exists():
df = pd.read_csv(file, parse_dates=["Date"], index_col="Date") return _read_monthly_close(file)
return df["Close"]
df = yf.download( df = yf.download(
ticker, ticker,
@ -28,7 +42,7 @@ def load_monthly_close(ticker, years=10):
if df.empty: if df.empty:
raise RuntimeError(f"No history for {ticker}") raise RuntimeError(f"No history for {ticker}")
series = df["Close"].resample("M").last() series = df["Close"].resample("ME").last()
series.to_csv(file, header=["Close"]) series.to_csv(file, header=["Close"])
return series return series

View File

@ -3,15 +3,18 @@ import threading
import time import time
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from psycopg2.extras import Json
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open
from indian_paper_trading_strategy.engine.execution import try_execute_sip from indian_paper_trading_strategy.engine.execution import try_execute_sip
from indian_paper_trading_strategy.engine.broker import PaperBroker from indian_paper_trading_strategy.engine.broker import PaperBroker, LiveZerodhaBroker, BrokerAuthExpired
from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm
from indian_paper_trading_strategy.engine.state import load_state from indian_paper_trading_strategy.engine.state import load_state
from indian_paper_trading_strategy.engine.data import fetch_live_price from indian_paper_trading_strategy.engine.data import fetch_live_price
from indian_paper_trading_strategy.engine.history import load_monthly_close from indian_paper_trading_strategy.engine.history import load_monthly_close
from indian_paper_trading_strategy.engine.strategy import allocation from indian_paper_trading_strategy.engine.strategy import allocation
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
from app.services.zerodha_service import KiteTokenError
from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context
@ -139,6 +142,69 @@ def can_execute(now: datetime) -> tuple[bool, str]:
return False, "MARKET_CLOSED" return False, "MARKET_CLOSED"
return True, "OK" return True, "OK"
def _last_execution_anchor(state: dict, mode: str) -> str | None:
mode_key = (mode or "LIVE").strip().upper()
if mode_key == "LIVE":
return state.get("last_sip_ts")
return state.get("last_run") or state.get("last_sip_ts")
def _pause_for_auth_expiry(
user_id: str,
run_id: str,
reason: str,
emit_event_cb=None,
):
def _op(cur, _conn):
now = datetime.utcnow().replace(tzinfo=timezone.utc)
cur.execute(
"""
UPDATE strategy_run
SET status = 'STOPPED',
stopped_at = %s,
meta = COALESCE(meta, '{}'::jsonb) || %s
WHERE user_id = %s AND run_id = %s
""",
(
now,
Json({"reason": "auth_expired", "lifecycle": "auth_expired"}),
user_id,
run_id,
),
)
run_with_retry(_op)
_set_state(
user_id,
run_id,
state="STOPPED",
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
)
_update_engine_status(user_id, run_id, "STOPPED")
log_event(
event="BROKER_AUTH_EXPIRED",
message="Broker authentication expired",
meta={"reason": reason},
)
log_event(
event="ENGINE_PAUSED",
message="Engine paused until broker reconnect",
meta={"reason": "auth_expired"},
)
if callable(emit_event_cb):
emit_event_cb(
event="BROKER_AUTH_EXPIRED",
message="Broker authentication expired",
meta={"reason": reason},
)
emit_event_cb(
event="STRATEGY_STOPPED",
message="Strategy stopped",
meta={"reason": "broker_auth_expired"},
)
def _engine_loop(config, stop_event: threading.Event): def _engine_loop(config, stop_event: threading.Event):
print("Strategy engine started with config:", config) print("Strategy engine started with config:", config)
@ -173,24 +239,12 @@ def _engine_loop(config, stop_event: threading.Event):
mode = (config.get("mode") or "LIVE").strip().upper() mode = (config.get("mode") or "LIVE").strip().upper()
if mode not in {"PAPER", "LIVE"}: if mode not in {"PAPER", "LIVE"}:
mode = "LIVE" mode = "LIVE"
broker_type = config.get("broker") or "paper" broker_type = (config.get("broker") or ("paper" if mode == "PAPER" else "zerodha")).strip().lower()
if broker_type != "paper": initial_cash = float(config.get("initial_cash", 0))
broker_type = "paper"
if broker_type == "paper": if broker_type == "paper":
mode = "PAPER" mode = "PAPER"
initial_cash = float(config.get("initial_cash", 0)) broker = PaperBroker(initial_cash=initial_cash)
broker = PaperBroker(initial_cash=initial_cash) log_event(
log_event(
event="DEBUG_PAPER_STORE_PATH",
message="Paper broker store path",
meta={
"cwd": os.getcwd(),
"paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH",
"abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A",
},
)
if emit_event_cb:
emit_event_cb(
event="DEBUG_PAPER_STORE_PATH", event="DEBUG_PAPER_STORE_PATH",
message="Paper broker store path", message="Paper broker store path",
meta={ meta={
@ -199,6 +253,21 @@ def _engine_loop(config, stop_event: threading.Event):
"abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A",
}, },
) )
if emit_event_cb:
emit_event_cb(
event="DEBUG_PAPER_STORE_PATH",
message="Paper broker store path",
meta={
"cwd": os.getcwd(),
"paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH",
"abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A",
},
)
elif broker_type == "zerodha":
broker = LiveZerodhaBroker(user_id=scope_user, run_id=scope_run)
else:
raise ValueError(f"Unsupported broker: {broker_type}")
market_data_provider = "yfinance"
log_event("ENGINE_START", { log_event("ENGINE_START", {
"strategy": strategy_name, "strategy": strategy_name,
@ -243,7 +312,7 @@ def _engine_loop(config, stop_event: threading.Event):
delta = timedelta(days=freq) delta = timedelta(days=freq)
# Gate 2: time to SIP # Gate 2: time to SIP
last_run = state.get("last_run") or state.get("last_sip_ts") last_run = _last_execution_anchor(state, mode)
is_first_run = last_run is None is_first_run = last_run is None
now = datetime.now() now = datetime.now()
debug_event( debug_event(
@ -282,21 +351,47 @@ def _engine_loop(config, stop_event: threading.Event):
try: try:
debug_event("PRICE_FETCH_START", "fetching live prices", {"tickers": [NIFTY, GOLD]}) debug_event("PRICE_FETCH_START", "fetching live prices", {"tickers": [NIFTY, GOLD]})
nifty_price = fetch_live_price(NIFTY) nifty_price = fetch_live_price(
gold_price = fetch_live_price(GOLD) NIFTY,
provider=market_data_provider,
user_id=scope_user,
run_id=scope_run,
)
gold_price = fetch_live_price(
GOLD,
provider=market_data_provider,
user_id=scope_user,
run_id=scope_run,
)
debug_event( debug_event(
"PRICE_FETCHED", "PRICE_FETCHED",
"fetched live prices", "fetched live prices",
{"nifty_price": float(nifty_price), "gold_price": float(gold_price)}, {"nifty_price": float(nifty_price), "gold_price": float(gold_price)},
) )
except KiteTokenError as exc:
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
break
except Exception as exc: except Exception as exc:
debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)}) debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)})
sleep_with_heartbeat(30, stop_event, scope_user, scope_run) sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
continue continue
try: try:
nifty_hist = load_monthly_close(NIFTY) nifty_hist = load_monthly_close(
gold_hist = load_monthly_close(GOLD) NIFTY,
provider=market_data_provider,
user_id=scope_user,
run_id=scope_run,
)
gold_hist = load_monthly_close(
GOLD,
provider=market_data_provider,
user_id=scope_user,
run_id=scope_run,
)
except KiteTokenError as exc:
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
break
except Exception as exc: except Exception as exc:
debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)}) debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)})
sleep_with_heartbeat(30, stop_event, scope_user, scope_run) sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
@ -449,6 +544,9 @@ def _engine_loop(config, stop_event: threading.Event):
) )
sleep_with_heartbeat(30, stop_event, scope_user, scope_run) sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
except BrokerAuthExpired as exc:
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True)
except Exception as e: except Exception as e:
_set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") _set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
_update_engine_status(scope_user, scope_run, "ERROR") _update_engine_status(scope_user, scope_run, "ERROR")