# engine/state.py from datetime import datetime, timezone from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context DEFAULT_STATE = { "initial_cash": 0.0, "cash": 0.0, "total_invested": 0.0, "nifty_units": 0.0, "gold_units": 0.0, "last_sip_ts": None, "last_run": None, "sip_frequency": None, } DEFAULT_PAPER_STATE = { **DEFAULT_STATE, "initial_cash": 1_000_000.0, "cash": 1_000_000.0, "sip_frequency": {"value": 30, "unit": "days"}, } def _state_key(mode: str | None): key = (mode or "LIVE").strip().upper() return "PAPER" if key == "PAPER" else "LIVE" def _default_state(mode: str | None): if _state_key(mode) == "PAPER": return DEFAULT_PAPER_STATE.copy() return DEFAULT_STATE.copy() def _local_tz(): return datetime.now().astimezone().tzinfo def _format_local_ts(value: datetime | None): if value is None: return None return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() def _parse_ts(value): if value is None: return None if isinstance(value, datetime): if value.tzinfo is None: return value.replace(tzinfo=_local_tz()) return value if isinstance(value, str): text = value.strip() if not text: return None try: parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) except ValueError: return None if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=_local_tz()) return parsed return None def _resolve_scope(user_id: str | None, run_id: str | None): return get_context(user_id, run_id) def load_state( mode: str | None = "LIVE", *, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None, ): scope_user, scope_run = _resolve_scope(user_id, run_id) key = _state_key(mode) if key == "PAPER": if cur is None: with db_connection() as conn: with conn.cursor() as cur: return load_state( mode=mode, cur=cur, for_update=for_update, user_id=scope_user, run_id=scope_run, ) lock_clause = " FOR UPDATE" if for_update else "" cur.execute( f""" SELECT initial_cash, cash, total_invested, nifty_units, gold_units, last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit FROM engine_state_paper WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1 """, (scope_user, scope_run), ) row = cur.fetchone() if not row: return _default_state(mode) merged = _default_state(mode) merged.update( { "initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"], "cash": float(row[1]) if row[1] is not None else merged["cash"], "total_invested": float(row[2]) if row[2] is not None else merged["total_invested"], "nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"], "gold_units": float(row[4]) if row[4] is not None else merged["gold_units"], "last_sip_ts": _format_local_ts(row[5]), "last_run": _format_local_ts(row[6]), } ) if row[7] is not None or row[8] is not None: merged["sip_frequency"] = {"value": row[7], "unit": row[8]} return merged if cur is None: with db_connection() as conn: with conn.cursor() as cur: return load_state( mode=mode, cur=cur, for_update=for_update, user_id=scope_user, run_id=scope_run, ) lock_clause = " FOR UPDATE" if for_update else "" cur.execute( f""" SELECT total_invested, nifty_units, gold_units, last_sip_ts, last_run FROM engine_state WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1 """, (scope_user, scope_run), ) row = cur.fetchone() if not row: return _default_state(mode) merged = _default_state(mode) merged.update( { "total_invested": float(row[0]) if row[0] is not None else merged["total_invested"], "nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"], "gold_units": float(row[2]) if row[2] is not None else merged["gold_units"], "last_sip_ts": _format_local_ts(row[3]), "last_run": _format_local_ts(row[4]), } ) return merged def init_paper_state( initial_cash: float, sip_frequency: dict | None = None, *, cur=None, user_id: str | None = None, run_id: str | None = None, ): state = DEFAULT_PAPER_STATE.copy() state.update( { "initial_cash": float(initial_cash), "cash": float(initial_cash), "total_invested": 0.0, "nifty_units": 0.0, "gold_units": 0.0, "last_sip_ts": None, "last_run": None, "sip_frequency": sip_frequency or state.get("sip_frequency"), } ) save_state(state, mode="PAPER", cur=cur, emit_event=True, user_id=user_id, run_id=run_id) return state def save_state( state, mode: str | None = "LIVE", *, cur=None, emit_event: bool = False, event_meta: dict | None = None, user_id: str | None = None, run_id: str | None = None, ): scope_user, scope_run = _resolve_scope(user_id, run_id) key = _state_key(mode) last_sip_ts = _parse_ts(state.get("last_sip_ts")) last_run = _parse_ts(state.get("last_run")) if key == "PAPER": sip_frequency = state.get("sip_frequency") sip_value = None sip_unit = None if isinstance(sip_frequency, dict): sip_value = sip_frequency.get("value") sip_unit = sip_frequency.get("unit") def _save(cur): cur.execute( """ INSERT INTO engine_state_paper ( user_id, run_id, initial_cash, cash, total_invested, nifty_units, gold_units, last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (user_id, run_id) DO UPDATE SET initial_cash = EXCLUDED.initial_cash, cash = EXCLUDED.cash, total_invested = EXCLUDED.total_invested, nifty_units = EXCLUDED.nifty_units, gold_units = EXCLUDED.gold_units, last_sip_ts = EXCLUDED.last_sip_ts, last_run = EXCLUDED.last_run, sip_frequency_value = EXCLUDED.sip_frequency_value, sip_frequency_unit = EXCLUDED.sip_frequency_unit """, ( scope_user, scope_run, float(state.get("initial_cash", 0.0)), float(state.get("cash", 0.0)), float(state.get("total_invested", 0.0)), float(state.get("nifty_units", 0.0)), float(state.get("gold_units", 0.0)), last_sip_ts, last_run, sip_value, sip_unit, ), ) if emit_event: insert_engine_event( cur, "STATE_UPDATED", data={ "mode": "PAPER", "cash": state.get("cash"), "total_invested": state.get("total_invested"), "nifty_units": state.get("nifty_units"), "gold_units": state.get("gold_units"), "last_sip_ts": state.get("last_sip_ts"), "last_run": state.get("last_run"), }, meta=event_meta, ts=datetime.utcnow().replace(tzinfo=timezone.utc), ) if cur is not None: _save(cur) return def _op(cur, _conn): _save(cur) return run_with_retry(_op) def _save(cur): cur.execute( """ INSERT INTO engine_state ( user_id, run_id, total_invested, nifty_units, gold_units, last_sip_ts, last_run ) VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (user_id, run_id) DO UPDATE SET total_invested = EXCLUDED.total_invested, nifty_units = EXCLUDED.nifty_units, gold_units = EXCLUDED.gold_units, last_sip_ts = EXCLUDED.last_sip_ts, last_run = EXCLUDED.last_run """, ( scope_user, scope_run, float(state.get("total_invested", 0.0)), float(state.get("nifty_units", 0.0)), float(state.get("gold_units", 0.0)), last_sip_ts, last_run, ), ) if emit_event: insert_engine_event( cur, "STATE_UPDATED", data={ "mode": "LIVE", "total_invested": state.get("total_invested"), "nifty_units": state.get("nifty_units"), "gold_units": state.get("gold_units"), "last_sip_ts": state.get("last_sip_ts"), "last_run": state.get("last_run"), }, meta=event_meta, ts=datetime.utcnow().replace(tzinfo=timezone.utc), ) if cur is not None: _save(cur) return def _op(cur, _conn): _save(cur) return run_with_retry(_op)