304 lines
10 KiB
Python
304 lines
10 KiB
Python
# engine/state.py
|
|
from datetime import datetime, timezone
|
|
|
|
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
|
|
|
|
DEFAULT_STATE = {
|
|
"initial_cash": 0.0,
|
|
"cash": 0.0,
|
|
"total_invested": 0.0,
|
|
"nifty_units": 0.0,
|
|
"gold_units": 0.0,
|
|
"last_sip_ts": None,
|
|
"last_run": None,
|
|
"sip_frequency": None,
|
|
}
|
|
|
|
DEFAULT_PAPER_STATE = {
|
|
**DEFAULT_STATE,
|
|
"initial_cash": 1_000_000.0,
|
|
"cash": 1_000_000.0,
|
|
"sip_frequency": {"value": 30, "unit": "days"},
|
|
}
|
|
|
|
def _state_key(mode: str | None):
|
|
key = (mode or "LIVE").strip().upper()
|
|
return "PAPER" if key == "PAPER" else "LIVE"
|
|
|
|
def _default_state(mode: str | None):
|
|
if _state_key(mode) == "PAPER":
|
|
return DEFAULT_PAPER_STATE.copy()
|
|
return DEFAULT_STATE.copy()
|
|
|
|
def _local_tz():
|
|
return datetime.now().astimezone().tzinfo
|
|
|
|
def _format_local_ts(value: datetime | None):
|
|
if value is None:
|
|
return None
|
|
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
|
|
|
|
def _parse_ts(value):
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, datetime):
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=_local_tz())
|
|
return value
|
|
if isinstance(value, str):
|
|
text = value.strip()
|
|
if not text:
|
|
return None
|
|
try:
|
|
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
|
|
except ValueError:
|
|
return None
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=_local_tz())
|
|
return parsed
|
|
return None
|
|
|
|
def _resolve_scope(user_id: str | None, run_id: str | None):
|
|
return get_context(user_id, run_id)
|
|
|
|
|
|
def load_state(
|
|
mode: str | None = "LIVE",
|
|
*,
|
|
cur=None,
|
|
for_update: bool = False,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
scope_user, scope_run = _resolve_scope(user_id, run_id)
|
|
key = _state_key(mode)
|
|
if key == "PAPER":
|
|
if cur is None:
|
|
with db_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
return load_state(
|
|
mode=mode,
|
|
cur=cur,
|
|
for_update=for_update,
|
|
user_id=scope_user,
|
|
run_id=scope_run,
|
|
)
|
|
lock_clause = " FOR UPDATE" if for_update else ""
|
|
cur.execute(
|
|
f"""
|
|
SELECT initial_cash, cash, total_invested, nifty_units, gold_units,
|
|
last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit
|
|
FROM engine_state_paper
|
|
WHERE user_id = %s AND run_id = %s{lock_clause}
|
|
LIMIT 1
|
|
""",
|
|
(scope_user, scope_run),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return _default_state(mode)
|
|
merged = _default_state(mode)
|
|
merged.update(
|
|
{
|
|
"initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"],
|
|
"cash": float(row[1]) if row[1] is not None else merged["cash"],
|
|
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
|
|
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
|
|
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
|
|
"last_sip_ts": _format_local_ts(row[5]),
|
|
"last_run": _format_local_ts(row[6]),
|
|
}
|
|
)
|
|
if row[7] is not None or row[8] is not None:
|
|
merged["sip_frequency"] = {"value": row[7], "unit": row[8]}
|
|
return merged
|
|
|
|
if cur is None:
|
|
with db_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
return load_state(
|
|
mode=mode,
|
|
cur=cur,
|
|
for_update=for_update,
|
|
user_id=scope_user,
|
|
run_id=scope_run,
|
|
)
|
|
lock_clause = " FOR UPDATE" if for_update else ""
|
|
cur.execute(
|
|
f"""
|
|
SELECT total_invested, nifty_units, gold_units, last_sip_ts, last_run
|
|
FROM engine_state
|
|
WHERE user_id = %s AND run_id = %s{lock_clause}
|
|
LIMIT 1
|
|
""",
|
|
(scope_user, scope_run),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return _default_state(mode)
|
|
merged = _default_state(mode)
|
|
merged.update(
|
|
{
|
|
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
|
|
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
|
|
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
|
|
"last_sip_ts": _format_local_ts(row[3]),
|
|
"last_run": _format_local_ts(row[4]),
|
|
}
|
|
)
|
|
return merged
|
|
|
|
def init_paper_state(
|
|
initial_cash: float,
|
|
sip_frequency: dict | None = None,
|
|
*,
|
|
cur=None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
state = DEFAULT_PAPER_STATE.copy()
|
|
state.update(
|
|
{
|
|
"initial_cash": float(initial_cash),
|
|
"cash": float(initial_cash),
|
|
"total_invested": 0.0,
|
|
"nifty_units": 0.0,
|
|
"gold_units": 0.0,
|
|
"last_sip_ts": None,
|
|
"last_run": None,
|
|
"sip_frequency": sip_frequency or state.get("sip_frequency"),
|
|
}
|
|
)
|
|
save_state(state, mode="PAPER", cur=cur, emit_event=True, user_id=user_id, run_id=run_id)
|
|
return state
|
|
|
|
def save_state(
|
|
state,
|
|
mode: str | None = "LIVE",
|
|
*,
|
|
cur=None,
|
|
emit_event: bool = False,
|
|
event_meta: dict | None = None,
|
|
user_id: str | None = None,
|
|
run_id: str | None = None,
|
|
):
|
|
scope_user, scope_run = _resolve_scope(user_id, run_id)
|
|
key = _state_key(mode)
|
|
last_sip_ts = _parse_ts(state.get("last_sip_ts"))
|
|
last_run = _parse_ts(state.get("last_run"))
|
|
if key == "PAPER":
|
|
sip_frequency = state.get("sip_frequency")
|
|
sip_value = None
|
|
sip_unit = None
|
|
if isinstance(sip_frequency, dict):
|
|
sip_value = sip_frequency.get("value")
|
|
sip_unit = sip_frequency.get("unit")
|
|
def _save(cur):
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO engine_state_paper (
|
|
user_id, run_id, initial_cash, cash, total_invested, nifty_units, gold_units,
|
|
last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (user_id, run_id) DO UPDATE
|
|
SET initial_cash = EXCLUDED.initial_cash,
|
|
cash = EXCLUDED.cash,
|
|
total_invested = EXCLUDED.total_invested,
|
|
nifty_units = EXCLUDED.nifty_units,
|
|
gold_units = EXCLUDED.gold_units,
|
|
last_sip_ts = EXCLUDED.last_sip_ts,
|
|
last_run = EXCLUDED.last_run,
|
|
sip_frequency_value = EXCLUDED.sip_frequency_value,
|
|
sip_frequency_unit = EXCLUDED.sip_frequency_unit
|
|
""",
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
float(state.get("initial_cash", 0.0)),
|
|
float(state.get("cash", 0.0)),
|
|
float(state.get("total_invested", 0.0)),
|
|
float(state.get("nifty_units", 0.0)),
|
|
float(state.get("gold_units", 0.0)),
|
|
last_sip_ts,
|
|
last_run,
|
|
sip_value,
|
|
sip_unit,
|
|
),
|
|
)
|
|
if emit_event:
|
|
insert_engine_event(
|
|
cur,
|
|
"STATE_UPDATED",
|
|
data={
|
|
"mode": "PAPER",
|
|
"cash": state.get("cash"),
|
|
"total_invested": state.get("total_invested"),
|
|
"nifty_units": state.get("nifty_units"),
|
|
"gold_units": state.get("gold_units"),
|
|
"last_sip_ts": state.get("last_sip_ts"),
|
|
"last_run": state.get("last_run"),
|
|
},
|
|
meta=event_meta,
|
|
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
|
|
)
|
|
|
|
if cur is not None:
|
|
_save(cur)
|
|
return
|
|
|
|
def _op(cur, _conn):
|
|
_save(cur)
|
|
|
|
return run_with_retry(_op)
|
|
|
|
def _save(cur):
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO engine_state (
|
|
user_id, run_id, total_invested, nifty_units, gold_units, last_sip_ts, last_run
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (user_id, run_id) DO UPDATE
|
|
SET total_invested = EXCLUDED.total_invested,
|
|
nifty_units = EXCLUDED.nifty_units,
|
|
gold_units = EXCLUDED.gold_units,
|
|
last_sip_ts = EXCLUDED.last_sip_ts,
|
|
last_run = EXCLUDED.last_run
|
|
""",
|
|
(
|
|
scope_user,
|
|
scope_run,
|
|
float(state.get("total_invested", 0.0)),
|
|
float(state.get("nifty_units", 0.0)),
|
|
float(state.get("gold_units", 0.0)),
|
|
last_sip_ts,
|
|
last_run,
|
|
),
|
|
)
|
|
if emit_event:
|
|
insert_engine_event(
|
|
cur,
|
|
"STATE_UPDATED",
|
|
data={
|
|
"mode": "LIVE",
|
|
"total_invested": state.get("total_invested"),
|
|
"nifty_units": state.get("nifty_units"),
|
|
"gold_units": state.get("gold_units"),
|
|
"last_sip_ts": state.get("last_sip_ts"),
|
|
"last_run": state.get("last_run"),
|
|
},
|
|
meta=event_meta,
|
|
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
|
|
)
|
|
|
|
if cur is not None:
|
|
_save(cur)
|
|
return
|
|
|
|
def _op(cur, _conn):
|
|
_save(cur)
|
|
|
|
return run_with_retry(_op)
|
|
|