import os import threading import time from datetime import date, datetime, timedelta, timezone from decimal import Decimal from zoneinfo import ZoneInfo from app.services.db import db_connection from app.services.zerodha_service import ( KiteApiError, fetch_funds, fetch_holdings, holding_effective_quantity, holding_last_price, ) from app.services.zerodha_storage import get_session IST = ZoneInfo("Asia/Calcutta") AUTO_SNAPSHOT_AFTER_HOUR = int(os.getenv("LIVE_EQUITY_SNAPSHOT_HOUR", "15")) AUTO_SNAPSHOT_AFTER_MINUTE = int(os.getenv("LIVE_EQUITY_SNAPSHOT_MINUTE", "35")) AUTO_SNAPSHOT_INTERVAL_SEC = int(os.getenv("LIVE_EQUITY_SNAPSHOT_INTERVAL_SEC", "1800")) _SNAPSHOT_THREAD = None _SNAPSHOT_LOCK = threading.Lock() _LAST_AUTO_SNAPSHOT_DATE: date | None = None def _now_utc() -> datetime: return datetime.now(timezone.utc) def _now_ist() -> datetime: return _now_utc().astimezone(IST) def _snapshot_day(ts: datetime) -> date: return ts.astimezone(IST).date() def _first_numeric(*values, default: float = 0.0) -> float: for value in values: try: if value is None or value == "": continue return float(value) except (TypeError, ValueError): continue return float(default) def _extract_cash_value(funds_data: dict | None) -> float: equity = funds_data.get("equity", {}) if isinstance(funds_data, dict) else {} available = equity.get("available", {}) if isinstance(equity, dict) else {} return _first_numeric( equity.get("balance") if isinstance(equity, dict) else None, equity.get("net") if isinstance(equity, dict) else None, equity.get("withdrawable") if isinstance(equity, dict) else None, equity.get("cash") if isinstance(equity, dict) else None, available.get("live_balance") if isinstance(available, dict) else None, available.get("opening_balance") if isinstance(available, dict) else None, available.get("cash") if isinstance(available, dict) else None, default=0.0, ) def _extract_holdings_value(holdings: list[dict] | None) -> float: total = 0.0 for item in holdings or []: qty = holding_effective_quantity(item) last_price = holding_last_price(item) total += qty * last_price return total def _upsert_snapshot( *, user_id: str, snapshot_date: date, captured_at: datetime, cash_value: float, holdings_value: float, ): total_value = cash_value + holdings_value with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO live_equity_snapshot ( user_id, snapshot_date, captured_at, cash_value, holdings_value, total_value ) VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT (user_id, snapshot_date) DO UPDATE SET captured_at = EXCLUDED.captured_at, cash_value = EXCLUDED.cash_value, holdings_value = EXCLUDED.holdings_value, total_value = EXCLUDED.total_value """, ( user_id, snapshot_date, captured_at, Decimal(str(round(cash_value, 2))), Decimal(str(round(holdings_value, 2))), Decimal(str(round(total_value, 2))), ), ) return { "snapshotDate": snapshot_date.isoformat(), "capturedAt": captured_at.isoformat(), "cashValue": round(cash_value, 2), "holdingsValue": round(holdings_value, 2), "totalValue": round(total_value, 2), } def capture_live_equity_snapshot( user_id: str, *, holdings: list[dict] | None = None, funds_data: dict | None = None, captured_at: datetime | None = None, ): session = get_session(user_id) if not session: return None captured_at = captured_at or _now_utc() if holdings is None: holdings = fetch_holdings(session["api_key"], session["access_token"]) if funds_data is None: funds_data = fetch_funds(session["api_key"], session["access_token"]) cash_value = _extract_cash_value(funds_data) holdings_value = _extract_holdings_value(holdings) return _upsert_snapshot( user_id=user_id, snapshot_date=_snapshot_day(captured_at), captured_at=captured_at, cash_value=cash_value, holdings_value=holdings_value, ) def get_live_equity_curve(user_id: str, *, start_date: date | None = None): if start_date is None: start_date = _snapshot_day(_now_utc()) - timedelta(days=90) with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT snapshot_date, total_value FROM live_equity_snapshot WHERE user_id = %s AND snapshot_date >= %s ORDER BY snapshot_date ASC """, (user_id, start_date), ) rows = cur.fetchall() cur.execute( """ SELECT MIN(snapshot_date) FROM live_equity_snapshot WHERE user_id = %s """, (user_id,), ) first_row = cur.fetchone() points = [ {"date": row[0].isoformat(), "value": round(float(row[1] or 0), 2)} for row in rows ] first_snapshot = first_row[0].isoformat() if first_row and first_row[0] else None return { "startDate": start_date.isoformat(), "endDate": _now_utc().isoformat(), "exactFrom": first_snapshot, "points": points, } def _list_connected_zerodha_users() -> list[str]: with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT user_id FROM user_broker WHERE connected = TRUE AND UPPER(COALESCE(broker, '')) = 'ZERODHA' """ ) return [row[0] for row in cur.fetchall()] def _should_auto_snapshot(now_local: datetime) -> bool: if now_local.weekday() >= 5: return False snapshot_cutoff = now_local.replace( hour=AUTO_SNAPSHOT_AFTER_HOUR, minute=AUTO_SNAPSHOT_AFTER_MINUTE, second=0, microsecond=0, ) return now_local >= snapshot_cutoff def _run_auto_snapshot_cycle(): global _LAST_AUTO_SNAPSHOT_DATE now_local = _now_ist() today = now_local.date() if _LAST_AUTO_SNAPSHOT_DATE == today: return if not _should_auto_snapshot(now_local): return for user_id in _list_connected_zerodha_users(): try: capture_live_equity_snapshot(user_id) except KiteApiError: continue except Exception: continue _LAST_AUTO_SNAPSHOT_DATE = today def _snapshot_loop(): while True: try: _run_auto_snapshot_cycle() except Exception: pass time.sleep(max(AUTO_SNAPSHOT_INTERVAL_SEC, 60)) def start_live_equity_snapshot_daemon(): global _SNAPSHOT_THREAD with _SNAPSHOT_LOCK: if _SNAPSHOT_THREAD and _SNAPSHOT_THREAD.is_alive(): return thread = threading.Thread( target=_snapshot_loop, name="live-equity-snapshot-daemon", daemon=True, ) thread.start() _SNAPSHOT_THREAD = thread