Add live equity snapshots and improve broker handling

This commit is contained in:
Thigazhezhilan J 2026-03-25 23:33:09 +05:30
parent c17222ad9c
commit 9770b7a338
8 changed files with 432 additions and 54 deletions

View File

@ -399,6 +399,27 @@ class PaperEquityCurve(Base):
)
class LiveEquitySnapshot(Base):
__tablename__ = "live_equity_snapshot"
user_id = Column(String, primary_key=True)
snapshot_date = Column(Date, primary_key=True)
captured_at = Column(DateTime(timezone=True), nullable=False)
cash_value = Column(Numeric, nullable=False)
holdings_value = Column(Numeric, nullable=False)
total_value = Column(Numeric, nullable=False)
__table_args__ = (
ForeignKeyConstraint(
["user_id"],
["app_user.id"],
ondelete="CASCADE",
),
Index("idx_live_equity_snapshot_captured_at", "captured_at"),
Index("idx_live_equity_snapshot_user_date", "user_id", "snapshot_date"),
)
class MTMLedger(Base):
__tablename__ = "mtm_ledger"

View File

@ -13,6 +13,7 @@ from app.routers.zerodha import router as zerodha_router, public_router as zerod
from app.routers.paper import router as paper_router
from market import router as market_router
from paper_mtm import router as paper_mtm_router
from app.services.live_equity_service import start_live_equity_snapshot_daemon
from app.services.strategy_service import init_log_state, resume_running_runs
from app.admin_router import router as admin_router
from app.admin_role_service import bootstrap_super_admin
@ -78,3 +79,4 @@ def init_app_state():
init_log_state()
bootstrap_super_admin()
resume_running_runs()
start_live_equity_snapshot_daemon()

View File

