commit 53be845b6e3affe63f0f77d26f116e77805cee49 Author: thigazhezhilan Date: Sun Feb 1 13:57:30 2026 +0000 Backend full repo clean diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3b3c3b --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# ===================== +# Node / JS +# ===================== +node_modules/ +npm-debug.log* +yarn.lock +pnpm-lock.yaml + +# ===================== +# Environment / Secrets +# ===================== +.env +.env.* +!.env.example + +# ===================== +# Build output +# ===================== +dist/ +build/ + +# ===================== +# Python / venv +# ===================== +__pycache__/ +*.pyc +.venv/ +venv/ +Lib/ +Scripts/ + +# ===================== +# Orchestration / temp +# ===================== +.orchestration/ +tmp/ +temp/ + +# OS / Editor +# ===================== +.vscode/ +.idea/ +.DS_Store +Thumbs.db + +# ===================== +# Logs +# ===================== +*.log +*.err +*.zip +.local/ diff --git a/ADMIN_DASHBOARD.md b/ADMIN_DASHBOARD.md new file mode 100644 index 0000000..2d853c7 --- /dev/null +++ b/ADMIN_DASHBOARD.md @@ -0,0 +1,40 @@ +# Admin Dashboard + +## Mark a user as admin + +In PostgreSQL: + +```sql +UPDATE app_user SET is_admin = true WHERE username = 'you@example.com'; +``` + +## Run migrations + +```powershell +Get-Content db_migrations\20260118_admin_rbac_views.sql | docker exec -i trading_postgres psql -U trader -d trading_db +``` + +## Run backend + +```powershell +cd backend +.\.venv\Scripts\python -m uvicorn app.main:app --reload --port 8000 +``` + +## Run frontend + +```powershell +cd frontend +npm install +npm run dev +``` + +## Open admin + +Visit: + +``` +http://localhost:3000/admin +``` + +Non-admin users will see Not Found. diff --git a/README_ORCHESTRATION.md b/README_ORCHESTRATION.md new file mode 100644 index 0000000..320f9a7 --- /dev/null +++ b/README_ORCHESTRATION.md @@ -0,0 +1,50 @@ +# Local Orchestration + +One-click scripts to run the full stack locally. + +## Windows (PowerShell) + +From repo root: + +``` +.\start_all.ps1 +``` + +Stop everything: + +``` +.\stop_all.ps1 +``` + +## Linux / macOS + +From repo root: + +``` +chmod +x start_all.sh stop_all.sh +./start_all.sh +``` + +Stop everything: + +``` +./stop_all.sh +``` + +## What the scripts do + +- Start PostgreSQL via Docker if not running +- Apply migrations (Alembic if present, otherwise `db_migrations/*.sql`) +- Start FastAPI backend +- Start engine runner +- Start React frontend +- Wait for `http://localhost:8000/health` +- Open `http://localhost:3000/admin` + +## Requirements + +- Docker Desktop +- Python venv at `.venv` +- Node + npm + +If you are using Alembic, place `alembic.ini` at repo root or `backend/alembic.ini`. diff --git a/README_test.md b/README_test.md new file mode 100644 index 0000000..756710f --- /dev/null +++ b/README_test.md @@ -0,0 +1,31 @@ +# Test Execution + +Prerequisites: +- Backend API running at `BASE_URL` (default `http://127.0.0.1:8000`) +- PostgreSQL reachable via `DB_DSN` + +Environment: +- `BASE_URL` (optional) +- `DB_DSN` (optional) + +Install deps: + +``` +pip install pytest requests psycopg2-binary +``` + +Run all tests: + +``` +pytest -q +``` + +Run a subset: + +``` +pytest -q tests/db_invariants +pytest -q tests/e2e_api +pytest -q tests/e2e_engine +pytest -q tests/concurrency +pytest -q tests/failure_injection +``` diff --git a/SYSTEM_ARM.md b/SYSTEM_ARM.md new file mode 100644 index 0000000..b2f1815 --- /dev/null +++ b/SYSTEM_ARM.md @@ -0,0 +1,27 @@ +# System Arm + +## Daily Login +- Zerodha Kite access tokens expire daily. +- Users must complete a broker login once per trading day. +- Use `/api/broker/login` to start the login flow. + +## Arm Flow +1) User logs in to Zerodha. +2) UI calls `POST /api/system/arm`. +3) Backend validates broker session and arms all active runs. +4) Scheduler resumes from the latest committed state and starts execution. + +## Failure States +- Broker auth expired: `POST /api/system/arm` returns 401 with `redirect_url`. +- Run status `ERROR`: skipped and returned in `failed_runs`. +- Missing broker credentials: `/api/broker/login` returns 400. + +## Recovery +- Reconnect broker via `/api/broker/login`. +- Reset runs in `ERROR` (admin or manual reset), then re-arm. +- Re-run `POST /api/system/arm` to resume. + +## Determinism Guarantees +- Arm is idempotent: already `RUNNING` runs are not re-written. +- Event ledger uses logical time uniqueness to prevent duplicate events. +- Next execution is computed from stored strategy frequency and latest state. diff --git a/TEST_PLAN.md b/TEST_PLAN.md new file mode 100644 index 0000000..91eea49 --- /dev/null +++ b/TEST_PLAN.md @@ -0,0 +1,154 @@ +# Production Test Plan: Multi-User, Multi-Run Trading Engine + +## Scope +- Backend API, engine, broker, DB constraints/triggers +- No UI tests + +## Environment +- `BASE_URL` for API +- `DB_DSN` for PostgreSQL + +## Test Case Matrix + +| ID | Category | Steps | Expected Result | DB Assertions | +|---|---|---|---|---| +| RL-01 | Lifecycle | Create RUNNING run; insert engine_state | Insert succeeds | engine_state row exists | +| RL-02 | Lifecycle | STOPPED run; insert engine_state | Insert rejected | no new rows | +| RL-03 | Lifecycle | ERROR run; insert engine_state | Insert rejected | no new rows | +| RL-04 | Lifecycle | STOPPED -> RUNNING | Update rejected | status unchanged | +| RL-05 | Lifecycle | Insert 2nd RUNNING run same user | Insert rejected | unique violation | +| RL-06 | Lifecycle | Insert row with user_id/run_id mismatch | Insert rejected | FK violation | +| RL-07 | Lifecycle | Delete strategy_run | Cascades delete | no orphans | +| RL-08 | Lifecycle | Start run via API | status RUNNING | engine_status row | +| RL-09 | Lifecycle | Stop run via API | status STOPPED | no new writes allowed | +| OR-01 | Orders | place market BUY | FILLED | order+trade rows | +| OR-02 | Orders | place order qty=0 | REJECTED | no trade | +| OR-03 | Orders | place order qty<0 | REJECTED | no trade | +| OR-04 | Orders | cash < cost | REJECTED | no trade | +| OR-05 | Orders | limit order below price | PENDING | no trade | +| OR-06 | Orders | stop run | pending orders canceled | status updated | +| MT-01 | Market time | market closed SIP | no execution | no order/trade | +| MT-02 | Market time | scheduled SIP | executes once | event_ledger row | +| MT-03 | Market time | repeat same logical_time | no dup | unique constraint | +| MT-04 | Market time | rebalance once | one event | event_ledger | +| IR-01 | Idempotency | tick T twice | no duplicate rows | counts unchanged | +| IR-02 | Idempotency | crash mid-tx | rollback | no partial rows | +| IR-03 | Idempotency | retry same tick | exactly-once | counts unchanged | +| RK-01 | Risk | insufficient cash | reject | no trade | +| RK-02 | Risk | max position | reject | no trade | +| RK-03 | Risk | forbidden symbol | reject | no trade | +| LG-01 | Ledger | order->trade->position | consistent | FK + counts | +| LG-02 | Ledger | equity = cash + mtm | matches | tolerance | +| MU-01 | Multi-user | two users run | isolated | no cross rows | +| MU-02 | Multi-run | same user two runs | isolated | run_id separation | +| DR-01 | Determinism | replay same feed | identical outputs | diff=0 | +| CC-01 | Concurrency | tick vs stop | no partial writes | consistent | +| CC-02 | Concurrency | two ticks same time | dedupe | unique constraint | +| FI-01 | Failure | injected error | rollback | no partial rows | + +## Detailed Cases (Mandatory) + +### Run Lifecycle & State Integrity +- RL-02 STOPPED rejects writes + - Steps: create STOPPED run; attempt inserts into engine_state, engine_status, paper_order, paper_trade, mtm_ledger, event_ledger, paper_equity_curve + - Expected: each insert fails + - DB assertions: no new rows for the run +- RL-03 ERROR rejects writes + - Steps: create ERROR run; attempt same inserts as RL-02 + - Expected: each insert fails + - DB assertions: no new rows for the run +- RL-04 STOPPED cannot revive + - Steps: create STOPPED run; UPDATE status=RUNNING + - Expected: update rejected + - DB assertions: status remains STOPPED +- RL-05 One RUNNING run per user + - Steps: insert RUNNING run; insert another RUNNING run for same user + - Expected: unique violation + - DB assertions: only one RUNNING row +- RL-06 Composite FK enforcement + - Steps: create run for user A and run for user B; attempt insert with user A + run B + - Expected: FK violation + - DB assertions: no inserted row +- RL-07 Cascades + - Steps: insert run and child rows; delete strategy_run + - Expected: child rows removed + - DB assertions: zero rows for run_id in children + +### Order Execution Semantics +- OR-01 Market order fill + - Steps: RUNNING run; place market BUY via broker tick + - Expected: order FILLED, trade created, position updated, equity updated + - DB assertions: rows in paper_order, paper_trade, paper_position, paper_equity_curve +- OR-10/11 Reject invalid qty + - Steps: place order qty=0 and qty<0 + - Expected: REJECTED + - DB assertions: no trade rows for that run +- OR-12 Insufficient cash + - Steps: initial_cash too low; place buy above cash + - Expected: REJECTED + - DB assertions: no trade rows for that run + +### Market Time & Session Logic +- MT-01 Market closed SIP + - Steps: call try_execute_sip with market_open=False + - Expected: no order/trade + - DB assertions: counts unchanged +- MT-03 Idempotent logical_time + - Steps: execute tick twice at same logical_time + - Expected: no duplicate orders/trades/mtm/equity/events + - DB assertions: unique logical_time counts <= 1 + +### Engine Idempotency & Crash Recovery +- IR-02 Crash mid-transaction + - Steps: open transaction, write state, raise exception + - Expected: rollback, no partial rows + - DB assertions: counts unchanged +- IR-03 Retry same tick + - Steps: rerun tick for same logical_time + - Expected: exactly-once side effects + - DB assertions: no duplicates + +### Multi-User / Multi-Run Isolation +- MU-01 Two users, isolated runs + - Steps: create two users + runs; execute tick for each + - Expected: separate data + - DB assertions: rows scoped by (user_id, run_id) +- MU-02 Same user, two runs + - Steps: run A tick; stop; run B tick + - Expected: separate data + - DB assertions: counts separate per run_id + +### Deterministic Replay +- DR-01 Replay determinism + - Steps: run N ticks with fixed feed for run A and run B + - Expected: identical ledgers (excluding surrogate IDs) + - DB assertions: normalized rows equal + +### Concurrency & Race Conditions +- CC-01 Tick vs Stop race + - Steps: lock run row in thread A; update status in thread B + - Expected: consistent final status; no partial writes + - DB assertions: no half-written rows +- CC-02 Two ticks same logical_time + - Steps: run two ticks concurrently with same logical_time + - Expected: dedupe + - DB assertions: SIP_EXECUTED count <= 1 + +### Failure Injection +- FI-01 Crash after trade before MTM + - Steps: execute SIP (orders/trades), then log MTM in separate tx + - Expected: no MTM before explicit insert; MTM inserts once when run + - DB assertions: MTM logical_time count == 1 + +## DB Assertions (Post-Run) + +- Orphan check: + - `SELECT COUNT(*) FROM engine_state es LEFT JOIN strategy_run sr ON sr.user_id=es.user_id AND sr.run_id=es.run_id WHERE sr.run_id IS NULL;` +- Dedupe check: + - `SELECT user_id, run_id, logical_time, COUNT(*) FROM mtm_ledger GROUP BY user_id, run_id, logical_time HAVING COUNT(*) > 1;` +- Lifecycle check: + - `SELECT user_id, COUNT(*) FROM strategy_run WHERE status='RUNNING' GROUP BY user_id HAVING COUNT(*) > 1;` + +## Test Execution + +Use `pytest -q` from repo root. See `README_test.md`. diff --git a/backend/README.md b/backend/README.md new file mode 100644 index 0000000..71b5b09 --- /dev/null +++ b/backend/README.md @@ -0,0 +1 @@ +Control plane API skeleton. diff --git a/backend/alembic.ini b/backend/alembic.ini new file mode 100644 index 0000000..14e63b9 --- /dev/null +++ b/backend/alembic.ini @@ -0,0 +1,117 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/admin_auth.py b/backend/app/admin_auth.py new file mode 100644 index 0000000..f1d4646 --- /dev/null +++ b/backend/app/admin_auth.py @@ -0,0 +1,71 @@ +from fastapi import HTTPException, Request + +from app.services.auth_service import get_user_for_session +from app.services.db import db_connection + +SESSION_COOKIE_NAME = "session_id" + + +def _resolve_role(row) -> str: + role = row[2] + if role: + return role + if row[4]: + return "SUPER_ADMIN" + if row[3]: + return "ADMIN" + return "USER" + + +def require_admin(request: Request): + session_id = request.cookies.get(SESSION_COOKIE_NAME) + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username, role, is_admin, is_super_admin FROM app_user WHERE id = %s", + (user["id"],), + ) + row = cur.fetchone() + if not row: + raise HTTPException(status_code=403, detail="Admin access required") + role = _resolve_role(row) + if role not in ("ADMIN", "SUPER_ADMIN"): + raise HTTPException(status_code=403, detail="Admin access required") + return { + "id": row[0], + "username": row[1], + "role": role, + } + + +def require_super_admin(request: Request): + session_id = request.cookies.get(SESSION_COOKIE_NAME) + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username, role, is_admin, is_super_admin FROM app_user WHERE id = %s", + (user["id"],), + ) + row = cur.fetchone() + if not row: + raise HTTPException(status_code=403, detail="Super admin access required") + role = _resolve_role(row) + if role != "SUPER_ADMIN": + raise HTTPException(status_code=403, detail="Super admin access required") + return { + "id": row[0], + "username": row[1], + "role": role, + } diff --git a/backend/app/admin_models.py b/backend/app/admin_models.py new file mode 100644 index 0000000..6c9ca86 --- /dev/null +++ b/backend/app/admin_models.py @@ -0,0 +1,163 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel + + +class TopError(BaseModel): + ts: Optional[datetime] + event: str + message: Optional[str] + source: str + user_id: Optional[str] + run_id: Optional[str] + + +class OverviewResponse(BaseModel): + total_users: int + users_logged_in_last_24h: int + total_runs: int + running_runs: int + stopped_runs: int + error_runs: int + live_runs_count: int + paper_runs_count: int + orders_last_24h: int + trades_last_24h: int + sip_executed_last_24h: int + top_errors: list[TopError] + + +class UserSummary(BaseModel): + user_id: str + username: str + role: str + is_admin: bool + created_at: Optional[datetime] + last_login_at: Optional[datetime] + active_run_id: Optional[str] + active_run_status: Optional[str] + runs_count: int + broker_connected: bool + + +class UsersResponse(BaseModel): + page: int + page_size: int + total: int + users: list[UserSummary] + + +class RunSummary(BaseModel): + run_id: str + user_id: str + status: str + created_at: Optional[datetime] + started_at: Optional[datetime] + stopped_at: Optional[datetime] + strategy: Optional[str] + mode: Optional[str] + broker: Optional[str] + sip_amount: Optional[float] + sip_frequency_value: Optional[int] + sip_frequency_unit: Optional[str] + last_event_time: Optional[datetime] + last_sip_time: Optional[datetime] + next_sip_time: Optional[datetime] + order_count: int + trade_count: int + equity_latest: Optional[float] + pnl_latest: Optional[float] + + +class RunsResponse(BaseModel): + page: int + page_size: int + total: int + runs: list[RunSummary] + + +class EventItem(BaseModel): + ts: Optional[datetime] + source: str + event: str + message: Optional[str] + level: Optional[str] + run_id: Optional[str] + meta: Optional[dict[str, Any]] + + +class CapitalSummary(BaseModel): + cash: Optional[float] + invested: Optional[float] + mtm: Optional[float] + equity: Optional[float] + pnl: Optional[float] + + +class UserDetailResponse(BaseModel): + user: UserSummary + runs: list[RunSummary] + current_config: Optional[dict[str, Any]] + events: list[EventItem] + capital_summary: CapitalSummary + + +class EngineStatusResponse(BaseModel): + status: Optional[str] + last_updated: Optional[datetime] + + +class RunDetailResponse(BaseModel): + run: RunSummary + config: Optional[dict[str, Any]] + engine_status: Optional[EngineStatusResponse] + state_snapshot: Optional[dict[str, Any]] + ledger_events: list[dict[str, Any]] + orders: list[dict[str, Any]] + trades: list[dict[str, Any]] + invariants: dict[str, Any] + + +class InvariantsResponse(BaseModel): + running_runs_per_user_violations: int + orphan_rows: int + duplicate_logical_time: int + negative_cash: int + invalid_qty: int + stale_running_runs: int + + +class SupportTicketSummary(BaseModel): + ticket_id: str + name: str + email: str + subject: str + message: str + status: str + created_at: Optional[datetime] + updated_at: Optional[datetime] + + +class SupportTicketsResponse(BaseModel): + page: int + page_size: int + total: int + tickets: list[SupportTicketSummary] + + +class DeleteSupportTicketResponse(BaseModel): + ticket_id: str + deleted: bool + + +class DeleteUserResponse(BaseModel): + user_id: str + deleted: dict[str, int] + audit_id: int + + +class HardResetResponse(BaseModel): + user_id: str + deleted: dict[str, int] + audit_id: int diff --git a/backend/app/admin_role_service.py b/backend/app/admin_role_service.py new file mode 100644 index 0000000..c1e1995 --- /dev/null +++ b/backend/app/admin_role_service.py @@ -0,0 +1,109 @@ +import os +from app.services.auth_service import create_user, get_user_by_username +from app.services.db import db_connection + +VALID_ROLES = {"USER", "ADMIN", "SUPER_ADMIN"} + + +def _sync_legacy_flags(cur, user_id: str, role: str): + cur.execute( + """ + UPDATE app_user + SET is_admin = %s, is_super_admin = %s + WHERE id = %s + """, + (role in ("ADMIN", "SUPER_ADMIN"), role == "SUPER_ADMIN", user_id), + ) + + +def set_user_role(actor_id: str, target_id: str, new_role: str): + if new_role not in VALID_ROLES: + return {"error": "invalid_role"} + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + "SELECT role FROM app_user WHERE id = %s", + (target_id,), + ) + row = cur.fetchone() + if not row: + return None + old_role = row[0] + + if actor_id == target_id and old_role == "SUPER_ADMIN" and new_role != "SUPER_ADMIN": + return {"error": "cannot_demote_self"} + + if old_role == new_role: + return { + "user_id": target_id, + "old_role": old_role, + "new_role": new_role, + } + + cur.execute( + """ + UPDATE app_user + SET role = %s + WHERE id = %s + """, + (new_role, target_id), + ) + _sync_legacy_flags(cur, target_id, new_role) + + cur.execute( + """ + INSERT INTO admin_role_audit + (actor_user_id, target_user_id, old_role, new_role) + VALUES (%s, %s, %s, %s) + """, + (actor_id, target_id, old_role, new_role), + ) + return { + "user_id": target_id, + "old_role": old_role, + "new_role": new_role, + } + + +def bootstrap_super_admin(): + email = (os.getenv("SUPER_ADMIN_EMAIL") or "").strip() + if not email: + return + + existing = get_user_by_username(email) + if existing: + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE app_user + SET role = 'SUPER_ADMIN' + WHERE id = %s + """, + (existing["id"],), + ) + _sync_legacy_flags(cur, existing["id"], "SUPER_ADMIN") + return + + password = (os.getenv("SUPER_ADMIN_PASSWORD") or "").strip() + if not password: + raise RuntimeError("SUPER_ADMIN_PASSWORD must be set to bootstrap SUPER_ADMIN") + + user = create_user(email, password) + if not user: + return + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE app_user + SET role = 'SUPER_ADMIN' + WHERE id = %s + """, + (user["id"],), + ) + _sync_legacy_flags(cur, user["id"], "SUPER_ADMIN") diff --git a/backend/app/admin_router.py b/backend/app/admin_router.py new file mode 100644 index 0000000..376e87f --- /dev/null +++ b/backend/app/admin_router.py @@ -0,0 +1,151 @@ +from fastapi import APIRouter, Depends, HTTPException, Query + +from app.admin_auth import require_admin, require_super_admin +from app.admin_models import ( + DeleteUserResponse, + HardResetResponse, + InvariantsResponse, + SupportTicketsResponse, + DeleteSupportTicketResponse, + OverviewResponse, + RunsResponse, + RunDetailResponse, + UsersResponse, + UserDetailResponse, +) +from app.admin_service import ( + delete_user_hard, + hard_reset_user_data, + get_invariants, + get_support_tickets, + delete_support_ticket, + get_overview, + get_run_detail, + get_runs, + get_user_detail, + get_users, +) +from app.admin_role_service import set_user_role + +router = APIRouter(prefix="/api/admin", dependencies=[Depends(require_admin)]) + + +@router.get("/overview", response_model=OverviewResponse) +def admin_overview(): + return get_overview() + + +@router.get("/users", response_model=UsersResponse) +def admin_users( + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), + query: str | None = None, +): + return get_users(page, page_size, query) + + +@router.get("/users/{user_id}", response_model=UserDetailResponse) +def admin_user_detail(user_id: str): + detail = get_user_detail(user_id) + if not detail: + raise HTTPException(status_code=404, detail="User not found") + return detail + + +@router.delete("/users/{user_id}", response_model=DeleteUserResponse) +def admin_delete_user( + user_id: str, + hard: bool = Query(False), + admin_user: dict = Depends(require_super_admin), +): + if not hard: + raise HTTPException(status_code=400, detail="Hard delete requires hard=true") + result = delete_user_hard(user_id, admin_user) + if result is None: + raise HTTPException(status_code=404, detail="User not found") + return result + + +@router.post("/users/{user_id}/hard-reset", response_model=HardResetResponse) +def admin_hard_reset_user( + user_id: str, + admin_user: dict = Depends(require_super_admin), +): + result = hard_reset_user_data(user_id, admin_user) + if result is None: + raise HTTPException(status_code=404, detail="User not found") + return result + + +@router.post("/users/{user_id}/make-admin") +def admin_make_admin(user_id: str, admin_user: dict = Depends(require_super_admin)): + result = set_user_role(admin_user["id"], user_id, "ADMIN") + if result is None: + raise HTTPException(status_code=404, detail="User not found") + if result.get("error") == "cannot_demote_self": + raise HTTPException(status_code=400, detail="Cannot demote self") + if result.get("error") == "invalid_role": + raise HTTPException(status_code=400, detail="Invalid role") + return result + + +@router.post("/users/{user_id}/revoke-admin") +def admin_revoke_admin(user_id: str, admin_user: dict = Depends(require_super_admin)): + result = set_user_role(admin_user["id"], user_id, "USER") + if result is None: + raise HTTPException(status_code=404, detail="User not found") + if result.get("error") == "cannot_demote_self": + raise HTTPException(status_code=400, detail="Cannot demote self") + if result.get("error") == "invalid_role": + raise HTTPException(status_code=400, detail="Invalid role") + return result + + +@router.post("/users/{user_id}/make-super-admin") +def admin_make_super_admin(user_id: str, admin_user: dict = Depends(require_super_admin)): + result = set_user_role(admin_user["id"], user_id, "SUPER_ADMIN") + if result is None: + raise HTTPException(status_code=404, detail="User not found") + if result.get("error") == "invalid_role": + raise HTTPException(status_code=400, detail="Invalid role") + return result + + +@router.get("/runs", response_model=RunsResponse) +def admin_runs( + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), + status: str | None = None, + mode: str | None = None, + user_id: str | None = None, +): + return get_runs(page, page_size, status, mode, user_id) + + +@router.get("/runs/{run_id}", response_model=RunDetailResponse) +def admin_run_detail(run_id: str): + detail = get_run_detail(run_id) + if not detail: + raise HTTPException(status_code=404, detail="Run not found") + return detail + + +@router.get("/health/invariants", response_model=InvariantsResponse) +def admin_invariants(): + return get_invariants() + + +@router.get("/support-tickets", response_model=SupportTicketsResponse) +def admin_support_tickets( + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), +): + return get_support_tickets(page, page_size) + + +@router.delete("/support-tickets/{ticket_id}", response_model=DeleteSupportTicketResponse) +def admin_delete_support_ticket(ticket_id: str): + result = delete_support_ticket(ticket_id) + if not result: + raise HTTPException(status_code=404, detail="Ticket not found") + return result diff --git a/backend/app/admin_service.py b/backend/app/admin_service.py new file mode 100644 index 0000000..3c6a51d --- /dev/null +++ b/backend/app/admin_service.py @@ -0,0 +1,762 @@ +from datetime import datetime, timedelta, timezone +import hashlib +import os + +from psycopg2.extras import Json +from psycopg2.extras import RealDictCursor + +from app.services.db import db_connection +from app.services.run_service import get_running_run_id +from indian_paper_trading_strategy.engine.runner import stop_engine + + +def _paginate(page: int, page_size: int): + page = max(page, 1) + page_size = max(min(page_size, 200), 1) + offset = (page - 1) * page_size + return page, page_size, offset + + +def get_overview(): + now = datetime.now(timezone.utc) + since = now - timedelta(hours=24) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) FROM app_user") + total_users = cur.fetchone()[0] + cur.execute( + """ + SELECT COUNT(DISTINCT user_id) + FROM app_session + WHERE COALESCE(last_seen_at, created_at) >= %s + """, + (since,), + ) + users_logged_in_last_24h = cur.fetchone()[0] + cur.execute( + """ + SELECT + COUNT(*) AS total_runs, + COUNT(*) FILTER (WHERE status = 'RUNNING') AS running_runs, + COUNT(*) FILTER (WHERE status = 'STOPPED') AS stopped_runs, + COUNT(*) FILTER (WHERE status = 'ERROR') AS error_runs, + COUNT(*) FILTER (WHERE mode = 'LIVE') AS live_runs_count, + COUNT(*) FILTER (WHERE mode = 'PAPER') AS paper_runs_count + FROM strategy_run + """ + ) + run_row = cur.fetchone() + cur.execute( + """ + SELECT COUNT(*) FROM paper_order WHERE "timestamp" >= %s + """, + (since,), + ) + orders_last_24h = cur.fetchone()[0] + cur.execute( + """ + SELECT COUNT(*) FROM paper_trade WHERE "timestamp" >= %s + """, + (since,), + ) + trades_last_24h = cur.fetchone()[0] + cur.execute( + """ + SELECT COUNT(*) + FROM event_ledger + WHERE event = 'SIP_EXECUTED' AND "timestamp" >= %s + """, + (since,), + ) + sip_executed_last_24h = cur.fetchone()[0] + cur.execute( + """ + SELECT ts, event, message, source, user_id, run_id + FROM ( + SELECT ts, event, message, 'engine_event' AS source, user_id, run_id + FROM engine_event + WHERE event ILIKE '%ERROR%' + UNION ALL + SELECT ts, event, message, 'strategy_log' AS source, user_id, run_id + FROM strategy_log + WHERE level = 'ERROR' + ) t + ORDER BY ts DESC NULLS LAST + LIMIT 10 + """ + ) + top_errors = [ + { + "ts": row[0], + "event": row[1], + "message": row[2], + "source": row[3], + "user_id": row[4], + "run_id": row[5], + } + for row in cur.fetchall() + ] + return { + "total_users": total_users, + "users_logged_in_last_24h": users_logged_in_last_24h, + "total_runs": run_row[0], + "running_runs": run_row[1], + "stopped_runs": run_row[2], + "error_runs": run_row[3], + "live_runs_count": run_row[4], + "paper_runs_count": run_row[5], + "orders_last_24h": orders_last_24h, + "trades_last_24h": trades_last_24h, + "sip_executed_last_24h": sip_executed_last_24h, + "top_errors": top_errors, + } + + +def get_users(page: int, page_size: int, query: str | None): + page, page_size, offset = _paginate(page, page_size) + params = [] + where = "" + if query: + where = "WHERE username ILIKE %s OR user_id = %s" + params = [f"%{query}%", query] + with db_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(f"SELECT COUNT(*) FROM admin_user_metrics {where}", params) + total = cur.fetchone()["count"] + cur.execute( + f""" + SELECT * + FROM admin_user_metrics + {where} + ORDER BY created_at DESC NULLS LAST + LIMIT %s OFFSET %s + """, + (*params, page_size, offset), + ) + rows = cur.fetchall() + return { + "page": page, + "page_size": page_size, + "total": total, + "users": rows, + } + + +def _get_active_run_id(cur, user_id: str): + cur.execute( + """ + SELECT run_id + FROM strategy_run + WHERE user_id = %s AND status = 'RUNNING' + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if row: + return row[0] + cur.execute( + """ + SELECT run_id + FROM strategy_run + WHERE user_id = %s + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + return row[0] if row else None + + +def get_user_detail(user_id: str): + with db_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute("SELECT * FROM admin_user_metrics WHERE user_id = %s", (user_id,)) + user = cur.fetchone() + if not user: + return None + + cur.execute( + """ + SELECT * FROM admin_run_metrics + WHERE user_id = %s + ORDER BY created_at DESC NULLS LAST + LIMIT 20 + """, + (user_id,), + ) + runs = cur.fetchall() + + active_run_id = _get_active_run_id(cur, user_id) + config = None + if active_run_id: + cur.execute( + """ + SELECT strategy, sip_amount, sip_frequency_value, sip_frequency_unit, + mode, broker, active, frequency, frequency_days, unit, next_run + FROM strategy_config + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (user_id, active_run_id), + ) + cfg_row = cur.fetchone() + if cfg_row: + config = dict(cfg_row) + + cur.execute( + """ + SELECT ts, event, message, level, run_id, meta, 'strategy_log' AS source + FROM strategy_log + WHERE user_id = %s + UNION ALL + SELECT ts, event, message, NULL AS level, run_id, meta, 'engine_event' AS source + FROM engine_event + WHERE user_id = %s + ORDER BY ts DESC NULLS LAST + LIMIT 50 + """, + (user_id, user_id), + ) + events = [ + { + "ts": row[0], + "event": row[1], + "message": row[2], + "level": row[3], + "run_id": row[4], + "meta": row[5], + "source": row[6], + } + for row in cur.fetchall() + ] + + capital_summary = { + "cash": None, + "invested": None, + "mtm": None, + "equity": None, + "pnl": None, + } + if active_run_id: + cur.execute( + """ + SELECT + (SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s LIMIT 1) AS cash, + (SELECT total_invested FROM engine_state_paper WHERE user_id = %s AND run_id = %s LIMIT 1) AS invested, + (SELECT portfolio_value FROM mtm_ledger WHERE user_id = %s AND run_id = %s ORDER BY "timestamp" DESC LIMIT 1) AS mtm, + (SELECT equity FROM paper_equity_curve WHERE user_id = %s AND run_id = %s ORDER BY "timestamp" DESC LIMIT 1) AS equity, + (SELECT pnl FROM paper_equity_curve WHERE user_id = %s AND run_id = %s ORDER BY "timestamp" DESC LIMIT 1) AS pnl + """, + ( + user_id, + active_run_id, + user_id, + active_run_id, + user_id, + active_run_id, + user_id, + active_run_id, + user_id, + active_run_id, + ), + ) + row = cur.fetchone() + if row: + capital_summary = { + "cash": row[0], + "invested": row[1], + "mtm": row[2], + "equity": row[3], + "pnl": row[4], + } + + return { + "user": user, + "runs": runs, + "current_config": config, + "events": events, + "capital_summary": capital_summary, + } + + +def get_runs(page: int, page_size: int, status: str | None, mode: str | None, user_id: str | None): + page, page_size, offset = _paginate(page, page_size) + filters = [] + params = [] + if status: + filters.append("status = %s") + params.append(status) + if mode: + filters.append("mode = %s") + params.append(mode) + if user_id: + filters.append("user_id = %s") + params.append(user_id) + where = f"WHERE {' AND '.join(filters)}" if filters else "" + + with db_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(f"SELECT COUNT(*) FROM admin_run_metrics {where}", params) + total = cur.fetchone()["count"] + cur.execute( + f""" + SELECT * + FROM admin_run_metrics + {where} + ORDER BY created_at DESC NULLS LAST + LIMIT %s OFFSET %s + """, + (*params, page_size, offset), + ) + runs = cur.fetchall() + return { + "page": page, + "page_size": page_size, + "total": total, + "runs": runs, + } + + +def get_run_detail(run_id: str): + with db_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute("SELECT * FROM admin_run_metrics WHERE run_id = %s", (run_id,)) + run = cur.fetchone() + if not run: + return None + + user_id = run["user_id"] + + cur.execute( + """ + SELECT strategy, sip_amount, sip_frequency_value, sip_frequency_unit, + mode, broker, active, frequency, frequency_days, unit, next_run + FROM strategy_config + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (user_id, run_id), + ) + config = cur.fetchone() + + cur.execute( + """ + SELECT status, last_updated + FROM engine_status + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (user_id, run_id), + ) + engine_status = cur.fetchone() + + cur.execute( + """ + SELECT initial_cash, cash, total_invested, nifty_units, gold_units, + last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit + FROM engine_state_paper + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (user_id, run_id), + ) + state = cur.fetchone() + state_snapshot = dict(state) if state else None + + cur.execute( + """ + SELECT event, "timestamp", logical_time, nifty_units, gold_units, nifty_price, gold_price, amount + FROM event_ledger + WHERE user_id = %s AND run_id = %s + ORDER BY "timestamp" DESC + LIMIT 100 + """, + (user_id, run_id), + ) + ledger_events = cur.fetchall() + + cur.execute( + """ + SELECT id, symbol, side, qty, price, status, "timestamp" + FROM paper_order + WHERE user_id = %s AND run_id = %s + ORDER BY "timestamp" DESC + LIMIT 50 + """, + (user_id, run_id), + ) + orders = cur.fetchall() + + cur.execute( + """ + SELECT id, order_id, symbol, side, qty, price, "timestamp" + FROM paper_trade + WHERE user_id = %s AND run_id = %s + ORDER BY "timestamp" DESC + LIMIT 50 + """, + (user_id, run_id), + ) + trades = cur.fetchall() + + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT logical_time FROM event_ledger + WHERE user_id = %s AND run_id = %s + GROUP BY logical_time, event + HAVING COUNT(*) > 1 + ) t + """, + (user_id, run_id), + ) + dup_event = cur.fetchone()["count"] + + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT logical_time FROM mtm_ledger + WHERE user_id = %s AND run_id = %s + GROUP BY logical_time + HAVING COUNT(*) > 1 + ) t + """, + (user_id, run_id), + ) + dup_mtm = cur.fetchone()["count"] + + cur.execute( + """ + SELECT COUNT(*) FROM paper_broker_account + WHERE user_id = %s AND run_id = %s AND cash < 0 + """, + (user_id, run_id), + ) + neg_cash = cur.fetchone()["count"] + + cur.execute( + """ + SELECT COUNT(*) FROM paper_order + WHERE user_id = %s AND run_id = %s AND qty <= 0 + """, + (user_id, run_id), + ) + bad_qty = cur.fetchone()["count"] + + invariants = { + "duplicate_event_logical_time": dup_event, + "duplicate_mtm_logical_time": dup_mtm, + "negative_cash": neg_cash, + "invalid_qty": bad_qty, + } + + return { + "run": run, + "config": dict(config) if config else None, + "engine_status": dict(engine_status) if engine_status else None, + "state_snapshot": state_snapshot, + "ledger_events": ledger_events, + "orders": orders, + "trades": trades, + "invariants": invariants, + } + + +def get_invariants(stale_minutes: int = 30): + cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_minutes) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT user_id FROM strategy_run + WHERE status = 'RUNNING' + GROUP BY user_id + HAVING COUNT(*) > 1 + ) t + """ + ) + running_runs_per_user_violations = cur.fetchone()[0] + + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT user_id, run_id FROM engine_state + UNION ALL + SELECT user_id, run_id FROM engine_status + UNION ALL + SELECT user_id, run_id FROM paper_order + UNION ALL + SELECT user_id, run_id FROM paper_trade + ) t + LEFT JOIN strategy_run sr + ON sr.user_id = t.user_id AND sr.run_id = t.run_id + WHERE sr.run_id IS NULL + """ + ) + orphan_rows = cur.fetchone()[0] + + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT user_id, run_id, logical_time, event + FROM event_ledger + GROUP BY user_id, run_id, logical_time, event + HAVING COUNT(*) > 1 + ) t + """ + ) + dup_event = cur.fetchone()[0] + + cur.execute( + """ + SELECT COUNT(*) FROM ( + SELECT user_id, run_id, logical_time + FROM mtm_ledger + GROUP BY user_id, run_id, logical_time + HAVING COUNT(*) > 1 + ) t + """ + ) + dup_mtm = cur.fetchone()[0] + + cur.execute( + "SELECT COUNT(*) FROM paper_broker_account WHERE cash < 0" + ) + negative_cash = cur.fetchone()[0] + + cur.execute( + "SELECT COUNT(*) FROM paper_order WHERE qty <= 0" + ) + invalid_qty = cur.fetchone()[0] + + cur.execute( + """ + SELECT COUNT(*) FROM strategy_run sr + LEFT JOIN ( + SELECT user_id, run_id, MAX(ts) AS last_ts + FROM ( + SELECT user_id, run_id, ts FROM engine_event + UNION ALL + SELECT user_id, run_id, ts FROM strategy_log + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM event_ledger + ) t + GROUP BY user_id, run_id + ) activity + ON activity.user_id = sr.user_id AND activity.run_id = sr.run_id + WHERE sr.status = 'RUNNING' AND (activity.last_ts IS NULL OR activity.last_ts < %s) + """, + (cutoff,), + ) + stale_running_runs = cur.fetchone()[0] + + return { + "running_runs_per_user_violations": running_runs_per_user_violations, + "orphan_rows": orphan_rows, + "duplicate_logical_time": dup_event + dup_mtm, + "negative_cash": negative_cash, + "invalid_qty": invalid_qty, + "stale_running_runs": stale_running_runs, + } + + +def get_support_tickets(page: int, page_size: int): + page, page_size, offset = _paginate(page, page_size) + with db_connection() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute("SELECT COUNT(*) FROM support_ticket") + total = cur.fetchone()["count"] + cur.execute( + """ + SELECT id AS ticket_id, name, email, subject, message, status, created_at, updated_at + FROM support_ticket + ORDER BY created_at DESC NULLS LAST + LIMIT %s OFFSET %s + """, + (page_size, offset), + ) + rows = cur.fetchall() + tickets = [] + for row in rows: + ticket = dict(row) + ticket["ticket_id"] = str(ticket.get("ticket_id")) + if ticket.get("created_at"): + ticket["created_at"] = ticket["created_at"] + if ticket.get("updated_at"): + ticket["updated_at"] = ticket["updated_at"] + tickets.append(ticket) + return { + "page": page, + "page_size": page_size, + "total": total, + "tickets": tickets, + } + + +def delete_support_ticket(ticket_id: str) -> dict | None: + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM support_ticket WHERE id = %s", (ticket_id,)) + if cur.rowcount == 0: + return None + return {"ticket_id": ticket_id, "deleted": True} + + +def _hash_value(value: str | None) -> str | None: + if value is None: + return None + return hashlib.sha256(value.encode("utf-8")).hexdigest() + + +def delete_user_hard(user_id: str, admin_user: dict): + table_counts = [ + ("app_user", "SELECT COUNT(*) FROM app_user WHERE id = %s"), + ("app_session", "SELECT COUNT(*) FROM app_session WHERE user_id = %s"), + ("user_broker", "SELECT COUNT(*) FROM user_broker WHERE user_id = %s"), + ("zerodha_session", "SELECT COUNT(*) FROM zerodha_session WHERE user_id = %s"), + ("zerodha_request_token", "SELECT COUNT(*) FROM zerodha_request_token WHERE user_id = %s"), + ("strategy_run", "SELECT COUNT(*) FROM strategy_run WHERE user_id = %s"), + ("strategy_config", "SELECT COUNT(*) FROM strategy_config WHERE user_id = %s"), + ("strategy_log", "SELECT COUNT(*) FROM strategy_log WHERE user_id = %s"), + ("engine_status", "SELECT COUNT(*) FROM engine_status WHERE user_id = %s"), + ("engine_state", "SELECT COUNT(*) FROM engine_state WHERE user_id = %s"), + ("engine_state_paper", "SELECT COUNT(*) FROM engine_state_paper WHERE user_id = %s"), + ("engine_event", "SELECT COUNT(*) FROM engine_event WHERE user_id = %s"), + ("paper_broker_account", "SELECT COUNT(*) FROM paper_broker_account WHERE user_id = %s"), + ("paper_position", "SELECT COUNT(*) FROM paper_position WHERE user_id = %s"), + ("paper_order", "SELECT COUNT(*) FROM paper_order WHERE user_id = %s"), + ("paper_trade", "SELECT COUNT(*) FROM paper_trade WHERE user_id = %s"), + ("paper_equity_curve", "SELECT COUNT(*) FROM paper_equity_curve WHERE user_id = %s"), + ("mtm_ledger", "SELECT COUNT(*) FROM mtm_ledger WHERE user_id = %s"), + ("event_ledger", "SELECT COUNT(*) FROM event_ledger WHERE user_id = %s"), + ] + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username FROM app_user WHERE id = %s", + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + target_username = row[1] + + counts = {} + for name, query in table_counts: + cur.execute(query, (user_id,)) + counts[name] = cur.fetchone()[0] + + cur.execute("DELETE FROM app_user WHERE id = %s", (user_id,)) + if cur.rowcount == 0: + return None + + audit_meta = {"deleted": counts, "hard": True} + cur.execute( + """ + INSERT INTO admin_audit_log + (actor_user_hash, target_user_hash, target_username_hash, action, meta) + VALUES (%s, %s, %s, %s, %s) + RETURNING id + """, + ( + _hash_value(admin_user["id"]), + _hash_value(user_id), + _hash_value(target_username), + "HARD_DELETE_USER", + Json(audit_meta), + ), + ) + audit_id = cur.fetchone()[0] + + return { + "user_id": user_id, + "deleted": counts, + "audit_id": audit_id, + } + + +def hard_reset_user_data(user_id: str, admin_user: dict): + table_counts = [ + ("strategy_run", "SELECT COUNT(*) FROM strategy_run WHERE user_id = %s"), + ("strategy_config", "SELECT COUNT(*) FROM strategy_config WHERE user_id = %s"), + ("strategy_log", "SELECT COUNT(*) FROM strategy_log WHERE user_id = %s"), + ("engine_status", "SELECT COUNT(*) FROM engine_status WHERE user_id = %s"), + ("engine_state", "SELECT COUNT(*) FROM engine_state WHERE user_id = %s"), + ("engine_state_paper", "SELECT COUNT(*) FROM engine_state_paper WHERE user_id = %s"), + ("engine_event", "SELECT COUNT(*) FROM engine_event WHERE user_id = %s"), + ("paper_broker_account", "SELECT COUNT(*) FROM paper_broker_account WHERE user_id = %s"), + ("paper_position", "SELECT COUNT(*) FROM paper_position WHERE user_id = %s"), + ("paper_order", "SELECT COUNT(*) FROM paper_order WHERE user_id = %s"), + ("paper_trade", "SELECT COUNT(*) FROM paper_trade WHERE user_id = %s"), + ("paper_equity_curve", "SELECT COUNT(*) FROM paper_equity_curve WHERE user_id = %s"), + ("mtm_ledger", "SELECT COUNT(*) FROM mtm_ledger WHERE user_id = %s"), + ("event_ledger", "SELECT COUNT(*) FROM event_ledger WHERE user_id = %s"), + ] + + engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} + running_run_id = get_running_run_id(user_id) + if running_run_id and not engine_external: + stop_engine(user_id, timeout=15.0) + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username FROM app_user WHERE id = %s", + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + target_username = row[1] + + counts = {} + for name, query in table_counts: + cur.execute(query, (user_id,)) + counts[name] = cur.fetchone()[0] + + cur.execute("DELETE FROM strategy_log WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM engine_event WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM paper_equity_curve WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM paper_trade WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM paper_order WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM paper_position WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM paper_broker_account WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM mtm_ledger WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM event_ledger WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM engine_state_paper WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM engine_state WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM engine_status WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM strategy_config WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM strategy_run WHERE user_id = %s", (user_id,)) + + audit_meta = {"reset": counts, "hard": True} + cur.execute( + """ + INSERT INTO admin_audit_log + (actor_user_hash, target_user_hash, target_username_hash, action, meta) + VALUES (%s, %s, %s, %s, %s) + RETURNING id + """, + ( + _hash_value(admin_user["id"]), + _hash_value(user_id), + _hash_value(target_username), + "HARD_RESET_USER", + Json(audit_meta), + ), + ) + audit_id = cur.fetchone()[0] + + return { + "user_id": user_id, + "deleted": counts, + "audit_id": audit_id, + } diff --git a/backend/app/broker_store.py b/backend/app/broker_store.py new file mode 100644 index 0000000..09d15fa --- /dev/null +++ b/backend/app/broker_store.py @@ -0,0 +1,296 @@ +from datetime import datetime, timezone + +from app.services.crypto_service import decrypt_value, encrypt_value +from app.services.db import db_transaction + + +def _row_to_entry(row): + ( + user_id, + broker, + connected, + access_token, + connected_at, + api_key, + api_secret, + user_name, + broker_user_id, + auth_state, + pending_broker, + pending_api_key, + pending_api_secret, + pending_started_at, + ) = row + entry = { + "broker": broker, + "connected": bool(connected), + "connected_at": connected_at, + "api_key": api_key, + "auth_state": auth_state, + "user_name": user_name, + "broker_user_id": broker_user_id, + } + if pending_broker or pending_api_key or pending_api_secret or pending_started_at: + pending = { + "broker": pending_broker, + "api_key": pending_api_key, + "api_secret": decrypt_value(pending_api_secret) + if pending_api_secret + else None, + "started_at": pending_started_at, + } + entry["pending"] = pending + return entry + + +def load_user_brokers(): + with db_transaction() as cur: + cur.execute( + """ + SELECT user_id, broker, connected, access_token, connected_at, + api_key, api_secret, user_name, broker_user_id, auth_state, + pending_broker, pending_api_key, pending_api_secret, pending_started_at + FROM user_broker + """ + ) + rows = cur.fetchall() + return {row[0]: _row_to_entry(row) for row in rows} + + +def save_user_brokers(data): + with db_transaction() as cur: + for user_id, entry in data.items(): + cur.execute( + """ + INSERT INTO user_broker ( + user_id, broker, connected, access_token, connected_at, + api_key, api_secret, user_name, broker_user_id, auth_state, + pending_broker, pending_api_key, pending_api_secret, pending_started_at + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id) + DO UPDATE SET + broker = EXCLUDED.broker, + connected = EXCLUDED.connected, + access_token = EXCLUDED.access_token, + connected_at = EXCLUDED.connected_at, + api_key = EXCLUDED.api_key, + api_secret = EXCLUDED.api_secret, + user_name = EXCLUDED.user_name, + broker_user_id = EXCLUDED.broker_user_id, + auth_state = EXCLUDED.auth_state, + pending_broker = EXCLUDED.pending_broker, + pending_api_key = EXCLUDED.pending_api_key, + pending_api_secret = EXCLUDED.pending_api_secret, + pending_started_at = EXCLUDED.pending_started_at + """, + ( + user_id, + entry.get("broker"), + bool(entry.get("connected")), + encrypt_value(entry.get("access_token")) + if entry.get("access_token") + else None, + entry.get("connected_at"), + entry.get("api_key"), + encrypt_value(entry.get("api_secret")) + if entry.get("api_secret") + else None, + entry.get("user_name"), + entry.get("broker_user_id"), + entry.get("auth_state"), + (entry.get("pending") or {}).get("broker"), + (entry.get("pending") or {}).get("api_key"), + encrypt_value((entry.get("pending") or {}).get("api_secret")) + if (entry.get("pending") or {}).get("api_secret") + else None, + (entry.get("pending") or {}).get("started_at"), + ), + ) + + +def now_utc(): + return datetime.now(timezone.utc) + + +def get_user_broker(user_id: str): + with db_transaction() as cur: + cur.execute( + """ + SELECT user_id, broker, connected, access_token, connected_at, + api_key, api_secret, user_name, broker_user_id, auth_state, + pending_broker, pending_api_key, pending_api_secret, pending_started_at + FROM user_broker + WHERE user_id = %s + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + return _row_to_entry(row) + + +def clear_user_broker(user_id: str): + with db_transaction() as cur: + cur.execute("DELETE FROM user_broker WHERE user_id = %s", (user_id,)) + + +def set_pending_broker(user_id: str, broker: str, api_key: str, api_secret: str): + started_at = now_utc() + with db_transaction() as cur: + cur.execute( + """ + INSERT INTO user_broker ( + user_id, pending_broker, pending_api_key, pending_api_secret, pending_started_at, + api_key, api_secret, auth_state + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id) + DO UPDATE SET + pending_broker = EXCLUDED.pending_broker, + pending_api_key = EXCLUDED.pending_api_key, + pending_api_secret = EXCLUDED.pending_api_secret, + pending_started_at = EXCLUDED.pending_started_at, + api_key = EXCLUDED.api_key, + api_secret = EXCLUDED.api_secret, + auth_state = EXCLUDED.auth_state + """, + ( + user_id, + broker, + api_key, + encrypt_value(api_secret), + started_at, + api_key, + encrypt_value(api_secret), + "PENDING", + ), + ) + return { + "broker": broker, + "api_key": api_key, + "api_secret": api_secret, + "started_at": started_at, + } + + +def get_pending_broker(user_id: str): + with db_transaction() as cur: + cur.execute( + """ + SELECT pending_broker, pending_api_key, pending_api_secret, pending_started_at + FROM user_broker + WHERE user_id = %s + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + if not row[0] or not row[1] or not row[2]: + return None + return { + "broker": row[0], + "api_key": row[1], + "api_secret": decrypt_value(row[2]), + "started_at": row[3], + } + + +def get_broker_credentials(user_id: str): + with db_transaction() as cur: + cur.execute( + """ + SELECT api_key, api_secret, pending_api_key, pending_api_secret + FROM user_broker + WHERE user_id = %s + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + api_key, api_secret, pending_key, pending_secret = row + key = api_key or pending_key + secret = api_secret or pending_secret + if not key or not secret: + return None + return { + "api_key": key, + "api_secret": decrypt_value(secret), + } + + +def set_broker_auth_state(user_id: str, auth_state: str): + with db_transaction() as cur: + cur.execute( + """ + UPDATE user_broker + SET auth_state = %s + WHERE user_id = %s + """, + (auth_state, user_id), + ) + + +def set_connected_broker( + user_id: str, + broker: str, + access_token: str, + api_key: str | None = None, + api_secret: str | None = None, + user_name: str | None = None, + broker_user_id: str | None = None, + auth_state: str | None = None, +): + connected_at = now_utc() + with db_transaction() as cur: + cur.execute( + """ + INSERT INTO user_broker ( + user_id, broker, connected, access_token, connected_at, + api_key, api_secret, user_name, broker_user_id, auth_state, + pending_broker, pending_api_key, pending_api_secret, pending_started_at + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NULL, NULL, NULL, NULL) + ON CONFLICT (user_id) + DO UPDATE SET + broker = EXCLUDED.broker, + connected = EXCLUDED.connected, + access_token = EXCLUDED.access_token, + connected_at = EXCLUDED.connected_at, + api_key = EXCLUDED.api_key, + api_secret = EXCLUDED.api_secret, + user_name = EXCLUDED.user_name, + broker_user_id = EXCLUDED.broker_user_id, + auth_state = EXCLUDED.auth_state, + pending_broker = NULL, + pending_api_key = NULL, + pending_api_secret = NULL, + pending_started_at = NULL + """, + ( + user_id, + broker, + True, + encrypt_value(access_token), + connected_at, + api_key, + encrypt_value(api_secret) if api_secret else None, + user_name, + broker_user_id, + auth_state, + ), + ) + return { + "broker": broker, + "connected": True, + "access_token": access_token, + "connected_at": connected_at, + "api_key": api_key, + "api_secret": api_secret, + "user_name": user_name, + "broker_user_id": broker_user_id, + "auth_state": auth_state, + } diff --git a/backend/app/db_models.py b/backend/app/db_models.py new file mode 100644 index 0000000..9c3c16d --- /dev/null +++ b/backend/app/db_models.py @@ -0,0 +1,491 @@ +from sqlalchemy import ( + BigInteger, + Boolean, + CheckConstraint, + Column, + Date, + DateTime, + ForeignKey, + ForeignKeyConstraint, + Index, + Integer, + Numeric, + String, + Text, + UniqueConstraint, + func, + text, +) +from sqlalchemy.dialects.postgresql import JSONB + +from app.services.db import Base + + +class AppUser(Base): + __tablename__ = "app_user" + + id = Column(String, primary_key=True) + username = Column(String, nullable=False, unique=True) + password_hash = Column(String, nullable=False) + is_admin = Column(Boolean, nullable=False, server_default=text("false")) + is_super_admin = Column(Boolean, nullable=False, server_default=text("false")) + role = Column(String, nullable=False, server_default=text("'USER'")) + + __table_args__ = ( + CheckConstraint("role IN ('USER','ADMIN','SUPER_ADMIN')", name="chk_app_user_role"), + Index("idx_app_user_role", "role"), + Index("idx_app_user_is_admin", "is_admin"), + Index("idx_app_user_is_super_admin", "is_super_admin"), + ) + + +class AppSession(Base): + __tablename__ = "app_session" + + id = Column(String, primary_key=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False) + last_seen_at = Column(DateTime(timezone=True)) + expires_at = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + Index("idx_app_session_user_id", "user_id"), + Index("idx_app_session_expires_at", "expires_at"), + ) + + +class UserBroker(Base): + __tablename__ = "user_broker" + + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), primary_key=True) + broker = Column(String) + connected = Column(Boolean, nullable=False, server_default=text("false")) + access_token = Column(Text) + connected_at = Column(DateTime(timezone=True)) + api_key = Column(Text) + user_name = Column(Text) + broker_user_id = Column(Text) + pending_broker = Column(Text) + pending_api_key = Column(Text) + pending_api_secret = Column(Text) + pending_started_at = Column(DateTime(timezone=True)) + + __table_args__ = ( + Index("idx_user_broker_broker", "broker"), + Index("idx_user_broker_connected", "connected"), + ) + + +class ZerodhaSession(Base): + __tablename__ = "zerodha_session" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + linked_at = Column(DateTime(timezone=True), nullable=False) + api_key = Column(Text) + access_token = Column(Text) + request_token = Column(Text) + user_name = Column(Text) + broker_user_id = Column(Text) + + __table_args__ = ( + Index("idx_zerodha_session_user_id", "user_id"), + Index("idx_zerodha_session_linked_at", "linked_at"), + ) + + +class ZerodhaRequestToken(Base): + __tablename__ = "zerodha_request_token" + + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), primary_key=True) + request_token = Column(Text, nullable=False) + + +class StrategyRun(Base): + __tablename__ = "strategy_run" + + run_id = Column(String, primary_key=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at = Column(DateTime(timezone=True)) + stopped_at = Column(DateTime(timezone=True)) + status = Column(String, nullable=False) + strategy = Column(String) + mode = Column(String) + broker = Column(String) + meta = Column(JSONB) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_strategy_run_user_run"), + CheckConstraint("status IN ('RUNNING','STOPPED','ERROR')", name="chk_strategy_run_status"), + Index("idx_strategy_run_user_status", "user_id", "status"), + Index("idx_strategy_run_user_created", "user_id", "created_at"), + Index( + "uq_one_running_run_per_user", + "user_id", + unique=True, + postgresql_where=text("status = 'RUNNING'"), + ), + ) + + +class StrategyConfig(Base): + __tablename__ = "strategy_config" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + strategy = Column(String) + sip_amount = Column(Numeric) + sip_frequency_value = Column(Integer) + sip_frequency_unit = Column(String) + mode = Column(String) + broker = Column(String) + active = Column(Boolean) + frequency = Column(Text) + frequency_days = Column(Integer) + unit = Column(String) + next_run = Column(DateTime(timezone=True)) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_strategy_config_user_run"), + ) + + +class StrategyLog(Base): + __tablename__ = "strategy_log" + + seq = Column(BigInteger, primary_key=True) + ts = Column(DateTime(timezone=True), nullable=False) + level = Column(String) + category = Column(String) + event = Column(String) + message = Column(Text) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False) + meta = Column(JSONB) + + __table_args__ = ( + Index("idx_strategy_log_ts", "ts"), + Index("idx_strategy_log_event", "event"), + Index("idx_strategy_log_user_run_ts", "user_id", "run_id", "ts"), + ) + + +class EngineStatus(Base): + __tablename__ = "engine_status" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + status = Column(String, nullable=False) + last_updated = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_engine_status_user_run"), + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + Index("idx_engine_status_user_run", "user_id", "run_id"), + ) + + +class EngineState(Base): + __tablename__ = "engine_state" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + total_invested = Column(Numeric) + nifty_units = Column(Numeric) + gold_units = Column(Numeric) + last_sip_ts = Column(DateTime(timezone=True)) + last_run = Column(DateTime(timezone=True)) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_engine_state_user_run"), + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + ) + + +class EngineStatePaper(Base): + __tablename__ = "engine_state_paper" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + initial_cash = Column(Numeric) + cash = Column(Numeric) + total_invested = Column(Numeric) + nifty_units = Column(Numeric) + gold_units = Column(Numeric) + last_sip_ts = Column(DateTime(timezone=True)) + last_run = Column(DateTime(timezone=True)) + sip_frequency_value = Column(Integer) + sip_frequency_unit = Column(String) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_engine_state_paper_user_run"), + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + CheckConstraint("cash >= 0", name="chk_engine_state_paper_cash_non_negative"), + ) + + +class EngineEvent(Base): + __tablename__ = "engine_event" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + ts = Column(DateTime(timezone=True), nullable=False) + event = Column(String) + data = Column(JSONB) + message = Column(Text) + meta = Column(JSONB) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False) + + __table_args__ = ( + Index("idx_engine_event_ts", "ts"), + Index("idx_engine_event_user_run_ts", "user_id", "run_id", "ts"), + ) + + +class PaperBrokerAccount(Base): + __tablename__ = "paper_broker_account" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + cash = Column(Numeric, nullable=False) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", name="uq_paper_broker_account_user_run"), + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + CheckConstraint("cash >= 0", name="chk_paper_broker_cash_non_negative"), + ) + + +class PaperPosition(Base): + __tablename__ = "paper_position" + + user_id = Column(String, primary_key=True) + run_id = Column(String, primary_key=True) + symbol = Column(String, primary_key=True) + qty = Column(Numeric, nullable=False) + avg_price = Column(Numeric) + last_price = Column(Numeric) + updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + CheckConstraint("qty > 0", name="chk_paper_position_qty_positive"), + UniqueConstraint("user_id", "run_id", "symbol", name="uq_paper_position_scope"), + Index("idx_paper_position_user_run", "user_id", "run_id"), + ) + + +class PaperOrder(Base): + __tablename__ = "paper_order" + + id = Column(String, primary_key=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + symbol = Column(String, nullable=False) + side = Column(String, nullable=False) + qty = Column(Numeric, nullable=False) + price = Column(Numeric) + status = Column(String, nullable=False) + timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) + logical_time = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + UniqueConstraint("user_id", "run_id", "id", name="uq_paper_order_scope_id"), + UniqueConstraint( + "user_id", + "run_id", + "logical_time", + "symbol", + "side", + name="uq_paper_order_logical_key", + ), + CheckConstraint("qty > 0", name="chk_paper_order_qty_positive"), + CheckConstraint("price >= 0", name="chk_paper_order_price_non_negative"), + Index("idx_paper_order_ts", "timestamp"), + Index("idx_paper_order_user_run_ts", "user_id", "run_id", "timestamp"), + ) + + +class PaperTrade(Base): + __tablename__ = "paper_trade" + + id = Column(String, primary_key=True) + order_id = Column(String) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + symbol = Column(String, nullable=False) + side = Column(String, nullable=False) + qty = Column(Numeric, nullable=False) + price = Column(Numeric, nullable=False) + timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) + logical_time = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + ForeignKeyConstraint( + ["user_id", "run_id", "order_id"], + ["paper_order.user_id", "paper_order.run_id", "paper_order.id"], + ondelete="CASCADE", + ), + UniqueConstraint("user_id", "run_id", "id", name="uq_paper_trade_scope_id"), + UniqueConstraint( + "user_id", + "run_id", + "logical_time", + "symbol", + "side", + name="uq_paper_trade_logical_key", + ), + CheckConstraint("qty > 0", name="chk_paper_trade_qty_positive"), + CheckConstraint("price >= 0", name="chk_paper_trade_price_non_negative"), + Index("idx_paper_trade_ts", "timestamp"), + Index("idx_paper_trade_user_run_ts", "user_id", "run_id", "timestamp"), + ) + + +class PaperEquityCurve(Base): + __tablename__ = "paper_equity_curve" + + user_id = Column(String, primary_key=True) + run_id = Column(String, primary_key=True) + timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) + logical_time = Column(DateTime(timezone=True), primary_key=True) + equity = Column(Numeric, nullable=False) + pnl = Column(Numeric) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + Index("idx_paper_equity_curve_ts", "timestamp"), + Index("idx_paper_equity_curve_user_run_ts", "user_id", "run_id", "timestamp"), + ) + + +class MTMLedger(Base): + __tablename__ = "mtm_ledger" + + user_id = Column(String, primary_key=True) + run_id = Column(String, primary_key=True) + timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) + logical_time = Column(DateTime(timezone=True), primary_key=True) + nifty_units = Column(Numeric) + gold_units = Column(Numeric) + nifty_price = Column(Numeric) + gold_price = Column(Numeric) + nifty_value = Column(Numeric) + gold_value = Column(Numeric) + portfolio_value = Column(Numeric) + total_invested = Column(Numeric) + pnl = Column(Numeric) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + Index("idx_mtm_ledger_ts", "timestamp"), + Index("idx_mtm_ledger_user_run_ts", "user_id", "run_id", "timestamp"), + ) + + +class EventLedger(Base): + __tablename__ = "event_ledger" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, nullable=False) + timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) + logical_time = Column(DateTime(timezone=True), nullable=False) + event = Column(String, nullable=False) + nifty_units = Column(Numeric) + gold_units = Column(Numeric) + nifty_price = Column(Numeric) + gold_price = Column(Numeric) + amount = Column(Numeric) + + __table_args__ = ( + ForeignKeyConstraint( + ["user_id", "run_id"], + ["strategy_run.user_id", "strategy_run.run_id"], + ondelete="CASCADE", + ), + UniqueConstraint("user_id", "run_id", "event", "logical_time", name="uq_event_ledger_event_time"), + Index("idx_event_ledger_user_run_logical", "user_id", "run_id", "logical_time"), + Index("idx_event_ledger_ts", "timestamp"), + Index("idx_event_ledger_user_run_ts", "user_id", "run_id", "timestamp"), + ) + + +class MarketClose(Base): + __tablename__ = "market_close" + + symbol = Column(String, primary_key=True) + date = Column(Date, primary_key=True) + close = Column(Numeric, nullable=False) + + __table_args__ = ( + Index("idx_market_close_symbol", "symbol"), + Index("idx_market_close_date", "date"), + ) + + +class AdminAuditLog(Base): + __tablename__ = "admin_audit_log" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + ts = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + actor_user_hash = Column(Text, nullable=False) + target_user_hash = Column(Text, nullable=False) + target_username_hash = Column(Text) + action = Column(Text, nullable=False) + meta = Column(JSONB) + + +class AdminRoleAudit(Base): + __tablename__ = "admin_role_audit" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + actor_user_id = Column(String, nullable=False) + target_user_id = Column(String, nullable=False) + old_role = Column(String, nullable=False) + new_role = Column(String, nullable=False) + changed_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000..1b616f3 --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,71 @@ +import os + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from app.routers.auth import router as auth_router +from app.routers.broker import router as broker_router +from app.routers.health import router as health_router +from app.routers.password_reset import router as password_reset_router +from app.routers.support_ticket import router as support_ticket_router +from app.routers.system import router as system_router +from app.routers.strategy import router as strategy_router +from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router +from app.routers.paper import router as paper_router +from market import router as market_router +from paper_mtm import router as paper_mtm_router +from app.services.strategy_service import init_log_state, resume_running_runs +from app.admin_router import router as admin_router +from app.admin_role_service import bootstrap_super_admin + +app = FastAPI( + title="QuantFortune Backend", + version="1.0" +) + +cors_origins = [ + origin.strip() + for origin in os.getenv("CORS_ORIGINS", "").split(",") + if origin.strip() +] +if not cors_origins: + cors_origins = [ + "http://localhost:3000", + "http://127.0.0.1:3000", + ] + +cors_origin_regex = os.getenv("CORS_ORIGIN_REGEX", "").strip() +if not cors_origin_regex: + cors_origin_regex = ( + r"https://.*\\.ngrok-free\\.dev" + r"|https://.*\\.ngrok-free\\.app" + r"|https://.*\\.ngrok\\.io" + ) + +app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_origin_regex=cors_origin_regex or None, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(strategy_router) +app.include_router(auth_router) +app.include_router(broker_router) +app.include_router(zerodha_router) +app.include_router(zerodha_public_router) +app.include_router(paper_router) +app.include_router(market_router) +app.include_router(paper_mtm_router) +app.include_router(health_router) +app.include_router(system_router) +app.include_router(admin_router) +app.include_router(support_ticket_router) +app.include_router(password_reset_router) + +@app.on_event("startup") +def init_app_state(): + init_log_state() + bootstrap_super_admin() + resume_running_runs() diff --git a/backend/app/models.py b/backend/app/models.py new file mode 100644 index 0000000..20308dc --- /dev/null +++ b/backend/app/models.py @@ -0,0 +1,37 @@ +from pydantic import BaseModel, validator +from typing import Literal, Optional + + +class SipFrequency(BaseModel): + value: int + unit: Literal["days", "minutes"] + +class StrategyStartRequest(BaseModel): + strategy_name: str + initial_cash: Optional[float] = None + sip_amount: float + sip_frequency: SipFrequency + mode: Literal["PAPER"] + + @validator("initial_cash") + def validate_cash(cls, v): + if v is None: + return v + if v < 10000: + raise ValueError("Initial cash must be at least 10,000") + return v + +class AuthPayload(BaseModel): + email: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None + + +class PasswordResetRequest(BaseModel): + email: str + + +class PasswordResetConfirm(BaseModel): + email: str + otp: str + new_password: str diff --git a/backend/app/routers/__init__.py b/backend/app/routers/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/backend/app/routers/__init__.py @@ -0,0 +1 @@ + diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py new file mode 100644 index 0000000..1ac7f16 --- /dev/null +++ b/backend/app/routers/auth.py @@ -0,0 +1,116 @@ +import os + +from fastapi import APIRouter, HTTPException, Request, Response +from app.models import AuthPayload +from app.services.auth_service import ( + SESSION_TTL_SECONDS, + create_session, + create_user, + delete_session, + get_user_for_session, + get_last_session_meta, + verify_user, +) +from app.services.email_service import send_email + +router = APIRouter(prefix="/api") +SESSION_COOKIE_NAME = "session_id" +COOKIE_SECURE = os.getenv("COOKIE_SECURE", "0") == "1" +COOKIE_SAMESITE = (os.getenv("COOKIE_SAMESITE") or "lax").lower() + + +def _set_session_cookie(response: Response, session_id: str): + same_site = COOKIE_SAMESITE if COOKIE_SAMESITE in {"lax", "strict", "none"} else "lax" + response.set_cookie( + SESSION_COOKIE_NAME, + session_id, + httponly=True, + samesite=same_site, + max_age=SESSION_TTL_SECONDS, + secure=COOKIE_SECURE, + path="/", + ) + + +def _get_identifier(payload: AuthPayload) -> str: + identifier = payload.username or payload.email or "" + return identifier.strip() + + +@router.post("/signup") +def signup(payload: AuthPayload, response: Response): + identifier = _get_identifier(payload) + if not identifier or not payload.password: + raise HTTPException(status_code=400, detail="Email and password are required") + + user = create_user(identifier, payload.password) + if not user: + raise HTTPException(status_code=409, detail="User already exists") + + session_id = create_session(user["id"]) + _set_session_cookie(response, session_id) + try: + body = ( + "Welcome to Quantfortune!\n\n" + "Your account has been created successfully.\n\n" + "You can now log in and start using the platform.\n\n" + "Quantfortune Support" + ) + send_email(user["username"], "Welcome to Quantfortune", body) + except Exception: + pass + return {"id": user["id"], "username": user["username"], "role": user.get("role")} + + +@router.post("/login") +def login(payload: AuthPayload, response: Response, request: Request): + identifier = _get_identifier(payload) + if not identifier or not payload.password: + raise HTTPException(status_code=400, detail="Email and password are required") + + user = verify_user(identifier, payload.password) + if not user: + raise HTTPException(status_code=401, detail="Invalid email or password") + + client_ip = request.client.host if request.client else None + user_agent = request.headers.get("user-agent") + last_meta = get_last_session_meta(user["id"]) + if last_meta.get("ip") and ( + last_meta.get("ip") != client_ip or last_meta.get("user_agent") != user_agent + ): + try: + body = ( + "New login detected on your Quantfortune account.\n\n" + f"IP: {client_ip or 'unknown'}\n" + f"Device: {user_agent or 'unknown'}\n\n" + "If this wasn't you, please reset your password immediately." + ) + send_email(user["username"], "New login detected", body) + except Exception: + pass + + session_id = create_session(user["id"], ip=client_ip, user_agent=user_agent) + _set_session_cookie(response, session_id) + return {"id": user["id"], "username": user["username"], "role": user.get("role")} + + +@router.post("/logout") +def logout(request: Request, response: Response): + session_id = request.cookies.get(SESSION_COOKIE_NAME) + if session_id: + delete_session(session_id) + response.delete_cookie(SESSION_COOKIE_NAME, path="/") + return {"ok": True} + + +@router.get("/me") +def me(request: Request): + session_id = request.cookies.get(SESSION_COOKIE_NAME) + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + return {"id": user["id"], "username": user["username"], "role": user.get("role")} diff --git a/backend/app/routers/broker.py b/backend/app/routers/broker.py new file mode 100644 index 0000000..8e5f7d6 --- /dev/null +++ b/backend/app/routers/broker.py @@ -0,0 +1,205 @@ +import os + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import RedirectResponse + +from app.broker_store import ( + clear_user_broker, + get_broker_credentials, + get_pending_broker, + get_user_broker, + set_broker_auth_state, + set_connected_broker, + set_pending_broker, +) +from app.services.auth_service import get_user_for_session +from app.services.zerodha_service import build_login_url, exchange_request_token +from app.services.email_service import send_email +from app.services.zerodha_storage import set_session + +router = APIRouter(prefix="/api/broker") + + +def _require_user(request: Request): + session_id = request.cookies.get("session_id") + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + return user + + +@router.post("/connect") +async def connect_broker(payload: dict, request: Request): + user = _require_user(request) + broker = (payload.get("broker") or "").strip() + token = (payload.get("token") or "").strip() + user_name = (payload.get("userName") or "").strip() + broker_user_id = (payload.get("brokerUserId") or "").strip() + if not broker or not token: + raise HTTPException(status_code=400, detail="Broker and token are required") + + set_connected_broker( + user["id"], + broker, + token, + user_name=user_name or None, + broker_user_id=broker_user_id or None, + ) + try: + body = ( + "Your broker has been connected to Quantfortune.\n\n" + f"Broker: {broker}\n" + f"Broker User ID: {broker_user_id or 'N/A'}\n" + ) + send_email(user["username"], "Broker connected", body) + except Exception: + pass + return {"connected": True} + + +@router.get("/status") +async def broker_status(request: Request): + user = _require_user(request) + entry = get_user_broker(user["id"]) + if not entry or not entry.get("connected"): + return {"connected": False} + return { + "connected": True, + "broker": entry.get("broker"), + "connected_at": entry.get("connected_at"), + "userName": entry.get("user_name"), + "brokerUserId": entry.get("broker_user_id"), + "authState": entry.get("auth_state"), + } + + +@router.post("/disconnect") +async def disconnect_broker(request: Request): + user = _require_user(request) + clear_user_broker(user["id"]) + set_broker_auth_state(user["id"], "DISCONNECTED") + try: + body = "Your broker connection has been disconnected from Quantfortune." + send_email(user["username"], "Broker disconnected", body) + except Exception: + pass + return {"connected": False} + + +@router.post("/zerodha/login") +async def zerodha_login(payload: dict, request: Request): + user = _require_user(request) + api_key = (payload.get("apiKey") or "").strip() + api_secret = (payload.get("apiSecret") or "").strip() + redirect_url = (payload.get("redirectUrl") or "").strip() + if not api_key or not api_secret: + raise HTTPException(status_code=400, detail="API key and secret are required") + + set_pending_broker(user["id"], "ZERODHA", api_key, api_secret) + return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None)} + + +@router.get("/zerodha/callback") +async def zerodha_callback(request: Request, request_token: str = ""): + user = _require_user(request) + token = request_token.strip() + if not token: + raise HTTPException(status_code=400, detail="Missing request_token") + + pending = get_pending_broker(user["id"]) or {} + api_key = (pending.get("api_key") or "").strip() + api_secret = (pending.get("api_secret") or "").strip() + if not api_key or not api_secret: + raise HTTPException(status_code=400, detail="Zerodha login not initialized") + + try: + session_data = exchange_request_token(api_key, api_secret, token) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + access_token = session_data.get("access_token") + if not access_token: + raise HTTPException(status_code=400, detail="Missing access token from Zerodha") + + saved = set_session( + user["id"], + { + "api_key": api_key, + "access_token": access_token, + "request_token": session_data.get("request_token", token), + "user_name": session_data.get("user_name"), + "broker_user_id": session_data.get("user_id"), + }, + ) + set_connected_broker( + user["id"], + "ZERODHA", + access_token, + api_key=api_key, + api_secret=api_secret, + user_name=session_data.get("user_name"), + broker_user_id=session_data.get("user_id"), + auth_state="VALID", + ) + return { + "connected": True, + "userName": saved.get("user_name"), + "brokerUserId": saved.get("broker_user_id"), + } + + +@router.get("/login") +async def broker_login(request: Request): + user = _require_user(request) + creds = get_broker_credentials(user["id"]) + if not creds: + raise HTTPException(status_code=400, detail="Broker credentials not configured") + redirect_url = (os.getenv("ZERODHA_REDIRECT_URL") or "").strip() + if not redirect_url: + base = str(request.base_url).rstrip("/") + redirect_url = f"{base}/api/broker/callback" + login_url = build_login_url(creds["api_key"], redirect_url=redirect_url) + return RedirectResponse(login_url) + + +@router.get("/callback") +async def broker_callback(request: Request, request_token: str = ""): + user = _require_user(request) + token = request_token.strip() + if not token: + raise HTTPException(status_code=400, detail="Missing request_token") + creds = get_broker_credentials(user["id"]) + if not creds: + raise HTTPException(status_code=400, detail="Broker credentials not configured") + try: + session_data = exchange_request_token(creds["api_key"], creds["api_secret"], token) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + access_token = session_data.get("access_token") + if not access_token: + raise HTTPException(status_code=400, detail="Missing access token from Zerodha") + + set_session( + user["id"], + { + "api_key": creds["api_key"], + "access_token": access_token, + "request_token": session_data.get("request_token", token), + "user_name": session_data.get("user_name"), + "broker_user_id": session_data.get("user_id"), + }, + ) + set_connected_broker( + user["id"], + "ZERODHA", + access_token, + api_key=creds["api_key"], + api_secret=creds["api_secret"], + user_name=session_data.get("user_name"), + broker_user_id=session_data.get("user_id"), + auth_state="VALID", + ) + target_url = os.getenv("BROKER_DASHBOARD_URL") or "/dashboard?armed=false" + return RedirectResponse(target_url) diff --git a/backend/app/routers/health.py b/backend/app/routers/health.py new file mode 100644 index 0000000..9ec315b --- /dev/null +++ b/backend/app/routers/health.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter, HTTPException + +from app.services.db import health_check + +router = APIRouter() + + +@router.get("/health") +def health(): + if not health_check(): + raise HTTPException(status_code=503, detail="db_unavailable") + return {"status": "ok", "db": "ok"} diff --git a/backend/app/routers/paper.py b/backend/app/routers/paper.py new file mode 100644 index 0000000..c7b564c --- /dev/null +++ b/backend/app/routers/paper.py @@ -0,0 +1,75 @@ +from fastapi import APIRouter, HTTPException, Request + +from app.services.paper_broker_service import ( + add_cash, + get_equity_curve, + get_funds, + get_orders, + get_positions, + get_trades, + reset_paper_state, +) +from app.services.tenant import get_request_user_id + +router = APIRouter(prefix="/api/paper") + + +@router.get("/funds") +def funds(request: Request): + user_id = get_request_user_id(request) + return {"funds": get_funds(user_id)} + + +@router.get("/positions") +def positions(request: Request): + user_id = get_request_user_id(request) + return {"positions": get_positions(user_id)} + + +@router.get("/orders") +def orders(request: Request): + user_id = get_request_user_id(request) + return {"orders": get_orders(user_id)} + + +@router.get("/trades") +def trades(request: Request): + user_id = get_request_user_id(request) + return {"trades": get_trades(user_id)} + + +@router.get("/equity-curve") +def equity_curve(request: Request): + user_id = get_request_user_id(request) + return get_equity_curve(user_id) + + +@router.post("/add-cash") +def add_cash_endpoint(request: Request, payload: dict): + try: + amount = float(payload.get("amount", 0)) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="Invalid amount") + if amount <= 0: + raise HTTPException(status_code=400, detail="Amount must be positive") + try: + user_id = get_request_user_id(request) + add_cash(user_id, amount) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"funds": get_funds(user_id)} + + +@router.post("/reset") +def reset_paper(request: Request): + try: + from app.services.strategy_service import stop_strategy + + user_id = get_request_user_id(request) + stop_strategy(user_id) + except Exception: + pass + user_id = get_request_user_id(request) + reset_paper_state(user_id) + + return {"ok": True, "message": "Paper reset completed"} diff --git a/backend/app/routers/password_reset.py b/backend/app/routers/password_reset.py new file mode 100644 index 0000000..e5e0751 --- /dev/null +++ b/backend/app/routers/password_reset.py @@ -0,0 +1,59 @@ +from fastapi import APIRouter, HTTPException + +from app.models import PasswordResetConfirm, PasswordResetRequest +from app.services.auth_service import ( + consume_password_reset_otp, + create_password_reset_otp, + get_user_by_username, + update_user_password, +) +from app.services.email_service import send_email + +router = APIRouter(prefix="/api/password-reset") + + +@router.post("/request") +def request_reset(payload: PasswordResetRequest): + email = payload.email.strip() + if not email: + raise HTTPException(status_code=400, detail="Email is required") + + user = get_user_by_username(email) + if not user: + return {"ok": True} + + otp = create_password_reset_otp(email) + body = ( + "Hi,\n\n" + "We received a request to reset your Quantfortune password.\n\n" + f"Your OTP code is: {otp}\n" + "This code is valid for 10 minutes.\n\n" + "If you did not request this, you can ignore this email.\n\n" + "Quantfortune Support" + ) + try: + ok = send_email(email, "Quantfortune Password Reset OTP", body) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Email send failed: {exc}") from exc + if not ok: + raise HTTPException(status_code=500, detail="Email send failed: SMTP not configured") + return {"ok": True} + + +@router.post("/confirm") +def confirm_reset(payload: PasswordResetConfirm): + email = payload.email.strip() + otp = payload.otp.strip() + new_password = payload.new_password + if not email or not otp or not new_password: + raise HTTPException(status_code=400, detail="Email, OTP, and new password are required") + + user = get_user_by_username(email) + if not user: + raise HTTPException(status_code=400, detail="Invalid OTP or email") + + if not consume_password_reset_otp(email, otp): + raise HTTPException(status_code=400, detail="Invalid or expired OTP") + + update_user_password(user["id"], new_password) + return {"ok": True} diff --git a/backend/app/routers/strategy.py b/backend/app/routers/strategy.py new file mode 100644 index 0000000..5df9b50 --- /dev/null +++ b/backend/app/routers/strategy.py @@ -0,0 +1,47 @@ +from fastapi import APIRouter, Query, Request +from app.models import StrategyStartRequest +from app.services.strategy_service import ( + start_strategy, + stop_strategy, + get_strategy_status, + get_engine_status, + get_market_status, + get_strategy_logs as fetch_strategy_logs, +) +from app.services.tenant import get_request_user_id + +router = APIRouter(prefix="/api") + +@router.post("/strategy/start") +def start(req: StrategyStartRequest, request: Request): + user_id = get_request_user_id(request) + return start_strategy(req, user_id) + +@router.post("/strategy/stop") +def stop(request: Request): + user_id = get_request_user_id(request) + return stop_strategy(user_id) + +@router.get("/strategy/status") +def status(request: Request): + user_id = get_request_user_id(request) + return get_strategy_status(user_id) + +@router.get("/engine/status") +def engine_status(request: Request): + user_id = get_request_user_id(request) + return get_engine_status(user_id) + +@router.get("/market/status") +def market_status(): + return get_market_status() + +@router.get("/logs") +def get_logs(request: Request, since_seq: int = Query(0)): + user_id = get_request_user_id(request) + return fetch_strategy_logs(user_id, since_seq) + +@router.get("/strategy/logs") +def get_strategy_logs_endpoint(request: Request, since_seq: int = Query(0)): + user_id = get_request_user_id(request) + return fetch_strategy_logs(user_id, since_seq) diff --git a/backend/app/routers/support_ticket.py b/backend/app/routers/support_ticket.py new file mode 100644 index 0000000..06db908 --- /dev/null +++ b/backend/app/routers/support_ticket.py @@ -0,0 +1,39 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.services.support_ticket import create_ticket, get_ticket_status + + +router = APIRouter(prefix="/api/support") + + +class TicketCreate(BaseModel): + name: str + email: str + subject: str + message: str + + +class TicketStatusRequest(BaseModel): + email: str + + +@router.post("/ticket") +def submit_ticket(payload: TicketCreate): + if not payload.subject.strip() or not payload.message.strip(): + raise HTTPException(status_code=400, detail="Subject and message are required") + ticket = create_ticket( + name=payload.name.strip(), + email=payload.email.strip(), + subject=payload.subject.strip(), + message=payload.message.strip(), + ) + return ticket + + +@router.post("/ticket/status/{ticket_id}") +def ticket_status(ticket_id: str, payload: TicketStatusRequest): + status = get_ticket_status(ticket_id.strip(), payload.email.strip()) + if not status: + raise HTTPException(status_code=404, detail="Ticket not found") + return status diff --git a/backend/app/routers/system.py b/backend/app/routers/system.py new file mode 100644 index 0000000..2609f3d --- /dev/null +++ b/backend/app/routers/system.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Request + +from app.services.auth_service import get_user_for_session +from app.services.system_service import arm_system, system_status +from app.services.zerodha_service import KiteApiError + +router = APIRouter(prefix="/api/system") + + +def _require_user(request: Request): + session_id = request.cookies.get("session_id") + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + return user + + +@router.post("/arm") +def arm(request: Request): + user = _require_user(request) + try: + result = arm_system(user["id"], client_ip=request.client.host if request.client else None) + except KiteApiError as exc: + raise HTTPException(status_code=502, detail=str(exc)) from exc + + if not result.get("ok"): + if result.get("code") == "BROKER_AUTH_REQUIRED": + raise HTTPException( + status_code=401, + detail={"redirect_url": result.get("redirect_url")}, + ) + raise HTTPException(status_code=400, detail="Unable to arm system") + return result + + +@router.get("/status") +def status(request: Request): + user = _require_user(request) + return system_status(user["id"]) diff --git a/backend/app/routers/zerodha.py b/backend/app/routers/zerodha.py new file mode 100644 index 0000000..9aec21f --- /dev/null +++ b/backend/app/routers/zerodha.py @@ -0,0 +1,234 @@ +from datetime import datetime, timedelta + +from fastapi import APIRouter, HTTPException, Query, Request +from fastapi.responses import HTMLResponse + +from app.broker_store import clear_user_broker +from app.services.auth_service import get_user_for_session +from app.services.zerodha_service import ( + KiteApiError, + KiteTokenError, + build_login_url, + exchange_request_token, + fetch_funds, + fetch_holdings, +) +from app.services.zerodha_storage import ( + clear_session, + consume_request_token, + get_session, + set_session, + store_request_token, +) + +router = APIRouter(prefix="/api/zerodha") +public_router = APIRouter() + + +def _require_user(request: Request): + session_id = request.cookies.get("session_id") + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + user = get_user_for_session(session_id) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + return user + + +def _capture_request_token(request: Request, request_token: str): + user = _require_user(request) + token = request_token.strip() + if not token: + raise HTTPException(status_code=400, detail="Missing request_token") + store_request_token(user["id"], token) + + +def _clear_broker_session(user_id: str): + clear_user_broker(user_id) + clear_session(user_id) + + +def _raise_kite_error(user_id: str, exc: KiteApiError): + if isinstance(exc, KiteTokenError): + _clear_broker_session(user_id) + raise HTTPException( + status_code=401, detail="Zerodha session expired. Please reconnect." + ) from exc + raise HTTPException(status_code=502, detail=str(exc)) from exc + + +@router.post("/login-url") +async def login_url(payload: dict, request: Request): + _require_user(request) + api_key = (payload.get("apiKey") or "").strip() + if not api_key: + raise HTTPException(status_code=400, detail="API key is required") + return {"loginUrl": build_login_url(api_key)} + + +@router.post("/session") +async def create_session(payload: dict, request: Request): + user = _require_user(request) + api_key = (payload.get("apiKey") or "").strip() + api_secret = (payload.get("apiSecret") or "").strip() + request_token = (payload.get("requestToken") or "").strip() + if not api_key or not api_secret or not request_token: + raise HTTPException( + status_code=400, detail="API key, secret, and request token are required" + ) + + try: + session_data = exchange_request_token(api_key, api_secret, request_token) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + saved = set_session( + user["id"], + { + "api_key": api_key, + "access_token": session_data.get("access_token"), + "request_token": session_data.get("request_token", request_token), + "user_name": session_data.get("user_name"), + "broker_user_id": session_data.get("user_id"), + }, + ) + + return { + "connected": True, + "userName": saved.get("user_name"), + "brokerUserId": saved.get("broker_user_id"), + "accessToken": saved.get("access_token"), + } + + +@router.get("/status") +async def status(request: Request): + user = _require_user(request) + session = get_session(user["id"]) + if not session: + return {"connected": False} + + return { + "connected": True, + "broker": "zerodha", + "userName": session.get("user_name"), + "linkedAt": session.get("linked_at"), + } + + +@router.get("/request-token") +async def request_token(request: Request): + user = _require_user(request) + token = consume_request_token(user["id"]) + if not token: + raise HTTPException(status_code=404, detail="No request token available.") + return {"requestToken": token} + + +@router.get("/holdings") +async def holdings(request: Request): + user = _require_user(request) + session = get_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + try: + data = fetch_holdings(session["api_key"], session["access_token"]) + except KiteApiError as exc: + _raise_kite_error(user["id"], exc) + return {"holdings": data} + + +@router.get("/funds") +async def funds(request: Request): + user = _require_user(request) + session = get_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + try: + data = fetch_funds(session["api_key"], session["access_token"]) + except KiteApiError as exc: + _raise_kite_error(user["id"], exc) + equity = data.get("equity", {}) if isinstance(data, dict) else {} + return {"funds": {**equity, "raw": data}} + + +@router.get("/equity-curve") +async def equity_curve(request: Request, from_: str = Query("", alias="from")): + user = _require_user(request) + session = get_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + + try: + holdings = fetch_holdings(session["api_key"], session["access_token"]) + funds_data = fetch_funds(session["api_key"], session["access_token"]) + except KiteApiError as exc: + _raise_kite_error(user["id"], exc) + + equity = funds_data.get("equity", {}) if isinstance(funds_data, dict) else {} + total_holdings_value = 0 + for item in holdings: + qty = float(item.get("quantity") or item.get("qty") or 0) + last = float(item.get("last_price") or item.get("average_price") or 0) + total_holdings_value += qty * last + + total_funds = float(equity.get("cash") or 0) + current_value = max(0, total_holdings_value + total_funds) + + ms_in_day = 86400000 + now = datetime.utcnow() + default_start = now - timedelta(days=90) + if from_: + try: + start_date = datetime.fromisoformat(from_) + except ValueError: + start_date = default_start + else: + start_date = default_start + if start_date > now: + start_date = now + + span_days = max( + 2, + int(((now - start_date).total_seconds() * 1000) // ms_in_day), + ) + start_value = current_value * 0.85 if current_value > 0 else 10000 + points = [] + for i in range(span_days): + day = start_date + timedelta(days=i) + progress = i / (span_days - 1) + trend = start_value + (current_value - start_value) * progress + value = max(0, round(trend)) + points.append({"date": day.isoformat(), "value": value}) + + return { + "startDate": start_date.isoformat(), + "endDate": now.isoformat(), + "accountOpenDate": session.get("linked_at"), + "points": points, + } + + +@router.get("/callback") +async def callback(request: Request, request_token: str = ""): + _capture_request_token(request, request_token) + return { + "status": "ok", + "message": "Request token captured. You can close this tab.", + } + + +@router.get("/login") +async def login_redirect(request: Request, request_token: str = ""): + return await callback(request, request_token=request_token) + + +@public_router.get("/login", response_class=HTMLResponse) +async def login_capture(request: Request, request_token: str = ""): + _capture_request_token(request, request_token) + return ( + "" + "

Request token captured

" + "

You can close this tab and return to QuantFortune.

" + "" + ) diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..ca6a6b5 --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,280 @@ +import hashlib +import os +import secrets +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +from app.services.db import db_connection + +SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) +SESSION_REFRESH_WINDOW_SECONDS = int( + os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60)) +) +RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10")) +RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret") + + +def _now_utc() -> datetime: + return datetime.now(timezone.utc) + + +def _new_expiry(now: datetime) -> datetime: + return now + timedelta(seconds=SESSION_TTL_SECONDS) + + +def _hash_password(password: str) -> str: + return hashlib.sha256(password.encode("utf-8")).hexdigest() + + +def _hash_otp(email: str, otp: str) -> str: + payload = f"{email}:{otp}:{RESET_OTP_SECRET}" + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _row_to_user(row): + if not row: + return None + return { + "id": row[0], + "username": row[1], + "password": row[2], + "role": row[3] if len(row) > 3 else None, + } + + +def get_user_by_username(username: str): + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username, password_hash, role FROM app_user WHERE username = %s", + (username,), + ) + return _row_to_user(cur.fetchone()) + + +def get_user_by_id(user_id: str): + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", + (user_id,), + ) + return _row_to_user(cur.fetchone()) + + +def create_user(username: str, password: str): + user_id = str(uuid4()) + password_hash = _hash_password(password) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO app_user (id, username, password_hash, role) + VALUES (%s, %s, %s, 'USER') + ON CONFLICT (username) DO NOTHING + RETURNING id, username, password_hash, role + """, + (user_id, username, password_hash), + ) + return _row_to_user(cur.fetchone()) + + +def authenticate_user(username: str, password: str): + user = get_user_by_username(username) + if not user: + return None + if user.get("password") != _hash_password(password): + return None + return user + + +def verify_user(username: str, password: str): + return authenticate_user(username, password) + + +def create_session(user_id: str, ip: str | None = None, user_agent: str | None = None) -> str: + session_id = str(uuid4()) + now = _now_utc() + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO app_session (id, user_id, created_at, last_seen_at, expires_at, ip, user_agent) + VALUES (%s, %s, %s, %s, %s, %s, %s) + """, + (session_id, user_id, now, now, _new_expiry(now), ip, user_agent), + ) + return session_id + + +def get_last_session_meta(user_id: str): + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT ip, user_agent + FROM app_session + WHERE user_id = %s + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return {"ip": None, "user_agent": None} + return {"ip": row[0], "user_agent": row[1]} + + +def update_user_password(user_id: str, new_password: str): + password_hash = _hash_password(new_password) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + "UPDATE app_user SET password_hash = %s WHERE id = %s", + (password_hash, user_id), + ) + + +def create_password_reset_otp(email: str): + otp = f"{secrets.randbelow(10000):04d}" + now = _now_utc() + expires_at = now + timedelta(minutes=RESET_OTP_TTL_MINUTES) + otp_hash = _hash_otp(email, otp) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO password_reset_otp (id, email, otp_hash, created_at, expires_at, used_at) + VALUES (%s, %s, %s, %s, %s, NULL) + """, + (str(uuid4()), email, otp_hash, now, expires_at), + ) + return otp + + +def consume_password_reset_otp(email: str, otp: str) -> bool: + now = _now_utc() + otp_hash = _hash_otp(email, otp) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id + FROM password_reset_otp + WHERE email = %s + AND otp_hash = %s + AND used_at IS NULL + AND expires_at > %s + ORDER BY created_at DESC + LIMIT 1 + """, + (email, otp_hash, now), + ) + row = cur.fetchone() + if not row: + return False + cur.execute( + "UPDATE password_reset_otp SET used_at = %s WHERE id = %s", + (now, row[0]), + ) + return True + + +def get_session(session_id: str): + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, user_id, created_at, last_seen_at, expires_at + FROM app_session + WHERE id = %s + """, + (session_id,), + ) + row = cur.fetchone() + if not row: + return None + created_at = row[2].isoformat() if row[2] else None + last_seen_at = row[3].isoformat() if row[3] else None + expires_at = row[4].isoformat() if row[4] else None + return { + "id": row[0], + "user_id": row[1], + "created_at": created_at, + "last_seen_at": last_seen_at, + "expires_at": expires_at, + } + + +def delete_session(session_id: str): + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) + + +def get_user_for_session(session_id: str): + if not session_id: + return None + now = _now_utc() + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + DELETE FROM app_session + WHERE expires_at IS NOT NULL AND expires_at <= %s + """, + (now,), + ) + cur.execute( + """ + SELECT id, user_id, created_at, last_seen_at, expires_at + FROM app_session + WHERE id = %s + """, + (session_id,), + ) + row = cur.fetchone() + if not row: + return None + + expires_at = row[4] + if expires_at is None: + new_expiry = _new_expiry(now) + cur.execute( + """ + UPDATE app_session + SET expires_at = %s, last_seen_at = %s + WHERE id = %s + """, + (new_expiry, now, session_id), + ) + expires_at = new_expiry + + if expires_at <= now: + cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) + return None + + if (expires_at - now).total_seconds() <= SESSION_REFRESH_WINDOW_SECONDS: + new_expiry = _new_expiry(now) + cur.execute( + """ + UPDATE app_session + SET expires_at = %s, last_seen_at = %s + WHERE id = %s + """, + (new_expiry, now, session_id), + ) + + cur.execute( + "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", + (row[1],), + ) + return _row_to_user(cur.fetchone()) diff --git a/backend/app/services/broker_service.py b/backend/app/services/broker_service.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/crypto_service.py b/backend/app/services/crypto_service.py new file mode 100644 index 0000000..d637569 --- /dev/null +++ b/backend/app/services/crypto_service.py @@ -0,0 +1,39 @@ +import os + +from cryptography.fernet import Fernet, InvalidToken + +ENCRYPTION_PREFIX = "enc:" +KEY_ENV_VAR = "BROKER_TOKEN_KEY" + + +def _get_fernet() -> Fernet: + key = (os.getenv(KEY_ENV_VAR) or "").strip() + if not key: + raise RuntimeError(f"{KEY_ENV_VAR} is not set") + try: + return Fernet(key.encode("utf-8")) + except Exception as exc: + raise RuntimeError( + f"{KEY_ENV_VAR} must be a urlsafe base64-encoded 32-byte key" + ) from exc + + +def encrypt_value(value: str | None) -> str | None: + if not value: + return value + if value.startswith(ENCRYPTION_PREFIX): + return value + token = _get_fernet().encrypt(value.encode("utf-8")).decode("utf-8") + return f"{ENCRYPTION_PREFIX}{token}" + + +def decrypt_value(value: str | None) -> str | None: + if not value: + return value + if not value.startswith(ENCRYPTION_PREFIX): + return value + token = value[len(ENCRYPTION_PREFIX) :] + try: + return _get_fernet().decrypt(token.encode("utf-8")).decode("utf-8") + except InvalidToken as exc: + raise RuntimeError("Unable to decrypt token; invalid BROKER_TOKEN_KEY") from exc diff --git a/backend/app/services/db.py b/backend/app/services/db.py new file mode 100644 index 0000000..97796a3 --- /dev/null +++ b/backend/app/services/db.py @@ -0,0 +1,210 @@ +import os +import threading +import time +from contextlib import contextmanager +from typing import Generator + +from sqlalchemy import create_engine, schema, text +from sqlalchemy.engine import Engine, URL +from sqlalchemy.exc import InterfaceError as SAInterfaceError +from sqlalchemy.exc import OperationalError as SAOperationalError +from sqlalchemy.orm import declarative_base, sessionmaker +from psycopg2 import OperationalError as PGOperationalError +from psycopg2 import InterfaceError as PGInterfaceError + +Base = declarative_base() + +_ENGINE: Engine | None = None +_ENGINE_LOCK = threading.Lock() + + +class _ConnectionProxy: + def __init__(self, conn): + self._conn = conn + + def __getattr__(self, name): + return getattr(self._conn, name) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if exc_type is None: + try: + self._conn.commit() + except Exception: + self._conn.rollback() + raise + else: + try: + self._conn.rollback() + except Exception: + pass + return False + + +def _db_config() -> dict[str, str | int]: + url = os.getenv("DATABASE_URL") + if url: + return {"url": url} + + return { + "host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost", + "port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"), + "dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db", + "user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader", + "password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass", + "connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")), + "schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app", + } + + +def get_database_url(cfg: dict[str, str | int] | None = None) -> str: + cfg = cfg or _db_config() + if "url" in cfg: + return str(cfg["url"]) + schema_name = cfg.get("schema") + query = {"connect_timeout": str(cfg["connect_timeout"])} + if schema_name: + query["options"] = f"-csearch_path={schema_name},public" + url = URL.create( + "postgresql+psycopg2", + username=str(cfg["user"]), + password=str(cfg["password"]), + host=str(cfg["host"]), + port=int(cfg["port"]), + database=str(cfg["dbname"]), + query=query, + ) + return url.render_as_string(hide_password=False) + + +def _create_engine() -> Engine: + cfg = _db_config() + pool_size = int(os.getenv("DB_POOL_SIZE", os.getenv("DB_POOL_MIN", "5"))) + max_overflow = int(os.getenv("DB_POOL_MAX", "10")) + pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30")) + engine = create_engine( + get_database_url(cfg), + pool_size=pool_size, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + pool_pre_ping=True, + future=True, + ) + schema_name = cfg.get("schema") + if schema_name: + try: + with engine.begin() as conn: + conn.execute(schema.CreateSchema(schema_name, if_not_exists=True)) + except Exception: + # Schema creation is best-effort; permissions might be limited in some environments. + pass + return engine + + +def get_engine() -> Engine: + global _ENGINE + if _ENGINE is None: + with _ENGINE_LOCK: + if _ENGINE is None: + _ENGINE = _create_engine() + return _ENGINE + + +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + expire_on_commit=False, + bind=get_engine(), +) + + +def _get_connection(): + return get_engine().raw_connection() + + +def _put_connection(conn, close=False): + try: + conn.close() + except Exception: + if not close: + raise + + +@contextmanager +def db_connection(retries: int | None = None, delay: float | None = None): + attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3")) + backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2")) + last_error = None + for attempt in range(attempts): + conn = None + try: + conn = _get_connection() + conn.autocommit = False + yield _ConnectionProxy(conn) + return + except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc: + last_error = exc + if conn is not None: + _put_connection(conn) + conn = None + time.sleep(backoff * (2 ** attempt)) + continue + finally: + if conn is not None: + _put_connection(conn, close=conn.closed != 0) + if last_error: + raise last_error + + +def run_with_retry(operation, retries: int | None = None, delay: float | None = None): + attempts = retries if retries is not None else int(os.getenv("DB_RETRY_COUNT", "3")) + backoff = delay if delay is not None else float(os.getenv("DB_RETRY_DELAY", "0.2")) + last_error = None + for attempt in range(attempts): + with db_connection(retries=1) as conn: + try: + with conn.cursor() as cur: + result = operation(cur, conn) + conn.commit() + return result + except (SAOperationalError, SAInterfaceError, PGOperationalError, PGInterfaceError) as exc: + conn.rollback() + last_error = exc + time.sleep(backoff * (2 ** attempt)) + continue + except Exception: + conn.rollback() + raise + if last_error: + raise last_error + + +@contextmanager +def db_transaction(): + with db_connection() as conn: + try: + with conn.cursor() as cur: + yield cur + conn.commit() + except Exception: + conn.rollback() + raise + + +def get_db() -> Generator: + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def health_check() -> bool: + try: + with get_engine().connect() as conn: + conn.execute(text("SELECT 1")) + return True + except Exception: + return False diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py new file mode 100644 index 0000000..f7213b0 --- /dev/null +++ b/backend/app/services/email_service.py @@ -0,0 +1,28 @@ +import os +import smtplib +import ssl +from email.message import EmailMessage + + +def send_email(to_email: str, subject: str, body_text: str) -> bool: + smtp_user = (os.getenv("SMTP_USER") or "").strip() + smtp_pass = (os.getenv("SMTP_PASS") or "").replace(" ", "").strip() + smtp_host = (os.getenv("SMTP_HOST") or "smtp.gmail.com").strip() + smtp_port = int((os.getenv("SMTP_PORT") or "587").strip()) + from_name = (os.getenv("SMTP_FROM_NAME") or "Quantfortune Support").strip() + + if not smtp_user or not smtp_pass: + return False + + msg = EmailMessage() + msg["From"] = f"{from_name} <{smtp_user}>" + msg["To"] = to_email + msg["Subject"] = subject + msg.set_content(body_text) + + context = ssl.create_default_context() + with smtplib.SMTP(smtp_host, smtp_port) as server: + server.starttls(context=context) + server.login(smtp_user, smtp_pass) + server.send_message(msg) + return True diff --git a/backend/app/services/paper_broker_service.py b/backend/app/services/paper_broker_service.py new file mode 100644 index 0000000..db2101b --- /dev/null +++ b/backend/app/services/paper_broker_service.py @@ -0,0 +1,191 @@ +import os +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[3] +if str(PROJECT_ROOT) not in sys.path: + sys.path.append(str(PROJECT_ROOT)) + +from indian_paper_trading_strategy.engine.broker import PaperBroker +from indian_paper_trading_strategy.engine.state import load_state, save_state +from indian_paper_trading_strategy.engine.db import engine_context, insert_engine_event +from app.services.db import run_with_retry +from app.services.run_service import get_active_run_id, get_running_run_id + +_logged_path = False + + +def _broker(): + global _logged_path + state = load_state(mode="PAPER") + initial_cash = float(state.get("initial_cash", 0)) + broker = PaperBroker(initial_cash=initial_cash) + if not _logged_path: + _logged_path = True + print( + "PaperBroker store path:", + { + "cwd": os.getcwd(), + "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", + "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", + }, + ) + return broker + + +def get_paper_broker(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + return _broker() + + +def get_funds(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + return _broker().get_funds() + + +def get_positions(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + positions = _broker().get_positions() + enriched = [] + for item in positions: + qty = float(item.get("qty", 0)) + avg = float(item.get("avg_price", 0)) + ltp = float(item.get("last_price", 0)) + pnl = (ltp - avg) * qty + pnl_pct = ((ltp - avg) / avg * 100) if avg else 0.0 + enriched.append( + { + **item, + "pnl": pnl, + "pnl_pct": pnl_pct, + } + ) + return enriched + + +def get_orders(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + return _broker().get_orders() + + +def get_trades(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + return _broker().get_trades() + + +def get_equity_curve(user_id: str): + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + broker = _broker() + points = broker.get_equity_curve() + if not points: + return [] + + state = load_state(mode="PAPER") + initial_cash = float(state.get("initial_cash", 0)) + response = [] + for point in points: + equity = float(point.get("equity", 0)) + pnl = point.get("pnl") + if pnl is None: + pnl = equity - float(initial_cash) + response.append( + { + "timestamp": point.get("timestamp"), + "equity": equity, + "pnl": float(pnl), + } + ) + return response + + +def add_cash(user_id: str, amount: float): + if amount <= 0: + raise ValueError("Amount must be positive") + run_id = get_running_run_id(user_id) + if not run_id: + raise ValueError("Strategy must be running to add cash") + + def _op(cur, _conn): + with engine_context(user_id, run_id): + state = load_state(mode="PAPER", cur=cur, for_update=True) + initial_cash = float(state.get("initial_cash", 0)) + broker = PaperBroker(initial_cash=initial_cash) + store = broker._load_store(cur=cur, for_update=True) + cash = float(store.get("cash", 0)) + amount + store["cash"] = cash + broker._save_store(store, cur=cur) + + state["cash"] = cash + state["initial_cash"] = initial_cash + amount + state["total_invested"] = float(state.get("total_invested", 0)) + amount + save_state( + state, + mode="PAPER", + cur=cur, + emit_event=True, + event_meta={"source": "add_cash"}, + ) + insert_engine_event( + cur, + "CASH_ADDED", + data={"amount": amount, "cash": cash}, + ) + return state + + return run_with_retry(_op) + + +def reset_paper_state(user_id: str): + run_id = get_active_run_id(user_id) + + def _op(cur, _conn): + with engine_context(user_id, run_id): + cur.execute( + "DELETE FROM strategy_log WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM engine_event WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM paper_equity_curve WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM paper_trade WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM paper_order WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM paper_broker_account WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM mtm_ledger WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM event_ledger WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + cur.execute( + "DELETE FROM engine_state_paper WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + insert_engine_event(cur, "PAPER_RESET", data={}) + + run_with_retry(_op) diff --git a/backend/app/services/run_lifecycle.py b/backend/app/services/run_lifecycle.py new file mode 100644 index 0000000..5606031 --- /dev/null +++ b/backend/app/services/run_lifecycle.py @@ -0,0 +1,22 @@ +class RunLifecycleError(Exception): + pass + + +class RunLifecycleManager: + ARMABLE = {"STOPPED", "PAUSED_AUTH_EXPIRED"} + + @classmethod + def assert_can_arm(cls, status: str): + normalized = (status or "").strip().upper() + if normalized == "RUNNING": + raise RunLifecycleError("Run already RUNNING") + if normalized == "ERROR": + raise RunLifecycleError("Run in ERROR must be reset before arming") + if normalized not in cls.ARMABLE: + raise RunLifecycleError(f"Run cannot be armed from status {normalized}") + return normalized + + @classmethod + def is_armable(cls, status: str) -> bool: + normalized = (status or "").strip().upper() + return normalized in cls.ARMABLE diff --git a/backend/app/services/run_service.py b/backend/app/services/run_service.py new file mode 100644 index 0000000..4e1e578 --- /dev/null +++ b/backend/app/services/run_service.py @@ -0,0 +1,176 @@ +import threading +from datetime import datetime, timezone +from uuid import uuid4 + +from psycopg2.extras import Json + +from app.services.db import run_with_retry + +_DEFAULT_USER_ID = None +_DEFAULT_LOCK = threading.Lock() + + +def _utc_now(): + return datetime.now(timezone.utc) + + +def get_default_user_id(): + global _DEFAULT_USER_ID + if _DEFAULT_USER_ID: + return _DEFAULT_USER_ID + + def _op(cur, _conn): + cur.execute("SELECT id FROM app_user ORDER BY username LIMIT 1") + row = cur.fetchone() + return row[0] if row else None + + user_id = run_with_retry(_op) + if user_id: + with _DEFAULT_LOCK: + _DEFAULT_USER_ID = user_id + return user_id + + +def _default_run_id(user_id: str) -> str: + return f"default_{user_id}" + + +def ensure_default_run(user_id: str): + run_id = _default_run_id(user_id) + + def _op(cur, _conn): + now = _utc_now() + cur.execute( + """ + INSERT INTO strategy_run ( + run_id, user_id, created_at, started_at, stopped_at, status, strategy, mode, broker, meta + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (run_id) DO NOTHING + """, + ( + run_id, + user_id, + now, + None, + None, + "STOPPED", + None, + None, + None, + Json({}), + ), + ) + return run_id + + return run_with_retry(_op) + + +def get_active_run_id(user_id: str): + def _op(cur, _conn): + cur.execute( + """ + SELECT run_id + FROM strategy_run + WHERE user_id = %s AND status = 'RUNNING' + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if row: + return row[0] + cur.execute( + """ + SELECT run_id + FROM strategy_run + WHERE user_id = %s + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if row: + return row[0] + return None + + run_id = run_with_retry(_op) + if run_id: + return run_id + return ensure_default_run(user_id) + + +def get_running_run_id(user_id: str): + def _op(cur, _conn): + cur.execute( + """ + SELECT run_id + FROM strategy_run + WHERE user_id = %s AND status = 'RUNNING' + ORDER BY created_at DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + return row[0] if row else None + + return run_with_retry(_op) + + +def create_strategy_run(user_id: str, strategy: str | None, mode: str | None, broker: str | None, meta: dict | None): + run_id = str(uuid4()) + + def _op(cur, _conn): + now = _utc_now() + cur.execute( + """ + INSERT INTO strategy_run ( + run_id, user_id, created_at, started_at, stopped_at, status, strategy, mode, broker, meta + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + run_id, + user_id, + now, + now, + None, + "RUNNING", + strategy, + mode, + broker, + Json(meta or {}), + ), + ) + return run_id + + return run_with_retry(_op) + + +def update_run_status(user_id: str, run_id: str, status: str, meta: dict | None = None): + def _op(cur, _conn): + now = _utc_now() + if status == "RUNNING": + cur.execute( + """ + UPDATE strategy_run + SET status = %s, started_at = COALESCE(started_at, %s), meta = COALESCE(meta, '{}'::jsonb) || %s + WHERE run_id = %s AND user_id = %s + """, + (status, now, Json(meta or {}), run_id, user_id), + ) + else: + cur.execute( + """ + UPDATE strategy_run + SET status = %s, stopped_at = %s, meta = COALESCE(meta, '{}'::jsonb) || %s + WHERE run_id = %s AND user_id = %s + """, + (status, now, Json(meta or {}), run_id, user_id), + ) + return True + + return run_with_retry(_op) diff --git a/backend/app/services/strategy_service.py b/backend/app/services/strategy_service.py new file mode 100644 index 0000000..f31c591 --- /dev/null +++ b/backend/app/services/strategy_service.py @@ -0,0 +1,650 @@ +import json +import os +import sys +import threading +from datetime import datetime, timedelta, timezone +from pathlib import Path + +ENGINE_ROOT = Path(__file__).resolve().parents[3] +if str(ENGINE_ROOT) not in sys.path: + sys.path.append(str(ENGINE_ROOT)) + +from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open +from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine +from indian_paper_trading_strategy.engine.state import init_paper_state, load_state +from indian_paper_trading_strategy.engine.broker import PaperBroker +from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta +from indian_paper_trading_strategy.engine.db import engine_context + +from app.services.db import db_connection +from app.services.run_service import ( + create_strategy_run, + get_active_run_id, + get_running_run_id, + update_run_status, +) +from app.services.auth_service import get_user_by_id +from app.services.email_service import send_email +from psycopg2.extras import Json +from psycopg2 import errors + +SEQ_LOCK = threading.Lock() +SEQ = 0 +LAST_WAIT_LOG_TS = {} +WAIT_LOG_INTERVAL = timedelta(seconds=60) + +def init_log_state(): + global SEQ + + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute("SELECT COALESCE(MAX(seq), 0) FROM strategy_log") + row = cur.fetchone() + SEQ = row[0] if row and row[0] is not None else 0 + +def start_new_run(user_id: str, run_id: str): + LAST_WAIT_LOG_TS.pop(run_id, None) + emit_event( + user_id=user_id, + run_id=run_id, + event="STRATEGY_STARTED", + message="Strategy started", + meta={}, + ) + + +def stop_run(user_id: str, run_id: str, reason="user_request"): + emit_event( + user_id=user_id, + run_id=run_id, + event="STRATEGY_STOPPED", + message="Strategy stopped", + meta={"reason": reason}, + ) + + +def emit_event( + *, + user_id: str, + run_id: str, + event: str, + message: str, + level: str = "INFO", + category: str = "ENGINE", + meta: dict | None = None +): + global SEQ, LAST_WAIT_LOG_TS + if not user_id or not run_id: + return + + now = datetime.now(timezone.utc) + if event == "SIP_WAITING": + last_ts = LAST_WAIT_LOG_TS.get(run_id) + if last_ts and (now - last_ts) < WAIT_LOG_INTERVAL: + return + LAST_WAIT_LOG_TS[run_id] = now + + with SEQ_LOCK: + SEQ += 1 + seq = SEQ + + evt = { + "seq": seq, + "ts": now.isoformat().replace("+00:00", "Z"), + "level": level, + "category": category, + "event": event, + "message": message, + "run_id": run_id, + "meta": meta or {} + } + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO strategy_log ( + seq, ts, level, category, event, message, user_id, run_id, meta + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (seq) DO NOTHING + """, + ( + evt["seq"], + now, + evt["level"], + evt["category"], + evt["event"], + evt["message"], + user_id, + evt["run_id"], + Json(evt["meta"]), + ), + ) + +def _maybe_parse_json(value): + if value is None: + return None + if not isinstance(value, str): + return value + text = value.strip() + if not text: + return None + try: + return json.loads(text) + except Exception: + return value + + +def _local_tz(): + return datetime.now().astimezone().tzinfo + + +def _format_local_ts(value: datetime | None): + if value is None: + return None + return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() + + +def _load_config(user_id: str, run_id: str): + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT strategy, sip_amount, sip_frequency_value, sip_frequency_unit, + mode, broker, active, frequency, frequency_days, unit, next_run + FROM strategy_config + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (user_id, run_id), + ) + row = cur.fetchone() + if not row: + return {} + cfg = { + "strategy": row[0], + "sip_amount": float(row[1]) if row[1] is not None else None, + "mode": row[4], + "broker": row[5], + "active": row[6], + "frequency": _maybe_parse_json(row[7]), + "frequency_days": row[8], + "unit": row[9], + "next_run": _format_local_ts(row[10]), + } + if row[2] is not None or row[3] is not None: + cfg["sip_frequency"] = { + "value": row[2], + "unit": row[3], + } + return cfg + + +def _save_config(cfg, user_id: str, run_id: str): + sip_frequency = cfg.get("sip_frequency") + sip_value = None + sip_unit = None + if isinstance(sip_frequency, dict): + sip_value = sip_frequency.get("value") + sip_unit = sip_frequency.get("unit") + + frequency = cfg.get("frequency") + if not isinstance(frequency, str) and frequency is not None: + frequency = json.dumps(frequency) + + next_run = cfg.get("next_run") + next_run_dt = None + if isinstance(next_run, str): + try: + parsed = datetime.fromisoformat(next_run) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=_local_tz()) + next_run_dt = parsed + except ValueError: + next_run_dt = None + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO strategy_config ( + user_id, + run_id, + strategy, + sip_amount, + sip_frequency_value, + sip_frequency_unit, + mode, + broker, + active, + frequency, + frequency_days, + unit, + next_run + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET strategy = EXCLUDED.strategy, + sip_amount = EXCLUDED.sip_amount, + sip_frequency_value = EXCLUDED.sip_frequency_value, + sip_frequency_unit = EXCLUDED.sip_frequency_unit, + mode = EXCLUDED.mode, + broker = EXCLUDED.broker, + active = EXCLUDED.active, + frequency = EXCLUDED.frequency, + frequency_days = EXCLUDED.frequency_days, + unit = EXCLUDED.unit, + next_run = EXCLUDED.next_run + """, + ( + user_id, + run_id, + cfg.get("strategy"), + cfg.get("sip_amount"), + sip_value, + sip_unit, + cfg.get("mode"), + cfg.get("broker"), + cfg.get("active"), + frequency, + cfg.get("frequency_days"), + cfg.get("unit"), + next_run_dt, + ), + ) + +def save_strategy_config(cfg, user_id: str, run_id: str): + _save_config(cfg, user_id, run_id) + +def deactivate_strategy_config(user_id: str, run_id: str): + cfg = _load_config(user_id, run_id) + cfg["active"] = False + _save_config(cfg, user_id, run_id) + +def _write_status(user_id: str, run_id: str, status): + now_local = datetime.now().astimezone() + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO engine_status (user_id, run_id, status, last_updated) + VALUES (%s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET status = EXCLUDED.status, + last_updated = EXCLUDED.last_updated + """, + (user_id, run_id, status, now_local), + ) + +def validate_frequency(freq: dict, mode: str): + if not isinstance(freq, dict): + raise ValueError("Frequency payload is required") + value = int(freq.get("value", 0)) + unit = freq.get("unit") + + if unit not in {"minutes", "days"}: + raise ValueError(f"Unsupported frequency unit: {unit}") + + if unit == "minutes": + if mode != "PAPER": + raise ValueError("Minute-level frequency allowed only in PAPER mode") + if value < 1: + raise ValueError("Minimum frequency is 1 minute") + + if unit == "days" and value < 1: + raise ValueError("Minimum frequency is 1 day") + +def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): + if not last_run or not sip_frequency: + return None + try: + last_dt = datetime.fromisoformat(last_run) + except ValueError: + return None + try: + delta = frequency_to_timedelta(sip_frequency) + except ValueError: + return None + next_dt = last_dt + delta + next_dt = align_to_market_open(next_dt) + return next_dt.isoformat() + +def start_strategy(req, user_id: str): + engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} + running_run_id = get_running_run_id(user_id) + if running_run_id: + if engine_external: + return {"status": "already_running", "run_id": running_run_id} + engine_config = _build_engine_config(user_id, running_run_id, req) + if engine_config: + started = start_engine(engine_config) + if started: + _write_status(user_id, running_run_id, "RUNNING") + return {"status": "restarted", "run_id": running_run_id} + return {"status": "already_running", "run_id": running_run_id} + mode = (req.mode or "PAPER").strip().upper() + if mode != "PAPER": + return {"status": "unsupported_mode"} + frequency_payload = req.sip_frequency.dict() if hasattr(req.sip_frequency, "dict") else dict(req.sip_frequency) + validate_frequency(frequency_payload, mode) + initial_cash = float(req.initial_cash) if req.initial_cash is not None else 1_000_000.0 + + try: + run_id = create_strategy_run( + user_id, + strategy=req.strategy_name, + mode=mode, + broker="paper", + meta={ + "sip_amount": req.sip_amount, + "sip_frequency": frequency_payload, + "initial_cash": initial_cash, + }, + ) + except errors.UniqueViolation: + return {"status": "already_running"} + + with engine_context(user_id, run_id): + init_paper_state(initial_cash, frequency_payload) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO paper_broker_account (user_id, run_id, cash) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET cash = EXCLUDED.cash + """, + (user_id, run_id, initial_cash), + ) + PaperBroker(initial_cash=initial_cash) + config = { + "strategy": req.strategy_name, + "sip_amount": req.sip_amount, + "sip_frequency": frequency_payload, + "mode": mode, + "broker": "paper", + "active": True, + } + save_strategy_config(config, user_id, run_id) + start_new_run(user_id, run_id) + _write_status(user_id, run_id, "RUNNING") + if not engine_external: + def emit_event_cb(*, event: str, message: str, level: str = "INFO", category: str = "ENGINE", meta: dict | None = None): + emit_event( + user_id=user_id, + run_id=run_id, + event=event, + message=message, + level=level, + category=category, + meta=meta, + ) + + engine_config = dict(config) + engine_config["initial_cash"] = initial_cash + engine_config["run_id"] = run_id + engine_config["user_id"] = user_id + engine_config["emit_event"] = emit_event_cb + start_engine(engine_config) + + try: + user = get_user_by_id(user_id) + if user: + body = ( + "Your strategy has been started.\n\n" + f"Strategy: {req.strategy_name}\n" + f"Mode: {mode}\n" + f"Run ID: {run_id}\n" + ) + send_email(user["username"], "Strategy started", body) + except Exception: + pass + + return {"status": "started", "run_id": run_id} + + +def _build_engine_config(user_id: str, run_id: str, req=None): + cfg = _load_config(user_id, run_id) + sip_frequency = cfg.get("sip_frequency") + if not isinstance(sip_frequency, dict) and req is not None: + sip_frequency = req.sip_frequency.dict() if hasattr(req.sip_frequency, "dict") else dict(req.sip_frequency) + if not isinstance(sip_frequency, dict): + sip_frequency = {"value": cfg.get("frequency_days") or 1, "unit": cfg.get("unit") or "days"} + + sip_amount = cfg.get("sip_amount") + if sip_amount is None and req is not None: + sip_amount = req.sip_amount + + mode = (cfg.get("mode") or (req.mode if req is not None else "PAPER") or "PAPER").strip().upper() + broker = cfg.get("broker") or "paper" + strategy_name = cfg.get("strategy") or cfg.get("strategy_name") or (req.strategy_name if req is not None else None) + + with engine_context(user_id, run_id): + state = load_state(mode=mode) + initial_cash = float(state.get("initial_cash") or 1_000_000.0) + + def emit_event_cb(*, event: str, message: str, level: str = "INFO", category: str = "ENGINE", meta: dict | None = None): + emit_event( + user_id=user_id, + run_id=run_id, + event=event, + message=message, + level=level, + category=category, + meta=meta, + ) + + return { + "strategy": strategy_name or "Golden Nifty", + "sip_amount": sip_amount or 0, + "sip_frequency": sip_frequency, + "mode": mode, + "broker": broker, + "active": cfg.get("active", True), + "initial_cash": initial_cash, + "user_id": user_id, + "run_id": run_id, + "emit_event": emit_event_cb, + } + + +def resume_running_runs(): + engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} + if engine_external: + return + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT user_id, run_id + FROM strategy_run + WHERE status = 'RUNNING' + ORDER BY created_at DESC + """ + ) + runs = cur.fetchall() + for user_id, run_id in runs: + engine_config = _build_engine_config(user_id, run_id, None) + if not engine_config: + continue + started = start_engine(engine_config) + if started: + _write_status(user_id, run_id, "RUNNING") + +def stop_strategy(user_id: str): + run_id = get_active_run_id(user_id) + engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} + if not engine_external: + stop_engine(user_id, run_id, timeout=15.0) + deactivate_strategy_config(user_id, run_id) + stop_run(user_id, run_id, reason="user_request") + _write_status(user_id, run_id, "STOPPED") + update_run_status(user_id, run_id, "STOPPED", meta={"reason": "user_request"}) + + try: + user = get_user_by_id(user_id) + if user: + body = "Your strategy has been stopped." + send_email(user["username"], "Strategy stopped", body) + except Exception: + pass + + return {"status": "stopped"} + +def get_strategy_status(user_id: str): + run_id = get_active_run_id(user_id) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT status, last_updated FROM engine_status WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + row = cur.fetchone() + if not row: + status = {"status": "IDLE", "last_updated": None} + else: + status = { + "status": row[0], + "last_updated": _format_local_ts(row[1]), + } + if status.get("status") == "RUNNING": + cfg = _load_config(user_id, run_id) + mode = (cfg.get("mode") or "LIVE").strip().upper() + with engine_context(user_id, run_id): + state = load_state(mode=mode) + last_execution_ts = state.get("last_run") or state.get("last_sip_ts") + sip_frequency = cfg.get("sip_frequency") + if not isinstance(sip_frequency, dict): + frequency = cfg.get("frequency") + unit = cfg.get("unit") + if isinstance(frequency, dict): + unit = frequency.get("unit", unit) + frequency = frequency.get("value") + if frequency is None and cfg.get("frequency_days") is not None: + frequency = cfg.get("frequency_days") + unit = unit or "days" + if frequency is not None and unit: + sip_frequency = {"value": frequency, "unit": unit} + next_eligible = compute_next_eligible(last_execution_ts, sip_frequency) + status["last_execution_ts"] = last_execution_ts + status["next_eligible_ts"] = next_eligible + if next_eligible: + try: + parsed_next = datetime.fromisoformat(next_eligible) + now_cmp = datetime.now(parsed_next.tzinfo) if parsed_next.tzinfo else datetime.now() + if parsed_next > now_cmp: + status["status"] = "WAITING" + except ValueError: + pass + return status + +def get_engine_status(user_id: str): + run_id = get_active_run_id(user_id) + status = { + "state": "STOPPED", + "run_id": run_id, + "user_id": user_id, + "last_heartbeat_ts": None, + } + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT status, last_updated + FROM engine_status + WHERE user_id = %s AND run_id = %s + ORDER BY last_updated DESC + LIMIT 1 + """, + (user_id, run_id), + ) + row = cur.fetchone() + if row: + status["state"] = row[0] + last_updated = row[1] + if last_updated is not None: + status["last_heartbeat_ts"] = ( + last_updated.astimezone(timezone.utc) + .isoformat() + .replace("+00:00", "Z") + ) + cfg = _load_config(user_id, run_id) + mode = (cfg.get("mode") or "LIVE").strip().upper() + with engine_context(user_id, run_id): + state = load_state(mode=mode) + last_execution_ts = state.get("last_run") or state.get("last_sip_ts") + sip_frequency = cfg.get("sip_frequency") + if isinstance(sip_frequency, dict): + sip_frequency = { + "value": sip_frequency.get("value"), + "unit": sip_frequency.get("unit"), + } + else: + frequency = cfg.get("frequency") + unit = cfg.get("unit") + if isinstance(frequency, dict): + unit = frequency.get("unit", unit) + frequency = frequency.get("value") + if frequency is None and cfg.get("frequency_days") is not None: + frequency = cfg.get("frequency_days") + unit = unit or "days" + if frequency is not None and unit: + sip_frequency = {"value": frequency, "unit": unit} + status["last_execution_ts"] = last_execution_ts + status["next_eligible_ts"] = compute_next_eligible(last_execution_ts, sip_frequency) + status["run_id"] = run_id + return status + + +def get_strategy_logs(user_id: str, since_seq: int): + run_id = get_active_run_id(user_id) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT seq, ts, level, category, event, message, run_id, meta + FROM strategy_log + WHERE user_id = %s AND run_id = %s AND seq > %s + ORDER BY seq + """, + (user_id, run_id, since_seq), + ) + rows = cur.fetchall() + events = [] + for row in rows: + ts = row[1] + if ts is not None: + ts_str = ts.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + else: + ts_str = None + events.append( + { + "seq": row[0], + "ts": ts_str, + "level": row[2], + "category": row[3], + "event": row[4], + "message": row[5], + "run_id": row[6], + "meta": row[7] if isinstance(row[7], dict) else {}, + } + ) + cur.execute( + "SELECT COALESCE(MAX(seq), 0) FROM strategy_log WHERE user_id = %s AND run_id = %s", + (user_id, run_id), + ) + latest_seq = cur.fetchone()[0] + return {"events": events, "latest_seq": latest_seq} + +def get_market_status(): + now = datetime.now() + return { + "status": "OPEN" if is_market_open(now) else "CLOSED", + "checked_at": now.isoformat(), + } diff --git a/backend/app/services/support_ticket.py b/backend/app/services/support_ticket.py new file mode 100644 index 0000000..fdbfa1c --- /dev/null +++ b/backend/app/services/support_ticket.py @@ -0,0 +1,70 @@ +import os +from datetime import datetime, timezone +from uuid import uuid4 + +from app.services.db import db_connection +from app.services.email_service import send_email + + +def _now(): + return datetime.now(timezone.utc) + + +def create_ticket(name: str, email: str, subject: str, message: str) -> dict: + ticket_id = str(uuid4()) + now = _now() + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO support_ticket + (id, name, email, subject, message, status, created_at, updated_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + """, + (ticket_id, name, email, subject, message, "NEW", now, now), + ) + email_sent = False + try: + email_body = ( + "Hi,\n\n" + "Your support ticket has been created.\n\n" + f"Ticket ID: {ticket_id}\n" + f"Subject: {subject}\n" + "Status: NEW\n\n" + "We will get back to you shortly.\n\n" + "Quantfortune Support" + ) + email_sent = send_email(email, "Quantfortune Support Ticket Created", email_body) + except Exception: + email_sent = False + return { + "ticket_id": ticket_id, + "status": "NEW", + "created_at": now.isoformat(), + "email_sent": email_sent, + } + + +def get_ticket_status(ticket_id: str, email: str) -> dict | None: + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, email, status, created_at, updated_at + FROM support_ticket + WHERE id = %s + """, + (ticket_id,), + ) + row = cur.fetchone() + if not row: + return None + if row[1].lower() != email.lower(): + return None + return { + "ticket_id": row[0], + "status": row[2], + "created_at": row[3].isoformat() if row[3] else None, + "updated_at": row[4].isoformat() if row[4] else None, + } diff --git a/backend/app/services/system_service.py b/backend/app/services/system_service.py new file mode 100644 index 0000000..8681e12 --- /dev/null +++ b/backend/app/services/system_service.py @@ -0,0 +1,378 @@ +import hashlib +import json +import os +from datetime import datetime, timezone + +from psycopg2.extras import Json + +from app.broker_store import get_user_broker, set_broker_auth_state +from app.services.db import db_connection +from app.services.run_lifecycle import RunLifecycleError, RunLifecycleManager +from app.services.strategy_service import compute_next_eligible, resume_running_runs +from app.services.zerodha_service import KiteTokenError, fetch_funds +from app.services.zerodha_storage import get_session + + +def _hash_value(value: str | None) -> str | None: + if value is None: + return None + return hashlib.sha256(value.encode("utf-8")).hexdigest() + + +def _parse_frequency(raw_value): + if raw_value is None: + return None + if isinstance(raw_value, dict): + return raw_value + if isinstance(raw_value, str): + text = raw_value.strip() + if not text: + return None + try: + return json.loads(text) + except Exception: + return None + return None + + +def _resolve_sip_frequency(row: dict): + value = row.get("sip_frequency_value") + unit = row.get("sip_frequency_unit") + if value is not None and unit: + return {"value": int(value), "unit": unit} + + frequency = _parse_frequency(row.get("frequency")) + if isinstance(frequency, dict): + freq_value = frequency.get("value") + freq_unit = frequency.get("unit") + if freq_value is not None and freq_unit: + return {"value": int(freq_value), "unit": freq_unit} + + fallback_value = row.get("frequency_days") + fallback_unit = row.get("unit") or "days" + if fallback_value is not None: + return {"value": int(fallback_value), "unit": fallback_unit} + + return None + + +def _parse_ts(value: str | None): + if not value: + return None + try: + return datetime.fromisoformat(value) + except ValueError: + return None + + +def _validate_broker_session(user_id: str): + session = get_session(user_id) + if not session: + return False + if os.getenv("BROKER_VALIDATION_MODE", "").strip().lower() == "skip": + return True + try: + fetch_funds(session["api_key"], session["access_token"]) + except KiteTokenError: + set_broker_auth_state(user_id, "EXPIRED") + return False + return True + + +def arm_system(user_id: str, client_ip: str | None = None): + if not _validate_broker_session(user_id): + return { + "ok": False, + "code": "BROKER_AUTH_REQUIRED", + "redirect_url": "/api/broker/login", + } + + now = datetime.now(timezone.utc) + armed_runs = [] + failed_runs = [] + next_runs = [] + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT sr.run_id, sr.status, sr.strategy, sr.mode, sr.broker, + sc.active, sc.sip_frequency_value, sc.sip_frequency_unit, + sc.frequency, sc.frequency_days, sc.unit, sc.next_run + FROM strategy_run sr + LEFT JOIN strategy_config sc + ON sc.user_id = sr.user_id AND sc.run_id = sr.run_id + WHERE sr.user_id = %s AND COALESCE(sc.active, false) = true + ORDER BY sr.created_at DESC + """, + (user_id,), + ) + rows = cur.fetchall() + + cur.execute("SELECT username FROM app_user WHERE id = %s", (user_id,)) + user_row = cur.fetchone() + username = user_row[0] if user_row else None + + for row in rows: + run = { + "run_id": row[0], + "status": row[1], + "strategy": row[2], + "mode": row[3], + "broker": row[4], + "active": row[5], + "sip_frequency_value": row[6], + "sip_frequency_unit": row[7], + "frequency": row[8], + "frequency_days": row[9], + "unit": row[10], + "next_run": row[11], + } + status = (run["status"] or "").strip().upper() + if status == "RUNNING": + armed_runs.append( + { + "run_id": run["run_id"], + "status": status, + "already_running": True, + } + ) + if run.get("next_run"): + next_runs.append(run["next_run"]) + continue + if status == "ERROR": + failed_runs.append( + { + "run_id": run["run_id"], + "status": status, + "reason": "ERROR", + } + ) + continue + try: + RunLifecycleManager.assert_can_arm(status) + except RunLifecycleError as exc: + failed_runs.append( + { + "run_id": run["run_id"], + "status": status, + "reason": str(exc), + } + ) + continue + + sip_frequency = _resolve_sip_frequency(run) + last_run = now.isoformat() + next_run = compute_next_eligible(last_run, sip_frequency) + next_run_dt = _parse_ts(next_run) + + cur.execute( + """ + UPDATE strategy_run + SET status = 'RUNNING', + started_at = COALESCE(started_at, %s), + stopped_at = NULL, + meta = COALESCE(meta, '{}'::jsonb) || %s + WHERE user_id = %s AND run_id = %s + """, + ( + now, + Json({"armed_at": now.isoformat()}), + user_id, + run["run_id"], + ), + ) + + cur.execute( + """ + INSERT INTO engine_status (user_id, run_id, status, last_updated) + VALUES (%s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET status = EXCLUDED.status, + last_updated = EXCLUDED.last_updated + """, + (user_id, run["run_id"], "RUNNING", now), + ) + + if (run.get("mode") or "").strip().upper() == "PAPER": + cur.execute( + """ + INSERT INTO engine_state_paper (user_id, run_id, last_run) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET last_run = EXCLUDED.last_run + """, + (user_id, run["run_id"], now), + ) + else: + cur.execute( + """ + INSERT INTO engine_state (user_id, run_id, last_run) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET last_run = EXCLUDED.last_run + """, + (user_id, run["run_id"], now), + ) + + cur.execute( + """ + UPDATE strategy_config + SET next_run = %s + WHERE user_id = %s AND run_id = %s + """, + (next_run_dt, user_id, run["run_id"]), + ) + + logical_time = now.replace(microsecond=0) + cur.execute( + """ + INSERT INTO engine_event (user_id, run_id, ts, event, message, meta) + VALUES (%s, %s, %s, %s, %s, %s) + """, + ( + user_id, + run["run_id"], + now, + "SYSTEM_ARMED", + "System armed", + Json({"next_run": next_run}), + ), + ) + cur.execute( + """ + INSERT INTO engine_event (user_id, run_id, ts, event, message, meta) + VALUES (%s, %s, %s, %s, %s, %s) + """, + ( + user_id, + run["run_id"], + now, + "RUN_REARMED", + "Run re-armed", + Json({"next_run": next_run}), + ), + ) + cur.execute( + """ + INSERT INTO event_ledger ( + user_id, run_id, timestamp, logical_time, event + ) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id, event, logical_time) DO NOTHING + """, + ( + user_id, + run["run_id"], + now, + logical_time, + "SYSTEM_ARMED", + ), + ) + cur.execute( + """ + INSERT INTO event_ledger ( + user_id, run_id, timestamp, logical_time, event + ) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id, event, logical_time) DO NOTHING + """, + ( + user_id, + run["run_id"], + now, + logical_time, + "RUN_REARMED", + ), + ) + + armed_runs.append( + { + "run_id": run["run_id"], + "status": "RUNNING", + "next_run": next_run, + } + ) + if next_run_dt: + next_runs.append(next_run_dt) + + audit_meta = { + "run_count": len(armed_runs), + "ip": client_ip, + } + cur.execute( + """ + INSERT INTO admin_audit_log + (actor_user_hash, target_user_hash, target_username_hash, action, meta) + VALUES (%s, %s, %s, %s, %s) + """, + ( + _hash_value(user_id), + _hash_value(user_id), + _hash_value(username), + "SYSTEM_ARM", + Json(audit_meta), + ), + ) + + try: + resume_running_runs() + except Exception: + pass + + broker_state = get_user_broker(user_id) or {} + next_execution = min(next_runs).isoformat() if next_runs else None + return { + "ok": True, + "armed_runs": armed_runs, + "failed_runs": failed_runs, + "next_execution": next_execution, + "broker_state": { + "connected": bool(broker_state.get("connected")), + "auth_state": broker_state.get("auth_state"), + "broker": broker_state.get("broker"), + "user_name": broker_state.get("user_name"), + }, + } + + +def system_status(user_id: str): + broker_state = get_user_broker(user_id) or {} + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT sr.run_id, sr.status, sr.strategy, sr.mode, sr.broker, + sc.next_run, sc.active + FROM strategy_run sr + LEFT JOIN strategy_config sc + ON sc.user_id = sr.user_id AND sc.run_id = sr.run_id + WHERE sr.user_id = %s + ORDER BY sr.created_at DESC + """, + (user_id,), + ) + rows = cur.fetchall() + runs = [ + { + "run_id": row[0], + "status": row[1], + "strategy": row[2], + "mode": row[3], + "broker": row[4], + "next_run": row[5].isoformat() if row[5] else None, + "active": bool(row[6]) if row[6] is not None else False, + "lifecycle": row[1], + } + for row in rows + ] + return { + "runs": runs, + "broker_state": { + "connected": bool(broker_state.get("connected")), + "auth_state": broker_state.get("auth_state"), + "broker": broker_state.get("broker"), + "user_name": broker_state.get("user_name"), + }, + } diff --git a/backend/app/services/tenant.py b/backend/app/services/tenant.py new file mode 100644 index 0000000..5270cf0 --- /dev/null +++ b/backend/app/services/tenant.py @@ -0,0 +1,19 @@ +from fastapi import HTTPException, Request + +from app.services.auth_service import get_user_for_session +from app.services.run_service import get_default_user_id + +SESSION_COOKIE_NAME = "session_id" + + +def get_request_user_id(request: Request) -> str: + session_id = request.cookies.get(SESSION_COOKIE_NAME) + if session_id: + user = get_user_for_session(session_id) + if user: + return user["id"] + + default_user_id = get_default_user_id() + if default_user_id: + return default_user_id + raise HTTPException(status_code=401, detail="Not authenticated") diff --git a/backend/app/services/zerodha_service.py b/backend/app/services/zerodha_service.py new file mode 100644 index 0000000..a1d8214 --- /dev/null +++ b/backend/app/services/zerodha_service.py @@ -0,0 +1,89 @@ +import hashlib +import json +import os +import urllib.error +import urllib.parse +import urllib.request + + +KITE_API_BASE = os.getenv("KITE_API_BASE", "https://api.kite.trade") +KITE_LOGIN_URL = os.getenv("KITE_LOGIN_URL", "https://kite.trade/connect/login") +KITE_VERSION = "3" + + +class KiteApiError(Exception): + def __init__(self, status_code: int, error_type: str, message: str): + super().__init__(f"Kite API error {status_code}: {error_type} - {message}") + self.status_code = status_code + self.error_type = error_type + self.message = message + + +class KiteTokenError(KiteApiError): + pass + + +def build_login_url(api_key: str, redirect_url: str | None = None) -> str: + params = {"api_key": api_key, "v": KITE_VERSION} + redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip() + if redirect_url: + params["redirect_url"] = redirect_url + query = urllib.parse.urlencode(params) + return f"{KITE_LOGIN_URL}?{query}" + + +def _request(method: str, url: str, data: dict | None = None, headers: dict | None = None): + payload = None + if data is not None: + payload = urllib.parse.urlencode(data).encode("utf-8") + req = urllib.request.Request(url, data=payload, headers=headers or {}, method=method) + try: + with urllib.request.urlopen(req, timeout=20) as resp: + body = resp.read().decode("utf-8") + except urllib.error.HTTPError as err: + error_body = err.read().decode("utf-8") if err.fp else "" + try: + payload = json.loads(error_body) if error_body else {} + except json.JSONDecodeError: + payload = {} + error_type = payload.get("error_type") or payload.get("status") or "unknown_error" + message = payload.get("message") or error_body or err.reason + exc_cls = KiteTokenError if error_type == "TokenException" else KiteApiError + raise exc_cls(err.code, error_type, message) from err + return json.loads(body) + + +def _auth_headers(api_key: str, access_token: str) -> dict: + return { + "X-Kite-Version": KITE_VERSION, + "Authorization": f"token {api_key}:{access_token}", + } + + +def exchange_request_token(api_key: str, api_secret: str, request_token: str) -> dict: + checksum = hashlib.sha256( + f"{api_key}{request_token}{api_secret}".encode("utf-8") + ).hexdigest() + url = f"{KITE_API_BASE}/session/token" + response = _request( + "POST", + url, + data={ + "api_key": api_key, + "request_token": request_token, + "checksum": checksum, + }, + ) + return response.get("data", {}) + + +def fetch_holdings(api_key: str, access_token: str) -> list: + url = f"{KITE_API_BASE}/portfolio/holdings" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", []) + + +def fetch_funds(api_key: str, access_token: str) -> dict: + url = f"{KITE_API_BASE}/user/margins" + response = _request("GET", url, headers=_auth_headers(api_key, access_token)) + return response.get("data", {}) diff --git a/backend/app/services/zerodha_storage.py b/backend/app/services/zerodha_storage.py new file mode 100644 index 0000000..13a291e --- /dev/null +++ b/backend/app/services/zerodha_storage.py @@ -0,0 +1,125 @@ +from datetime import datetime, timezone + +from app.services.crypto_service import decrypt_value, encrypt_value +from app.services.db import db_transaction + + +def _row_to_session(row): + access_token = decrypt_value(row[1]) if row[1] else None + request_token = decrypt_value(row[2]) if row[2] else None + return { + "api_key": row[0], + "access_token": access_token, + "request_token": request_token, + "user_name": row[3], + "broker_user_id": row[4], + "linked_at": row[5], + } + + +def get_session(user_id: str): + with db_transaction() as cur: + cur.execute( + """ + SELECT api_key, access_token, request_token, user_name, broker_user_id, linked_at + FROM zerodha_session + WHERE user_id = %s + ORDER BY linked_at DESC NULLS LAST, id DESC + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if row: + return _row_to_session(row) + + with db_transaction() as cur: + cur.execute( + """ + SELECT broker, connected, access_token, api_key, user_name, broker_user_id, connected_at + FROM user_broker + WHERE user_id = %s + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + broker, connected, access_token, api_key, user_name, broker_user_id, connected_at = row + if not connected or not access_token or not api_key: + return None + if (broker or "").strip().upper() != "ZERODHA": + return None + return { + "api_key": api_key, + "access_token": decrypt_value(access_token), + "request_token": None, + "user_name": user_name, + "broker_user_id": broker_user_id, + "linked_at": connected_at, + } + + +def set_session(user_id: str, data: dict): + access_token = data.get("access_token") + request_token = data.get("request_token") + linked_at = datetime.now(timezone.utc) + with db_transaction() as cur: + cur.execute( + """ + INSERT INTO zerodha_session ( + user_id, linked_at, api_key, access_token, request_token, user_name, broker_user_id + ) + VALUES (%s, %s, %s, %s, %s, %s, %s) + RETURNING linked_at + """, + ( + user_id, + linked_at, + data.get("api_key"), + encrypt_value(access_token) if access_token else None, + encrypt_value(request_token) if request_token else None, + data.get("user_name"), + data.get("broker_user_id"), + ), + ) + linked_at = cur.fetchone()[0] + return { + **data, + "user_id": user_id, + "linked_at": linked_at, + "access_token": access_token, + "request_token": request_token, + } + + +def store_request_token(user_id: str, request_token: str): + with db_transaction() as cur: + cur.execute( + """ + INSERT INTO zerodha_request_token (user_id, request_token) + VALUES (%s, %s) + ON CONFLICT (user_id) + DO UPDATE SET request_token = EXCLUDED.request_token + """, + (user_id, encrypt_value(request_token)), + ) + + +def consume_request_token(user_id: str): + with db_transaction() as cur: + cur.execute( + "SELECT request_token FROM zerodha_request_token WHERE user_id = %s", + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + cur.execute("DELETE FROM zerodha_request_token WHERE user_id = %s", (user_id,)) + return decrypt_value(row[0]) + + +def clear_session(user_id: str): + with db_transaction() as cur: + cur.execute("DELETE FROM zerodha_session WHERE user_id = %s", (user_id,)) + cur.execute("DELETE FROM zerodha_request_token WHERE user_id = %s", (user_id,)) diff --git a/backend/market.py b/backend/market.py new file mode 100644 index 0000000..c4b6df9 --- /dev/null +++ b/backend/market.py @@ -0,0 +1,91 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict +import sys +import time + +from fastapi import APIRouter + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.append(str(PROJECT_ROOT)) + +from indian_paper_trading_strategy.engine.data import fetch_live_price, get_price_snapshot + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" + +router = APIRouter(prefix="/api/market", tags=["market"]) + +_LTP_CACHE: Dict[str, Any] = { + "ts_epoch": 0.0, + "data": None, +} + +CACHE_TTL_SECONDS = 5 +STALE_SECONDS = 60 + + +@router.get("/ltp") +def get_ltp(allow_cache: bool = False): + now_epoch = time.time() + cached = _LTP_CACHE["data"] + if cached is not None and (now_epoch - _LTP_CACHE["ts_epoch"]) < CACHE_TTL_SECONDS: + return cached + + nifty_ltp = None + gold_ltp = None + try: + nifty_ltp = fetch_live_price(NIFTY) + except Exception: + nifty_ltp = None + try: + gold_ltp = fetch_live_price(GOLD) + except Exception: + gold_ltp = None + + nifty_meta = get_price_snapshot(NIFTY) or {} + gold_meta = get_price_snapshot(GOLD) or {} + now = datetime.now(timezone.utc) + + def _is_stale(meta: Dict[str, Any], ltp: float | None) -> bool: + if ltp is None: + return True + source = meta.get("source") + ts = meta.get("ts") + if source != "live": + return True + if isinstance(ts, datetime): + return (now - ts).total_seconds() > STALE_SECONDS + return False + + nifty_source = str(nifty_meta.get("source") or "").lower() + gold_source = str(gold_meta.get("source") or "").lower() + stale_map = { + NIFTY: _is_stale(nifty_meta, nifty_ltp), + GOLD: _is_stale(gold_meta, gold_ltp), + } + stale_any = stale_map[NIFTY] or stale_map[GOLD] + if allow_cache and stale_any: + cache_sources = {"cache", "cached", "history"} + if nifty_source in cache_sources and gold_source in cache_sources: + stale_map = {NIFTY: False, GOLD: False} + stale_any = False + + payload = { + "ts": now.isoformat(), + "ltp": { + NIFTY: float(nifty_ltp) if nifty_ltp is not None else None, + GOLD: float(gold_ltp) if gold_ltp is not None else None, + }, + "source": { + NIFTY: nifty_meta.get("source"), + GOLD: gold_meta.get("source"), + }, + "stale": stale_map, + "stale_any": stale_any, + } + + _LTP_CACHE["ts_epoch"] = now_epoch + _LTP_CACHE["data"] = payload + return payload diff --git a/backend/migrations/README b/backend/migrations/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/backend/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/backend/migrations/env.py b/backend/migrations/env.py new file mode 100644 index 0000000..db75070 --- /dev/null +++ b/backend/migrations/env.py @@ -0,0 +1,87 @@ +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context +from app.services.db import Base, get_database_url +import app.db_models # noqa: F401 + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = get_database_url() +if "%" in db_url: + db_url = db_url.replace("%", "%%") +config.set_main_option("sqlalchemy.url", db_url) +schema_name = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app" + +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + version_table_schema=schema_name, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + configuration = config.get_section(config.config_ini_section, {}) + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + version_table_schema=schema_name, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/migrations/script.py.mako b/backend/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/backend/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/backend/migrations/versions/52abc790351d_initial_schema.py b/backend/migrations/versions/52abc790351d_initial_schema.py new file mode 100644 index 0000000..d0ee30f --- /dev/null +++ b/backend/migrations/versions/52abc790351d_initial_schema.py @@ -0,0 +1,674 @@ +"""initial_schema + +Revision ID: 52abc790351d +Revises: +Create Date: 2026-01-18 08:34:50.268181 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '52abc790351d' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('admin_audit_log', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('ts', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('actor_user_hash', sa.Text(), nullable=False), + sa.Column('target_user_hash', sa.Text(), nullable=False), + sa.Column('target_username_hash', sa.Text(), nullable=True), + sa.Column('action', sa.Text(), nullable=False), + sa.Column('meta', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('admin_role_audit', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('actor_user_id', sa.String(), nullable=False), + sa.Column('target_user_id', sa.String(), nullable=False), + sa.Column('old_role', sa.String(), nullable=False), + sa.Column('new_role', sa.String(), nullable=False), + sa.Column('changed_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('app_user', + sa.Column('id', sa.String(), nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column('password_hash', sa.String(), nullable=False), + sa.Column('is_admin', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_super_admin', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('role', sa.String(), server_default=sa.text("'USER'"), nullable=False), + sa.CheckConstraint("role IN ('USER','ADMIN','SUPER_ADMIN')", name='chk_app_user_role'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('username') + ) + op.create_index('idx_app_user_is_admin', 'app_user', ['is_admin'], unique=False) + op.create_index('idx_app_user_is_super_admin', 'app_user', ['is_super_admin'], unique=False) + op.create_index('idx_app_user_role', 'app_user', ['role'], unique=False) + op.create_table('market_close', + sa.Column('symbol', sa.String(), nullable=False), + sa.Column('date', sa.Date(), nullable=False), + sa.Column('close', sa.Numeric(), nullable=False), + sa.PrimaryKeyConstraint('symbol', 'date') + ) + op.create_index('idx_market_close_date', 'market_close', ['date'], unique=False) + op.create_index('idx_market_close_symbol', 'market_close', ['symbol'], unique=False) + op.create_table('app_session', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('last_seen_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_app_session_expires_at', 'app_session', ['expires_at'], unique=False) + op.create_index('idx_app_session_user_id', 'app_session', ['user_id'], unique=False) + op.create_table('strategy_run', + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('stopped_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('status', sa.String(), nullable=False), + sa.Column('strategy', sa.String(), nullable=True), + sa.Column('mode', sa.String(), nullable=True), + sa.Column('broker', sa.String(), nullable=True), + sa.Column('meta', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.CheckConstraint("status IN ('RUNNING','STOPPED','ERROR')", name='chk_strategy_run_status'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('run_id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_strategy_run_user_run') + ) + op.create_index('idx_strategy_run_user_created', 'strategy_run', ['user_id', 'created_at'], unique=False) + op.create_index('idx_strategy_run_user_status', 'strategy_run', ['user_id', 'status'], unique=False) + op.create_index('uq_one_running_run_per_user', 'strategy_run', ['user_id'], unique=True, postgresql_where=sa.text("status = 'RUNNING'")) + op.create_table('user_broker', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('broker', sa.String(), nullable=True), + sa.Column('connected', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('access_token', sa.Text(), nullable=True), + sa.Column('connected_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('api_key', sa.Text(), nullable=True), + sa.Column('user_name', sa.Text(), nullable=True), + sa.Column('broker_user_id', sa.Text(), nullable=True), + sa.Column('pending_broker', sa.Text(), nullable=True), + sa.Column('pending_api_key', sa.Text(), nullable=True), + sa.Column('pending_api_secret', sa.Text(), nullable=True), + sa.Column('pending_started_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id') + ) + op.create_index('idx_user_broker_broker', 'user_broker', ['broker'], unique=False) + op.create_index('idx_user_broker_connected', 'user_broker', ['connected'], unique=False) + op.create_table('zerodha_request_token', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('request_token', sa.Text(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id') + ) + op.create_table('zerodha_session', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('linked_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('api_key', sa.Text(), nullable=True), + sa.Column('access_token', sa.Text(), nullable=True), + sa.Column('request_token', sa.Text(), nullable=True), + sa.Column('user_name', sa.Text(), nullable=True), + sa.Column('broker_user_id', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_zerodha_session_linked_at', 'zerodha_session', ['linked_at'], unique=False) + op.create_index('idx_zerodha_session_user_id', 'zerodha_session', ['user_id'], unique=False) + op.create_table('engine_event', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('ts', sa.DateTime(timezone=True), nullable=False), + sa.Column('event', sa.String(), nullable=True), + sa.Column('data', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('meta', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['run_id'], ['strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_engine_event_ts', 'engine_event', ['ts'], unique=False) + op.create_index('idx_engine_event_user_run_ts', 'engine_event', ['user_id', 'run_id', 'ts'], unique=False) + op.create_table('engine_state', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('total_invested', sa.Numeric(), nullable=True), + sa.Column('nifty_units', sa.Numeric(), nullable=True), + sa.Column('gold_units', sa.Numeric(), nullable=True), + sa.Column('last_sip_ts', sa.DateTime(timezone=True), nullable=True), + sa.Column('last_run', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_engine_state_user_run') + ) + op.create_table('engine_state_paper', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('initial_cash', sa.Numeric(), nullable=True), + sa.Column('cash', sa.Numeric(), nullable=True), + sa.Column('total_invested', sa.Numeric(), nullable=True), + sa.Column('nifty_units', sa.Numeric(), nullable=True), + sa.Column('gold_units', sa.Numeric(), nullable=True), + sa.Column('last_sip_ts', sa.DateTime(timezone=True), nullable=True), + sa.Column('last_run', sa.DateTime(timezone=True), nullable=True), + sa.Column('sip_frequency_value', sa.Integer(), nullable=True), + sa.Column('sip_frequency_unit', sa.String(), nullable=True), + sa.CheckConstraint('cash >= 0', name='chk_engine_state_paper_cash_non_negative'), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_engine_state_paper_user_run') + ) + op.create_table('engine_status', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('status', sa.String(), nullable=False), + sa.Column('last_updated', sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_engine_status_user_run') + ) + op.create_index('idx_engine_status_user_run', 'engine_status', ['user_id', 'run_id'], unique=False) + op.create_table('event_ledger', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('logical_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('event', sa.String(), nullable=False), + sa.Column('nifty_units', sa.Numeric(), nullable=True), + sa.Column('gold_units', sa.Numeric(), nullable=True), + sa.Column('nifty_price', sa.Numeric(), nullable=True), + sa.Column('gold_price', sa.Numeric(), nullable=True), + sa.Column('amount', sa.Numeric(), nullable=True), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', 'event', 'logical_time', name='uq_event_ledger_event_time') + ) + op.create_index('idx_event_ledger_ts', 'event_ledger', ['timestamp'], unique=False) + op.create_index('idx_event_ledger_user_run_ts', 'event_ledger', ['user_id', 'run_id', 'timestamp'], unique=False) + op.create_table('mtm_ledger', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('logical_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('nifty_units', sa.Numeric(), nullable=True), + sa.Column('gold_units', sa.Numeric(), nullable=True), + sa.Column('nifty_price', sa.Numeric(), nullable=True), + sa.Column('gold_price', sa.Numeric(), nullable=True), + sa.Column('nifty_value', sa.Numeric(), nullable=True), + sa.Column('gold_value', sa.Numeric(), nullable=True), + sa.Column('portfolio_value', sa.Numeric(), nullable=True), + sa.Column('total_invested', sa.Numeric(), nullable=True), + sa.Column('pnl', sa.Numeric(), nullable=True), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id', 'run_id', 'logical_time') + ) + op.create_index('idx_mtm_ledger_ts', 'mtm_ledger', ['timestamp'], unique=False) + op.create_index('idx_mtm_ledger_user_run_ts', 'mtm_ledger', ['user_id', 'run_id', 'timestamp'], unique=False) + op.create_table('paper_broker_account', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('cash', sa.Numeric(), nullable=False), + sa.CheckConstraint('cash >= 0', name='chk_paper_broker_cash_non_negative'), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_paper_broker_account_user_run') + ) + op.create_table('paper_equity_curve', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('logical_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('equity', sa.Numeric(), nullable=False), + sa.Column('pnl', sa.Numeric(), nullable=True), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id', 'run_id', 'logical_time') + ) + op.create_index('idx_paper_equity_curve_ts', 'paper_equity_curve', ['timestamp'], unique=False) + op.create_index('idx_paper_equity_curve_user_run_ts', 'paper_equity_curve', ['user_id', 'run_id', 'timestamp'], unique=False) + op.create_table('paper_order', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('symbol', sa.String(), nullable=False), + sa.Column('side', sa.String(), nullable=False), + sa.Column('qty', sa.Numeric(), nullable=False), + sa.Column('price', sa.Numeric(), nullable=True), + sa.Column('status', sa.String(), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('logical_time', sa.DateTime(timezone=True), nullable=False), + sa.CheckConstraint('price >= 0', name='chk_paper_order_price_non_negative'), + sa.CheckConstraint('qty > 0', name='chk_paper_order_qty_positive'), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', 'id', name='uq_paper_order_scope_id'), + sa.UniqueConstraint('user_id', 'run_id', 'logical_time', 'symbol', 'side', name='uq_paper_order_logical_key') + ) + op.create_index('idx_paper_order_ts', 'paper_order', ['timestamp'], unique=False) + op.create_index('idx_paper_order_user_run_ts', 'paper_order', ['user_id', 'run_id', 'timestamp'], unique=False) + op.create_table('paper_position', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('symbol', sa.String(), nullable=False), + sa.Column('qty', sa.Numeric(), nullable=False), + sa.Column('avg_price', sa.Numeric(), nullable=True), + sa.Column('last_price', sa.Numeric(), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.CheckConstraint('qty > 0', name='chk_paper_position_qty_positive'), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id', 'run_id', 'symbol'), + sa.UniqueConstraint('user_id', 'run_id', 'symbol', name='uq_paper_position_scope') + ) + op.create_index('idx_paper_position_user_run', 'paper_position', ['user_id', 'run_id'], unique=False) + op.create_table('strategy_config', + sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column('strategy', sa.String(), nullable=True), + sa.Column('sip_amount', sa.Numeric(), nullable=True), + sa.Column('sip_frequency_value', sa.Integer(), nullable=True), + sa.Column('sip_frequency_unit', sa.String(), nullable=True), + sa.Column('mode', sa.String(), nullable=True), + sa.Column('broker', sa.String(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), + sa.Column('frequency', sa.Text(), nullable=True), + sa.Column('frequency_days', sa.Integer(), nullable=True), + sa.Column('unit', sa.String(), nullable=True), + sa.Column('next_run', sa.DateTime(timezone=True), nullable=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['run_id'], ['strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', name='uq_strategy_config_user_run') + ) + op.create_table('strategy_log', + sa.Column('seq', sa.BigInteger(), nullable=False), + sa.Column('ts', sa.DateTime(timezone=True), nullable=False), + sa.Column('level', sa.String(), nullable=True), + sa.Column('category', sa.String(), nullable=True), + sa.Column('event', sa.String(), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('meta', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.ForeignKeyConstraint(['run_id'], ['strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('seq') + ) + op.create_index('idx_strategy_log_event', 'strategy_log', ['event'], unique=False) + op.create_index('idx_strategy_log_ts', 'strategy_log', ['ts'], unique=False) + op.create_index('idx_strategy_log_user_run_ts', 'strategy_log', ['user_id', 'run_id', 'ts'], unique=False) + op.create_table('paper_trade', + sa.Column('id', sa.String(), nullable=False), + sa.Column('order_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('run_id', sa.String(), nullable=False), + sa.Column('symbol', sa.String(), nullable=False), + sa.Column('side', sa.String(), nullable=False), + sa.Column('qty', sa.Numeric(), nullable=False), + sa.Column('price', sa.Numeric(), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('logical_time', sa.DateTime(timezone=True), nullable=False), + sa.CheckConstraint('price >= 0', name='chk_paper_trade_price_non_negative'), + sa.CheckConstraint('qty > 0', name='chk_paper_trade_qty_positive'), + sa.ForeignKeyConstraint(['user_id', 'run_id', 'order_id'], ['paper_order.user_id', 'paper_order.run_id', 'paper_order.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id', 'run_id'], ['strategy_run.user_id', 'strategy_run.run_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['app_user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'run_id', 'id', name='uq_paper_trade_scope_id'), + sa.UniqueConstraint('user_id', 'run_id', 'logical_time', 'symbol', 'side', name='uq_paper_trade_logical_key') + ) + op.create_index('idx_paper_trade_ts', 'paper_trade', ['timestamp'], unique=False) + op.create_index('idx_paper_trade_user_run_ts', 'paper_trade', ['user_id', 'run_id', 'timestamp'], unique=False) + # admin views and protections + op.execute( + """ + CREATE OR REPLACE FUNCTION prevent_super_admin_delete() + RETURNS trigger AS $$ + BEGIN + IF OLD.role = 'SUPER_ADMIN' OR OLD.is_super_admin THEN + RAISE EXCEPTION 'cannot delete super admin user'; + END IF; + RETURN OLD; + END; + $$ LANGUAGE plpgsql; + """ + ) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'trg_prevent_super_admin_delete') THEN + CREATE TRIGGER trg_prevent_super_admin_delete + BEFORE DELETE ON app_user + FOR EACH ROW + EXECUTE FUNCTION prevent_super_admin_delete(); + END IF; + END $$; + """ + ) + op.create_index('idx_event_ledger_user_run_logical', 'event_ledger', ['user_id', 'run_id', 'logical_time'], unique=False) + op.execute( + """ + CREATE OR REPLACE VIEW admin_user_metrics AS + WITH session_stats AS ( + SELECT + user_id, + MIN(created_at) AS first_session_at, + MAX(COALESCE(last_seen_at, created_at)) AS last_login_at + FROM app_session + GROUP BY user_id + ), + run_stats AS ( + SELECT + user_id, + COUNT(*) AS runs_count, + MAX(CASE WHEN status = 'RUNNING' THEN run_id END) AS active_run_id, + MAX(CASE WHEN status = 'RUNNING' THEN status END) AS active_run_status, + MIN(created_at) AS first_run_at + FROM strategy_run + GROUP BY user_id + ), + broker_stats AS ( + SELECT user_id, BOOL_OR(connected) AS broker_connected + FROM user_broker + GROUP BY user_id + ) + SELECT + u.id AS user_id, + u.username, + u.role, + (u.role IN ('ADMIN','SUPER_ADMIN')) AS is_admin, + COALESCE(session_stats.first_session_at, run_stats.first_run_at) AS created_at, + session_stats.last_login_at, + COALESCE(run_stats.runs_count, 0) AS runs_count, + run_stats.active_run_id, + run_stats.active_run_status, + COALESCE(broker_stats.broker_connected, FALSE) AS broker_connected + FROM app_user u + LEFT JOIN session_stats ON session_stats.user_id = u.id + LEFT JOIN run_stats ON run_stats.user_id = u.id + LEFT JOIN broker_stats ON broker_stats.user_id = u.id; + """ + ) + op.execute( + """ + CREATE OR REPLACE VIEW admin_run_metrics AS + WITH order_stats AS ( + SELECT user_id, run_id, COUNT(*) AS order_count, MAX("timestamp") AS last_order_time + FROM paper_order + GROUP BY user_id, run_id + ), + trade_stats AS ( + SELECT user_id, run_id, COUNT(*) AS trade_count, MAX("timestamp") AS last_trade_time + FROM paper_trade + GROUP BY user_id, run_id + ), + event_stats AS ( + SELECT + user_id, + run_id, + MAX("timestamp") AS last_event_time, + MAX(CASE WHEN event = 'SIP_EXECUTED' THEN "timestamp" END) AS last_sip_time + FROM event_ledger + GROUP BY user_id, run_id + ), + equity_latest AS ( + SELECT DISTINCT ON (user_id, run_id) + user_id, + run_id, + equity AS equity_latest, + pnl AS pnl_latest, + "timestamp" AS equity_ts + FROM paper_equity_curve + ORDER BY user_id, run_id, "timestamp" DESC + ), + mtm_latest AS ( + SELECT DISTINCT ON (user_id, run_id) + user_id, + run_id, + "timestamp" AS mtm_ts + FROM mtm_ledger + ORDER BY user_id, run_id, "timestamp" DESC + ), + log_latest AS ( + SELECT user_id, run_id, MAX(ts) AS last_log_time + FROM strategy_log + GROUP BY user_id, run_id + ), + engine_latest AS ( + SELECT user_id, run_id, MAX(ts) AS last_engine_time + FROM engine_event + GROUP BY user_id, run_id + ), + activity AS ( + SELECT user_id, run_id, MAX(ts) AS last_event_time + FROM ( + SELECT user_id, run_id, ts FROM engine_event + UNION ALL + SELECT user_id, run_id, ts FROM strategy_log + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM paper_order + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM paper_trade + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM mtm_ledger + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM paper_equity_curve + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM event_ledger + ) t + GROUP BY user_id, run_id + ) + SELECT + sr.run_id, + sr.user_id, + sr.status, + sr.created_at, + sr.started_at, + sr.stopped_at, + sr.strategy, + sr.mode, + sr.broker, + sc.sip_amount, + sc.sip_frequency_value, + sc.sip_frequency_unit, + sc.next_run AS next_sip_time, + activity.last_event_time, + event_stats.last_sip_time, + COALESCE(order_stats.order_count, 0) AS order_count, + COALESCE(trade_stats.trade_count, 0) AS trade_count, + equity_latest.equity_latest, + equity_latest.pnl_latest + FROM strategy_run sr + LEFT JOIN strategy_config sc + ON sc.user_id = sr.user_id AND sc.run_id = sr.run_id + LEFT JOIN order_stats + ON order_stats.user_id = sr.user_id AND order_stats.run_id = sr.run_id + LEFT JOIN trade_stats + ON trade_stats.user_id = sr.user_id AND trade_stats.run_id = sr.run_id + LEFT JOIN event_stats + ON event_stats.user_id = sr.user_id AND event_stats.run_id = sr.run_id + LEFT JOIN equity_latest + ON equity_latest.user_id = sr.user_id AND equity_latest.run_id = sr.run_id + LEFT JOIN mtm_latest + ON mtm_latest.user_id = sr.user_id AND mtm_latest.run_id = sr.run_id + LEFT JOIN log_latest + ON log_latest.user_id = sr.user_id AND log_latest.run_id = sr.run_id + LEFT JOIN engine_latest + ON engine_latest.user_id = sr.user_id AND engine_latest.run_id = sr.run_id + LEFT JOIN activity + ON activity.user_id = sr.user_id AND activity.run_id = sr.run_id; + """ + ) + op.execute( + """ + CREATE OR REPLACE VIEW admin_engine_health AS + WITH activity AS ( + SELECT user_id, run_id, MAX(ts) AS last_event_time + FROM ( + SELECT user_id, run_id, ts FROM engine_event + UNION ALL + SELECT user_id, run_id, ts FROM strategy_log + UNION ALL + SELECT user_id, run_id, "timestamp" AS ts FROM event_ledger + ) t + GROUP BY user_id, run_id + ) + SELECT + sr.run_id, + sr.user_id, + sr.status, + activity.last_event_time, + es.status AS engine_status, + es.last_updated AS engine_status_ts + FROM strategy_run sr + LEFT JOIN activity + ON activity.user_id = sr.user_id AND activity.run_id = sr.run_id + LEFT JOIN engine_status es + ON es.user_id = sr.user_id AND es.run_id = sr.run_id; + """ + ) + op.execute( + """ + CREATE OR REPLACE VIEW admin_order_stats AS + SELECT + user_id, + run_id, + COUNT(*) AS total_orders, + COUNT(*) FILTER (WHERE "timestamp" >= now() - interval '24 hours') AS orders_last_24h, + COUNT(*) FILTER (WHERE status = 'FILLED') AS filled_orders + FROM paper_order + GROUP BY user_id, run_id; + """ + ) + op.execute( + """ + CREATE OR REPLACE VIEW admin_ledger_stats AS + WITH mtm_latest AS ( + SELECT DISTINCT ON (user_id, run_id) + user_id, + run_id, + portfolio_value, + pnl, + "timestamp" AS mtm_ts + FROM mtm_ledger + ORDER BY user_id, run_id, "timestamp" DESC + ), + equity_latest AS ( + SELECT DISTINCT ON (user_id, run_id) + user_id, + run_id, + equity, + pnl, + "timestamp" AS equity_ts + FROM paper_equity_curve + ORDER BY user_id, run_id, "timestamp" DESC + ) + SELECT + sr.user_id, + sr.run_id, + mtm_latest.portfolio_value AS mtm_value, + mtm_latest.pnl AS mtm_pnl, + mtm_latest.mtm_ts, + equity_latest.equity AS equity_value, + equity_latest.pnl AS equity_pnl, + equity_latest.equity_ts + FROM strategy_run sr + LEFT JOIN mtm_latest + ON mtm_latest.user_id = sr.user_id AND mtm_latest.run_id = sr.run_id + LEFT JOIN equity_latest + ON equity_latest.user_id = sr.user_id AND equity_latest.run_id = sr.run_id; + """ + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.execute("DROP VIEW IF EXISTS admin_ledger_stats;") + op.execute("DROP VIEW IF EXISTS admin_order_stats;") + op.execute("DROP VIEW IF EXISTS admin_engine_health;") + op.execute("DROP VIEW IF EXISTS admin_run_metrics;") + op.execute("DROP VIEW IF EXISTS admin_user_metrics;") + op.execute("DROP TRIGGER IF EXISTS trg_prevent_super_admin_delete ON app_user;") + op.execute("DROP FUNCTION IF EXISTS prevent_super_admin_delete;") + op.drop_index('idx_paper_trade_user_run_ts', table_name='paper_trade') + op.drop_index('idx_paper_trade_ts', table_name='paper_trade') + op.drop_table('paper_trade') + op.drop_index('idx_strategy_log_user_run_ts', table_name='strategy_log') + op.drop_index('idx_strategy_log_ts', table_name='strategy_log') + op.drop_index('idx_strategy_log_event', table_name='strategy_log') + op.drop_table('strategy_log') + op.drop_table('strategy_config') + op.drop_index('idx_paper_position_user_run', table_name='paper_position') + op.drop_table('paper_position') + op.drop_index('idx_paper_order_user_run_ts', table_name='paper_order') + op.drop_index('idx_paper_order_ts', table_name='paper_order') + op.drop_table('paper_order') + op.drop_index('idx_paper_equity_curve_user_run_ts', table_name='paper_equity_curve') + op.drop_index('idx_paper_equity_curve_ts', table_name='paper_equity_curve') + op.drop_table('paper_equity_curve') + op.drop_table('paper_broker_account') + op.drop_index('idx_mtm_ledger_user_run_ts', table_name='mtm_ledger') + op.drop_index('idx_mtm_ledger_ts', table_name='mtm_ledger') + op.drop_table('mtm_ledger') + op.drop_index('idx_event_ledger_user_run_logical', table_name='event_ledger') + op.drop_index('idx_event_ledger_user_run_ts', table_name='event_ledger') + op.drop_index('idx_event_ledger_ts', table_name='event_ledger') + op.drop_table('event_ledger') + op.drop_index('idx_engine_status_user_run', table_name='engine_status') + op.drop_table('engine_status') + op.drop_table('engine_state_paper') + op.drop_table('engine_state') + op.drop_index('idx_engine_event_user_run_ts', table_name='engine_event') + op.drop_index('idx_engine_event_ts', table_name='engine_event') + op.drop_table('engine_event') + op.drop_index('idx_zerodha_session_user_id', table_name='zerodha_session') + op.drop_index('idx_zerodha_session_linked_at', table_name='zerodha_session') + op.drop_table('zerodha_session') + op.drop_table('zerodha_request_token') + op.drop_index('idx_user_broker_connected', table_name='user_broker') + op.drop_index('idx_user_broker_broker', table_name='user_broker') + op.drop_table('user_broker') + op.drop_index('uq_one_running_run_per_user', table_name='strategy_run', postgresql_where=sa.text("status = 'RUNNING'")) + op.drop_index('idx_strategy_run_user_status', table_name='strategy_run') + op.drop_index('idx_strategy_run_user_created', table_name='strategy_run') + op.drop_table('strategy_run') + op.drop_index('idx_app_session_user_id', table_name='app_session') + op.drop_index('idx_app_session_expires_at', table_name='app_session') + op.drop_table('app_session') + op.drop_index('idx_market_close_symbol', table_name='market_close') + op.drop_index('idx_market_close_date', table_name='market_close') + op.drop_table('market_close') + op.drop_index('idx_app_user_role', table_name='app_user') + op.drop_index('idx_app_user_is_super_admin', table_name='app_user') + op.drop_index('idx_app_user_is_admin', table_name='app_user') + op.drop_table('app_user') + op.drop_table('admin_role_audit') + op.drop_table('admin_audit_log') + # ### end Alembic commands ### diff --git a/backend/paper_mtm.py b/backend/paper_mtm.py new file mode 100644 index 0000000..f9192f0 --- /dev/null +++ b/backend/paper_mtm.py @@ -0,0 +1,76 @@ +from typing import Any, Dict +from pathlib import Path +import sys + +from fastapi import APIRouter, Request + +from app.services.paper_broker_service import get_paper_broker +from app.services.tenant import get_request_user_id +from app.services.run_service import get_active_run_id +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.append(str(PROJECT_ROOT)) + +from indian_paper_trading_strategy.engine.db import engine_context +from market import get_ltp + +from indian_paper_trading_strategy.engine.state import load_state + +router = APIRouter(prefix="/api/paper", tags=["paper-mtm"]) + + +@router.get("/mtm") +def paper_mtm(request: Request) -> Dict[str, Any]: + user_id = get_request_user_id(request) + run_id = get_active_run_id(user_id) + with engine_context(user_id, run_id): + broker = get_paper_broker(user_id) + + positions = broker.get_positions() + state = load_state(mode="PAPER") + cash = float(state.get("cash", 0)) + initial_cash = float(state.get("initial_cash", 0)) + + ltp_payload = get_ltp(allow_cache=True) + ltp_map = ltp_payload["ltp"] + + mtm_positions = [] + positions_value = 0.0 + + for pos in positions: + symbol = pos.get("symbol") + if not symbol: + continue + qty = float(pos.get("qty", 0)) + avg_price = float(pos.get("avg_price", 0)) + ltp = ltp_map.get(symbol) + if ltp is None: + continue + + pnl = (ltp - avg_price) * qty + positions_value += qty * ltp + + mtm_positions.append( + { + "symbol": symbol, + "qty": qty, + "avg_price": avg_price, + "ltp": ltp, + "pnl": pnl, + } + ) + + equity = cash + positions_value + unrealized_pnl = equity - float(initial_cash) + + return { + "ts": ltp_payload["ts"], + "initial_cash": initial_cash, + "cash": cash, + "positions_value": positions_value, + "equity": equity, + "unrealized_pnl": unrealized_pnl, + "positions": mtm_positions, + "price_stale": ltp_payload.get("stale_any", False), + "price_source": ltp_payload.get("source", {}), + } diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..1036150 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,43 @@ +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.1 +beautifulsoup4==4.14.3 +certifi==2026.1.4 +cffi==2.0.0 +charset-normalizer==3.4.4 +click==8.3.1 +colorama==0.4.6 +cryptography==46.0.3 +curl_cffi==0.13.0 +fastapi==0.128.0 +frozendict==2.4.7 +h11==0.16.0 +idna==3.11 +httpx==0.27.2 +multitasking==0.0.12 +numpy==2.4.1 +pandas==2.3.3 +peewee==3.19.0 +platformdirs==4.5.1 +protobuf==6.33.4 +psycopg2-binary==2.9.11 +SQLAlchemy==2.0.36 +pycparser==2.23 +pydantic==2.12.5 +pydantic_core==2.41.5 +python-dateutil==2.9.0.post0 +pytz==2025.2 +requests==2.32.5 +six==1.17.0 +soupsieve==2.8.1 +starlette==0.50.0 +ta==0.11.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.3 +urllib3==2.6.3 +uvicorn==0.40.0 +websockets==16.0 +yfinance==1.0 +alembic==1.13.3 +pytest==8.3.5 diff --git a/backend/run_backend.ps1 b/backend/run_backend.ps1 new file mode 100644 index 0000000..e312e4d --- /dev/null +++ b/backend/run_backend.ps1 @@ -0,0 +1,29 @@ +Set-Location $PSScriptRoot +if (-not $env:DB_HOST) { $env:DB_HOST = 'localhost' } +if (-not $env:DB_PORT) { $env:DB_PORT = '5432' } +if (-not $env:DB_NAME) { $env:DB_NAME = 'trading_db' } +if (-not $env:DB_USER) { $env:DB_USER = 'trader' } +if (-not $env:DB_PASSWORD) { $env:DB_PASSWORD = 'traderpass' } +if (-not $env:DB_SCHEMA) { $env:DB_SCHEMA = 'quant_app' } +if (-not $env:DB_CONNECT_TIMEOUT) { $env:DB_CONNECT_TIMEOUT = '5' } +$frontendUrlFile = Join-Path (Split-Path -Parent $PSScriptRoot) 'ngrok_frontend_url.txt' +$env:ZERODHA_REDIRECT_URL = 'http://localhost:3000/login' +if (Test-Path $frontendUrlFile) { + $frontendUrl = (Get-Content $frontendUrlFile -Raw).Trim() + if ($frontendUrl) { + $env:CORS_ORIGINS = "http://localhost:3000,http://127.0.0.1:3000,$frontendUrl" + $env:COOKIE_SECURE = '1' + $env:COOKIE_SAMESITE = 'none' + $env:ZERODHA_REDIRECT_URL = "$frontendUrl/login" + } +} +if (-not $env:BROKER_TOKEN_KEY) { $env:BROKER_TOKEN_KEY = 'CHANGE_ME' } +if (-not $env:SUPER_ADMIN_EMAIL) { $env:SUPER_ADMIN_EMAIL = 'admin@example.com' } +if (-not $env:SUPER_ADMIN_PASSWORD) { $env:SUPER_ADMIN_PASSWORD = 'AdminPass123!' } +if (-not $env:SMTP_HOST) { $env:SMTP_HOST = 'smtp.gmail.com' } +if (-not $env:SMTP_PORT) { $env:SMTP_PORT = '587' } +if (-not $env:SMTP_USER) { $env:SMTP_USER = 'quantfortune@gmail.com' } +if (-not $env:SMTP_PASS) { $env:SMTP_PASS = 'CHANGE_ME' } +if (-not $env:SMTP_FROM_NAME) { $env:SMTP_FROM_NAME = 'Quantfortune Support' } +if (-not $env:RESET_OTP_SECRET) { $env:RESET_OTP_SECRET = 'CHANGE_ME' } +.\venv\Scripts\uvicorn.exe app.main:app --host 0.0.0.0 --port 8000 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..10ee50f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,18 @@ +version: "3.9" + +services: + postgres: + image: postgres:15 + container_name: trading_postgres + restart: unless-stopped + environment: + POSTGRES_USER: trader + POSTGRES_PASSWORD: traderpass + POSTGRES_DB: trading_db + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + +volumes: + pgdata: diff --git a/indian_paper_trading_strategy/app/streamlit_app.py b/indian_paper_trading_strategy/app/streamlit_app.py new file mode 100644 index 0000000..8ce6e63 --- /dev/null +++ b/indian_paper_trading_strategy/app/streamlit_app.py @@ -0,0 +1,208 @@ +import streamlit as st +import time +from datetime import datetime +from pathlib import Path +import sys +import pandas as pd + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from indian_paper_trading_strategy.engine.history import load_monthly_close +from indian_paper_trading_strategy.engine.market import india_market_status +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.strategy import allocation +from indian_paper_trading_strategy.engine.execution import try_execute_sip +from indian_paper_trading_strategy.engine.state import load_state, save_state +from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_default_user_id, get_active_run_id, set_context +from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm + +_STREAMLIT_USER_ID = get_default_user_id() +_STREAMLIT_RUN_ID = get_active_run_id(_STREAMLIT_USER_ID) if _STREAMLIT_USER_ID else None +if _STREAMLIT_USER_ID and _STREAMLIT_RUN_ID: + set_context(_STREAMLIT_USER_ID, _STREAMLIT_RUN_ID) + +def reset_runtime_state(): + def _op(cur, _conn): + cur.execute( + "DELETE FROM mtm_ledger WHERE user_id = %s AND run_id = %s", + (_STREAMLIT_USER_ID, _STREAMLIT_RUN_ID), + ) + cur.execute( + "DELETE FROM event_ledger WHERE user_id = %s AND run_id = %s", + (_STREAMLIT_USER_ID, _STREAMLIT_RUN_ID), + ) + cur.execute( + "DELETE FROM engine_state WHERE user_id = %s AND run_id = %s", + (_STREAMLIT_USER_ID, _STREAMLIT_RUN_ID), + ) + insert_engine_event(cur, "LIVE_RESET", data={}) + + run_with_retry(_op) + +def load_mtm_df(): + with db_connection() as conn: + return pd.read_sql_query( + "SELECT timestamp, pnl FROM mtm_ledger WHERE user_id = %s AND run_id = %s ORDER BY timestamp", + conn, + params=(_STREAMLIT_USER_ID, _STREAMLIT_RUN_ID), + ) + +def is_engine_running(): + state = load_state(mode="LIVE") + return state.get("total_invested", 0) > 0 or \ + state.get("nifty_units", 0) > 0 or \ + state.get("gold_units", 0) > 0 + +if "engine_active" not in st.session_state: + st.session_state.engine_active = is_engine_running() + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" +SMA_MONTHS = 36 + +def get_prices(): + try: + nifty = fetch_live_price(NIFTY) + gold = fetch_live_price(GOLD) + return nifty, gold + except Exception as e: + st.error(e) + return None, None + +SIP_AMOUNT = st.number_input("SIP Amount (\u20B9)", 500, 100000, 5000) +SIP_INTERVAL_SEC = st.number_input("SIP Interval (sec) [TEST]", 30, 3600, 120) +REFRESH_SEC = st.slider("Refresh interval (sec)", 5, 60, 10) + +st.title("SIPXAR INDIA - Phase-1 Safe Engine") + +market_open, market_time = india_market_status() +st.info(f"NSE Market {'OPEN' if market_open else 'CLOSED'} | IST {market_time}") +if not market_open: + st.info("Market is closed. Portfolio values are frozen at last available prices.") + +col1, col2 = st.columns(2) + +with col1: + if st.button("START ENGINE"): + if is_engine_running(): + st.info("Engine already running. Resuming.") + st.session_state.engine_active = True + else: + st.session_state.engine_active = True + + # HARD RESET ONLY ON FIRST START + reset_runtime_state() + + save_state({ + "total_invested": 0.0, + "nifty_units": 0.0, + "gold_units": 0.0, + "last_sip_ts": None, + }, mode="LIVE", emit_event=True, event_meta={"source": "streamlit_start"}) + + st.success("Engine started") + +with col2: + if st.button("KILL ENGINE"): + st.session_state.engine_active = False + + reset_runtime_state() + + st.warning("Engine killed and state wiped") + st.stop() + +if not st.session_state.engine_active: + st.stop() + +state = load_state(mode="LIVE") +nifty_price, gold_price = get_prices() + +if nifty_price is None: + st.stop() + +st.subheader("Latest Market Prices (LTP)") + +c1, c2 = st.columns(2) + +with c1: + st.metric( + label="NIFTYBEES", + value=f"\u20B9{nifty_price:,.2f}", + help="Last traded price (delayed)" + ) + +with c2: + st.metric( + label="GOLDBEES", + value=f"\u20B9{gold_price:,.2f}", + help="Last traded price (delayed)" + ) + +st.caption(f"Price timestamp: {datetime.now().strftime('%H:%M:%S')}") + +nifty_hist = load_monthly_close(NIFTY) +gold_hist = load_monthly_close(GOLD) + +nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1] +gold_sma = gold_hist.rolling(SMA_MONTHS).mean().iloc[-1] + +eq_w, gd_w = allocation( + sp_price=nifty_price, + gd_price=gold_price, + sp_sma=nifty_sma, + gd_sma=gold_sma +) + +state, executed = try_execute_sip( + now=datetime.now(), + market_open=market_open, + sip_interval=SIP_INTERVAL_SEC, + sip_amount=SIP_AMOUNT, + sp_price=nifty_price, + gd_price=gold_price, + eq_w=eq_w, + gd_w=gd_w, + mode="LIVE", +) + +now = datetime.now() + +if market_open and should_log_mtm(None, now): + portfolio_value, pnl = log_mtm( + nifty_units=state["nifty_units"], + gold_units=state["gold_units"], + nifty_price=nifty_price, + gold_price=gold_price, + total_invested=state["total_invested"], + ) +else: + # Market closed -> freeze valuation (do NOT log) + portfolio_value = ( + state["nifty_units"] * nifty_price + + state["gold_units"] * gold_price + ) + pnl = portfolio_value - state["total_invested"] + +st.subheader("Equity Curve (Unrealized PnL)") + +mtm_df = load_mtm_df() + +if "timestamp" in mtm_df.columns and "pnl" in mtm_df.columns and len(mtm_df) > 1: + mtm_df["timestamp"] = pd.to_datetime(mtm_df["timestamp"]) + mtm_df = mtm_df.sort_values("timestamp").set_index("timestamp") + + st.line_chart(mtm_df["pnl"], height=350) +else: + st.warning("Not enough MTM data or missing columns. Expected: timestamp, pnl.") + +st.metric("Total Invested", f"\u20B9{state['total_invested']:,.0f}") +st.metric("NIFTY Units", round(state["nifty_units"], 4)) +st.metric("Gold Units", round(state["gold_units"], 4)) +st.metric("Portfolio Value", f"\u20B9{portfolio_value:,.0f}") +st.metric("PnL", f"\u20B9{pnl:,.0f}") + +time.sleep(REFRESH_SEC) +st.rerun() + diff --git a/indian_paper_trading_strategy/engine/__init__.py b/indian_paper_trading_strategy/engine/__init__.py new file mode 100644 index 0000000..bf0d74e --- /dev/null +++ b/indian_paper_trading_strategy/engine/__init__.py @@ -0,0 +1 @@ +"""Engine package for the India paper trading strategy.""" diff --git a/indian_paper_trading_strategy/engine/broker.py b/indian_paper_trading_strategy/engine/broker.py new file mode 100644 index 0000000..cc3ec87 --- /dev/null +++ b/indian_paper_trading_strategy/engine/broker.py @@ -0,0 +1,697 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib + +from psycopg2.extras import execute_values + +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context + + +class Broker(ABC): + @abstractmethod + def place_order( + self, + symbol: str, + side: str, + quantity: float, + price: float | None = None, + logical_time: datetime | None = None, + ): + raise NotImplementedError + + @abstractmethod + def get_positions(self): + raise NotImplementedError + + @abstractmethod + def get_orders(self): + raise NotImplementedError + + @abstractmethod + def get_funds(self): + raise NotImplementedError + + +def _local_tz(): + return datetime.now().astimezone().tzinfo + + +def _format_utc_ts(value: datetime | None): + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=_local_tz()) + return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _format_local_ts(value: datetime | None): + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=_local_tz()) + return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() + + +def _parse_ts(value, assume_local: bool = True): + if value is None: + return None + if isinstance(value, datetime): + if value.tzinfo is None: + return value.replace(tzinfo=_local_tz() if assume_local else timezone.utc) + return value + if isinstance(value, str): + text = value.strip() + if not text: + return None + if text.endswith("Z"): + try: + return datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + return None + try: + parsed = datetime.fromisoformat(text) + except ValueError: + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=_local_tz() if assume_local else timezone.utc) + return parsed + return None + + +def _stable_num(value: float) -> str: + return f"{float(value):.12f}" + + +def _normalize_ts_for_id(ts: datetime) -> str: + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + return ts.astimezone(timezone.utc).replace(microsecond=0).isoformat() + + +def _deterministic_id(prefix: str, parts: list[str]) -> str: + payload = "|".join(parts) + digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16] + return f"{prefix}_{digest}" + + +def _resolve_scope(user_id: str | None, run_id: str | None): + return get_context(user_id, run_id) + + +@dataclass +class PaperBroker(Broker): + initial_cash: float + store_path: str | None = None + + def _default_store(self): + return { + "cash": float(self.initial_cash), + "positions": {}, + "orders": [], + "trades": [], + "equity_curve": [], + } + + def _load_store(self, cur=None, for_update: bool = False, user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = _resolve_scope(user_id, run_id) + if cur is None: + with db_connection() as conn: + with conn.cursor() as cur: + return self._load_store( + cur=cur, + for_update=for_update, + user_id=scope_user, + run_id=scope_run, + ) + + store = self._default_store() + lock_clause = " FOR UPDATE" if for_update else "" + cur.execute( + f"SELECT cash FROM paper_broker_account WHERE user_id = %s AND run_id = %s{lock_clause} LIMIT 1", + (scope_user, scope_run), + ) + row = cur.fetchone() + if row and row[0] is not None: + store["cash"] = float(row[0]) + + cur.execute( + f""" + SELECT symbol, qty, avg_price, last_price + FROM paper_position + WHERE user_id = %s AND run_id = %s{lock_clause} + """ + , + (scope_user, scope_run), + ) + positions = {} + for symbol, qty, avg_price, last_price in cur.fetchall(): + positions[symbol] = { + "qty": float(qty) if qty is not None else 0.0, + "avg_price": float(avg_price) if avg_price is not None else 0.0, + "last_price": float(last_price) if last_price is not None else 0.0, + } + store["positions"] = positions + + cur.execute( + """ + SELECT id, symbol, side, qty, price, status, timestamp, logical_time + FROM paper_order + WHERE user_id = %s AND run_id = %s + ORDER BY timestamp, id + """ + , + (scope_user, scope_run), + ) + orders = [] + for order_id, symbol, side, qty, price, status, ts, logical_ts in cur.fetchall(): + orders.append( + { + "id": order_id, + "symbol": symbol, + "side": side, + "qty": float(qty) if qty is not None else 0.0, + "price": float(price) if price is not None else 0.0, + "status": status, + "timestamp": _format_utc_ts(ts), + "_logical_time": _format_utc_ts(logical_ts), + } + ) + store["orders"] = orders + + cur.execute( + """ + SELECT id, order_id, symbol, side, qty, price, timestamp, logical_time + FROM paper_trade + WHERE user_id = %s AND run_id = %s + ORDER BY timestamp, id + """ + , + (scope_user, scope_run), + ) + trades = [] + for trade_id, order_id, symbol, side, qty, price, ts, logical_ts in cur.fetchall(): + trades.append( + { + "id": trade_id, + "order_id": order_id, + "symbol": symbol, + "side": side, + "qty": float(qty) if qty is not None else 0.0, + "price": float(price) if price is not None else 0.0, + "timestamp": _format_utc_ts(ts), + "_logical_time": _format_utc_ts(logical_ts), + } + ) + store["trades"] = trades + + cur.execute( + """ + SELECT timestamp, logical_time, equity, pnl + FROM paper_equity_curve + WHERE user_id = %s AND run_id = %s + ORDER BY timestamp + """ + , + (scope_user, scope_run), + ) + equity_curve = [] + for ts, logical_ts, equity, pnl in cur.fetchall(): + equity_curve.append( + { + "timestamp": _format_local_ts(ts), + "_logical_time": _format_local_ts(logical_ts), + "equity": float(equity) if equity is not None else 0.0, + "pnl": float(pnl) if pnl is not None else 0.0, + } + ) + store["equity_curve"] = equity_curve + return store + + def _save_store(self, store, cur=None, user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = _resolve_scope(user_id, run_id) + if cur is None: + def _persist(cur, _conn): + self._save_store(store, cur=cur, user_id=scope_user, run_id=scope_run) + return run_with_retry(_persist) + + cash = store.get("cash") + if cash is not None: + cur.execute( + """ + INSERT INTO paper_broker_account (user_id, run_id, cash) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET cash = EXCLUDED.cash + """, + (scope_user, scope_run, float(cash)), + ) + + positions = store.get("positions") + if isinstance(positions, dict): + symbols = [s for s in positions.keys() if s] + if symbols: + cur.execute( + "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s AND symbol NOT IN %s", + (scope_user, scope_run, tuple(symbols)), + ) + else: + cur.execute( + "DELETE FROM paper_position WHERE user_id = %s AND run_id = %s", + (scope_user, scope_run), + ) + + if symbols: + rows = [] + updated_at = datetime.now(timezone.utc) + for symbol, data in positions.items(): + if not symbol or not isinstance(data, dict): + continue + rows.append( + ( + scope_user, + scope_run, + symbol, + float(data.get("qty", 0.0)), + float(data.get("avg_price", 0.0)), + float(data.get("last_price", 0.0)), + updated_at, + ) + ) + if rows: + execute_values( + cur, + """ + INSERT INTO paper_position ( + user_id, run_id, symbol, qty, avg_price, last_price, updated_at + ) + VALUES %s + ON CONFLICT (user_id, run_id, symbol) DO UPDATE + SET qty = EXCLUDED.qty, + avg_price = EXCLUDED.avg_price, + last_price = EXCLUDED.last_price, + updated_at = EXCLUDED.updated_at + """, + rows, + ) + + orders = store.get("orders") + if isinstance(orders, list) and orders: + rows = [] + for order in orders: + if not isinstance(order, dict): + continue + order_id = order.get("id") + if not order_id: + continue + ts = _parse_ts(order.get("timestamp"), assume_local=False) + logical_ts = _parse_ts(order.get("_logical_time"), assume_local=False) or ts + rows.append( + ( + scope_user, + scope_run, + order_id, + order.get("symbol"), + order.get("side"), + float(order.get("qty", 0.0)), + float(order.get("price", 0.0)), + order.get("status"), + ts, + logical_ts, + ) + ) + if rows: + execute_values( + cur, + """ + INSERT INTO paper_order ( + user_id, run_id, id, symbol, side, qty, price, status, timestamp, logical_time + ) + VALUES %s + ON CONFLICT DO NOTHING + """, + rows, + ) + + trades = store.get("trades") + if isinstance(trades, list) and trades: + rows = [] + for trade in trades: + if not isinstance(trade, dict): + continue + trade_id = trade.get("id") + if not trade_id: + continue + ts = _parse_ts(trade.get("timestamp"), assume_local=False) + logical_ts = _parse_ts(trade.get("_logical_time"), assume_local=False) or ts + rows.append( + ( + scope_user, + scope_run, + trade_id, + trade.get("order_id"), + trade.get("symbol"), + trade.get("side"), + float(trade.get("qty", 0.0)), + float(trade.get("price", 0.0)), + ts, + logical_ts, + ) + ) + if rows: + execute_values( + cur, + """ + INSERT INTO paper_trade ( + user_id, run_id, id, order_id, symbol, side, qty, price, timestamp, logical_time + ) + VALUES %s + ON CONFLICT DO NOTHING + """, + rows, + ) + + equity_curve = store.get("equity_curve") + if isinstance(equity_curve, list) and equity_curve: + rows = [] + for point in equity_curve: + if not isinstance(point, dict): + continue + ts = _parse_ts(point.get("timestamp"), assume_local=True) + logical_ts = _parse_ts(point.get("_logical_time"), assume_local=True) or ts + if ts is None: + continue + rows.append( + ( + scope_user, + scope_run, + ts, + logical_ts, + float(point.get("equity", 0.0)), + float(point.get("pnl", 0.0)), + ) + ) + if rows: + execute_values( + cur, + """ + INSERT INTO paper_equity_curve (user_id, run_id, timestamp, logical_time, equity, pnl) + VALUES %s + ON CONFLICT DO NOTHING + """, + rows, + ) + + def get_funds(self, cur=None): + store = self._load_store(cur=cur) + cash = float(store.get("cash", 0)) + positions = store.get("positions", {}) + positions_value = 0.0 + for position in positions.values(): + qty = float(position.get("qty", 0)) + last_price = float(position.get("last_price", position.get("avg_price", 0))) + positions_value += qty * last_price + total_equity = cash + positions_value + return { + "cash_available": cash, + "invested_value": positions_value, + "cash": cash, + "used_margin": 0.0, + "available": cash, + "net": total_equity, + "total_equity": total_equity, + } + + def get_positions(self, cur=None): + store = self._load_store(cur=cur) + positions = store.get("positions", {}) + return [ + { + "symbol": symbol, + "qty": float(data.get("qty", 0)), + "avg_price": float(data.get("avg_price", 0)), + "last_price": float(data.get("last_price", data.get("avg_price", 0))), + } + for symbol, data in positions.items() + ] + + def get_orders(self, cur=None): + store = self._load_store(cur=cur) + orders = [] + for order in store.get("orders", []): + if isinstance(order, dict): + order = {k: v for k, v in order.items() if k != "_logical_time"} + orders.append(order) + return orders + + def get_trades(self, cur=None): + store = self._load_store(cur=cur) + trades = [] + for trade in store.get("trades", []): + if isinstance(trade, dict): + trade = {k: v for k, v in trade.items() if k != "_logical_time"} + trades.append(trade) + return trades + + def get_equity_curve(self, cur=None): + store = self._load_store(cur=cur) + points = [] + for point in store.get("equity_curve", []): + if isinstance(point, dict): + point = {k: v for k, v in point.items() if k != "_logical_time"} + points.append(point) + return points + + def _update_equity_in_tx( + self, + cur, + prices: dict[str, float], + now: datetime, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + store = self._load_store(cur=cur, for_update=True, user_id=user_id, run_id=run_id) + positions = store.get("positions", {}) + for symbol, price in prices.items(): + if symbol in positions: + positions[symbol]["last_price"] = float(price) + + cash = float(store.get("cash", 0)) + positions_value = 0.0 + for symbol, position in positions.items(): + qty = float(position.get("qty", 0)) + price = float(position.get("last_price", position.get("avg_price", 0))) + positions_value += qty * price + + equity = cash + positions_value + pnl = equity - float(self.initial_cash) + ts_for_equity = logical_time or now + store.setdefault("equity_curve", []).append( + { + "timestamp": _format_local_ts(ts_for_equity), + "_logical_time": _format_local_ts(ts_for_equity), + "equity": equity, + "pnl": pnl, + } + ) + store["positions"] = positions + self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) + insert_engine_event( + cur, + "EQUITY_UPDATED", + data={ + "timestamp": _format_utc_ts(ts_for_equity), + "equity": equity, + "pnl": pnl, + }, + ) + return equity + + def update_equity( + self, + prices: dict[str, float], + now: datetime, + cur=None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + if cur is not None: + return self._update_equity_in_tx( + cur, + prices, + now, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + def _op(cur, _conn): + return self._update_equity_in_tx( + cur, + prices, + now, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + return run_with_retry(_op) + + def _place_order_in_tx( + self, + cur, + symbol: str, + side: str, + quantity: float, + price: float | None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + scope_user, scope_run = _resolve_scope(user_id, run_id) + store = self._load_store(cur=cur, for_update=True, user_id=scope_user, run_id=scope_run) + side = side.upper().strip() + qty = float(quantity) + if price is None: + price = fetch_live_price(symbol) + price = float(price) + + logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) + timestamp = logical_ts + timestamp_str = _format_utc_ts(timestamp) + logical_ts_str = _format_utc_ts(logical_ts) + order_id = _deterministic_id( + "ord", + [ + scope_user, + scope_run, + _normalize_ts_for_id(logical_ts), + symbol, + side, + _stable_num(qty), + _stable_num(price), + ], + ) + + order = { + "id": order_id, + "symbol": symbol, + "side": side, + "qty": qty, + "price": price, + "status": "REJECTED", + "timestamp": timestamp_str, + "_logical_time": logical_ts_str, + } + + if qty <= 0 or price <= 0: + store.setdefault("orders", []).append(order) + self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) + insert_engine_event(cur, "ORDER_PLACED", data=order) + return order + + positions = store.get("positions", {}) + cash = float(store.get("cash", 0)) + trade = None + + if side == "BUY": + cost = qty * price + if cash >= cost: + cash -= cost + existing = positions.get(symbol, {"qty": 0.0, "avg_price": 0.0, "last_price": price}) + new_qty = float(existing.get("qty", 0)) + qty + prev_cost = float(existing.get("qty", 0)) * float(existing.get("avg_price", 0)) + avg_price = (prev_cost + cost) / new_qty if new_qty else price + positions[symbol] = { + "qty": new_qty, + "avg_price": avg_price, + "last_price": price, + } + order["status"] = "FILLED" + trade = { + "id": _deterministic_id("trd", [order_id]), + "order_id": order_id, + "symbol": symbol, + "side": side, + "qty": qty, + "price": price, + "timestamp": timestamp_str, + "_logical_time": logical_ts_str, + } + store.setdefault("trades", []).append(trade) + elif side == "SELL": + existing = positions.get(symbol) + if existing and float(existing.get("qty", 0)) >= qty: + cash += qty * price + remaining = float(existing.get("qty", 0)) - qty + if remaining > 0: + existing["qty"] = remaining + existing["last_price"] = price + positions[symbol] = existing + else: + positions.pop(symbol, None) + order["status"] = "FILLED" + trade = { + "id": _deterministic_id("trd", [order_id]), + "order_id": order_id, + "symbol": symbol, + "side": side, + "qty": qty, + "price": price, + "timestamp": timestamp_str, + "_logical_time": logical_ts_str, + } + store.setdefault("trades", []).append(trade) + + store["cash"] = cash + store["positions"] = positions + store.setdefault("orders", []).append(order) + self._save_store(store, cur=cur, user_id=user_id, run_id=run_id) + insert_engine_event(cur, "ORDER_PLACED", data=order) + if trade is not None: + insert_engine_event(cur, "TRADE_EXECUTED", data=trade) + insert_engine_event(cur, "ORDER_FILLED", data={"order_id": order_id}) + return order + + def place_order( + self, + symbol: str, + side: str, + quantity: float, + price: float | None = None, + cur=None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + if cur is not None: + return self._place_order_in_tx( + cur, + symbol, + side, + quantity, + price, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + def _op(cur, _conn): + return self._place_order_in_tx( + cur, + symbol, + side, + quantity, + price, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + return run_with_retry(_op) + diff --git a/indian_paper_trading_strategy/engine/config.py b/indian_paper_trading_strategy/engine/config.py new file mode 100644 index 0000000..9321b89 --- /dev/null +++ b/indian_paper_trading_strategy/engine/config.py @@ -0,0 +1,150 @@ +import json +from datetime import datetime + +from indian_paper_trading_strategy.engine.db import db_connection, get_context + +DEFAULT_CONFIG = { + "active": False, + "sip_amount": 0, + "sip_frequency": {"value": 30, "unit": "days"}, + "next_run": None +} + +def _maybe_parse_json(value): + if value is None: + return None + if not isinstance(value, str): + return value + text = value.strip() + if not text: + return None + try: + return json.loads(text) + except Exception: + return value + + +def _format_ts(value: datetime | None): + if value is None: + return None + return value.isoformat() + + +def load_strategy_config(user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = get_context(user_id, run_id) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT strategy, sip_amount, sip_frequency_value, sip_frequency_unit, + mode, broker, active, frequency, frequency_days, unit, next_run + FROM strategy_config + WHERE user_id = %s AND run_id = %s + LIMIT 1 + """, + (scope_user, scope_run), + ) + row = cur.fetchone() + if not row: + return DEFAULT_CONFIG.copy() + + cfg = DEFAULT_CONFIG.copy() + cfg["strategy"] = row[0] + cfg["strategy_name"] = row[0] + cfg["sip_amount"] = float(row[1]) if row[1] is not None else cfg.get("sip_amount") + cfg["mode"] = row[4] + cfg["broker"] = row[5] + cfg["active"] = row[6] if row[6] is not None else cfg.get("active") + cfg["frequency"] = _maybe_parse_json(row[7]) + cfg["frequency_days"] = row[8] + cfg["unit"] = row[9] + cfg["next_run"] = _format_ts(row[10]) + if row[2] is not None or row[3] is not None: + cfg["sip_frequency"] = {"value": row[2], "unit": row[3]} + else: + value = cfg.get("frequency") + unit = cfg.get("unit") + if isinstance(value, dict): + unit = value.get("unit", unit) + value = value.get("value") + if value is None and cfg.get("frequency_days") is not None: + value = cfg.get("frequency_days") + unit = unit or "days" + if value is not None and unit: + cfg["sip_frequency"] = {"value": value, "unit": unit} + return cfg + +def save_strategy_config(cfg, user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = get_context(user_id, run_id) + sip_frequency = cfg.get("sip_frequency") + sip_value = None + sip_unit = None + if isinstance(sip_frequency, dict): + sip_value = sip_frequency.get("value") + sip_unit = sip_frequency.get("unit") + + frequency = cfg.get("frequency") + if not isinstance(frequency, str) and frequency is not None: + frequency = json.dumps(frequency) + + next_run = cfg.get("next_run") + next_run_dt = None + if isinstance(next_run, str): + try: + next_run_dt = datetime.fromisoformat(next_run) + except ValueError: + next_run_dt = None + + strategy = cfg.get("strategy") or cfg.get("strategy_name") + + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO strategy_config ( + user_id, + run_id, + strategy, + sip_amount, + sip_frequency_value, + sip_frequency_unit, + mode, + broker, + active, + frequency, + frequency_days, + unit, + next_run + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET strategy = EXCLUDED.strategy, + sip_amount = EXCLUDED.sip_amount, + sip_frequency_value = EXCLUDED.sip_frequency_value, + sip_frequency_unit = EXCLUDED.sip_frequency_unit, + mode = EXCLUDED.mode, + broker = EXCLUDED.broker, + active = EXCLUDED.active, + frequency = EXCLUDED.frequency, + frequency_days = EXCLUDED.frequency_days, + unit = EXCLUDED.unit, + next_run = EXCLUDED.next_run + """, + ( + scope_user, + scope_run, + strategy, + cfg.get("sip_amount"), + sip_value, + sip_unit, + cfg.get("mode"), + cfg.get("broker"), + cfg.get("active"), + frequency, + cfg.get("frequency_days"), + cfg.get("unit"), + next_run_dt, + ), + ) + diff --git a/indian_paper_trading_strategy/engine/data.py b/indian_paper_trading_strategy/engine/data.py new file mode 100644 index 0000000..c60aba2 --- /dev/null +++ b/indian_paper_trading_strategy/engine/data.py @@ -0,0 +1,81 @@ +# engine/data.py +from datetime import datetime, timezone +from pathlib import Path +import os +import threading + +import pandas as pd +import yfinance as yf + +ENGINE_ROOT = Path(__file__).resolve().parents[1] +HISTORY_DIR = ENGINE_ROOT / "storage" / "history" +ALLOW_PRICE_CACHE = os.getenv("ALLOW_PRICE_CACHE", "0").strip().lower() in {"1", "true", "yes"} + +_LAST_PRICE: dict[str, dict[str, object]] = {} +_LAST_PRICE_LOCK = threading.Lock() + + +def _set_last_price(ticker: str, price: float, source: str): + now = datetime.now(timezone.utc) + with _LAST_PRICE_LOCK: + _LAST_PRICE[ticker] = {"price": float(price), "source": source, "ts": now} + + +def get_price_snapshot(ticker: str) -> dict[str, object] | None: + with _LAST_PRICE_LOCK: + data = _LAST_PRICE.get(ticker) + if not data: + return None + return dict(data) + + +def _get_last_live_price(ticker: str) -> float | None: + with _LAST_PRICE_LOCK: + data = _LAST_PRICE.get(ticker) + if not data: + return None + if data.get("source") == "live": + return float(data.get("price", 0)) + return None + + +def _cached_last_close(ticker: str) -> float | None: + file = HISTORY_DIR / f"{ticker}.csv" + if not file.exists(): + return None + df = pd.read_csv(file) + if df.empty or "Close" not in df.columns: + return None + return float(df["Close"].iloc[-1]) + + +def fetch_live_price(ticker, allow_cache: bool | None = None): + if allow_cache is None: + allow_cache = ALLOW_PRICE_CACHE + try: + df = yf.download( + ticker, + period="1d", + interval="1m", + auto_adjust=True, + progress=False, + timeout=5, + ) + if df is not None and not df.empty: + price = float(df["Close"].iloc[-1]) + _set_last_price(ticker, price, "live") + return price + except Exception: + pass + + if allow_cache: + last_live = _get_last_live_price(ticker) + if last_live is not None: + return last_live + + cached = _cached_last_close(ticker) + if cached is not None: + _set_last_price(ticker, cached, "cache") + return cached + + raise RuntimeError(f"No live data for {ticker}") diff --git a/indian_paper_trading_strategy/engine/engine_runner.py b/indian_paper_trading_strategy/engine/engine_runner.py new file mode 100644 index 0000000..0be5f65 --- /dev/null +++ b/indian_paper_trading_strategy/engine/engine_runner.py @@ -0,0 +1,198 @@ +import time +from datetime import datetime, timezone + +from indian_paper_trading_strategy.engine.db import ( + run_with_retry, + insert_engine_event, + get_default_user_id, + get_active_run_id, + get_running_runs, + engine_context, +) + +def log_event(event: str, data: dict | None = None): + now = datetime.utcnow().replace(tzinfo=timezone.utc) + payload = data or {} + + def _op(cur, _conn): + insert_engine_event(cur, event, data=payload, ts=now) + + run_with_retry(_op) + +def _update_engine_status(user_id: str, run_id: str, status: str): + now = datetime.utcnow().replace(tzinfo=timezone.utc) + + def _op(cur, _conn): + cur.execute( + """ + INSERT INTO engine_status (user_id, run_id, status, last_updated) + VALUES (%s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET status = EXCLUDED.status, + last_updated = EXCLUDED.last_updated + """, + (user_id, run_id, status, now), + ) + + run_with_retry(_op) + +from indian_paper_trading_strategy.engine.config import load_strategy_config, save_strategy_config +from indian_paper_trading_strategy.engine.market import india_market_status +from indian_paper_trading_strategy.engine.execution import try_execute_sip +from indian_paper_trading_strategy.engine.state import load_state +from indian_paper_trading_strategy.engine.broker import PaperBroker +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm +from indian_paper_trading_strategy.engine.history import load_monthly_close +from indian_paper_trading_strategy.engine.strategy import allocation +from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta, normalize_logical_time + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" +SMA_MONTHS = 36 + +def run_engine(user_id: str | None = None, run_id: str | None = None): + print("Strategy engine started") + active_runs: dict[tuple[str, str], bool] = {} + + if run_id and not user_id: + raise ValueError("user_id is required when run_id is provided") + + while True: + try: + if user_id and run_id: + runs = [(user_id, run_id)] + elif user_id: + runs = get_running_runs(user_id) + else: + runs = get_running_runs() + if not runs: + default_user = get_default_user_id() + if default_user: + runs = get_running_runs(default_user) + + seen = set() + for scope_user, scope_run in runs: + if not scope_user or not scope_run: + continue + seen.add((scope_user, scope_run)) + with engine_context(scope_user, scope_run): + cfg = load_strategy_config(user_id=scope_user, run_id=scope_run) + if not cfg.get("active"): + continue + + strategy_name = cfg.get("strategy_name", "golden_nifty") + sip_amount = cfg.get("sip_amount", 0) + configured_frequency = cfg.get("sip_frequency") or {} + if not isinstance(configured_frequency, dict): + configured_frequency = {} + frequency_value = int(configured_frequency.get("value", cfg.get("frequency", 0))) + frequency_unit = configured_frequency.get("unit", cfg.get("unit", "days")) + frequency_info = {"value": frequency_value, "unit": frequency_unit} + frequency_label = f"{frequency_value} {frequency_unit}" + + if not active_runs.get((scope_user, scope_run)): + log_event( + "ENGINE_START", + { + "strategy": strategy_name, + "sip_amount": sip_amount, + "frequency": frequency_label, + }, + ) + active_runs[(scope_user, scope_run)] = True + + _update_engine_status(scope_user, scope_run, "RUNNING") + + market_open, _ = india_market_status() + if not market_open: + log_event("MARKET_CLOSED", {"reason": "Outside market hours"}) + continue + + now = datetime.now() + mode = (cfg.get("mode") or "PAPER").strip().upper() + if mode not in {"PAPER", "LIVE"}: + mode = "PAPER" + state = load_state(mode=mode) + initial_cash = float(state.get("initial_cash") or 0.0) + broker = PaperBroker(initial_cash=initial_cash) if mode == "PAPER" else None + + nifty_price = fetch_live_price(NIFTY) + gold_price = fetch_live_price(GOLD) + + next_run = cfg.get("next_run") + if next_run is None or now >= datetime.fromisoformat(next_run): + nifty_hist = load_monthly_close(NIFTY) + gold_hist = load_monthly_close(GOLD) + + nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1] + gold_sma = gold_hist.rolling(SMA_MONTHS).mean().iloc[-1] + + eq_w, gd_w = allocation( + sp_price=nifty_price, + gd_price=gold_price, + sp_sma=nifty_sma, + gd_sma=gold_sma, + ) + + weights = {"equity": eq_w, "gold": gd_w} + state, executed = try_execute_sip( + now=now, + market_open=True, + sip_interval=frequency_to_timedelta(frequency_info).total_seconds(), + sip_amount=sip_amount, + sp_price=nifty_price, + gd_price=gold_price, + eq_w=eq_w, + gd_w=gd_w, + broker=broker, + mode=mode, + ) + + if executed: + log_event( + "SIP_TRIGGERED", + { + "date": now.date().isoformat(), + "allocation": weights, + "cash_used": sip_amount, + }, + ) + portfolio_value = ( + state["nifty_units"] * nifty_price + + state["gold_units"] * gold_price + ) + log_event( + "PORTFOLIO_UPDATED", + { + "nifty_units": state["nifty_units"], + "gold_units": state["gold_units"], + "portfolio_value": portfolio_value, + }, + ) + cfg["next_run"] = (now + frequency_to_timedelta(frequency_info)).isoformat() + save_strategy_config(cfg, user_id=scope_user, run_id=scope_run) + + if should_log_mtm(None, now): + state = load_state(mode=mode) + log_mtm( + nifty_units=state["nifty_units"], + gold_units=state["gold_units"], + nifty_price=nifty_price, + gold_price=gold_price, + total_invested=state["total_invested"], + logical_time=normalize_logical_time(now), + ) + + for key in list(active_runs.keys()): + if key not in seen: + active_runs.pop(key, None) + + time.sleep(30) + except Exception as e: + log_event("ENGINE_ERROR", {"error": str(e)}) + raise + +if __name__ == "__main__": + run_engine() + diff --git a/indian_paper_trading_strategy/engine/execution.py b/indian_paper_trading_strategy/engine/execution.py new file mode 100644 index 0000000..e135a24 --- /dev/null +++ b/indian_paper_trading_strategy/engine/execution.py @@ -0,0 +1,157 @@ +# engine/execution.py +from datetime import datetime, timezone +from indian_paper_trading_strategy.engine.state import load_state, save_state + +from indian_paper_trading_strategy.engine.broker import Broker +from indian_paper_trading_strategy.engine.ledger import log_event, event_exists +from indian_paper_trading_strategy.engine.db import run_with_retry +from indian_paper_trading_strategy.engine.time_utils import compute_logical_time + +def _as_float(value): + if hasattr(value, "item"): + try: + return float(value.item()) + except Exception: + pass + if hasattr(value, "iloc"): + try: + return float(value.iloc[-1]) + except Exception: + pass + return float(value) + +def _local_tz(): + return datetime.now().astimezone().tzinfo + +def try_execute_sip( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, + broker: Broker | None = None, + mode: str | None = "LIVE", +): + def _op(cur, _conn): + if now.tzinfo is None: + now_ts = now.replace(tzinfo=_local_tz()) + else: + now_ts = now + event_ts = now_ts + log_event("DEBUG_ENTER_TRY_EXECUTE", { + "now": now_ts.isoformat(), + }, cur=cur, ts=event_ts) + + state = load_state(mode=mode, cur=cur, for_update=True) + + force_execute = state.get("last_sip_ts") is None + + if not market_open: + return state, False + + last = state.get("last_sip_ts") or state.get("last_run") + if last and not force_execute: + try: + last_dt = datetime.fromisoformat(last) + except ValueError: + last_dt = None + if last_dt: + if last_dt.tzinfo is None: + last_dt = last_dt.replace(tzinfo=_local_tz()) + if now_ts.tzinfo and last_dt.tzinfo and last_dt.tzinfo != now_ts.tzinfo: + last_dt = last_dt.astimezone(now_ts.tzinfo) + if last_dt and (now_ts - last_dt).total_seconds() < sip_interval: + return state, False + + logical_time = compute_logical_time(now_ts, last, sip_interval) + if event_exists("SIP_EXECUTED", logical_time, cur=cur): + return state, False + + sp_price_val = _as_float(sp_price) + gd_price_val = _as_float(gd_price) + eq_w_val = _as_float(eq_w) + gd_w_val = _as_float(gd_w) + sip_amount_val = _as_float(sip_amount) + + nifty_qty = (sip_amount_val * eq_w_val) / sp_price_val + gold_qty = (sip_amount_val * gd_w_val) / gd_price_val + + if broker is None: + return state, False + + funds = broker.get_funds(cur=cur) + cash = funds.get("cash") + if cash is not None and float(cash) < sip_amount_val: + return state, False + + log_event("DEBUG_EXECUTION_DECISION", { + "force_execute": force_execute, + "last_sip_ts": state.get("last_sip_ts"), + "now": now_ts.isoformat(), + }, cur=cur, ts=event_ts) + + nifty_order = broker.place_order( + "NIFTYBEES.NS", + "BUY", + nifty_qty, + sp_price_val, + cur=cur, + logical_time=logical_time, + ) + gold_order = broker.place_order( + "GOLDBEES.NS", + "BUY", + gold_qty, + gd_price_val, + cur=cur, + logical_time=logical_time, + ) + orders = [nifty_order, gold_order] + executed = all( + isinstance(order, dict) and order.get("status") == "FILLED" + for order in orders + ) + if not executed: + return state, False + assert len(orders) > 0, "executed=True but no broker orders placed" + + funds_after = broker.get_funds(cur=cur) + cash_after = funds_after.get("cash") + if cash_after is not None: + state["cash"] = float(cash_after) + + state["nifty_units"] += nifty_qty + state["gold_units"] += gold_qty + state["total_invested"] += sip_amount_val + state["last_sip_ts"] = now_ts.isoformat() + state["last_run"] = now_ts.isoformat() + + save_state( + state, + mode=mode, + cur=cur, + emit_event=True, + event_meta={"source": "sip"}, + ) + + log_event( + "SIP_EXECUTED", + { + "nifty_units": nifty_qty, + "gold_units": gold_qty, + "nifty_price": sp_price_val, + "gold_price": gd_price_val, + "amount": sip_amount_val, + }, + cur=cur, + ts=event_ts, + logical_time=logical_time, + ) + + return state, True + + return run_with_retry(_op) + diff --git a/indian_paper_trading_strategy/engine/history.py b/indian_paper_trading_strategy/engine/history.py new file mode 100644 index 0000000..28e4697 --- /dev/null +++ b/indian_paper_trading_strategy/engine/history.py @@ -0,0 +1,34 @@ +# engine/history.py +import yfinance as yf +import pandas as pd +from pathlib import Path + +ENGINE_ROOT = Path(__file__).resolve().parents[1] +STORAGE_DIR = ENGINE_ROOT / "storage" +STORAGE_DIR.mkdir(exist_ok=True) + +CACHE_DIR = STORAGE_DIR / "history" +CACHE_DIR.mkdir(exist_ok=True) + +def load_monthly_close(ticker, years=10): + file = CACHE_DIR / f"{ticker}.csv" + + if file.exists(): + df = pd.read_csv(file, parse_dates=["Date"], index_col="Date") + return df["Close"] + + df = yf.download( + ticker, + period=f"{years}y", + auto_adjust=True, + progress=False, + timeout=5, + ) + + if df.empty: + raise RuntimeError(f"No history for {ticker}") + + series = df["Close"].resample("M").last() + series.to_csv(file, header=["Close"]) + + return series diff --git a/indian_paper_trading_strategy/engine/ledger.py b/indian_paper_trading_strategy/engine/ledger.py new file mode 100644 index 0000000..874a5fa --- /dev/null +++ b/indian_paper_trading_strategy/engine/ledger.py @@ -0,0 +1,113 @@ +# engine/ledger.py +from datetime import datetime, timezone + +from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context +from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time + + +def _event_exists_in_tx(cur, event, logical_time, user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = get_context(user_id, run_id) + logical_ts = normalize_logical_time(logical_time) + cur.execute( + """ + SELECT 1 + FROM event_ledger + WHERE user_id = %s AND run_id = %s AND event = %s AND logical_time = %s + LIMIT 1 + """, + (scope_user, scope_run, event, logical_ts), + ) + return cur.fetchone() is not None + + +def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, run_id: str | None = None): + if cur is not None: + return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id) + + def _op(cur, _conn): + return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id) + + return run_with_retry(_op) + + +def _log_event_in_tx( + cur, + event, + payload, + ts, + logical_time=None, + user_id: str | None = None, + run_id: str | None = None, +): + scope_user, scope_run = get_context(user_id, run_id) + logical_ts = normalize_logical_time(logical_time or ts) + cur.execute( + """ + INSERT INTO event_ledger ( + user_id, + run_id, + timestamp, + logical_time, + event, + nifty_units, + gold_units, + nifty_price, + gold_price, + amount + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT DO NOTHING + """, + ( + scope_user, + scope_run, + ts, + logical_ts, + event, + payload.get("nifty_units"), + payload.get("gold_units"), + payload.get("nifty_price"), + payload.get("gold_price"), + payload.get("amount"), + ), + ) + if cur.rowcount: + insert_engine_event(cur, event, data=payload, ts=ts) + + +def log_event( + event, + payload, + *, + cur=None, + ts=None, + logical_time=None, + user_id: str | None = None, + run_id: str | None = None, +): + now = ts or logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) + if cur is not None: + _log_event_in_tx( + cur, + event, + payload, + now, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + return + + def _op(cur, _conn): + _log_event_in_tx( + cur, + event, + payload, + now, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + return run_with_retry(_op) + diff --git a/indian_paper_trading_strategy/engine/market.py b/indian_paper_trading_strategy/engine/market.py new file mode 100644 index 0000000..c16f5de --- /dev/null +++ b/indian_paper_trading_strategy/engine/market.py @@ -0,0 +1,42 @@ +# engine/market.py +from datetime import datetime, time as dtime, timedelta +import pytz + +_MARKET_TZ = pytz.timezone("Asia/Kolkata") +_OPEN_T = dtime(9, 15) +_CLOSE_T = dtime(15, 30) + +def _as_market_tz(value: datetime) -> datetime: + if value.tzinfo is None: + return _MARKET_TZ.localize(value) + return value.astimezone(_MARKET_TZ) + +def is_market_open(now: datetime) -> bool: + now = _as_market_tz(now) + return now.weekday() < 5 and _OPEN_T <= now.time() <= _CLOSE_T + +def india_market_status(): + now = datetime.now(_MARKET_TZ) + + return is_market_open(now), now + +def next_market_open_after(value: datetime) -> datetime: + current = _as_market_tz(value) + while current.weekday() >= 5: + current = current + timedelta(days=1) + current = current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) + if current.time() < _OPEN_T: + return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) + if current.time() > _CLOSE_T: + current = current + timedelta(days=1) + while current.weekday() >= 5: + current = current + timedelta(days=1) + return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) + return current + +def align_to_market_open(value: datetime) -> datetime: + current = _as_market_tz(value) + aligned = current if is_market_open(current) else next_market_open_after(current) + if value.tzinfo is None: + return aligned.replace(tzinfo=None) + return aligned diff --git a/indian_paper_trading_strategy/engine/mtm.py b/indian_paper_trading_strategy/engine/mtm.py new file mode 100644 index 0000000..90d6b1a --- /dev/null +++ b/indian_paper_trading_strategy/engine/mtm.py @@ -0,0 +1,154 @@ +from datetime import datetime, timezone +from pathlib import Path + +from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context +from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time + +ENGINE_ROOT = Path(__file__).resolve().parents[1] +STORAGE_DIR = ENGINE_ROOT / "storage" +MTM_FILE = STORAGE_DIR / "mtm_ledger.csv" + +MTM_INTERVAL_SECONDS = 60 + +def _log_mtm_in_tx( + cur, + nifty_units, + gold_units, + nifty_price, + gold_price, + total_invested, + ts, + logical_time=None, + user_id: str | None = None, + run_id: str | None = None, +): + scope_user, scope_run = get_context(user_id, run_id) + logical_ts = normalize_logical_time(logical_time or ts) + nifty_value = nifty_units * nifty_price + gold_value = gold_units * gold_price + portfolio_value = nifty_value + gold_value + pnl = portfolio_value - total_invested + + row = { + "timestamp": ts.isoformat(), + "logical_time": logical_ts.isoformat(), + "nifty_units": nifty_units, + "gold_units": gold_units, + "nifty_price": nifty_price, + "gold_price": gold_price, + "nifty_value": nifty_value, + "gold_value": gold_value, + "portfolio_value": portfolio_value, + "total_invested": total_invested, + "pnl": pnl, + } + cur.execute( + """ + INSERT INTO mtm_ledger ( + user_id, + run_id, + timestamp, + logical_time, + nifty_units, + gold_units, + nifty_price, + gold_price, + nifty_value, + gold_value, + portfolio_value, + total_invested, + pnl + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT DO NOTHING + """, + ( + scope_user, + scope_run, + ts, + logical_ts, + row["nifty_units"], + row["gold_units"], + row["nifty_price"], + row["gold_price"], + row["nifty_value"], + row["gold_value"], + row["portfolio_value"], + row["total_invested"], + row["pnl"], + ), + ) + if cur.rowcount: + insert_engine_event(cur, "MTM_UPDATED", data=row, ts=ts) + return portfolio_value, pnl + +def log_mtm( + nifty_units, + gold_units, + nifty_price, + gold_price, + total_invested, + *, + cur=None, + logical_time=None, + user_id: str | None = None, + run_id: str | None = None, +): + ts = logical_time or datetime.now(timezone.utc) + if cur is not None: + return _log_mtm_in_tx( + cur, + nifty_units, + gold_units, + nifty_price, + gold_price, + total_invested, + ts, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + def _op(cur, _conn): + return _log_mtm_in_tx( + cur, + nifty_units, + gold_units, + nifty_price, + gold_price, + total_invested, + ts, + logical_time=logical_time, + user_id=user_id, + run_id=run_id, + ) + + return run_with_retry(_op) + +def _get_last_mtm_ts(user_id: str | None = None, run_id: str | None = None): + scope_user, scope_run = get_context(user_id, run_id) + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT MAX(timestamp) FROM mtm_ledger WHERE user_id = %s AND run_id = %s", + (scope_user, scope_run), + ) + row = cur.fetchone() + if not row or row[0] is None: + return None + return row[0].astimezone().replace(tzinfo=None) + +def should_log_mtm(df, now, user_id: str | None = None, run_id: str | None = None): + if df is None: + last_ts = _get_last_mtm_ts(user_id=user_id, run_id=run_id) + if last_ts is None: + return True + return (now - last_ts).total_seconds() >= MTM_INTERVAL_SECONDS + if getattr(df, "empty", False): + return True + try: + last_ts = datetime.fromisoformat(str(df.iloc[-1]["timestamp"])) + except Exception: + return True + return (now - last_ts).total_seconds() >= MTM_INTERVAL_SECONDS + diff --git a/indian_paper_trading_strategy/engine/runner.py b/indian_paper_trading_strategy/engine/runner.py new file mode 100644 index 0000000..e0e6a6c --- /dev/null +++ b/indian_paper_trading_strategy/engine/runner.py @@ -0,0 +1,518 @@ +import os +import threading +import time +from datetime import datetime, timedelta, timezone + +from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open +from indian_paper_trading_strategy.engine.execution import try_execute_sip +from indian_paper_trading_strategy.engine.broker import PaperBroker +from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm +from indian_paper_trading_strategy.engine.state import load_state +from indian_paper_trading_strategy.engine.data import fetch_live_price +from indian_paper_trading_strategy.engine.history import load_monthly_close +from indian_paper_trading_strategy.engine.strategy import allocation +from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time + +from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context + + +def _update_engine_status(user_id: str, run_id: str, status: str): + now = datetime.utcnow().replace(tzinfo=timezone.utc) + + def _op(cur, _conn): + cur.execute( + """ + INSERT INTO engine_status (user_id, run_id, status, last_updated) + VALUES (%s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET status = EXCLUDED.status, + last_updated = EXCLUDED.last_updated + """, + (user_id, run_id, status, now), + ) + + run_with_retry(_op) + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" +SMA_MONTHS = 36 + +_DEFAULT_ENGINE_STATE = { + "state": "STOPPED", + "run_id": None, + "user_id": None, + "last_heartbeat_ts": None, +} + +_ENGINE_STATES = {} +_ENGINE_STATES_LOCK = threading.Lock() + +_RUNNERS = {} +_RUNNERS_LOCK = threading.Lock() + +engine_state = _ENGINE_STATES + + +def _state_key(user_id: str, run_id: str): + return (user_id, run_id) + + +def _get_state(user_id: str, run_id: str): + key = _state_key(user_id, run_id) + with _ENGINE_STATES_LOCK: + state = _ENGINE_STATES.get(key) + if state is None: + state = dict(_DEFAULT_ENGINE_STATE) + state["user_id"] = user_id + state["run_id"] = run_id + _ENGINE_STATES[key] = state + return state + + +def _set_state(user_id: str, run_id: str, **updates): + key = _state_key(user_id, run_id) + with _ENGINE_STATES_LOCK: + state = _ENGINE_STATES.get(key) + if state is None: + state = dict(_DEFAULT_ENGINE_STATE) + state["user_id"] = user_id + state["run_id"] = run_id + _ENGINE_STATES[key] = state + state.update(updates) + + +def get_engine_state(user_id: str, run_id: str): + state = _get_state(user_id, run_id) + return dict(state) + +def log_event( + event: str, + data: dict | None = None, + message: str | None = None, + meta: dict | None = None, +): + entry = { + "ts": datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(), + "event": event, + } + if message is not None or meta is not None: + entry["message"] = message or "" + entry["meta"] = meta or {} + else: + entry["data"] = data or {} + event_ts = datetime.fromisoformat(entry["ts"].replace("Z", "+00:00")) + data = entry.get("data") if "data" in entry else None + meta = entry.get("meta") if "meta" in entry else None + + def _op(cur, _conn): + insert_engine_event( + cur, + entry.get("event"), + data=data, + message=entry.get("message"), + meta=meta, + ts=event_ts, + ) + + run_with_retry(_op) + +def sleep_with_heartbeat( + total_seconds: int, + stop_event: threading.Event, + user_id: str, + run_id: str, + step_seconds: int = 5, +): + remaining = total_seconds + while remaining > 0 and not stop_event.is_set(): + time.sleep(min(step_seconds, remaining)) + _set_state(user_id, run_id, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") + remaining -= step_seconds + +def _clear_runner(user_id: str, run_id: str): + key = _state_key(user_id, run_id) + with _RUNNERS_LOCK: + _RUNNERS.pop(key, None) + +def can_execute(now: datetime) -> tuple[bool, str]: + if not is_market_open(now): + return False, "MARKET_CLOSED" + return True, "OK" + +def _engine_loop(config, stop_event: threading.Event): + print("Strategy engine started with config:", config) + + user_id = config.get("user_id") + run_id = config.get("run_id") + scope_user, scope_run = get_context(user_id, run_id) + set_context(scope_user, scope_run) + + strategy_name = config.get("strategy_name") or config.get("strategy") or "golden_nifty" + sip_amount = config["sip_amount"] + configured_frequency = config.get("sip_frequency") or {} + if not isinstance(configured_frequency, dict): + configured_frequency = {} + frequency_value = int(configured_frequency.get("value", config.get("frequency", 0))) + frequency_unit = configured_frequency.get("unit", config.get("unit", "days")) + frequency_label = f"{frequency_value} {frequency_unit}" + emit_event_cb = config.get("emit_event") + if not callable(emit_event_cb): + emit_event_cb = None + debug_enabled = os.getenv("ENGINE_DEBUG", "1").strip().lower() not in {"0", "false", "no"} + + def debug_event(event: str, message: str, meta: dict | None = None): + if not debug_enabled: + return + try: + log_event(event=event, message=message, meta=meta or {}) + except Exception: + pass + if emit_event_cb: + emit_event_cb(event=event, message=message, meta=meta or {}) + print(f"[ENGINE] {event} {message} {meta or {}}", flush=True) + mode = (config.get("mode") or "LIVE").strip().upper() + if mode not in {"PAPER", "LIVE"}: + mode = "LIVE" + broker_type = config.get("broker") or "paper" + if broker_type != "paper": + broker_type = "paper" + if broker_type == "paper": + mode = "PAPER" + initial_cash = float(config.get("initial_cash", 0)) + broker = PaperBroker(initial_cash=initial_cash) + log_event( + event="DEBUG_PAPER_STORE_PATH", + message="Paper broker store path", + meta={ + "cwd": os.getcwd(), + "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", + "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", + }, + ) + if emit_event_cb: + emit_event_cb( + event="DEBUG_PAPER_STORE_PATH", + message="Paper broker store path", + meta={ + "cwd": os.getcwd(), + "paper_store_path": str(broker.store_path) if hasattr(broker, "store_path") else "NO_STORE_PATH", + "abs_store_path": os.path.abspath(str(broker.store_path)) if hasattr(broker, "store_path") else "N/A", + }, + ) + + log_event("ENGINE_START", { + "strategy": strategy_name, + "sip_amount": sip_amount, + "frequency": frequency_label, + }) + debug_event("ENGINE_START_DEBUG", "engine loop started", {"run_id": scope_run, "user_id": scope_user}) + + _set_state( + scope_user, + scope_run, + state="RUNNING", + last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", + ) + _update_engine_status(scope_user, scope_run, "RUNNING") + + try: + while not stop_event.is_set(): + _set_state(scope_user, scope_run, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") + _update_engine_status(scope_user, scope_run, "RUNNING") + + state = load_state(mode=mode) + debug_event( + "STATE_LOADED", + "loaded engine state", + { + "last_sip_ts": state.get("last_sip_ts"), + "last_run": state.get("last_run"), + "cash": state.get("cash"), + "total_invested": state.get("total_invested"), + }, + ) + state_frequency = state.get("sip_frequency") + if not isinstance(state_frequency, dict): + state_frequency = {"value": frequency_value, "unit": frequency_unit} + freq = int(state_frequency.get("value", frequency_value)) + unit = state_frequency.get("unit", frequency_unit) + frequency_label = f"{freq} {unit}" + if unit == "minutes": + delta = timedelta(minutes=freq) + else: + delta = timedelta(days=freq) + + # Gate 2: time to SIP + last_run = state.get("last_run") or state.get("last_sip_ts") + is_first_run = last_run is None + now = datetime.now() + debug_event( + "ENGINE_LOOP_TICK", + "engine loop tick", + {"now": now.isoformat(), "frequency": frequency_label}, + ) + + if last_run and not is_first_run: + next_run = datetime.fromisoformat(last_run) + delta + next_run = align_to_market_open(next_run) + if now < next_run: + log_event( + event="SIP_WAITING", + message="Waiting for next SIP window", + meta={ + "last_run": last_run, + "next_eligible": next_run.isoformat(), + "now": now.isoformat(), + "frequency": frequency_label, + }, + ) + if emit_event_cb: + emit_event_cb( + event="SIP_WAITING", + message="Waiting for next SIP window", + meta={ + "last_run": last_run, + "next_eligible": next_run.isoformat(), + "now": now.isoformat(), + "frequency": frequency_label, + }, + ) + sleep_with_heartbeat(60, stop_event, scope_user, scope_run) + continue + + try: + debug_event("PRICE_FETCH_START", "fetching live prices", {"tickers": [NIFTY, GOLD]}) + nifty_price = fetch_live_price(NIFTY) + gold_price = fetch_live_price(GOLD) + debug_event( + "PRICE_FETCHED", + "fetched live prices", + {"nifty_price": float(nifty_price), "gold_price": float(gold_price)}, + ) + except Exception as exc: + debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)}) + sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + continue + + try: + nifty_hist = load_monthly_close(NIFTY) + gold_hist = load_monthly_close(GOLD) + except Exception as exc: + debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)}) + sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + continue + + nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1] + gold_sma = gold_hist.rolling(SMA_MONTHS).mean().iloc[-1] + + eq_w, gd_w = allocation( + sp_price=nifty_price, + gd_price=gold_price, + sp_sma=nifty_sma, + gd_sma=gold_sma + ) + debug_event( + "WEIGHTS_COMPUTED", + "computed allocation weights", + {"equity_weight": float(eq_w), "gold_weight": float(gd_w)}, + ) + + weights = {"equity": eq_w, "gold": gd_w} + allowed, reason = can_execute(now) + executed = False + if not allowed: + log_event( + event="EXECUTION_BLOCKED", + message="Execution blocked by market gate", + meta={ + "reason": reason, + "eligible_since": last_run, + "checked_at": now.isoformat(), + }, + ) + debug_event("MARKET_GATE", "market closed", {"reason": reason}) + if emit_event_cb: + emit_event_cb( + event="EXECUTION_BLOCKED", + message="Execution blocked by market gate", + meta={ + "reason": reason, + "eligible_since": last_run, + "checked_at": now.isoformat(), + }, + ) + else: + log_event( + event="DEBUG_BEFORE_TRY_EXECUTE", + message="About to call try_execute_sip", + meta={ + "last_run": last_run, + "frequency": frequency_label, + "allowed": allowed, + "reason": reason, + "sip_amount": sip_amount, + "broker": type(broker).__name__, + "now": now.isoformat(), + }, + ) + if emit_event_cb: + emit_event_cb( + event="DEBUG_BEFORE_TRY_EXECUTE", + message="About to call try_execute_sip", + meta={ + "last_run": last_run, + "frequency": frequency_label, + "allowed": allowed, + "reason": reason, + "sip_amount": sip_amount, + "broker": type(broker).__name__, + "now": now.isoformat(), + }, + ) + debug_event( + "TRY_EXECUTE_START", + "calling try_execute_sip", + {"sip_interval_sec": delta.total_seconds(), "sip_amount": sip_amount}, + ) + state, executed = try_execute_sip( + now=now, + market_open=True, + sip_interval=delta.total_seconds(), + sip_amount=sip_amount, + sp_price=nifty_price, + gd_price=gold_price, + eq_w=eq_w, + gd_w=gd_w, + broker=broker, + mode=mode, + ) + log_event( + event="DEBUG_AFTER_TRY_EXECUTE", + message="Returned from try_execute_sip", + meta={ + "executed": executed, + "state_last_run": state.get("last_run"), + "state_last_sip_ts": state.get("last_sip_ts"), + }, + ) + if emit_event_cb: + emit_event_cb( + event="DEBUG_AFTER_TRY_EXECUTE", + message="Returned from try_execute_sip", + meta={ + "executed": executed, + "state_last_run": state.get("last_run"), + "state_last_sip_ts": state.get("last_sip_ts"), + }, + ) + debug_event( + "TRY_EXECUTE_DONE", + "try_execute_sip finished", + {"executed": executed, "last_run": state.get("last_run")}, + ) + + if executed: + log_event("SIP_TRIGGERED", { + "date": now.date().isoformat(), + "allocation": weights, + "cash_used": sip_amount + }) + debug_event("SIP_TRIGGERED", "sip executed", {"cash_used": sip_amount}) + portfolio_value = ( + state["nifty_units"] * nifty_price + + state["gold_units"] * gold_price + ) + log_event("PORTFOLIO_UPDATED", { + "nifty_units": state["nifty_units"], + "gold_units": state["gold_units"], + "portfolio_value": portfolio_value + }) + print("SIP executed at", now) + + if should_log_mtm(None, now): + logical_time = normalize_logical_time(now) + with db_transaction() as cur: + log_mtm( + nifty_units=state["nifty_units"], + gold_units=state["gold_units"], + nifty_price=nifty_price, + gold_price=gold_price, + total_invested=state["total_invested"], + cur=cur, + logical_time=logical_time, + ) + broker.update_equity( + {NIFTY: nifty_price, GOLD: gold_price}, + now, + cur=cur, + logical_time=logical_time, + ) + + sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + except Exception as e: + _set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") + _update_engine_status(scope_user, scope_run, "ERROR") + log_event("ENGINE_ERROR", {"error": str(e)}) + raise + + log_event("ENGINE_STOP") + _set_state( + scope_user, + scope_run, + state="STOPPED", + last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", + ) + _update_engine_status(scope_user, scope_run, "STOPPED") + print("Strategy engine stopped") + _clear_runner(scope_user, scope_run) + +def start_engine(config): + user_id = config.get("user_id") + run_id = config.get("run_id") + if not user_id: + raise ValueError("user_id is required to start engine") + if not run_id: + raise ValueError("run_id is required to start engine") + + with _RUNNERS_LOCK: + key = _state_key(user_id, run_id) + runner = _RUNNERS.get(key) + if runner and runner["thread"].is_alive(): + return False + + stop_event = threading.Event() + thread = threading.Thread( + target=_engine_loop, + args=(config, stop_event), + daemon=True, + ) + _RUNNERS[key] = {"thread": thread, "stop_event": stop_event} + thread.start() + return True + +def stop_engine(user_id: str, run_id: str | None = None, timeout: float | None = 10.0): + runners = [] + with _RUNNERS_LOCK: + if run_id: + key = _state_key(user_id, run_id) + runner = _RUNNERS.get(key) + if runner: + runners.append((key, runner)) + else: + for key, runner in list(_RUNNERS.items()): + if key[0] == user_id: + runners.append((key, runner)) + for _key, runner in runners: + runner["stop_event"].set() + stopped_all = True + for key, runner in runners: + thread = runner["thread"] + if timeout is not None: + thread.join(timeout=timeout) + stopped = not thread.is_alive() + if stopped: + _clear_runner(key[0], key[1]) + else: + stopped_all = False + return stopped_all + diff --git a/indian_paper_trading_strategy/engine/state.py b/indian_paper_trading_strategy/engine/state.py new file mode 100644 index 0000000..9ec4ccc --- /dev/null +++ b/indian_paper_trading_strategy/engine/state.py @@ -0,0 +1,303 @@ +# engine/state.py +from datetime import datetime, timezone + +from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context + +DEFAULT_STATE = { + "initial_cash": 0.0, + "cash": 0.0, + "total_invested": 0.0, + "nifty_units": 0.0, + "gold_units": 0.0, + "last_sip_ts": None, + "last_run": None, + "sip_frequency": None, +} + +DEFAULT_PAPER_STATE = { + **DEFAULT_STATE, + "initial_cash": 1_000_000.0, + "cash": 1_000_000.0, + "sip_frequency": {"value": 30, "unit": "days"}, +} + +def _state_key(mode: str | None): + key = (mode or "LIVE").strip().upper() + return "PAPER" if key == "PAPER" else "LIVE" + +def _default_state(mode: str | None): + if _state_key(mode) == "PAPER": + return DEFAULT_PAPER_STATE.copy() + return DEFAULT_STATE.copy() + +def _local_tz(): + return datetime.now().astimezone().tzinfo + +def _format_local_ts(value: datetime | None): + if value is None: + return None + return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() + +def _parse_ts(value): + if value is None: + return None + if isinstance(value, datetime): + if value.tzinfo is None: + return value.replace(tzinfo=_local_tz()) + return value + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=_local_tz()) + return parsed + return None + +def _resolve_scope(user_id: str | None, run_id: str | None): + return get_context(user_id, run_id) + + +def load_state( + mode: str | None = "LIVE", + *, + cur=None, + for_update: bool = False, + user_id: str | None = None, + run_id: str | None = None, +): + scope_user, scope_run = _resolve_scope(user_id, run_id) + key = _state_key(mode) + if key == "PAPER": + if cur is None: + with db_connection() as conn: + with conn.cursor() as cur: + return load_state( + mode=mode, + cur=cur, + for_update=for_update, + user_id=scope_user, + run_id=scope_run, + ) + lock_clause = " FOR UPDATE" if for_update else "" + cur.execute( + f""" + SELECT initial_cash, cash, total_invested, nifty_units, gold_units, + last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit + FROM engine_state_paper + WHERE user_id = %s AND run_id = %s{lock_clause} + LIMIT 1 + """, + (scope_user, scope_run), + ) + row = cur.fetchone() + if not row: + return _default_state(mode) + merged = _default_state(mode) + merged.update( + { + "initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"], + "cash": float(row[1]) if row[1] is not None else merged["cash"], + "total_invested": float(row[2]) if row[2] is not None else merged["total_invested"], + "nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"], + "gold_units": float(row[4]) if row[4] is not None else merged["gold_units"], + "last_sip_ts": _format_local_ts(row[5]), + "last_run": _format_local_ts(row[6]), + } + ) + if row[7] is not None or row[8] is not None: + merged["sip_frequency"] = {"value": row[7], "unit": row[8]} + return merged + + if cur is None: + with db_connection() as conn: + with conn.cursor() as cur: + return load_state( + mode=mode, + cur=cur, + for_update=for_update, + user_id=scope_user, + run_id=scope_run, + ) + lock_clause = " FOR UPDATE" if for_update else "" + cur.execute( + f""" + SELECT total_invested, nifty_units, gold_units, last_sip_ts, last_run + FROM engine_state + WHERE user_id = %s AND run_id = %s{lock_clause} + LIMIT 1 + """, + (scope_user, scope_run), + ) + row = cur.fetchone() + if not row: + return _default_state(mode) + merged = _default_state(mode) + merged.update( + { + "total_invested": float(row[0]) if row[0] is not None else merged["total_invested"], + "nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"], + "gold_units": float(row[2]) if row[2] is not None else merged["gold_units"], + "last_sip_ts": _format_local_ts(row[3]), + "last_run": _format_local_ts(row[4]), + } + ) + return merged + +def init_paper_state( + initial_cash: float, + sip_frequency: dict | None = None, + *, + cur=None, + user_id: str | None = None, + run_id: str | None = None, +): + state = DEFAULT_PAPER_STATE.copy() + state.update( + { + "initial_cash": float(initial_cash), + "cash": float(initial_cash), + "total_invested": 0.0, + "nifty_units": 0.0, + "gold_units": 0.0, + "last_sip_ts": None, + "last_run": None, + "sip_frequency": sip_frequency or state.get("sip_frequency"), + } + ) + save_state(state, mode="PAPER", cur=cur, emit_event=True, user_id=user_id, run_id=run_id) + return state + +def save_state( + state, + mode: str | None = "LIVE", + *, + cur=None, + emit_event: bool = False, + event_meta: dict | None = None, + user_id: str | None = None, + run_id: str | None = None, +): + scope_user, scope_run = _resolve_scope(user_id, run_id) + key = _state_key(mode) + last_sip_ts = _parse_ts(state.get("last_sip_ts")) + last_run = _parse_ts(state.get("last_run")) + if key == "PAPER": + sip_frequency = state.get("sip_frequency") + sip_value = None + sip_unit = None + if isinstance(sip_frequency, dict): + sip_value = sip_frequency.get("value") + sip_unit = sip_frequency.get("unit") + def _save(cur): + cur.execute( + """ + INSERT INTO engine_state_paper ( + user_id, run_id, initial_cash, cash, total_invested, nifty_units, gold_units, + last_sip_ts, last_run, sip_frequency_value, sip_frequency_unit + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET initial_cash = EXCLUDED.initial_cash, + cash = EXCLUDED.cash, + total_invested = EXCLUDED.total_invested, + nifty_units = EXCLUDED.nifty_units, + gold_units = EXCLUDED.gold_units, + last_sip_ts = EXCLUDED.last_sip_ts, + last_run = EXCLUDED.last_run, + sip_frequency_value = EXCLUDED.sip_frequency_value, + sip_frequency_unit = EXCLUDED.sip_frequency_unit + """, + ( + scope_user, + scope_run, + float(state.get("initial_cash", 0.0)), + float(state.get("cash", 0.0)), + float(state.get("total_invested", 0.0)), + float(state.get("nifty_units", 0.0)), + float(state.get("gold_units", 0.0)), + last_sip_ts, + last_run, + sip_value, + sip_unit, + ), + ) + if emit_event: + insert_engine_event( + cur, + "STATE_UPDATED", + data={ + "mode": "PAPER", + "cash": state.get("cash"), + "total_invested": state.get("total_invested"), + "nifty_units": state.get("nifty_units"), + "gold_units": state.get("gold_units"), + "last_sip_ts": state.get("last_sip_ts"), + "last_run": state.get("last_run"), + }, + meta=event_meta, + ts=datetime.utcnow().replace(tzinfo=timezone.utc), + ) + + if cur is not None: + _save(cur) + return + + def _op(cur, _conn): + _save(cur) + + return run_with_retry(_op) + + def _save(cur): + cur.execute( + """ + INSERT INTO engine_state ( + user_id, run_id, total_invested, nifty_units, gold_units, last_sip_ts, last_run + ) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id) DO UPDATE + SET total_invested = EXCLUDED.total_invested, + nifty_units = EXCLUDED.nifty_units, + gold_units = EXCLUDED.gold_units, + last_sip_ts = EXCLUDED.last_sip_ts, + last_run = EXCLUDED.last_run + """, + ( + scope_user, + scope_run, + float(state.get("total_invested", 0.0)), + float(state.get("nifty_units", 0.0)), + float(state.get("gold_units", 0.0)), + last_sip_ts, + last_run, + ), + ) + if emit_event: + insert_engine_event( + cur, + "STATE_UPDATED", + data={ + "mode": "LIVE", + "total_invested": state.get("total_invested"), + "nifty_units": state.get("nifty_units"), + "gold_units": state.get("gold_units"), + "last_sip_ts": state.get("last_sip_ts"), + "last_run": state.get("last_run"), + }, + meta=event_meta, + ts=datetime.utcnow().replace(tzinfo=timezone.utc), + ) + + if cur is not None: + _save(cur) + return + + def _op(cur, _conn): + _save(cur) + + return run_with_retry(_op) + diff --git a/indian_paper_trading_strategy/engine/strategy.py b/indian_paper_trading_strategy/engine/strategy.py new file mode 100644 index 0000000..504034e --- /dev/null +++ b/indian_paper_trading_strategy/engine/strategy.py @@ -0,0 +1,12 @@ +# engine/strategy.py +import numpy as np + +def allocation(sp_price, gd_price, sp_sma, gd_sma, + base=0.6, tilt_mult=1.5, + max_tilt=0.25, min_eq=0.2, max_eq=0.9): + + rd = (sp_price / sp_sma) - (gd_price / gd_sma) + tilt = np.clip(-rd * tilt_mult, -max_tilt, max_tilt) + + eq_w = np.clip(base * (1 + tilt), min_eq, max_eq) + return eq_w, 1 - eq_w diff --git a/indian_paper_trading_strategy/engine/time_utils.py b/indian_paper_trading_strategy/engine/time_utils.py new file mode 100644 index 0000000..3315338 --- /dev/null +++ b/indian_paper_trading_strategy/engine/time_utils.py @@ -0,0 +1,41 @@ +from datetime import datetime, timedelta + + +def frequency_to_timedelta(freq: dict) -> timedelta: + value = int(freq.get("value", 0)) + unit = freq.get("unit") + + if value <= 0: + raise ValueError("Frequency value must be > 0") + + if unit == "minutes": + return timedelta(minutes=value) + if unit == "days": + return timedelta(days=value) + raise ValueError(f"Unsupported frequency unit: {unit}") + + +def normalize_logical_time(ts: datetime) -> datetime: + return ts.replace(microsecond=0) + + +def compute_logical_time( + now: datetime, + last_run: str | None, + interval_seconds: float | None, +) -> datetime: + base = now + if last_run and interval_seconds: + try: + parsed = datetime.fromisoformat(last_run.replace("Z", "+00:00")) + except ValueError: + parsed = None + if parsed is not None: + if now.tzinfo and parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=now.tzinfo) + elif now.tzinfo is None and parsed.tzinfo: + parsed = parsed.replace(tzinfo=None) + candidate = parsed + timedelta(seconds=interval_seconds) + if now >= candidate: + base = candidate + return normalize_logical_time(base) diff --git a/paper_live_trading/appstreamlit_app.py b/paper_live_trading/appstreamlit_app.py new file mode 100644 index 0000000..582b140 --- /dev/null +++ b/paper_live_trading/appstreamlit_app.py @@ -0,0 +1,182 @@ +# app/streamlit_app.py +import streamlit as st +import time +from datetime import datetime +from pathlib import Path +import pandas as pd +from history import load_monthly_close +from market import us_market_status +from data import fetch_live_price +from strategy import allocation +from execution import try_execute_sip +from state import load_state, save_state +from mtm import log_mtm, should_log_mtm + +if "engine_active" not in st.session_state: + st.session_state.engine_active = False + +SP500 = "SPY" +GOLD = "GLD" +GOLD_FALLBACK = "IAU" +SMA_MONTHS = 36 + +def get_prices(): + try: + sp = fetch_live_price("SPY") + except Exception as e: + st.error(f"SPY price error: {e}") + return None, None + + try: + gd = fetch_live_price("GLD") + except Exception: + try: + gd = fetch_live_price("IAU") + except Exception as e: + st.error(f"Gold price error: {e}") + return None, None + + return sp, gd + +SIP_AMOUNT = st.number_input("SIP Amount ($)", 100, 5000, 1000) +SIP_INTERVAL_SEC = st.number_input("SIP Interval (sec) [TEST]", 30, 3600, 120) + +st.title("SIPXAR — Phase-1 Safe Engine") + +market_open, market_time = us_market_status() +st.info(f"Market {'OPEN' if market_open else 'CLOSED'} | ET {market_time}") + +col1, col2 = st.columns(2) + +with col1: + if st.button("START ENGINE"): + st.session_state.engine_active = True + + # HARD RESET + for f in [ + Path("storage/state.json"), + Path("storage/mtm_ledger.csv"), + Path("storage/ledger.csv"), + ]: + if f.exists(): + f.unlink() + + fresh_state = { + "total_invested": 0.0, + "sp_units": 0.0, + "gd_units": 0.0, + "last_sip_ts": datetime.utcnow().isoformat(), + } + save_state(fresh_state) + + st.success("Engine started") + +with col2: + if st.button("KILL ENGINE"): + st.session_state.engine_active = False + + # HARD RESET + from pathlib import Path + + for f in [ + Path("storage/state.json"), + Path("storage/mtm_ledger.csv"), + Path("storage/ledger.csv"), + ]: + if f.exists(): + f.unlink() + + st.session_state.clear() + st.warning("Engine killed and state wiped") + + st.stop() + +if not st.session_state.engine_active: + st.info("Engine is stopped. Click START to begin.") + st.stop() + +state = load_state() + +sp_price, gd_price = get_prices() + +if sp_price is None or gd_price is None: + st.stop() + +st.write("Engine Prices") +st.write("SP Price:", sp_price) +st.write("Gold Price:", gd_price) + +sp_hist = load_monthly_close(SP500) +gd_hist = load_monthly_close(GOLD_FALLBACK) + +sp_sma = sp_hist.rolling(SMA_MONTHS).mean().iloc[-1] +gd_sma = gd_hist.rolling(SMA_MONTHS).mean().iloc[-1] + +eq_w, gd_w = allocation( + sp_price=sp_price, + gd_price=gd_price, + sp_sma=sp_sma, + gd_sma=gd_sma +) + +state, executed = try_execute_sip( + now=datetime.utcnow(), + market_open=market_open, + sip_interval=SIP_INTERVAL_SEC, + sip_amount=SIP_AMOUNT, + sp_price=sp_price, + gd_price=gd_price, + eq_w=eq_w, + gd_w=gd_w +) + +MTM_FILE = Path("storage/mtm_ledger.csv") + +if MTM_FILE.exists(): + mtm_df = pd.read_csv(MTM_FILE) +else: + mtm_df = pd.DataFrame() + +now = datetime.utcnow() + +if market_open and should_log_mtm(mtm_df, now): + portfolio_value, pnl = log_mtm( + sp_units=state["sp_units"], + gd_units=state["gd_units"], + sp_price=sp_price, + gd_price=gd_price, + total_invested=state["total_invested"], + ) +else: + # Do NOT log MTM when market is closed + portfolio_value = ( + state["sp_units"] * sp_price + + state["gd_units"] * gd_price + ) + pnl = portfolio_value - state["total_invested"] + +st.metric("Total Invested", f"${state['total_invested']:,.2f}") +st.metric("SP Units", round(state["sp_units"], 4)) +st.metric("Gold Units", round(state["gd_units"], 4)) +st.metric("Portfolio Value", f"${portfolio_value:,.2f}") +st.metric("Unrealized PnL", f"${pnl:,.2f}", delta=f"{pnl:,.2f}") + +st.subheader("Equity Curve (Unrealized PnL)") + +if not mtm_df.empty and len(mtm_df) > 1: + mtm_df["timestamp"] = pd.to_datetime(mtm_df["timestamp"]) + mtm_df = mtm_df.sort_values("timestamp") + mtm_df.set_index("timestamp", inplace=True) + + st.line_chart( + mtm_df["pnl"], + height=350, + ) +else: + st.info("Equity curve will appear after sufficient MTM data.") + +if executed: + st.success("SIP Executed") + +time.sleep(5) +st.rerun() diff --git a/paper_live_trading/data.py b/paper_live_trading/data.py new file mode 100644 index 0000000..892eb3c --- /dev/null +++ b/paper_live_trading/data.py @@ -0,0 +1,16 @@ +# engine/data.py +import yfinance as yf + +def fetch_live_price(ticker): + df = yf.download( + ticker, + period="1d", + interval="1m", + auto_adjust=True, + progress=False, + ) + + if df.empty: + raise RuntimeError(f"No live data for {ticker}") + + return float(df["Close"].iloc[-1]) diff --git a/paper_live_trading/execution.py b/paper_live_trading/execution.py new file mode 100644 index 0000000..44ac79a --- /dev/null +++ b/paper_live_trading/execution.py @@ -0,0 +1,62 @@ +# engine/execution.py +from datetime import datetime +from state import load_state, save_state +from ledger import log_event + +def _as_float(value): + if hasattr(value, "item"): + try: + return float(value.item()) + except Exception: + pass + if hasattr(value, "iloc"): + try: + return float(value.iloc[-1]) + except Exception: + pass + return float(value) + +def try_execute_sip( + now, + market_open, + sip_interval, + sip_amount, + sp_price, + gd_price, + eq_w, + gd_w, +): + state = load_state() + + if not market_open: + return state, False + + last = state["last_sip_ts"] + if last and (now - datetime.fromisoformat(last)).total_seconds() < sip_interval: + return state, False + + sp_price = _as_float(sp_price) + gd_price = _as_float(gd_price) + eq_w = _as_float(eq_w) + gd_w = _as_float(gd_w) + sip_amount = _as_float(sip_amount) + + sp_qty = (sip_amount * eq_w) / sp_price + gd_qty = (sip_amount * gd_w) / gd_price + + state["sp_units"] += sp_qty + state["gd_units"] += gd_qty + state["total_invested"] += sip_amount + state["last_sip_ts"] = now.isoformat() + + save_state(state) + + log_event("SIP_EXECUTED", { + "sp_units": sp_qty, + "gd_units": gd_qty, + "sp_price": sp_price, + "gd_price": gd_price, + "amount": sip_amount + }) + + return state, True diff --git a/paper_live_trading/history.py b/paper_live_trading/history.py new file mode 100644 index 0000000..ccb695f --- /dev/null +++ b/paper_live_trading/history.py @@ -0,0 +1,29 @@ +# engine/history.py +import yfinance as yf +import pandas as pd +from pathlib import Path + +CACHE_DIR = Path("storage/history") +CACHE_DIR.mkdir(exist_ok=True) + +def load_monthly_close(ticker, years=10): + file = CACHE_DIR / f"{ticker}.csv" + + if file.exists(): + df = pd.read_csv(file, parse_dates=["Date"], index_col="Date") + return df["Close"] + + df = yf.download( + ticker, + period=f"{years}y", + auto_adjust=True, + progress=False + ) + + if df.empty: + raise RuntimeError(f"No history for {ticker}") + + series = df["Close"].resample("M").last() + series.to_csv(file, header=["Close"]) + + return series diff --git a/paper_live_trading/ledger.py b/paper_live_trading/ledger.py new file mode 100644 index 0000000..beed78a --- /dev/null +++ b/paper_live_trading/ledger.py @@ -0,0 +1,22 @@ +# engine/ledger.py +import pandas as pd +from pathlib import Path +from datetime import datetime + +LEDGER_FILE = Path("storage/ledger.csv") + +def log_event(event, payload): + row = { + "timestamp": datetime.utcnow().isoformat(), + "event": event, + **payload + } + + df = pd.DataFrame([row]) + + LEDGER_FILE.parent.mkdir(exist_ok=True) + + if LEDGER_FILE.exists(): + df.to_csv(LEDGER_FILE, mode="a", header=False, index=False) + else: + df.to_csv(LEDGER_FILE, index=False) diff --git a/paper_live_trading/market.py b/paper_live_trading/market.py new file mode 100644 index 0000000..17bfc0e --- /dev/null +++ b/paper_live_trading/market.py @@ -0,0 +1,14 @@ +# engine/market.py +from datetime import datetime, time as dtime +import pytz + +def us_market_status(): + tz = pytz.timezone("America/New_York") + now = datetime.now(tz) + + open_t = dtime(9, 30) + close_t = dtime(16, 0) + + is_open = now.weekday() < 5 and open_t <= now.time() <= close_t + + return is_open, now diff --git a/paper_live_trading/mtm.py b/paper_live_trading/mtm.py new file mode 100644 index 0000000..565617a --- /dev/null +++ b/paper_live_trading/mtm.py @@ -0,0 +1,47 @@ +import pandas as pd +from pathlib import Path +from datetime import datetime + +MTM_FILE = Path("storage/mtm_ledger.csv") + +def log_mtm( + sp_units, + gd_units, + sp_price, + gd_price, + total_invested, +): + sp_value = sp_units * sp_price + gd_value = gd_units * gd_price + portfolio_value = sp_value + gd_value + pnl = portfolio_value - total_invested + + row = { + "timestamp": datetime.utcnow().isoformat(), + "sp_units": sp_units, + "gd_units": gd_units, + "sp_price": sp_price, + "gd_price": gd_price, + "sp_value": sp_value, + "gd_value": gd_value, + "portfolio_value": portfolio_value, + "total_invested": total_invested, + "pnl": pnl, + } + + df = pd.DataFrame([row]) + + MTM_FILE.parent.mkdir(exist_ok=True) + + if MTM_FILE.exists(): + df.to_csv(MTM_FILE, mode="a", header=False, index=False) + else: + df.to_csv(MTM_FILE, index=False) + + return portfolio_value, pnl + +def should_log_mtm(df, current_ts): + if df.empty: + return True + last_ts = pd.to_datetime(df.iloc[-1]["timestamp"]) + return current_ts.minute != last_ts.minute diff --git a/paper_live_trading/state.py b/paper_live_trading/state.py new file mode 100644 index 0000000..e75857d --- /dev/null +++ b/paper_live_trading/state.py @@ -0,0 +1,21 @@ +# engine/state.py +import json +from pathlib import Path + +STATE_FILE = Path("storage/state.json") + +DEFAULT_STATE = { + "total_invested": 0.0, + "sp_units": 0.0, + "gd_units": 0.0, + "last_sip_ts": None +} + +def load_state(): + if not STATE_FILE.exists(): + return DEFAULT_STATE.copy() + return json.loads(STATE_FILE.read_text()) + +def save_state(state): + STATE_FILE.parent.mkdir(exist_ok=True) + STATE_FILE.write_text(json.dumps(state, indent=2)) diff --git a/paper_live_trading/strategy.py b/paper_live_trading/strategy.py new file mode 100644 index 0000000..504034e --- /dev/null +++ b/paper_live_trading/strategy.py @@ -0,0 +1,12 @@ +# engine/strategy.py +import numpy as np + +def allocation(sp_price, gd_price, sp_sma, gd_sma, + base=0.6, tilt_mult=1.5, + max_tilt=0.25, min_eq=0.2, max_eq=0.9): + + rd = (sp_price / sp_sma) - (gd_price / gd_sma) + tilt = np.clip(-rd * tilt_mult, -max_tilt, max_tilt) + + eq_w = np.clip(base * (1 + tilt), min_eq, max_eq) + return eq_w, 1 - eq_w diff --git a/signup_test.py b/signup_test.py new file mode 100644 index 0000000..b2bdfd0 --- /dev/null +++ b/signup_test.py @@ -0,0 +1,7 @@ +import requests +import json +url = 'http://localhost:8000/api/signup' +data = {'email': 'testuser@example.com', 'password': 'TestPass123!'} +resp = requests.post(url, json=data) +print(resp.status_code) +print(resp.text) diff --git a/start_all.ps1 b/start_all.ps1 new file mode 100644 index 0000000..26ad2d1 --- /dev/null +++ b/start_all.ps1 @@ -0,0 +1,168 @@ +$ErrorActionPreference = "Stop" + +$root = Split-Path -Parent $MyInvocation.MyCommand.Path +$stateDir = Join-Path $root ".orchestration" +$pidFile = Join-Path $stateDir "pids.json" +$python = Join-Path $root ".venv\\Scripts\\python.exe" +$status = [ordered]@{} +$engineExternal = $env:ENGINE_EXTERNAL -and $env:ENGINE_EXTERNAL.ToLower() -in @("1", "true", "yes") + +if (-not $env:DB_HOST) { $env:DB_HOST = "localhost" } +if (-not $env:DB_PORT) { $env:DB_PORT = "5432" } +if (-not $env:DB_NAME) { $env:DB_NAME = "trading_db" } +if (-not $env:DB_USER) { $env:DB_USER = "trader" } +if (-not $env:DB_PASSWORD) { $env:DB_PASSWORD = "traderpass" } +if (-not $env:PGHOST) { $env:PGHOST = $env:DB_HOST } +if (-not $env:PGPORT) { $env:PGPORT = $env:DB_PORT } +if (-not $env:PGDATABASE) { $env:PGDATABASE = $env:DB_NAME } +if (-not $env:PGUSER) { $env:PGUSER = $env:DB_USER } +if (-not $env:PGPASSWORD) { $env:PGPASSWORD = $env:DB_PASSWORD } + +function Write-Status { + param( + [string]$Label, + [bool]$Ok, + [string]$Detail = "" + ) + if ($Ok) { + $status[$Label] = "OK" + Write-Host ("[OK] {0} {1}" -f $Label, $Detail) -ForegroundColor Green + } else { + $status[$Label] = "FAIL" + Write-Host ("[FAIL] {0} {1}" -f $Label, $Detail) -ForegroundColor Red + } +} + +if (-not (Test-Path $python)) { + Write-Status "Python venv" $false $python + exit 1 +} +Write-Status "Python venv" $true $python + +New-Item -ItemType Directory -Force -Path $stateDir | Out-Null + +if (Test-Path $pidFile) { + try { + $pids = Get-Content $pidFile | ConvertFrom-Json + $running = $false + foreach ($name in @("backend", "engine", "frontend")) { + $pid = $pids.$name + if ($pid) { + $proc = Get-Process -Id $pid -ErrorAction SilentlyContinue + if ($proc) { $running = $true } + } + } + if ($running) { + Write-Host "Services already running. Use .\\stop_all.ps1 to stop." + exit 0 + } + } catch { + Remove-Item $pidFile -Force -ErrorAction SilentlyContinue + } +} + +Write-Host "Starting PostgreSQL (docker)..." +$pgRunning = "" +docker ps --filter "name=trading_postgres" --format "{{.ID}}" | ForEach-Object { $pgRunning = $_ } +if ($LASTEXITCODE -ne 0) { + Write-Status "Docker" $false "not running" + exit 1 +} +Write-Status "Docker" $true "running" +if (-not $pgRunning) { + docker compose up -d postgres + if ($LASTEXITCODE -ne 0) { + Write-Status "PostgreSQL" $false "failed to start" + exit 1 + } +} +Write-Status "PostgreSQL" $true "container up" + +Write-Host "Applying migrations..." +if (Test-Path (Join-Path $root "alembic.ini")) { + & $python -m alembic upgrade head +} elseif (Test-Path (Join-Path $root "backend\\alembic.ini")) { + & $python -m alembic -c (Join-Path $root "backend\\alembic.ini") upgrade head +} +& $python (Join-Path $root "backend\\scripts\\run_migrations.py") +if ($LASTEXITCODE -ne 0) { + Write-Status "Migrations" $false "failed" + exit 1 +} +Write-Status "Migrations" $true "applied" + +Write-Host "Starting backend..." +$backendCmd = "& { `$env:PYTHONPATH = '$root;$root\\backend'; cd '$root\\backend'; & '$python' -m uvicorn app.main:app --reload }" +$backendProc = Start-Process -FilePath "powershell" -ArgumentList @( + "-NoExit", + "-Command", + $backendCmd +) -PassThru +Write-Status "Backend" $true "pid $($backendProc.Id)" + +if ($engineExternal) { + Write-Host "Starting engine runner..." + $engineCmd = "& { `$env:PYTHONPATH = '$root;$root\\backend'; cd '$root'; & '$python' -m indian_paper_trading_strategy.engine.engine_runner }" + $engineProc = Start-Process -FilePath "powershell" -ArgumentList @( + "-NoExit", + "-Command", + $engineCmd + ) -PassThru + Write-Status "Engine" $true "pid $($engineProc.Id)" +} else { + Write-Host "Engine runner skipped (ENGINE_EXTERNAL not set). Backend will run engine." + Write-Status "Engine" $true "embedded" +} + +Write-Host "Starting frontend..." +$frontendProc = Start-Process -FilePath "powershell" -ArgumentList @( + "-NoExit", + "-Command", + "cd '$root\\frontend'; npm run dev" +) -PassThru +Write-Status "Frontend" $true "pid $($frontendProc.Id)" + +@{ + backend = $backendProc.Id + engine = if ($engineExternal) { $engineProc.Id } else { $null } + frontend = $frontendProc.Id + started_at = (Get-Date).ToString("o") +} | ConvertTo-Json | Set-Content -Encoding ascii $pidFile + +Write-Host "Waiting for backend health..." +$healthUrl = "http://localhost:8000/health" +$deadline = [DateTime]::UtcNow.AddMinutes(2) +$healthy = $false +while ([DateTime]::UtcNow -lt $deadline) { + try { + $resp = Invoke-WebRequest -Uri $healthUrl -UseBasicParsing -TimeoutSec 2 + if ($resp.StatusCode -eq 200) { + $healthy = $true + break + } + } catch { + Start-Sleep -Seconds 2 + } +} + +if ($healthy) { + Write-Status "Backend health" $true $healthUrl + Write-Host "" + Write-Host "System online at http://localhost:3000/admin" + Start-Process "http://localhost:3000/admin" +} else { + Write-Status "Backend health" $false $healthUrl + Write-Host "Backend health check failed. Open http://localhost:3000/admin manually." +} + +Write-Host "" +Write-Host "Summary:" +foreach ($entry in $status.GetEnumerator()) { + $label = $entry.Key + $value = $entry.Value + if ($value -eq "OK") { + Write-Host (" [OK] {0}" -f $label) -ForegroundColor Green + } else { + Write-Host (" [FAIL] {0}" -f $label) -ForegroundColor Red + } +} diff --git a/start_all.sh b/start_all.sh new file mode 100644 index 0000000..58690c2 --- /dev/null +++ b/start_all.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +STATE_DIR="${ROOT}/.orchestration" +PID_FILE="${STATE_DIR}/pids" +PYTHON="${ROOT}/.venv/bin/python" + +if [[ ! -x "${PYTHON}" ]]; then + echo "Missing venv python at ${PYTHON}" + exit 1 +fi + +mkdir -p "${STATE_DIR}" + +echo "Starting PostgreSQL (docker)..." +if ! docker ps --format '{{.Names}}' | grep -q '^trading_postgres$'; then + docker compose up -d postgres +fi + +echo "Applying migrations..." +if [[ -f "${ROOT}/alembic.ini" ]]; then + "${PYTHON}" -m alembic upgrade head +elif [[ -f "${ROOT}/backend/alembic.ini" ]]; then + "${PYTHON}" -m alembic -c "${ROOT}/backend/alembic.ini" upgrade head +fi +"${PYTHON}" "${ROOT}/backend/scripts/run_migrations.py" + +echo "Starting backend..." +(cd "${ROOT}/backend" && PYTHONPATH="${ROOT}:${ROOT}/backend" "${PYTHON}" -m uvicorn app.main:app --reload) & +BACKEND_PID=$! + +echo "Starting engine runner..." +(cd "${ROOT}" && PYTHONPATH="${ROOT}:${ROOT}/backend" "${PYTHON}" -m indian_paper_trading_strategy.engine.engine_runner) & +ENGINE_PID=$! + +echo "Starting frontend..." +(cd "${ROOT}/frontend" && npm run dev) & +FRONTEND_PID=$! + +cat > "${PID_FILE}" </dev/null 2>&1; then + READY=1 + break + fi + sleep 2 +done + +if [[ "${READY}" -eq 1 ]]; then + if command -v xdg-open >/dev/null 2>&1; then + xdg-open "http://localhost:3000/admin" >/dev/null 2>&1 || true + elif command -v open >/dev/null 2>&1; then + open "http://localhost:3000/admin" >/dev/null 2>&1 || true + else + echo "Open http://localhost:3000/admin in your browser." + fi +else + echo "Backend health check failed. Open http://localhost:3000/admin manually." +fi diff --git a/start_ngrok.ps1 b/start_ngrok.ps1 new file mode 100644 index 0000000..6f324d6 --- /dev/null +++ b/start_ngrok.ps1 @@ -0,0 +1,27 @@ +$ngrok = 'C:\Tools\ngrok\ngrok.exe' +$config = 'C:\Users\quantfortune\SIP\SIP_India\ngrok.yml' + +if (-not (Test-Path $ngrok)) { + Write-Host 'ngrok not found at C:\Tools\ngrok\ngrok.exe' + exit 1 +} + +Start-Process -FilePath $ngrok -ArgumentList @('start', 'frontend', '--config', $config) | Out-Null +Start-Sleep -Seconds 3 + +$info = Invoke-RestMethod -Uri 'http://127.0.0.1:4040/api/tunnels' +$frontend = ($info.tunnels | Where-Object { $_.name -eq 'frontend' }).public_url +if (-not $frontend) { + $frontend = ($info.tunnels | Select-Object -First 1).public_url +} +$backend = $frontend + +if ($frontend) { + Set-Content -Path 'C:\Users\quantfortune\SIP\SIP_India\ngrok_frontend_url.txt' -Value $frontend -Encoding ascii +} +if ($backend) { + Set-Content -Path 'C:\Users\quantfortune\SIP\SIP_India\ngrok_backend_url.txt' -Value $backend -Encoding ascii +} + +Write-Host "Frontend URL: $frontend" +Write-Host "Backend URL: $backend" diff --git a/steps_instruction.txt b/steps_instruction.txt new file mode 100644 index 0000000..49c076a --- /dev/null +++ b/steps_instruction.txt @@ -0,0 +1,25 @@ +docker exec -it trading_postgres psql -U trader -d trading_db ( to enter the database) +\dt( to see all the tables) +SELECT * FROM app_session;(opening the storage) + +SELECT run_id, status, started_at +FROM strategy_run +ORDER BY created_at DESC +LIMIT 1; +(to check whether startedy running or not) + +2. check how many users are running +SELECT user_id, run_id, status, started_at +FROM strategy_run +WHERE status = 'RUNNING'; + +http://localhost:3000/admin/users( admin dashboard) + +powershell -ExecutionPolicy Bypass -File "C:\Users\91995\Desktop\thigal\Data-Sage\Data-Sage\stop_all.ps1" +powershell -ExecutionPolicy Bypass -File "C:\Users\91995\Desktop\thigal\Data-Sage\Data-Sage\start_all.ps1" + + + +how to test +1. terminal 1 - backend should be running +2. terminal 2 - C:\Users\91995\Desktop\thigal\Data-Sage\Data-Sage>.\.venv\Scripts\python -m pytest -q diff --git a/stop_all.ps1 b/stop_all.ps1 new file mode 100644 index 0000000..9e320db --- /dev/null +++ b/stop_all.ps1 @@ -0,0 +1,28 @@ +$ErrorActionPreference = "Stop" + +$root = Split-Path -Parent $MyInvocation.MyCommand.Path +$pidFile = Join-Path $root ".orchestration\\pids.json" + +if (Test-Path $pidFile) { + $pids = Get-Content $pidFile | ConvertFrom-Json + foreach ($name in @("frontend", "engine", "backend")) { + $pid = $pids.$name + if ($pid) { + $proc = Get-Process -Id $pid -ErrorAction SilentlyContinue + if ($proc) { + Write-Host "Stopping $name (pid $pid)..." + Stop-Process -Id $pid + } + } + } + Remove-Item $pidFile -Force +} else { + Write-Host "No pid file found." +} + +Write-Host "Stopping PostgreSQL (docker)..." +try { + docker compose stop postgres | Out-Null +} catch { + Write-Host "Docker not available or postgres not running." +} diff --git a/stop_all.sh b/stop_all.sh new file mode 100644 index 0000000..c515b42 --- /dev/null +++ b/stop_all.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PID_FILE="${ROOT}/.orchestration/pids" + +if [[ -f "${PID_FILE}" ]]; then + source "${PID_FILE}" + for pid in "${frontend:-}" "${engine:-}" "${backend:-}"; do + if [[ -n "${pid}" ]] && kill -0 "${pid}" >/dev/null 2>&1; then + echo "Stopping pid ${pid}..." + kill "${pid}" || true + fi + done + rm -f "${PID_FILE}" +else + echo "No pid file found." +fi + +echo "Stopping PostgreSQL (docker)..." +docker compose stop postgres >/dev/null 2>&1 || true diff --git a/strategy_code/US_paper_trading_yfinance.py b/strategy_code/US_paper_trading_yfinance.py new file mode 100644 index 0000000..6d25688 --- /dev/null +++ b/strategy_code/US_paper_trading_yfinance.py @@ -0,0 +1,255 @@ +import streamlit as st +import pandas as pd +import numpy as np +import yfinance as yf +from datetime import datetime, timedelta, time as dtime +import pytz +import time + +# ========================= +# CONFIG +# ========================= +SP500 = "SPY" +GOLD = "GLD" + +SMA_MONTHS = 36 +BASE_EQUITY = 0.60 +TILT_MULT = 1.5 +MAX_TILT = 0.25 +MIN_EQUITY = 0.20 +MAX_EQUITY = 0.90 + +PRICE_REFRESH_SEC = 5 + +# ========================= +# DATA +# ========================= +@st.cache_data(ttl=3600) +def load_history(ticker): + df = yf.download( + ticker, + period="10y", + auto_adjust=True, + progress=False, + ) + + if df.empty: + raise RuntimeError(f"No historical data for {ticker}") + + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + if "Close" not in df.columns: + raise RuntimeError(f"'Close' not found for {ticker}") + + series = df["Close"].copy() + + if not isinstance(series.index, pd.DatetimeIndex): + raise RuntimeError(f"Index is not DatetimeIndex for {ticker}") + + series = series.resample("M").last() + + return series + + +@st.cache_data(ttl=15) +def live_price(ticker): + df = yf.download( + ticker, + period="1d", + interval="1m", + progress=False, + ) + + if df.empty: + raise RuntimeError(f"No live data for {ticker}") + + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + price = df["Close"].iloc[-1] + + return float(price) + + +def us_market_status(): + tz = pytz.timezone("America/New_York") + now = datetime.now(tz) + + market_open = dtime(9, 30) + market_close = dtime(16, 0) + + is_weekday = now.weekday() < 5 + is_open = is_weekday and market_open <= now.time() <= market_close + + return { + "is_open": is_open, + "now_et": now.strftime("%Y-%m-%d %H:%M:%S"), + "session": "OPEN" if is_open else "CLOSED", + } + + +sp_hist = load_history(SP500) +gd_hist = load_history(GOLD) + +hist_prices = pd.concat( + [sp_hist, gd_hist], + axis=1, + keys=["SP500", "GOLD"], +).dropna() + +sma_sp = hist_prices["SP500"].rolling(SMA_MONTHS).mean() +sma_gd = hist_prices["GOLD"].rolling(SMA_MONTHS).mean() + +# ========================= +# STATE INIT +# ========================= +if "running" not in st.session_state: + st.session_state.running = False + st.session_state.last_sip = None + st.session_state.total_invested = 0.0 + st.session_state.sp_units = 0.0 + st.session_state.gd_units = 0.0 + st.session_state.last_sp_price = None + st.session_state.last_gd_price = None +if "pnl_ledger" not in st.session_state: + st.session_state.pnl_ledger = pd.DataFrame( + columns=["timestamp", "pnl"] + ) + +# ========================= +# UI +# ========================= +st.title("SIPXAR - Live SIP Portfolio (US)") +market = us_market_status() + +if market["is_open"]: + st.success(f"US Market OPEN (ET {market['now_et']})") +else: + st.warning(f"US Market CLOSED (ET {market['now_et']})") + +sip_amount = st.number_input("SIP Amount ($)", min_value=10, value=1000, step=100) +sip_minutes = st.number_input( + "SIP Frequency (minutes) - TEST MODE", + min_value=1, + value=2, + step=1, +) +sip_interval = timedelta(minutes=sip_minutes) + +col1, col2 = st.columns(2) +if col1.button("START"): + st.session_state.running = True + if st.session_state.last_sip is None: + st.session_state.last_sip = datetime.utcnow() - sip_interval + +if col2.button("STOP"): + st.session_state.running = False + +# ========================= +# ENGINE LOOP +# ========================= +if st.session_state.running: + now = datetime.utcnow() + + if market["is_open"]: + sp_price = live_price(SP500) + gd_price = live_price(GOLD) + st.session_state.last_sp_price = sp_price + st.session_state.last_gd_price = gd_price + else: + sp_price = st.session_state.get("last_sp_price") + gd_price = st.session_state.get("last_gd_price") + + if sp_price is None or pd.isna(sp_price): + sp_price = hist_prices["SP500"].iloc[-1] + if gd_price is None or pd.isna(gd_price): + gd_price = hist_prices["GOLD"].iloc[-1] + + if st.session_state.last_sip is None: + st.session_state.last_sip = now - sip_interval + + # SIP trigger only when market is open + if ( + market["is_open"] + and (now - st.session_state.last_sip) >= sip_interval + ): + rd = ( + (hist_prices["SP500"].iloc[-1] / sma_sp.iloc[-1]) + - (hist_prices["GOLD"].iloc[-1] / sma_gd.iloc[-1]) + ) + + tilt = np.clip(-rd * TILT_MULT, -MAX_TILT, MAX_TILT) + eq_w = np.clip(BASE_EQUITY * (1 + tilt), MIN_EQUITY, MAX_EQUITY) + gd_w = 1 - eq_w + + sp_buy = sip_amount * eq_w + gd_buy = sip_amount * gd_w + + st.session_state.sp_units += sp_buy / sp_price + st.session_state.gd_units += gd_buy / gd_price + st.session_state.total_invested += sip_amount + st.session_state.last_sip = now + + # MTM + sp_val = st.session_state.sp_units * sp_price + gd_val = st.session_state.gd_units * gd_price + port_val = sp_val + gd_val + pnl = port_val - st.session_state.total_invested + st.session_state.pnl_ledger = pd.concat( + [ + st.session_state.pnl_ledger, + pd.DataFrame( + { + "timestamp": [datetime.utcnow()], + "pnl": [pnl], + } + ), + ], + ignore_index=True, + ) + + # ========================= + # DISPLAY + # ========================= + st.subheader("Portfolio Snapshot") + st.caption( + "Prices updating live" if market["is_open"] else "Prices frozen - market closed" + ) + + c1, c2, c3 = st.columns(3) + c1.metric("Total Invested", f"${st.session_state.total_invested:,.2f}") + c2.metric("Portfolio Value", f"${port_val:,.2f}") + c3.metric("Unrealized PnL", f"${pnl:,.2f}", delta=f"{pnl:,.2f}") + + next_sip_in = sip_interval - (now - st.session_state.last_sip) + next_sip_sec = max(0, int(next_sip_in.total_seconds())) + st.caption(f"Next SIP in ~ {next_sip_sec} seconds (TEST MODE)") + + st.subheader("Equity Curve (PnL)") + if len(st.session_state.pnl_ledger) > 1: + pnl_df = st.session_state.pnl_ledger.copy() + pnl_df["timestamp"] = pd.to_datetime(pnl_df["timestamp"]) + pnl_df.set_index("timestamp", inplace=True) + + st.line_chart( + pnl_df["pnl"], + height=350, + ) + else: + st.info("Equity curve will appear after portfolio updates.") + + st.subheader("Holdings") + st.dataframe( + pd.DataFrame( + { + "Asset": ["SP500 (SPY)", "Gold (GLD)"], + "Units": [st.session_state.sp_units, st.session_state.gd_units], + "Price": [sp_price, gd_price], + "Value": [sp_val, gd_val], + } + ) + ) + + time.sleep(PRICE_REFRESH_SEC) + st.rerun() diff --git a/strategy_code/paper_trading_yfinance.py b/strategy_code/paper_trading_yfinance.py new file mode 100644 index 0000000..e283496 --- /dev/null +++ b/strategy_code/paper_trading_yfinance.py @@ -0,0 +1,166 @@ +# ========================================================= +# SIPXAR — PAPER TRADING (SESSION STATE ONLY) +# ========================================================= + +import math +import numpy as np +import pandas as pd +import yfinance as yf +import streamlit as st +from datetime import datetime + +# ========================================================= +# CONFIG +# ========================================================= + +MONTHLY_SIP = 10000 + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" + +SMA_MONTHS = 36 + +BASE_EQUITY = 0.60 +TILT_MULT = 1.5 +MAX_TILT = 0.25 +MIN_EQUITY = 0.20 +MAX_EQUITY = 0.90 + +# ========================================================= +# DATA +# ========================================================= + +@st.cache_data(ttl=300) +def fetch_price(ticker): + df = yf.download( + ticker, + period="5d", + interval="1d", + auto_adjust=True, + progress=False + ) + return float(df["Close"].iloc[-1].item()) + +@st.cache_data(ttl=3600) +def fetch_sma(ticker): + df = yf.download( + ticker, + period=f"{SMA_MONTHS+2}mo", + interval="1mo", + auto_adjust=True, + progress=False + ).dropna() + return float(df["Close"].tail(SMA_MONTHS).mean()) + +# ========================================================= +# STRATEGY +# ========================================================= + +def compute_weights(n_price, g_price, n_sma, g_sma): + dev_n = (n_price / n_sma) - 1 + dev_g = (g_price / g_sma) - 1 + rel = dev_n - dev_g + + tilt = np.clip(-rel * TILT_MULT, -MAX_TILT, MAX_TILT) + eq_w = BASE_EQUITY * (1 + tilt) + eq_w = min(max(eq_w, MIN_EQUITY), MAX_EQUITY) + + return eq_w, 1 - eq_w + +# ========================================================= +# SESSION STATE INIT +# ========================================================= + +if "nifty_units" not in st.session_state: + st.session_state.nifty_units = 0 + st.session_state.gold_units = 0 + st.session_state.invested = 0 + st.session_state.ledger = [] + +# ========================================================= +# STREAMLIT UI +# ========================================================= + +st.set_page_config(page_title="SIPXAR Paper Trading", layout="centered") +st.title("📊 SIPXAR — Paper Trading (Session Only)") + +# Fetch market data +n_price = fetch_price(NIFTY) +g_price = fetch_price(GOLD) +n_sma = fetch_sma(NIFTY) +g_sma = fetch_sma(GOLD) + +eq_w, g_w = compute_weights(n_price, g_price, n_sma, g_sma) + +# Market snapshot +st.subheader("Market Snapshot") +st.write(f"**NIFTY:** ₹{n_price:.2f}") +st.write(f"**GOLD:** ₹{g_price:.2f}") + +# Allocation +st.subheader("Allocation Weights") +st.progress(eq_w) +st.write(f"Equity: **{eq_w:.2%}** | Gold: **{g_w:.2%}**") + +# ========================================================= +# SIP ACTION +# ========================================================= + +if st.button("Run Paper SIP (Once)"): + n_amt = MONTHLY_SIP * eq_w + g_amt = MONTHLY_SIP * g_w + + n_qty = math.floor(n_amt / n_price) + g_qty = math.floor(g_amt / g_price) + + if n_qty == 0 and g_qty == 0: + st.warning( + "SIP amount too small to buy ETF units at current prices." + ) + else: + st.session_state.nifty_units += n_qty + st.session_state.gold_units += g_qty + st.session_state.invested += MONTHLY_SIP + + port_value = ( + st.session_state.nifty_units * n_price + + st.session_state.gold_units * g_price + ) + + st.session_state.ledger.append({ + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "nifty_units": st.session_state.nifty_units, + "gold_units": st.session_state.gold_units, + "invested": st.session_state.invested, + "portfolio_value": port_value, + "pnl": port_value - st.session_state.invested + }) + + st.success("Paper SIP executed") + +# ========================================================= +# PORTFOLIO VIEW +# ========================================================= + +if st.session_state.invested > 0: + port_value = ( + st.session_state.nifty_units * n_price + + st.session_state.gold_units * g_price + ) + + pnl = port_value - st.session_state.invested + + st.subheader("Portfolio Summary") + st.metric("Total Invested", f"₹{st.session_state.invested:,.0f}") + st.metric("Portfolio Value", f"₹{port_value:,.0f}") + st.metric("PnL", f"₹{pnl:,.0f}") + + df = pd.DataFrame(st.session_state.ledger) + + st.subheader("Equity Curve") + st.line_chart(df[["portfolio_value", "invested"]]) + + st.subheader("Ledger") + st.dataframe(df, use_container_width=True) +else: + st.info("No SIPs yet. Click the button to start.") diff --git a/strategy_code/sma_momemtum_sip_model.py b/strategy_code/sma_momemtum_sip_model.py new file mode 100644 index 0000000..673b8ba --- /dev/null +++ b/strategy_code/sma_momemtum_sip_model.py @@ -0,0 +1,590 @@ +import pandas as pd +import numpy as np +import yfinance as yf + +# ========================================================= +# CONFIG +# ========================================================= + +START_DATE = "2010-01-01" +END_DATE = "2025-12-31" + +SIP_START_DATE = "2018-04-01" +SIP_END_DATE = "2025-10-31" # set None for "till last data" + +MONTHLY_SIP = 100 + +# Order frequency parameter (N trading days) +ORDER_EVERY_N = 5 # 5 => every 5 trading days, 30 => every 30 trading days + +# Whether to keep MONTHLY_SIP constant (recommended) +# True => scales cash per order so approx monthly investment stays ~MONTHLY_SIP +# False => invests MONTHLY_SIP every order (be careful: for N=5 this is much larger than monthly SIP) +KEEP_MONTHLY_BUDGET_CONSTANT = True + +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" + +SIP_START_DATE = pd.to_datetime(SIP_START_DATE) +SIP_END_DATE = pd.to_datetime(SIP_END_DATE) if SIP_END_DATE else None + +# ========================================================= +# CENTRAL SIP DATE WINDOW +# ========================================================= + +def in_sip_window(date): + if date < SIP_START_DATE: + return False + if SIP_END_DATE and date > SIP_END_DATE: + return False + return True + +# ========================================================= +# VALUATION (SMA MEAN REVERSION) +# ========================================================= + +SMA_MONTHS = 36 +TILT_MULT = 1.5 +MAX_TILT = 0.25 +BASE_EQUITY = 0.60 + +MIN_EQUITY = 0.20 +MAX_EQUITY = 0.90 + +# For daily data +TRADING_DAYS_PER_MONTH = 21 +SMA_DAYS = SMA_MONTHS * TRADING_DAYS_PER_MONTH # ~36 months on daily series + +# ========================================================= +# DATA LOAD +# ========================================================= + +def load_price(ticker): + df = yf.download( + ticker, + start=START_DATE, + end=END_DATE, + auto_adjust=True, + progress=False + ) + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + return df["Close"] + +prices = pd.DataFrame({ + "NIFTY": load_price(NIFTY), + "GOLD": load_price(GOLD) +}).dropna() + +# Use business-day frequency (daily for trading) +prices = prices.resample("B").last().dropna() + +# ========================================================= +# ORDER SCHEDULE: EVERY N TRADING DAYS +# ========================================================= + +def get_order_dates(index: pd.DatetimeIndex, n: int) -> pd.DatetimeIndex: + window = index[(index >= SIP_START_DATE) & ((index <= SIP_END_DATE) if SIP_END_DATE else True)] + return window[::n] + +ORDER_DATES = get_order_dates(prices.index, ORDER_EVERY_N) + +# Cash per order +if KEEP_MONTHLY_BUDGET_CONSTANT: + # Approx monthly orders = 21 / N + orders_per_month = TRADING_DAYS_PER_MONTH / ORDER_EVERY_N + SIP_AMOUNT_PER_ORDER = MONTHLY_SIP / orders_per_month +else: + SIP_AMOUNT_PER_ORDER = MONTHLY_SIP + +# ========================================================= +# (OPTIONAL) CHECK PRICE ON A GIVEN DATE (UNCHANGED) +# ========================================================= + +def check_price_on_date(ticker, date_str): + date = pd.to_datetime(date_str) + + df = yf.download( + ticker, + start=date - pd.Timedelta(days=5), + end=date + pd.Timedelta(days=5), + auto_adjust=False, + progress=False + ) + + if df.empty: + print(f"No data returned for {ticker}") + return + + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + if date not in df.index: + print(f"{ticker} | {date.date()} → No trading data (holiday / no volume)") + return + + row = df.loc[date] + + print(f"\n{ticker} — {date.date()}") + print(f"Open : ₹{row['Open']:.2f}") + print(f"High : ₹{row['High']:.2f}") + print(f"Low : ₹{row['Low']:.2f}") + print(f"Close: ₹{row['Close']:.2f}") + print(f"Volume: {int(row['Volume'])}") + +check_price_on_date("NIFTYBEES.NS", "2025-11-10") + +# ========================================================= +# SMA DEVIATION (DAILY) +# ========================================================= + +sma_nifty = prices["NIFTY"].rolling(SMA_DAYS).mean() +sma_gold = prices["GOLD"].rolling(SMA_DAYS).mean() + +dev_nifty = (prices["NIFTY"] / sma_nifty) - 1 +dev_gold = (prices["GOLD"] / sma_gold) - 1 + +rel_dev = dev_nifty - dev_gold + +# ========================================================= +# SIPXAR ENGINE (FLOW ONLY, NO REBALANCE) — EXECUTES ON ORDER_DATES +# ========================================================= + +def run_sipxar(prices, rel_dev, order_dates, sip_amount_per_order): + + nifty_units = 0.0 + gold_units = 0.0 + + total_invested = 0.0 + prev_value = None + + rows = [] + + for date in order_dates: + + nifty_price = prices.loc[date, "NIFTY"] + gold_price = prices.loc[date, "GOLD"] + + cash = float(sip_amount_per_order) + total_invested += cash + + rd = rel_dev.loc[date] + + tilt = 0.0 + if not pd.isna(rd): + tilt = np.clip(-rd * TILT_MULT, -MAX_TILT, MAX_TILT) + + equity_w = BASE_EQUITY * (1 + tilt) + equity_w = min(max(equity_w, MIN_EQUITY), MAX_EQUITY) + gold_w = 1 - equity_w + + nifty_buy = cash * equity_w + gold_buy = cash * gold_w + + nifty_units += nifty_buy / nifty_price + gold_units += gold_buy / gold_price + + nifty_val = nifty_units * nifty_price + gold_val = gold_units * gold_price + port_val = nifty_val + gold_val + + unrealized = port_val - total_invested + + if prev_value is None: + period_pnl = 0.0 + else: + period_pnl = port_val - prev_value - cash + + prev_value = port_val + + rows.append({ + "Date": date, + "Cash_Added": round(cash, 4), + "Total_Invested": round(total_invested, 4), + + "Equity_Weight": round(equity_w, 3), + "Gold_Weight": round(gold_w, 3), + + "Rel_Deviation": round(float(rd) if not pd.isna(rd) else np.nan, 4), + + "NIFTY_Units": round(nifty_units, 6), + "GOLD_Units": round(gold_units, 6), + + "NIFTY_Value": round(nifty_val, 4), + "GOLD_Value": round(gold_val, 4), + + "Portfolio_Value": round(port_val, 4), + "Period_PnL": round(period_pnl, 4), + "Unrealized_PnL": round(unrealized, 4) + }) + + return pd.DataFrame(rows).set_index("Date") + +# ========================================================= +# RUN SIPXAR +# ========================================================= + +sipxar_ledger = run_sipxar( + prices=prices, + rel_dev=rel_dev, + order_dates=ORDER_DATES, + sip_amount_per_order=SIP_AMOUNT_PER_ORDER +) + +start_dt = sipxar_ledger.index.min() +end_dt = sipxar_ledger.index.max() + +# ========================================================= +# XIRR +# ========================================================= + +def xirr(cashflows): + dates = np.array([cf[0] for cf in cashflows], dtype="datetime64[D]") + amounts = np.array([cf[1] for cf in cashflows], dtype=float) + + def npv(rate): + years = (dates - dates[0]).astype(int) / 365.25 + return np.sum(amounts / ((1 + rate) ** years)) + + low, high = -0.99, 5.0 + for _ in range(200): + mid = (low + high) / 2 + val = npv(mid) + if abs(val) < 1e-6: + return mid + if val > 0: + low = mid + else: + high = mid + + return mid + +cashflows_sipxar = [] +for date, row in sipxar_ledger.iterrows(): + cashflows_sipxar.append((date, -row["Cash_Added"])) + +final_date = sipxar_ledger.index[-1] +final_value = sipxar_ledger["Portfolio_Value"].iloc[-1] +cashflows_sipxar.append((final_date, final_value)) + +sipxar_xirr = xirr(cashflows_sipxar) + +# ========================================================= +# COMPARISONS: NIFTY-ONLY + STATIC 60/40 — EXECUTES ON ORDER_DATES +# ========================================================= + +def run_nifty_sip(prices, order_dates, sip_amount): + units = 0.0 + rows = [] + for date in order_dates: + price = prices.loc[date, "NIFTY"] + units += sip_amount / price + rows.append((date, units * price)) + return pd.DataFrame(rows, columns=["Date", "Value"]).set_index("Date") + +def run_static_sip(prices, order_dates, sip_amount, eq_w=0.6): + n_units = 0.0 + g_units = 0.0 + rows = [] + for date in order_dates: + n_price = prices.loc[date, "NIFTY"] + g_price = prices.loc[date, "GOLD"] + + n_units += (sip_amount * eq_w) / n_price + g_units += (sip_amount * (1 - eq_w)) / g_price + + rows.append((date, n_units * n_price + g_units * g_price)) + return pd.DataFrame(rows, columns=["Date", "Value"]).set_index("Date") + +def build_sip_ledger(value_df, sip_amount): + total = 0.0 + rows = [] + prev_value = None + + for date, row in value_df.iterrows(): + total += sip_amount + value = float(row.iloc[0]) + + if prev_value is None: + period_pnl = 0.0 + else: + period_pnl = value - prev_value - sip_amount + + prev_value = value + + rows.append({ + "Date": date, + "Cash_Added": round(sip_amount, 4), + "Total_Invested": round(total, 4), + "Portfolio_Value": round(value, 4), + "Period_PnL": round(period_pnl, 4) + }) + + return pd.DataFrame(rows).set_index("Date") + +nifty_sip = run_nifty_sip(prices, ORDER_DATES, SIP_AMOUNT_PER_ORDER) +static_sip = run_static_sip(prices, ORDER_DATES, SIP_AMOUNT_PER_ORDER, 0.6) + +nifty_sip = nifty_sip.loc[start_dt:end_dt] +static_sip = static_sip.loc[start_dt:end_dt] + +nifty_ledger = build_sip_ledger(nifty_sip, SIP_AMOUNT_PER_ORDER) +static_ledger = build_sip_ledger(static_sip, SIP_AMOUNT_PER_ORDER) + +cashflows_nifty = [(d, -SIP_AMOUNT_PER_ORDER) for d in nifty_sip.index] +cashflows_nifty.append((nifty_sip.index[-1], float(nifty_sip["Value"].iloc[-1]))) +nifty_xirr = xirr(cashflows_nifty) + +cashflows_static = [(d, -SIP_AMOUNT_PER_ORDER) for d in static_sip.index] +cashflows_static.append((static_sip.index[-1], float(static_sip["Value"].iloc[-1]))) +static_xirr = xirr(cashflows_static) + +# ========================================================= +# PRINTS +# ========================================================= + +print("\n=== CONFIG SUMMARY ===") +print(f"Order every N trading days: {ORDER_EVERY_N}") +print(f"Keep monthly budget constant: {KEEP_MONTHLY_BUDGET_CONSTANT}") +print(f"MONTHLY_SIP (target): ₹{MONTHLY_SIP}") +print(f"SIP_AMOUNT_PER_ORDER: ₹{SIP_AMOUNT_PER_ORDER:.4f}") +print(f"Orders executed: {len(ORDER_DATES)}") +print(f"Period: {start_dt.date()} → {end_dt.date()}") + +print("\n=== SIPXAR LEDGER (LAST 12 ROWS) ===") +print(sipxar_ledger.tail(12)) + +print("\n=== EQUITY WEIGHT DISTRIBUTION ===") +print(sipxar_ledger["Equity_Weight"].describe()) + +print("\n=== STEP 1: XIRR COMPARISON ===") +print(f"SIPXAR XIRR : {sipxar_xirr*100:.2f}%") +print(f"NIFTY SIP XIRR : {nifty_xirr*100:.2f}%") +print(f"60/40 SIP XIRR : {static_xirr*100:.2f}%") + +# ========================================================= +# EXPORT +# ========================================================= + +output_file = "SIPXAR_Momentum_SIP.xlsx" + +with pd.ExcelWriter(output_file, engine="xlsxwriter") as writer: + sipxar_ledger.to_excel(writer, sheet_name="Ledger") + + yearly = sipxar_ledger.copy() + yearly["Year"] = yearly.index.year + + yearly_summary = yearly.groupby("Year").agg({ + "Cash_Added": "sum", + "Total_Invested": "last", + "Portfolio_Value": "last", + "Unrealized_PnL": "last" + }) + yearly_summary.to_excel(writer, sheet_name="Yearly_Summary") + +print(f"\nExcel exported successfully: {output_file}") + +# ========================================================= +# PHASE 2: CRASH & SIDEWAYS REGIME BACKTEST (ADAPTED TO DAILY INDEX) +# ========================================================= + +def rolling_cagr(series: pd.Series, periods: int, years: float): + r = series / series.shift(periods) + return (r ** (1 / years)) - 1 + +def window_xirr_from_value(value_df, start, end, sip_amount): + df = value_df.loc[start:end] + if len(df) < 6: + return np.nan + cashflows = [(d, -sip_amount) for d in df.index] + cashflows.append((df.index[-1], float(df.iloc[-1, 0]))) + return xirr(cashflows) + +def sip_max_drawdown(ledger): + value = ledger["Portfolio_Value"] + peak = value.cummax() + dd = value / peak - 1 + trough = dd.idxmin() + peak_date = value.loc[:trough].idxmax() + return {"Peak": peak_date, "Trough": trough, "Max_Drawdown": float(dd.min())} + +def worst_rolling_xirr(ledger, periods: int): + dates = ledger.index + worst = None + + for i in range(len(dates) - periods): + start = dates[i] + end = dates[i + periods] + window = ledger.loc[start:end] + if len(window) < max(6, periods // 2): + continue + + cashflows = [(d, -float(row["Cash_Added"])) for d, row in window.iterrows()] + cashflows.append((end, float(window["Portfolio_Value"].iloc[-1]))) + + try: + rx = xirr(cashflows) + if worst is None or rx < worst.get("XIRR", np.inf): + worst = {"Start": start, "End": end, "XIRR": rx} + except Exception: + pass + + return worst + +# 1) Identify crash windows from NIFTY drawdowns (daily) +nifty_price = prices["NIFTY"].loc[ + (prices.index >= SIP_START_DATE) & + ((prices.index <= SIP_END_DATE) if SIP_END_DATE else True) +] + +peak = nifty_price.cummax() +drawdown = nifty_price / peak - 1.0 + +CRASH_THRESHOLD = -0.15 +in_crash = drawdown <= CRASH_THRESHOLD + +crash_windows = [] +groups = (in_crash != in_crash.shift()).cumsum() + +for _, block in in_crash.groupby(groups): + if block.iloc[0]: + crash_windows.append((block.index[0], block.index[-1])) + +print("\n=== CRASH WINDOWS (NIFTY DD <= -15%) ===") +for s, e in crash_windows: + print(s.date(), "->", e.date()) + +# 2) Sideways windows using ~36M rolling CAGR on daily data +ROLL_MONTHS = 36 +SIDEWAYS_CAGR = 0.06 + +ROLL_DAYS = ROLL_MONTHS * TRADING_DAYS_PER_MONTH +cagr_36 = rolling_cagr(nifty_price, periods=ROLL_DAYS, years=ROLL_MONTHS / 12.0) + +in_sideways = cagr_36 <= SIDEWAYS_CAGR + +sideways_windows = [] +groups = (in_sideways != in_sideways.shift()).cumsum() + +for _, block in in_sideways.groupby(groups): + if block.iloc[0]: + sideways_windows.append((block.index[0], block.index[-1])) + +print("\n=== SIDEWAYS WINDOWS (~36M CAGR <= 6%) ===") +for s, e in sideways_windows: + print(s.date(), "->", e.date()) + +# 3) Score each regime window using order-based ledgers +rows = [] +for label, windows in [("CRASH", crash_windows), ("SIDEWAYS", sideways_windows)]: + for s, e in windows: + # align to nearest available ledger dates + s2 = sipxar_ledger.index[sipxar_ledger.index.get_indexer([s], method="nearest")[0]] + e2 = sipxar_ledger.index[sipxar_ledger.index.get_indexer([e], method="nearest")[0]] + + months_like = (e2.year - s2.year) * 12 + (e2.month - s2.month) + 1 + + rows.append({ + "Regime": label, + "Start": s2.date(), + "End": e2.date(), + "MonthsLike": months_like, + "SIPXAR_XIRR": window_xirr_from_value(sipxar_ledger[["Portfolio_Value"]], s2, e2, SIP_AMOUNT_PER_ORDER), + "NIFTY_SIP_XIRR": window_xirr_from_value(nifty_sip, s2, e2, SIP_AMOUNT_PER_ORDER), + "STATIC_60_40_XIRR": window_xirr_from_value(static_sip, s2, e2, SIP_AMOUNT_PER_ORDER) + }) + +regime_results = pd.DataFrame(rows) + +print("\n=== REGIME PERFORMANCE SUMMARY ===") +if len(regime_results) == 0: + print("No regime windows detected (check thresholds / data range).") +else: + print(regime_results.to_string(index=False)) + +# ========================================================= +# METRIC 1: TIME UNDERWATER +# ========================================================= + +sipxar_ledger["Underwater"] = (sipxar_ledger["Portfolio_Value"] < sipxar_ledger["Total_Invested"]) +periods_underwater = int(sipxar_ledger["Underwater"].sum()) + +print("\n=== TIME UNDERWATER ===") +print(f"Periods underwater: {periods_underwater} / {len(sipxar_ledger)}") +print(f"% Time underwater : {periods_underwater / len(sipxar_ledger) * 100:.1f}%") + +# ========================================================= +# METRIC 2: SIP-AWARE MAX DRAWDOWN +# ========================================================= + +dd_sipxar = sip_max_drawdown(sipxar_ledger) +dd_nifty = sip_max_drawdown(nifty_ledger) +dd_static = sip_max_drawdown(static_ledger) + +print("\n=== SIP-AWARE MAX DRAWDOWN ===") +for name, dd in [("SIPXAR", dd_sipxar), ("NIFTY SIP", dd_nifty), ("60/40 SIP", dd_static)]: + print( + f"{name:10s} | " + f"Peak: {dd['Peak'].date()} | " + f"Trough: {dd['Trough'].date()} | " + f"DD: {dd['Max_Drawdown']*100:.2f}%" + ) + +# ========================================================= +# METRIC 3: WORST ROLLING 24M SIP XIRR (ORDER-BASED) +# ========================================================= + +# Convert 24 months to "order periods" approximately +ORDERS_PER_MONTH = TRADING_DAYS_PER_MONTH / ORDER_EVERY_N +ROLL_24M_PERIODS = int(round(24 * ORDERS_PER_MONTH)) +ROLL_24M_PERIODS = max(6, ROLL_24M_PERIODS) + +worst_24_sipxar = worst_rolling_xirr(sipxar_ledger, ROLL_24M_PERIODS) +worst_24_nifty = worst_rolling_xirr(nifty_ledger, ROLL_24M_PERIODS) +worst_24_static = worst_rolling_xirr(static_ledger, ROLL_24M_PERIODS) + +print("\n=== WORST ROLLING ~24M XIRR (by order periods) ===") +print(f"Using {ROLL_24M_PERIODS} periods (~24 months) given N={ORDER_EVERY_N}") +for name, w in [("SIPXAR", worst_24_sipxar), ("NIFTY SIP", worst_24_nifty), ("60/40 SIP", worst_24_static)]: + if not w or pd.isna(w.get("XIRR", np.nan)): + print(f"{name:10s} | insufficient data") + continue + print(f"{name:10s} | {w['Start'].date()} → {w['End'].date()} | {w['XIRR']*100:.2f}%") + +# ========================================================= +# METRIC 4: PnL VOLATILITY (PER-ORDER) +# ========================================================= + +period_pnl = sipxar_ledger["Period_PnL"] +pnl_std = float(period_pnl.std()) +pnl_mean = float(period_pnl.mean()) + +print("\n=== PnL VOLATILITY (PER ORDER PERIOD) ===") +print(f"Avg Period PnL : ₹{pnl_mean:,.2f}") +print(f"PnL Std Dev : ₹{pnl_std:,.2f}") +print(f"Volatility % : {pnl_std / SIP_AMOUNT_PER_ORDER * 100:.1f}% of per-order SIP") + +# ========================================================= +# SIP GRAPH: INVESTED vs PORTFOLIO VALUE (HEADLESS SAFE) +# ========================================================= + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +sipxar_ledger["Invested_Capital"] = sipxar_ledger["Total_Invested"] + +plt.figure(figsize=(10, 5)) +plt.plot(sipxar_ledger.index, sipxar_ledger["Portfolio_Value"], label="Portfolio Value") +plt.plot(sipxar_ledger.index, sipxar_ledger["Invested_Capital"], label="Total Invested") + +plt.xlabel("Date") +plt.ylabel("Value (₹)") +plt.title(f"SIPXAR – SIP Performance (Every {ORDER_EVERY_N} Trading Days)") +plt.legend() +plt.tight_layout() + +plt.savefig("sipxar_performance.png", dpi=150) +plt.close() + +print("Plot saved: sipxar_performance.png") diff --git a/strategy_code/sma_momemtum_sip_model_optimization.py b/strategy_code/sma_momemtum_sip_model_optimization.py new file mode 100644 index 0000000..e8bd4cd --- /dev/null +++ b/strategy_code/sma_momemtum_sip_model_optimization.py @@ -0,0 +1,365 @@ +import pandas as pd +import numpy as np +import yfinance as yf + +# ========================================================= +# CONFIG +# ========================================================= +START_DATE = "2010-01-01" +END_DATE = "2025-12-31" + +SIP_START_DATE = "2018-04-01" +SIP_END_DATE = "2025-10-31" # set None for till last data + +# Monthly salary SIP budget +MONTHLY_SIP = 10_000 + +# Tickers +NIFTY = "NIFTYBEES.NS" +GOLD = "GOLDBEES.NS" + +# Optimization range: N trading days between deployments +N_MIN = 2 +N_MAX = 30 # inclusive + +# Salary credit timing: +# "MS" = Month Start, "ME" = Month End +SALARY_CREDIT = "MS" + +# Valuation tilt params (unchanged) +SMA_MONTHS = 36 +TILT_MULT = 1.5 +MAX_TILT = 0.25 +BASE_EQUITY = 0.60 +MIN_EQUITY = 0.20 +MAX_EQUITY_W = 0.90 + +# Daily approximation +TRADING_DAYS_PER_MONTH = 21 +SMA_DAYS = SMA_MONTHS * TRADING_DAYS_PER_MONTH + +# Crash/sideways definitions +CRASH_THRESHOLD = -0.15 +SIDEWAYS_CAGR = 0.06 +ROLL_MONTHS_SIDEWAYS = 36 + +# ========================================================= +# HELPERS +# ========================================================= +SIP_START_DATE = pd.to_datetime(SIP_START_DATE) +SIP_END_DATE = pd.to_datetime(SIP_END_DATE) if SIP_END_DATE else None + +def in_sip_window(date): + if date < SIP_START_DATE: + return False + if SIP_END_DATE and date > SIP_END_DATE: + return False + return True + +def load_price(ticker): + df = yf.download( + ticker, + start=START_DATE, + end=END_DATE, + auto_adjust=True, + progress=False + ) + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + return df["Close"] + +def xirr(cashflows): + dates = np.array([cf[0] for cf in cashflows], dtype="datetime64[D]") + amounts = np.array([cf[1] for cf in cashflows], dtype=float) + + def npv(rate): + years = (dates - dates[0]).astype(int) / 365.25 + return np.sum(amounts / ((1 + rate) ** years)) + + low, high = -0.99, 5.0 + for _ in range(250): + mid = (low + high) / 2 + val = npv(mid) + if abs(val) < 1e-6: + return mid + if val > 0: + low = mid + else: + high = mid + return mid + +def sip_max_drawdown(value_series: pd.Series): + peak = value_series.cummax() + dd = value_series / peak - 1 + trough = dd.idxmin() + peak_date = value_series.loc[:trough].idxmax() + return peak_date, trough, float(dd.min()) + +def worst_rolling_xirr(ledger_df: pd.DataFrame, periods: int): + dates = ledger_df.index + worst = None + if len(dates) < periods + 2: + return None + + for i in range(len(dates) - periods): + start = dates[i] + end = dates[i + periods] + window = ledger_df.loc[start:end] + if len(window) < max(6, periods // 2): + continue + + cashflows = [(d, -float(row["Cash_Added"])) for d, row in window.iterrows()] + cashflows.append((end, float(window["Portfolio_Value"].iloc[-1]))) + + try: + rx = xirr(cashflows) + if worst is None or rx < worst["XIRR"]: + worst = {"Start": start, "End": end, "XIRR": rx} + except Exception: + pass + + return worst + +def rolling_cagr(series: pd.Series, periods: int, years: float): + r = series / series.shift(periods) + return (r ** (1 / years)) - 1 + +# ========================================================= +# DATA +# ========================================================= +prices = pd.DataFrame({ + "NIFTY": load_price(NIFTY), + "GOLD": load_price(GOLD) +}).dropna() + +# business day frequency +prices = prices.resample("B").last().dropna() + +# SMA / deviation on daily data +sma_nifty = prices["NIFTY"].rolling(SMA_DAYS).mean() +sma_gold = prices["GOLD"].rolling(SMA_DAYS).mean() + +dev_nifty = (prices["NIFTY"] / sma_nifty) - 1 +dev_gold = (prices["GOLD"] / sma_gold) - 1 +rel_dev = dev_nifty - dev_gold + +# Restrict to SIP window once for speed +sip_prices = prices.loc[ + (prices.index >= SIP_START_DATE) & + ((prices.index <= SIP_END_DATE) if SIP_END_DATE else True) +].copy() + +sip_rel_dev = rel_dev.loc[sip_prices.index] + +# ========================================================= +# SALARY CREDIT DATES (MONTH START/END) ALIGNED TO TRADING DAYS +# ========================================================= +def get_salary_dates(trading_index: pd.DatetimeIndex, credit: str): + if credit not in ("MS", "ME"): + raise ValueError("SALARY_CREDIT must be 'MS' or 'ME'.") + + # calendar month anchors + start = trading_index.min().to_period("M").to_timestamp() + end = trading_index.max().to_period("M").to_timestamp(how="end") + + # choose all months in range + months = pd.date_range(start=start, end=end, freq="MS") + if credit == "ME": + months = months + pd.offsets.MonthEnd(0) + + # map each month anchor to nearest trading day in index + out = [] + for d in months: + if d < trading_index.min() or d > trading_index.max(): + continue + loc = trading_index.get_indexer([d], method="nearest")[0] + td = trading_index[loc] + if in_sip_window(td): + out.append(td) + + return pd.DatetimeIndex(sorted(set(out))) + +SALARY_DATES = get_salary_dates(sip_prices.index, SALARY_CREDIT) + +# ========================================================= +# CORE ENGINE: MONTHLY SALARY -> DEPLOY IN EQUAL PARTS ON EVERY Nth TRADING DAY +# ========================================================= +def build_deployment_schedule(trading_index: pd.DatetimeIndex, salary_dates: pd.DatetimeIndex, n: int): + """ + For each month: credit MONTHLY_SIP on salary date, then deploy it in equal chunks + every n trading days until next salary date (exclusive). + """ + deploy_map = {} # date -> cash_to_invest + + for i, s_date in enumerate(salary_dates): + if i < len(salary_dates) - 1: + next_salary = salary_dates[i + 1] + month_window = trading_index[(trading_index >= s_date) & (trading_index < next_salary)] + else: + month_window = trading_index[trading_index >= s_date] + + if len(month_window) == 0: + continue + + # deployment dates: every nth trading day starting at salary date + deploy_dates = month_window[::n] + if len(deploy_dates) == 0: + continue + + per_order = MONTHLY_SIP / len(deploy_dates) + + for d in deploy_dates: + deploy_map[d] = deploy_map.get(d, 0.0) + per_order + + deploy_dates = pd.DatetimeIndex(sorted(deploy_map.keys())) + return deploy_dates, deploy_map + +def run_strategy_for_n(sip_prices: pd.DataFrame, sip_rel_dev: pd.Series, salary_dates: pd.DatetimeIndex, n: int): + deploy_dates, deploy_map = build_deployment_schedule(sip_prices.index, salary_dates, n) + + nifty_units = 0.0 + gold_units = 0.0 + total_invested = 0.0 + prev_value = None + + rows = [] + + for date in deploy_dates: + cash = float(deploy_map[date]) + total_invested += cash + + nifty_price = float(sip_prices.loc[date, "NIFTY"]) + gold_price = float(sip_prices.loc[date, "GOLD"]) + rd = float(sip_rel_dev.loc[date]) if not pd.isna(sip_rel_dev.loc[date]) else np.nan + + tilt = 0.0 + if not np.isnan(rd): + tilt = np.clip(-rd * TILT_MULT, -MAX_TILT, MAX_TILT) + + equity_w = BASE_EQUITY * (1 + tilt) + equity_w = min(max(equity_w, MIN_EQUITY), MAX_EQUITY_W) + gold_w = 1 - equity_w + + nifty_buy = cash * equity_w + gold_buy = cash * gold_w + + nifty_units += nifty_buy / nifty_price + gold_units += gold_buy / gold_price + + nifty_val = nifty_units * nifty_price + gold_val = gold_units * gold_price + port_val = nifty_val + gold_val + + unrealized = port_val - total_invested + if prev_value is None: + period_pnl = 0.0 + else: + period_pnl = port_val - prev_value - cash + prev_value = port_val + + rows.append({ + "Date": date, + "Cash_Added": cash, + "Total_Invested": total_invested, + "Equity_Weight": equity_w, + "Gold_Weight": gold_w, + "Rel_Deviation": rd, + "Portfolio_Value": port_val, + "Period_PnL": period_pnl, + "Unrealized_PnL": unrealized + }) + + ledger = pd.DataFrame(rows).set_index("Date") + if ledger.empty: + return None + + # XIRR + cashflows = [(d, -float(r["Cash_Added"])) for d, r in ledger.iterrows()] + cashflows.append((ledger.index[-1], float(ledger["Portfolio_Value"].iloc[-1]))) + strat_xirr = xirr(cashflows) + + # Drawdown + peak_date, trough_date, mdd = sip_max_drawdown(ledger["Portfolio_Value"]) + + # Worst rolling ~24M XIRR: convert 24 months -> periods using actual deploy frequency for this N + # Estimate periods/month using actual average deployments per month + months = ledger.index.to_period("M") + dep_per_month = months.value_counts().mean() + roll_periods = int(round(24 * dep_per_month)) + roll_periods = max(6, roll_periods) + + worst = worst_rolling_xirr(ledger, roll_periods) + worst_xirr = np.nan if worst is None else float(worst["XIRR"]) + + # Time underwater + underwater = (ledger["Portfolio_Value"] < ledger["Total_Invested"]).mean() + + return { + "N_days": n, + "Deployments": int(len(ledger)), + "XIRR": float(strat_xirr), + "Max_Drawdown": float(mdd), + "DD_Peak": peak_date.date(), + "DD_Trough": trough_date.date(), + "Worst_Rolling_24M_XIRR": worst_xirr, + "Underwater_%": float(underwater * 100), + "Avg_Deployments_Per_Month": float(dep_per_month), + "Rolling_24M_Periods_Used": int(roll_periods) + }, ledger + +# ========================================================= +# OPTIMIZATION RUN +# ========================================================= +results = [] +best_ledgers = {} # optionally store some ledgers + +for n in range(N_MIN, N_MAX + 1): + out = run_strategy_for_n(sip_prices, sip_rel_dev, SALARY_DATES, n) + if out is None: + continue + metrics, ledger = out + results.append(metrics) + + # Store a few ledgers if you want to inspect later + if n in (2, 5, 10, 15, 20, 30): + best_ledgers[n] = ledger + +res_df = pd.DataFrame(results) +if res_df.empty: + raise RuntimeError("No results computed. Check date ranges and tickers.") + +# Sort by worst rolling 24M XIRR first (most important for SIP survivability), +# then by Max Drawdown (less negative is better), then by overall XIRR. +res_df_sorted = res_df.sort_values( + by=["Worst_Rolling_24M_XIRR", "Max_Drawdown", "XIRR"], + ascending=[False, False, False] +) + +pd.set_option("display.width", 200) +pd.set_option("display.max_columns", 50) + +print("\n=== OPTIMIZATION SUMMARY ===") +print(f"Salary credit: {SALARY_CREDIT} | Monthly SIP: ₹{MONTHLY_SIP:,}") +print(f"N range: {N_MIN} → {N_MAX} (trading days)") +print("\nTop 15 configs (ranked by Worst Rolling ~24M XIRR, then DD, then XIRR):") +print(res_df_sorted.head(15).to_string(index=False)) + +print("\nWorst 10 configs:") +print(res_df_sorted.tail(10).to_string(index=False)) + +# Export results +out_file = "SIPXAR_N_Optimization.xlsx" +with pd.ExcelWriter(out_file, engine="xlsxwriter") as writer: + res_df_sorted.to_excel(writer, sheet_name="Optimization_Ranked", index=False) + res_df.sort_values("N_days").to_excel(writer, sheet_name="Optimization_By_N", index=False) + +print(f"\nExcel exported: {out_file}") + +# ========================================================= +# OPTIONAL: SAVE LEDGERS FOR SELECTED N VALUES +# ========================================================= +# Uncomment if you want ledgers saved as separate sheets +# with pd.ExcelWriter("SIPXAR_Selected_Ledgers.xlsx", engine="xlsxwriter") as writer: +# for n, led in best_ledgers.items(): +# led.to_excel(writer, sheet_name=f"N_{n}") +# print("Selected ledgers exported: SIPXAR_Selected_Ledgers.xlsx") diff --git a/strategy_code/zerodha_live_monthly.py b/strategy_code/zerodha_live_monthly.py new file mode 100644 index 0000000..be1714a --- /dev/null +++ b/strategy_code/zerodha_live_monthly.py @@ -0,0 +1,290 @@ +# ========================================================= +# SIPXAR LIVE — SINGLE FILE +# ========================================================= +import os +import yfinance as yf +import numpy as np +import math +import sqlite3 +import logging +import uuid +import datetime as dt +from datetime import datetime, date +from kiteconnect import KiteConnect + +# ========================================================= +# CONFIG +# ========================================================= + +MONTHLY_SIP = 100 + +ALLOW_AFTER_MARKET_TEST = True # ONLY for testing + +NIFTY_YF = "NIFTYBEES.NS" +GOLD_YF = "GOLDBEES.NS" + +NIFTY_KITE = "NIFTYBEES" +GOLD_KITE = "GOLDBEES" + +SMA_MONTHS = 36 + +BASE_EQUITY = 0.60 +TILT_MULT = 1.5 +MAX_TILT = 0.25 +MIN_EQUITY = 0.20 +MAX_EQUITY = 0.90 + +STRATEGY_VERSION = "SIPXAR_v1.0" + +KITE_API_KEY = "YOUR_API_KEY" +KITE_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" + +KILL_SWITCH_FILE = "KILL_SWITCH" +DB_FILE = "sipxar_live.db" +LOG_FILE = f"sipxar_{datetime.now().strftime('%Y_%m')}.log" + +# ========================================================= +# LOGGER +# ========================================================= + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", + handlers=[ + logging.FileHandler(LOG_FILE), + logging.StreamHandler() + ] +) +log = logging.getLogger("SIPXAR") + +# ========================================================= +# LEDGER (SQLite) +# ========================================================= + +def db(): + return sqlite3.connect(DB_FILE) + +def init_db(): + with db() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + run_date TEXT PRIMARY KEY, + run_id TEXT + ) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT, + run_date TEXT, + symbol TEXT, + strategy_version TEXT, + equity_weight REAL, + allocated_amount REAL, + price_used REAL, + quantity INTEGER, + order_id TEXT, + order_status TEXT, + timestamp TEXT, + notes TEXT + ) + """) + +def already_ran_today(): + today = date.today().isoformat() + with db() as conn: + cur = conn.execute("SELECT 1 FROM runs WHERE run_date=?", (today,)) + return cur.fetchone() is not None + +def record_run(run_id): + with db() as conn: + conn.execute( + "INSERT INTO runs (run_date, run_id) VALUES (?, ?)", + (date.today().isoformat(), run_id) + ) + +def record_trade(**r): + with db() as conn: + conn.execute(""" + INSERT INTO trades ( + run_id, run_date, symbol, strategy_version, + equity_weight, allocated_amount, price_used, + quantity, order_id, order_status, timestamp, notes + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + r["run_id"], r["run_date"], r["symbol"], r["strategy_version"], + r["equity_weight"], r["allocated_amount"], r["price_used"], + r["quantity"], r["order_id"], r["order_status"], + datetime.now().isoformat(), r.get("notes") + )) + +# ========================================================= +# DATA (YFINANCE) +# ========================================================= + +def fetch_sma(ticker): + df = yf.download( + ticker, + period=f"{SMA_MONTHS + 2}mo", + interval="1mo", + auto_adjust=True, + progress=False + ).dropna() + return float(df["Close"].tail(SMA_MONTHS).mean()) + +def fetch_price(ticker): + df = yf.download( + ticker, + period="5d", + interval="1d", + auto_adjust=True, + progress=False + ) + return float(df["Close"].iloc[-1].item()) + +# ========================================================= +# STRATEGY +# ========================================================= + +def kill_switch_active(): + return os.path.exists(KILL_SWITCH_FILE) + +def market_open(): + now = dt.datetime.now().time() + return dt.time(9, 15) <= now <= dt.time(15, 30) + +def sanity_check(n_qty, g_qty): + if n_qty <= 0 and g_qty <= 0: + raise RuntimeError("Sanity check failed: both quantities zero") + +def compute_weights(n_price, g_price, n_sma, g_sma): + assert isinstance(n_price, float) + assert isinstance(g_price, float) + assert isinstance(n_sma, float) + assert isinstance(g_sma, float) + + dev_n = (n_price / n_sma) - 1 + dev_g = (g_price / g_sma) - 1 + rel = dev_n - dev_g + + tilt = np.clip(-rel * TILT_MULT, -MAX_TILT, MAX_TILT) + + eq_w = BASE_EQUITY * (1 + tilt) + eq_w = min(max(eq_w, MIN_EQUITY), MAX_EQUITY) + + return eq_w, 1 - eq_w + +# ========================================================= +# ZERODHA EXECUTION +# ========================================================= + +kite = KiteConnect(api_key=KITE_API_KEY) +kite.set_access_token(KITE_ACCESS_TOKEN) + +def place_buy(symbol, qty): + if qty <= 0: + return None + return kite.place_order( + variety=kite.VARIETY_REGULAR, + exchange=kite.EXCHANGE_NSE, + tradingsymbol=symbol, + transaction_type=kite.TRANSACTION_TYPE_BUY, + quantity=qty, + order_type=kite.ORDER_TYPE_MARKET, + product=kite.PRODUCT_CNC + ) + +# ========================================================= +# MAIN RUN +# ========================================================= + +def main(): + init_db() + + if kill_switch_active(): + log.critical("KILL SWITCH ACTIVE — ABORTING EXECUTION") + return + + if not market_open() and not ALLOW_AFTER_MARKET_TEST: + log.error("Market closed — aborting") + return + + if not market_open() and ALLOW_AFTER_MARKET_TEST: + log.warning("Market closed — TEST MODE override enabled") + + if already_ran_today(): + log.error("ABORT: Strategy already executed today") + return + + run_id = str(uuid.uuid4()) + record_run(run_id) + + log.info(f"Starting SIPXAR LIVE | run_id={run_id}") + + n_price = fetch_price(NIFTY_YF) + g_price = fetch_price(GOLD_YF) + + n_sma = fetch_sma(NIFTY_YF) + g_sma = fetch_sma(GOLD_YF) + + eq_w, g_w = compute_weights(n_price, g_price, n_sma, g_sma) + + n_amt = MONTHLY_SIP * eq_w + g_amt = MONTHLY_SIP * g_w + + n_qty = math.floor(n_amt / n_price) + g_qty = math.floor(g_amt / g_price) + sanity_check(n_qty, g_qty) + log.info(f"Weights → Equity={eq_w:.2f}, Gold={g_w:.2f}") + log.info(f"Qty → NIFTY={n_qty}, GOLD={g_qty}") + + for sym, qty, price, amt in [ + (NIFTY_KITE, n_qty, n_price, n_amt), + (GOLD_KITE, g_qty, g_price, g_amt) + ]: + if qty <= 0: + log.warning(f"Skipping {sym}, qty=0") + continue + + try: + oid = place_buy(sym, qty) + log.info(f"Order placed {sym} | order_id={oid}") + + record_trade( + run_id=run_id, + run_date=date.today().isoformat(), + symbol=sym, + strategy_version=STRATEGY_VERSION, + equity_weight=eq_w, + allocated_amount=amt, + price_used=price, + quantity=qty, + order_id=oid, + order_status="PLACED", + notes="LIVE order" + ) + + except Exception as e: + log.error(f"Order failed {sym}: {e}") + + record_trade( + run_id=run_id, + run_date=date.today().isoformat(), + symbol=sym, + strategy_version=STRATEGY_VERSION, + equity_weight=eq_w, + allocated_amount=amt, + price_used=price, + quantity=qty, + order_id=None, + order_status="FAILED", + notes=str(e) + ) + + log.info("SIPXAR LIVE run completed") + +# ========================================================= + +if __name__ == "__main__": + main()