2026-02-01 13:57:30 +00:00

304 lines
10 KiB
Python

# engine/state.py
from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
DEFAULT_STATE = {
"initial_cash": 0.0,
"cash": 0.0,
"total_invested": 0.0,
"nifty_units": 0.0,
"gold_units": 0.0,
"last_sip_ts": None,
"last_run": None,
"sip_frequency": None,
}
DEFAULT_PAPER_STATE = {
**DEFAULT_STATE,
"initial_cash": 1_000_000.0,
"cash": 1_000_000.0,
"sip_frequency": {"value": 30, "unit": "days"},
}
def _state_key(mode: str | None):
key = (mode or "LIVE").strip().upper()
return "PAPER" if key == "PAPER" else "LIVE"
def _default_state(mode: str | None):
if _state_key(mode) == "PAPER":
return DEFAULT_PAPER_STATE.copy()
return DEFAULT_STATE.copy()
def _local_tz():
return datetime.now().astimezone().tzinfo
def _format_local_ts(value: datetime | None):
if value is None:
return None
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
def _parse_ts(value):
if value is None:
return None
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=_local_tz())
return value
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=_local_tz())
return parsed
return None
def _resolve_scope(user_id: str | None, run_id: str | None):
return get_context(user_id, run_id)
def load_state(
mode: str | None = "LIVE",
*,
cur=None,
for_update: bool = False,
user_id: str | None = None,
run_id: str | None = None,
):
scope_user, scope_run = _resolve_scope(user_id, run_id)
key = _state_key(mode)
if key == "PAPER":
if cur is None:
with db_connection() as conn:
with conn.cursor() as cur:
return load_state(
mode=mode,
cur=cur,
for_update=for_update,
user_id=scope_user,
run_id=scope_run,
)
lock_clause = " FOR UPDATE" if for_update else ""
cur.execute(
f"""
SELECT initial_cash, cash, total_invested, nifty_units, gold_units,
last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit
FROM engine_state_paper
WHERE user_id = %s AND run_id = %s{lock_clause}
LIMIT 1
""",
(scope_user, scope_run),
)
row = cur.fetchone()
if not row:
return _default_state(mode)
merged = _default_state(mode)
merged.update(
{
"initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"],
"cash": float(row[1]) if row[1] is not None else merged["cash"],
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
"last_sip_ts": _format_local_ts(row[5]),
"last_run": _format_local_ts(row[6]),
}
)
if row[7] is not None or row[8] is not None:
merged["sip_frequency"] = {"value": row[7], "unit": row[8]}
return merged
if cur is None:
with db_connection() as conn:
with conn.cursor() as cur:
return load_state(
mode=mode,
cur=cur,
for_update=for_update,
user_id=scope_user,
run_id=scope_run,
)
lock_clause = " FOR UPDATE" if for_update else ""
cur.execute(
f"""
SELECT total_invested, nifty_units, gold_units, last_sip_ts, last_run
FROM engine_state
WHERE user_id = %s AND run_id = %s{lock_clause}
LIMIT 1
""",
(scope_user, scope_run),
)
row = cur.fetchone()
if not row:
return _default_state(mode)
merged = _default_state(mode)
merged.update(
{
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
"last_sip_ts": _format_local_ts(row[3]),
"last_run": _format_local_ts(row[4]),
}
)
return merged
def init_paper_state(
initial_cash: float,
sip_frequency: dict | None = None,
*,
cur=None,
user_id: str | None = None,
run_id: str | None = None,
):
state = DEFAULT_PAPER_STATE.copy()
state.update(
{
"initial_cash": float(initial_cash),
"cash": float(initial_cash),
"total_invested": 0.0,
"nifty_units": 0.0,
"gold_units": 0.0,
"last_sip_ts": None,
"last_run": None,
"sip_frequency": sip_frequency or state.get("sip_frequency"),
}
)
save_state(state, mode="PAPER", cur=cur, emit_event=True, user_id=user_id, run_id=run_id)
return state
def save_state(
state,
mode: str | None = "LIVE",
*,
cur=None,
emit_event: bool = False,
event_meta: dict | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
scope_user, scope_run = _resolve_scope(user_id, run_id)
key = _state_key(mode)
last_sip_ts = _parse_ts(state.get("last_sip_ts"))
last_run = _parse_ts(state.get("last_run"))
if key == "PAPER":
sip_frequency = state.get("sip_frequency")
sip_value = None
sip_unit = None
if isinstance(sip_frequency, dict):
sip_value = sip_frequency.get("value")
sip_unit = sip_frequency.get("unit")
def _save(cur):
cur.execute(
"""
INSERT INTO engine_state_paper (
user_id, run_id, initial_cash, cash, total_invested, nifty_units, gold_units,
last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (user_id, run_id) DO UPDATE
SET initial_cash = EXCLUDED.initial_cash,
cash = EXCLUDED.cash,
total_invested = EXCLUDED.total_invested,
nifty_units = EXCLUDED.nifty_units,
gold_units = EXCLUDED.gold_units,
last_sip_ts = EXCLUDED.last_sip_ts,
last_run = EXCLUDED.last_run,
sip_frequency_value = EXCLUDED.sip_frequency_value,
sip_frequency_unit = EXCLUDED.sip_frequency_unit
""",
(
scope_user,
scope_run,
float(state.get("initial_cash", 0.0)),
float(state.get("cash", 0.0)),
float(state.get("total_invested", 0.0)),
float(state.get("nifty_units", 0.0)),
float(state.get("gold_units", 0.0)),
last_sip_ts,
last_run,
sip_value,
sip_unit,
),
)
if emit_event:
insert_engine_event(
cur,
"STATE_UPDATED",
data={
"mode": "PAPER",
"cash": state.get("cash"),
"total_invested": state.get("total_invested"),
"nifty_units": state.get("nifty_units"),
"gold_units": state.get("gold_units"),
"last_sip_ts": state.get("last_sip_ts"),
"last_run": state.get("last_run"),
},
meta=event_meta,
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
)
if cur is not None:
_save(cur)
return
def _op(cur, _conn):
_save(cur)
return run_with_retry(_op)
def _save(cur):
cur.execute(
"""
INSERT INTO engine_state (
user_id, run_id, total_invested, nifty_units, gold_units, last_sip_ts, last_run
)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (user_id, run_id) DO UPDATE
SET total_invested = EXCLUDED.total_invested,
nifty_units = EXCLUDED.nifty_units,
gold_units = EXCLUDED.gold_units,
last_sip_ts = EXCLUDED.last_sip_ts,
last_run = EXCLUDED.last_run
""",
(
scope_user,
scope_run,
float(state.get("total_invested", 0.0)),
float(state.get("nifty_units", 0.0)),
float(state.get("gold_units", 0.0)),
last_sip_ts,
last_run,
),
)
if emit_event:
insert_engine_event(
cur,
"STATE_UPDATED",
data={
"mode": "LIVE",
"total_invested": state.get("total_invested"),
"nifty_units": state.get("nifty_units"),
"gold_units": state.get("gold_units"),
"last_sip_ts": state.get("last_sip_ts"),
"last_run": state.get("last_run"),
},
meta=event_meta,
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
)
if cur is not None:
_save(cur)
return
def _op(cur, _conn):
_save(cur)
return run_with_retry(_op)