@ -5,6 +5,10 @@ from fastapi.responses import HTMLResponse
from app.broker_store import clear_user_broker
from app.services.auth_service import get_user_for_session
from app.services.live_equity_service import (
capture_live_equity_snapshot,
get_live_equity_curve,
)
from app.services.zerodha_service import (
KiteApiError,
KiteTokenError,
@ -12,6 +16,7 @@ from app.services.zerodha_service import (
exchange_request_token,
fetch_funds,
fetch_holdings,
normalize_holding,
)
from app.services.zerodha_storage import (
clear_session,
@ -135,7 +140,7 @@ async def holdings(request: Request):
data = fetch_holdings(session["api_key"], session["access_token"])
except KiteApiError as exc:
_raise_kite_error(user["id"], exc)
return {"holdings": data}
return {"holdings": [normalize_holding(item) for item in data]}
@router.get("/funds")
@ -165,48 +170,27 @@ async def equity_curve(request: Request, from_: str = Query("", alias="from")):
except KiteApiError as exc:
_raise_kite_error(user["id"], exc)
equity = funds_data.get("equity", {}) if isinstance(funds_data, dict) else {}
total_holdings_value = 0
for item in holdings:
qty = float(item.get("quantity") or item.get("qty") or 0)
last = float(item.get("last_price") or item.get("average_price") or 0)
total_holdings_value += qty * last
try:
capture_live_equity_snapshot(
user["id"],
holdings=holdings,
funds_data=funds_data,
)
except KiteApiError as exc:
_raise_kite_error(user["id"], exc)
total_funds = float(equity.get("cash") or 0)
current_value = max(0, total_holdings_value + total_funds)
ms_in_day = 86400000
now = datetime.utcnow()
default_start = now - timedelta(days=90)
default_start = (now - timedelta(days=90)).date()
if from_:
try:
start_date = datetime.fromisoformat(from_)
start_date = datetime.fromisoformat(from_).date()
except ValueError:
start_date = default_start
else:
start_date = default_start
if start_date > now:
start_date = now
span_days = max(
2,
int(((now - start_date).total_seconds() * 1000) // ms_in_day),
)
start_value = current_value * 0.85 if current_value > 0 else 10000
points = []
for i in range(span_days):
day = start_date + timedelta(days=i)
progress = i / (span_days - 1)
trend = start_value + (current_value - start_value) * progress
value = max(0, round(trend))
points.append({"date": day.isoformat(), "value": value})
return {
"startDate": start_date.isoformat(),
"endDate": now.isoformat(),
"accountOpenDate": session.get("linked_at"),
"points": points,
}
if start_date > now.date():
start_date = now.date()
return get_live_equity_curve(user["id"], start_date=start_date)
@router.get("/callback")

View File

@ -0,0 +1,256 @@
import os
import threading
import time
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from zoneinfo import ZoneInfo
from app.services.db import db_connection
from app.services.zerodha_service import (
KiteApiError,
fetch_funds,
fetch_holdings,
holding_effective_quantity,
holding_last_price,
)
from app.services.zerodha_storage import get_session
IST = ZoneInfo("Asia/Calcutta")
AUTO_SNAPSHOT_AFTER_HOUR = int(os.getenv("LIVE_EQUITY_SNAPSHOT_HOUR", "15"))
AUTO_SNAPSHOT_AFTER_MINUTE = int(os.getenv("LIVE_EQUITY_SNAPSHOT_MINUTE", "35"))
AUTO_SNAPSHOT_INTERVAL_SEC = int(os.getenv("LIVE_EQUITY_SNAPSHOT_INTERVAL_SEC", "1800"))
_SNAPSHOT_THREAD = None
_SNAPSHOT_LOCK = threading.Lock()
_LAST_AUTO_SNAPSHOT_DATE: date | None = None
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _now_ist() -> datetime:
return _now_utc().astimezone(IST)
def _snapshot_day(ts: datetime) -> date:
return ts.astimezone(IST).date()
def _first_numeric(*values, default: float = 0.0) -> float:
for value in values:
try:
if value is None or value == "":
continue
return float(value)
except (TypeError, ValueError):
continue
return float(default)
def _extract_cash_value(funds_data: dict | None) -> float:
equity = funds_data.get("equity", {}) if isinstance(funds_data, dict) else {}
available = equity.get("available", {}) if isinstance(equity, dict) else {}
return _first_numeric(
equity.get("balance") if isinstance(equity, dict) else None,
equity.get("net") if isinstance(equity, dict) else None,
equity.get("withdrawable") if isinstance(equity, dict) else None,
equity.get("cash") if isinstance(equity, dict) else None,
available.get("live_balance") if isinstance(available, dict) else None,
available.get("opening_balance") if isinstance(available, dict) else None,
available.get("cash") if isinstance(available, dict) else None,
default=0.0,
)
def _extract_holdings_value(holdings: list[dict] | None) -> float:
total = 0.0
for item in holdings or []:
qty = holding_effective_quantity(item)
last_price = holding_last_price(item)
total += qty * last_price
return total
def _upsert_snapshot(
*,
user_id: str,
snapshot_date: date,
captured_at: datetime,
cash_value: float,
holdings_value: float,
):
total_value = cash_value + holdings_value
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO live_equity_snapshot (
user_id,
snapshot_date,
captured_at,
cash_value,
holdings_value,
total_value
)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (user_id, snapshot_date) DO UPDATE
SET captured_at = EXCLUDED.captured_at,
cash_value = EXCLUDED.cash_value,
holdings_value = EXCLUDED.holdings_value,
total_value = EXCLUDED.total_value
""",
(
user_id,
snapshot_date,
captured_at,
Decimal(str(round(cash_value, 2))),
Decimal(str(round(holdings_value, 2))),
Decimal(str(round(total_value, 2))),
),
)
return {
"snapshotDate": snapshot_date.isoformat(),
"capturedAt": captured_at.isoformat(),
"cashValue": round(cash_value, 2),
"holdingsValue": round(holdings_value, 2),
"totalValue": round(total_value, 2),
}
def capture_live_equity_snapshot(
user_id: str,
*,
holdings: list[dict] | None = None,
funds_data: dict | None = None,
captured_at: datetime | None = None,
):
session = get_session(user_id)
if not session:
return None
captured_at = captured_at or _now_utc()
if holdings is None:
holdings = fetch_holdings(session["api_key"], session["access_token"])
if funds_data is None:
funds_data = fetch_funds(session["api_key"], session["access_token"])
cash_value = _extract_cash_value(funds_data)
holdings_value = _extract_holdings_value(holdings)
return _upsert_snapshot(
user_id=user_id,
snapshot_date=_snapshot_day(captured_at),
captured_at=captured_at,
cash_value=cash_value,
holdings_value=holdings_value,
)
def get_live_equity_curve(user_id: str, *, start_date: date | None = None):
if start_date is None:
start_date = _snapshot_day(_now_utc()) - timedelta(days=90)
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT snapshot_date, total_value
FROM live_equity_snapshot
WHERE user_id = %s
AND snapshot_date >= %s
ORDER BY snapshot_date ASC
""",
(user_id, start_date),
)
rows = cur.fetchall()
cur.execute(
"""
SELECT MIN(snapshot_date)
FROM live_equity_snapshot
WHERE user_id = %s
""",
(user_id,),
)
first_row = cur.fetchone()
points = [
{"date": row[0].isoformat(), "value": round(float(row[1] or 0), 2)}
for row in rows
]
first_snapshot = first_row[0].isoformat() if first_row and first_row[0] else None
return {
"startDate": start_date.isoformat(),
"endDate": _now_utc().isoformat(),
"exactFrom": first_snapshot,
"points": points,
}
def _list_connected_zerodha_users() -> list[str]:
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT user_id
FROM user_broker
WHERE connected = TRUE
AND UPPER(COALESCE(broker, '')) = 'ZERODHA'
"""
)
return [row[0] for row in cur.fetchall()]
def _should_auto_snapshot(now_local: datetime) -> bool:
if now_local.weekday() >= 5:
return False
snapshot_cutoff = now_local.replace(
hour=AUTO_SNAPSHOT_AFTER_HOUR,
minute=AUTO_SNAPSHOT_AFTER_MINUTE,
second=0,
microsecond=0,
)
return now_local >= snapshot_cutoff
def _run_auto_snapshot_cycle():
global _LAST_AUTO_SNAPSHOT_DATE
now_local = _now_ist()
today = now_local.date()
if _LAST_AUTO_SNAPSHOT_DATE == today:
return
if not _should_auto_snapshot(now_local):
return
for user_id in _list_connected_zerodha_users():
try:
capture_live_equity_snapshot(user_id)
except KiteApiError:
continue
except Exception:
continue
_LAST_AUTO_SNAPSHOT_DATE = today
def _snapshot_loop():
while True:
try:
_run_auto_snapshot_cycle()
except Exception:
pass
time.sleep(max(AUTO_SNAPSHOT_INTERVAL_SEC, 60))
def start_live_equity_snapshot_daemon():
global _SNAPSHOT_THREAD
with _SNAPSHOT_LOCK:
if _SNAPSHOT_THREAD and _SNAPSHOT_THREAD.is_alive():
return
thread = threading.Thread(
target=_snapshot_loop,
name="live-equity-snapshot-daemon",
daemon=True,
)
thread.start()
_SNAPSHOT_THREAD = thread

View File

@ -296,8 +296,6 @@ def validate_frequency(freq: dict, mode: str):
raise ValueError(f"Unsupported frequency unit: {unit}")
if unit == "minutes":
if mode != "PAPER":
raise ValueError("Minute-level frequency allowed only in PAPER mode")
if value < 1:
raise ValueError("Minimum frequency is 1 minute")

View File

@ -69,6 +69,71 @@ def _auth_headers(api_key: str, access_token: str) -> dict:
}
def _first_float(*values, default: float = 0.0) -> float:
for value in values:
try:
if value is None or value == "":
continue
return float(value)
except (TypeError, ValueError):
continue
return float(default)
def holding_settled_quantity(item: dict | None) -> float:
entry = item or {}
return _first_float(entry.get("quantity"), entry.get("qty"), default=0.0)
def holding_t1_quantity(item: dict | None) -> float:
entry = item or {}
return _first_float(entry.get("t1_quantity"), default=0.0)
def holding_effective_quantity(item: dict | None) -> float:
entry = item or {}
return holding_settled_quantity(entry) + holding_t1_quantity(entry)
def holding_average_price(item: dict | None) -> float:
entry = item or {}
return _first_float(entry.get("average_price"), entry.get("avg_price"), default=0.0)
def holding_last_price(item: dict | None) -> float:
entry = item or {}
return _first_float(
entry.get("last_price"),
entry.get("close_price"),
entry.get("average_price"),
entry.get("avg_price"),
default=0.0,
)
def holding_display_pnl(item: dict | None) -> float:
entry = item or {}
effective_qty = holding_effective_quantity(entry)
last_price = holding_last_price(entry)
avg_price = holding_average_price(entry)
return effective_qty * (last_price - avg_price)
def normalize_holding(item: dict | None) -> dict:
entry = dict(item or {})
settled_qty = holding_settled_quantity(entry)
t1_qty = holding_t1_quantity(entry)
effective_qty = settled_qty + t1_qty
last_price = holding_last_price(entry)
avg_price = holding_average_price(entry)
entry["settled_quantity"] = settled_qty
entry["t1_quantity"] = t1_qty
entry["effective_quantity"] = effective_qty
entry["display_pnl"] = effective_qty * (last_price - avg_price)
entry["holding_value"] = effective_qty * last_price
return entry
def exchange_request_token(api_key: str, api_secret: str, request_token: str) -> dict:
checksum = hashlib.sha256(
f"{api_key}{request_token}{api_secret}".encode("utf-8")

View File

@ -0,0 +1,51 @@
"""add_live_equity_snapshot
Revision ID: 8f4f3e6f0f41
Revises: 52abc790351d
Create Date: 2026-03-24 22:25:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "8f4f3e6f0f41"
down_revision: Union[str, None] = "52abc790351d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"live_equity_snapshot",
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("snapshot_date", sa.Date(), nullable=False),
sa.Column("captured_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("cash_value", sa.Numeric(), nullable=False),
sa.Column("holdings_value", sa.Numeric(), nullable=False),
sa.Column("total_value", sa.Numeric(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["app_user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "snapshot_date"),
)
op.create_index(
"idx_live_equity_snapshot_captured_at",
"live_equity_snapshot",
["captured_at"],
unique=False,
)
op.create_index(
"idx_live_equity_snapshot_user_date",
"live_equity_snapshot",
["user_id", "snapshot_date"],
unique=False,
)
def downgrade() -> None:
op.drop_index("idx_live_equity_snapshot_user_date", table_name="live_equity_snapshot")
op.drop_index("idx_live_equity_snapshot_captured_at", table_name="live_equity_snapshot")
op.drop_table("live_equity_snapshot")

View File

@ -321,15 +321,16 @@ def _engine_loop(config, stop_event: threading.Event):
{"now": now.isoformat(), "frequency": frequency_label},
)
if last_run and not is_first_run:
next_run = datetime.fromisoformat(last_run) + delta
next_run = align_to_market_open(next_run)
if now < next_run:
log_event(
event="SIP_WAITING",
message="Waiting for next SIP window",
meta={
"last_run": last_run,
if last_run and not is_first_run:
next_run = datetime.fromisoformat(last_run) + delta
next_run = align_to_market_open(next_run)
if now < next_run:
wait_seconds = 5 if unit == "minutes" else 60
log_event(
event="SIP_WAITING",
message="Waiting for next SIP window",
meta={
"last_run": last_run,
"next_eligible": next_run.isoformat(),
"now": now.isoformat(),
"frequency": frequency_label,
@ -339,14 +340,14 @@ def _engine_loop(config, stop_event: threading.Event):
emit_event_cb(
event="SIP_WAITING",
message="Waiting for next SIP window",
meta={
"last_run": last_run,
"next_eligible": next_run.isoformat(),
"now": now.isoformat(),
"frequency": frequency_label,
},
)
sleep_with_heartbeat(60, stop_event, scope_user, scope_run)
meta={
"last_run": last_run,
"next_eligible": next_run.isoformat(),
"now": now.isoformat(),
"frequency": frequency_label,
},
)
sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run)
continue
try: