From 519addd78f7cce6d3a9d21e3a6eea7b95579b9e5 Mon Sep 17 00:00:00 2001 From: Thigazhezhilan J Date: Wed, 8 Apr 2026 22:02:24 +0530 Subject: [PATCH] Harden backend auth, execution safety, and market session logic --- backend/app/db_models.py | 116 +++++++- backend/app/main.py | 166 +++++++---- backend/app/routers/auth.py | 6 +- backend/app/routers/broker.py | 63 +++- backend/app/routers/strategy.py | 52 ++-- backend/app/routers/support_ticket.py | 35 ++- backend/app/services/auth_service.py | 63 +++- backend/app/services/broker_callback_state.py | 111 ++++++++ backend/app/services/db.py | 21 +- backend/app/services/run_service.py | 2 +- backend/app/services/strategy_service.py | 167 +++++++---- backend/app/services/support_abuse.py | 224 +++++++++++++++ backend/app/services/support_ticket.py | 7 +- backend/app/services/system_service.py | 17 +- backend/app/services/tenant.py | 4 - backend/app/services/zerodha_service.py | 4 +- backend/requirements.txt | 1 + backend/tests/conftest.py | 14 + backend/tests/test_api_semantics_and_utc.py | 111 ++++++++ backend/tests/test_auth_isolation_and_cors.py | 128 +++++++++ backend/tests/test_execution_claims.py | 120 ++++++++ backend/tests/test_market_calendar.py | 102 +++++++ backend/tests/test_runner_leases.py | 205 +++++++++++++ backend/tests/test_security_hardening.py | 269 ++++++++++++++++++ backend/tests/test_support_throttling.py | 131 +++++++++ indian_paper_trading_strategy/engine/db.py | 200 +++++++++++-- .../engine/execution.py | 4 +- .../engine/ledger.py | 79 ++++- .../engine/market.py | 79 ++--- .../engine/market_calendar.py | 201 +++++++++++++ .../engine/runner.py | 256 ++++++++++++++--- indian_paper_trading_strategy/engine/state.py | 65 ++--- .../engine/time_utils.py | 90 ++++-- 33 files changed, 2753 insertions(+), 360 deletions(-) create mode 100644 backend/app/services/broker_callback_state.py create mode 100644 backend/app/services/support_abuse.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_api_semantics_and_utc.py create mode 100644 backend/tests/test_auth_isolation_and_cors.py create mode 100644 backend/tests/test_execution_claims.py create mode 100644 backend/tests/test_market_calendar.py create mode 100644 backend/tests/test_runner_leases.py create mode 100644 backend/tests/test_security_hardening.py create mode 100644 backend/tests/test_support_throttling.py create mode 100644 indian_paper_trading_strategy/engine/market_calendar.py diff --git a/backend/app/db_models.py b/backend/app/db_models.py index 7b905c3..f40cc25 100644 --- a/backend/app/db_models.py +++ b/backend/app/db_models.py @@ -47,6 +47,8 @@ class AppSession(Base): created_at = Column(DateTime(timezone=True), nullable=False) last_seen_at = Column(DateTime(timezone=True)) expires_at = Column(DateTime(timezone=True), nullable=False) + ip = Column(Text) + user_agent = Column(Text) __table_args__ = ( Index("idx_app_session_user_id", "user_id"), @@ -63,6 +65,8 @@ class UserBroker(Base): access_token = Column(Text) connected_at = Column(DateTime(timezone=True)) api_key = Column(Text) + api_secret = Column(Text) + auth_state = Column(Text) user_name = Column(Text) broker_user_id = Column(Text) pending_broker = Column(Text) @@ -117,7 +121,10 @@ class StrategyRun(Base): __table_args__ = ( UniqueConstraint("user_id", "run_id", name="uq_strategy_run_user_run"), - CheckConstraint("status IN ('RUNNING','STOPPED','ERROR')", name="chk_strategy_run_status"), + CheckConstraint( + "status IN ('RUNNING','STOPPED','ERROR','PAUSED_AUTH_EXPIRED')", + name="chk_strategy_run_status", + ), Index("idx_strategy_run_user_status", "user_id", "status"), Index("idx_strategy_run_user_created", "user_id", "created_at"), Index( @@ -129,6 +136,22 @@ class StrategyRun(Base): ) +class PasswordResetOtp(Base): + __tablename__ = "password_reset_otp" + + id = Column(String, primary_key=True) + email = Column(Text, nullable=False) + otp_hash = Column(Text, nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + expires_at = Column(DateTime(timezone=True), nullable=False) + used_at = Column(DateTime(timezone=True)) + + __table_args__ = ( + Index("idx_password_reset_otp_email", "email"), + Index("idx_password_reset_otp_expires_at", "expires_at"), + ) + + class StrategyConfig(Base): __tablename__ = "strategy_config" @@ -420,6 +443,97 @@ class LiveEquitySnapshot(Base): ) +class SupportTicket(Base): + __tablename__ = "support_ticket" + + id = Column(String, primary_key=True) + name = Column(Text, nullable=False) + email = Column(Text, nullable=False) + subject = Column(Text, nullable=False) + message = Column(Text, nullable=False) + status = Column(Text, nullable=False, server_default=text("'NEW'")) + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + __table_args__ = ( + Index("idx_support_ticket_email", "email"), + Index("idx_support_ticket_created_at", "created_at"), + ) + + +class SupportRequestAudit(Base): + __tablename__ = "support_request_audit" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + endpoint = Column(Text, nullable=False) + ip_hash = Column(Text) + email_hash = Column(Text) + ticket_hash = Column(Text) + blocked = Column(Boolean, nullable=False, server_default=text("false")) + reason = Column(Text) + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + __table_args__ = ( + Index("idx_support_request_audit_endpoint_ip_created", "endpoint", "ip_hash", "created_at"), + Index("idx_support_request_audit_ticket_created", "ticket_hash", "created_at"), + ) + + +class BrokerCallbackState(Base): + __tablename__ = "broker_callback_state" + + id = Column(String, primary_key=True) + state_hash = Column(Text, nullable=False, unique=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + session_id = Column(String, ForeignKey("app_session.id", ondelete="CASCADE"), nullable=False) + broker = Column(Text, nullable=False) + flow = Column(Text, nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False) + expires_at = Column(DateTime(timezone=True), nullable=False) + consumed_at = Column(DateTime(timezone=True)) + + __table_args__ = ( + Index( + "idx_broker_callback_state_lookup", + "user_id", + "session_id", + "broker", + "flow", + "expires_at", + ), + ) + + +class ExecutionClaim(Base): + __tablename__ = "execution_claim" + + id = Column(String, primary_key=True) + user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False) + run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False) + mode = Column(Text, nullable=False) + logical_time = Column(DateTime(timezone=True), nullable=False) + claimed_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + __table_args__ = ( + UniqueConstraint("user_id", "run_id", "logical_time", name="uq_execution_claim_scope"), + Index("idx_execution_claim_run_claimed", "run_id", "claimed_at"), + ) + + +class RunLease(Base): + __tablename__ = "run_leases" + + run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), primary_key=True) + owner_id = Column(Text, nullable=False) + leased_at = Column(DateTime(timezone=True), nullable=False) + expires_at = Column(DateTime(timezone=True), nullable=False) + heartbeat_at = Column(DateTime(timezone=True)) + + __table_args__ = ( + Index("idx_run_leases_owner_expires", "owner_id", "expires_at"), + ) + + class MTMLedger(Base): __tablename__ = "mtm_ledger" diff --git a/backend/app/main.py b/backend/app/main.py index 72fc3b0..c9c4bec 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,82 +1,126 @@ import os +from urllib.parse import urlparse from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware + +from app.admin_role_service import bootstrap_super_admin +from app.admin_router import router as admin_router from app.routers.auth import router as auth_router from app.routers.broker import router as broker_router from app.routers.health import router as health_router +from app.routers.paper import router as paper_router from app.routers.password_reset import router as password_reset_router +from app.routers.strategy import router as strategy_router from app.routers.support_ticket import router as support_ticket_router from app.routers.system import router as system_router -from app.routers.strategy import router as strategy_router from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router -from app.routers.paper import router as paper_router -from market import router as market_router -from paper_mtm import router as paper_mtm_router +from app.services.db import _db_config as _validate_db_config from app.services.live_equity_service import start_live_equity_snapshot_daemon from app.services.strategy_service import init_log_state, resume_running_runs -from app.admin_router import router as admin_router -from app.admin_role_service import bootstrap_super_admin +from market import router as market_router +from paper_mtm import router as paper_mtm_router -app = FastAPI( - title="QuantFortune Backend", - version="1.0" -) +DEFAULT_PRODUCTION_ORIGINS = {"https://app.quantfortune.com"} +DEFAULT_DEV_ORIGINS = { + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:5173", + "http://127.0.0.1:5173", +} +PRODUCTION_ENV_NAMES = {"prod", "production"} -cors_origins = [ - origin.strip() - for origin in os.getenv("CORS_ORIGINS", "").split(",") - if origin.strip() -] -if not cors_origins: - cors_origins = [ - "http://localhost:3000", - "http://127.0.0.1:3000", - ] -cors_origin_regex = os.getenv("CORS_ORIGIN_REGEX", "").strip() -if not cors_origin_regex: - cors_origin_regex = ( - r"https://.*\\.ngrok-free\\.dev" - r"|https://.*\\.ngrok-free\\.app" - r"|https://.*\\.ngrok\\.io" +def _environment_name() -> str: + return ( + os.getenv("APP_ENV") + or os.getenv("ENVIRONMENT") + or os.getenv("FASTAPI_ENV") + or "development" + ).strip().lower() + + +def _normalize_origin(origin: str) -> str: + return origin.strip().rstrip("/") + + +def _is_dev_origin(origin: str) -> bool: + parsed = urlparse(origin) + return parsed.scheme == "http" and parsed.hostname in {"localhost", "127.0.0.1"} + + +def _validate_cors_origin(origin: str) -> str: + normalized = _normalize_origin(origin) + if not normalized: + raise RuntimeError("Empty CORS origin is not allowed") + if normalized in DEFAULT_PRODUCTION_ORIGINS or _is_dev_origin(normalized): + return normalized + raise RuntimeError( + f"Unsupported CORS origin '{normalized}'. Only app.quantfortune.com and localhost dev origins are allowed." ) -# app.add_middleware( -# CORSMiddleware, -# allow_origins=cors_origins, -# allow_origin_regex=cors_origin_regex or None, -# allow_credentials=True, -# allow_methods=["*"], -# allow_headers=["*"], -# ) -app.add_middleware( - CORSMiddleware, - allow_origins=[], # must be empty when using regex - allow_origin_regex=".*", # allow ANY origin - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +def _build_cors_origins() -> list[str]: + configured = [ + _normalize_origin(origin) + for origin in os.getenv("CORS_ORIGINS", "").split(",") + if origin.strip() + ] + env_name = _environment_name() -app.include_router(strategy_router) -app.include_router(auth_router) -app.include_router(broker_router) -app.include_router(zerodha_router) -app.include_router(zerodha_public_router) -app.include_router(paper_router) -app.include_router(market_router) -app.include_router(paper_mtm_router) -app.include_router(health_router) -app.include_router(system_router) -app.include_router(admin_router) -app.include_router(support_ticket_router) -app.include_router(password_reset_router) + if env_name in PRODUCTION_ENV_NAMES: + if not configured: + raise RuntimeError("CORS_ORIGINS must be configured explicitly in production") + origins = configured + else: + origins = configured or sorted(DEFAULT_DEV_ORIGINS) -@app.on_event("startup") -def init_app_state(): - init_log_state() - bootstrap_super_admin() - resume_running_runs() - start_live_equity_snapshot_daemon() + deduped: list[str] = [] + seen: set[str] = set() + for origin in origins: + validated = _validate_cors_origin(origin) + if validated not in seen: + seen.add(validated) + deduped.append(validated) + return deduped + + +def create_app() -> FastAPI: + _validate_db_config() + app = FastAPI(title="QuantFortune Backend", version="1.0") + cors_origins = _build_cors_origins() + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.include_router(strategy_router) + app.include_router(auth_router) + app.include_router(broker_router) + app.include_router(zerodha_router) + app.include_router(zerodha_public_router) + app.include_router(paper_router) + app.include_router(market_router) + app.include_router(paper_mtm_router) + app.include_router(health_router) + app.include_router(system_router) + app.include_router(admin_router) + app.include_router(support_ticket_router) + app.include_router(password_reset_router) + + @app.on_event("startup") + def init_app_state(): + if os.getenv("DISABLE_STARTUP_TASKS", "0") == "1": + return + init_log_state() + bootstrap_super_admin() + resume_running_runs() + start_live_equity_snapshot_daemon() + + return app + + +app = create_app() diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index c794c97..84733e6 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -15,8 +15,12 @@ from app.services.email_service import send_email_async router = APIRouter(prefix="/api") SESSION_COOKIE_NAME = "session_id" -COOKIE_SECURE = os.getenv("COOKIE_SECURE", "0") == "1" +APP_ENV = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower() +IS_PRODUCTION = APP_ENV in {"prod", "production"} +COOKIE_SECURE = True if IS_PRODUCTION else os.getenv("COOKIE_SECURE", "0") == "1" COOKIE_SAMESITE = (os.getenv("COOKIE_SAMESITE") or "lax").lower() +if IS_PRODUCTION and not COOKIE_SECURE: + raise RuntimeError("Secure session cookies are mandatory in production") def _set_session_cookie(response: Response, session_id: str): diff --git a/backend/app/routers/broker.py b/backend/app/routers/broker.py index 4adc7f5..9b63fdb 100644 --- a/backend/app/routers/broker.py +++ b/backend/app/routers/broker.py @@ -15,6 +15,10 @@ from app.broker_store import ( set_pending_broker, ) from app.services.auth_service import get_user_for_session +from app.services.broker_callback_state import ( + consume_broker_callback_state, + create_broker_callback_state, +) from app.services.email_service import send_email_async from app.services.groww_service import ( GrowwApiError, @@ -60,6 +64,13 @@ def _require_user(request: Request): return user +def _require_session_id(request: Request) -> str: + session_id = request.cookies.get("session_id") + if not session_id: + raise HTTPException(status_code=401, detail="Not authenticated") + return session_id + + def _first_number(*values, default: float = 0.0) -> float: for value in values: try: @@ -317,6 +328,7 @@ def _normalize_groww_funds(data: dict | None) -> dict: def _build_saved_broker_login_url( request: Request, user_id: str, + session_id: str, redirect_url_override: str | None = None, ) -> str: entry = get_user_broker(user_id) or {} @@ -332,7 +344,13 @@ def _build_saved_broker_login_url( if not redirect_url: base = str(request.base_url).rstrip("/") redirect_url = f"{base}/api/broker/callback" - return build_login_url(creds["api_key"], redirect_url=redirect_url) + state = create_broker_callback_state( + user_id=user_id, + session_id=session_id, + broker="ZERODHA", + flow="reconnect", + ) + return build_login_url(creds["api_key"], redirect_url=redirect_url, state=state) def _notify_broker_connected(username: str, broker: str, broker_user_id: str | None): @@ -401,6 +419,7 @@ async def disconnect_broker(request: Request): @router.post("/zerodha/login") async def zerodha_login(payload: dict, request: Request): user = _require_user(request) + session_id = _require_session_id(request) api_key = (payload.get("apiKey") or "").strip() api_secret = (payload.get("apiSecret") or "").strip() redirect_url = (payload.get("redirectUrl") or "").strip() @@ -408,7 +427,13 @@ async def zerodha_login(payload: dict, request: Request): raise HTTPException(status_code=400, detail="API key and secret are required") set_pending_broker(user["id"], "ZERODHA", api_key, api_secret) - return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None)} + state = create_broker_callback_state( + user_id=user["id"], + session_id=session_id, + broker="ZERODHA", + flow="connect", + ) + return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None, state=state)} @router.post("/groww/connect") @@ -490,11 +515,23 @@ async def groww_reconnect(request: Request): @router.get("/zerodha/callback") -async def zerodha_callback(request: Request, request_token: str = ""): +async def zerodha_callback(request: Request, request_token: str = "", state: str = ""): user = _require_user(request) + session_id = _require_session_id(request) token = request_token.strip() + callback_state = state.strip() if not token: raise HTTPException(status_code=400, detail="Missing request_token") + if not callback_state: + raise HTTPException(status_code=400, detail="Missing state") + if not consume_broker_callback_state( + state=callback_state, + user_id=user["id"], + session_id=session_id, + broker="ZERODHA", + flow="connect", + ): + raise HTTPException(status_code=401, detail="Invalid or expired broker callback state") pending = get_pending_broker(user["id"]) or {} api_key = (pending.get("api_key") or "").strip() @@ -541,32 +578,46 @@ async def zerodha_callback(request: Request, request_token: str = ""): @router.get("/login") async def broker_login(request: Request): user = _require_user(request) + session_id = _require_session_id(request) redirect_url = ( (request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "") .strip() or None ) - login_url = _build_saved_broker_login_url(request, user["id"], redirect_url) + login_url = _build_saved_broker_login_url(request, user["id"], session_id, redirect_url) return RedirectResponse(login_url) @router.get("/login-url") async def broker_login_url(request: Request): user = _require_user(request) + session_id = _require_session_id(request) redirect_url = ( (request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "") .strip() or None ) - return {"loginUrl": _build_saved_broker_login_url(request, user["id"], redirect_url)} + return {"loginUrl": _build_saved_broker_login_url(request, user["id"], session_id, redirect_url)} @router.get("/callback") -async def broker_callback(request: Request, request_token: str = ""): +async def broker_callback(request: Request, request_token: str = "", state: str = ""): user = _require_user(request) + session_id = _require_session_id(request) token = request_token.strip() + callback_state = state.strip() if not token: raise HTTPException(status_code=400, detail="Missing request_token") + if not callback_state: + raise HTTPException(status_code=400, detail="Missing state") + if not consume_broker_callback_state( + state=callback_state, + user_id=user["id"], + session_id=session_id, + broker="ZERODHA", + flow="reconnect", + ): + raise HTTPException(status_code=401, detail="Invalid or expired broker callback state") creds = get_broker_credentials(user["id"]) if not creds: raise HTTPException(status_code=400, detail="Broker credentials not configured") diff --git a/backend/app/routers/strategy.py b/backend/app/routers/strategy.py index 3e41544..b3d1e5d 100644 --- a/backend/app/routers/strategy.py +++ b/backend/app/routers/strategy.py @@ -1,5 +1,4 @@ -from fastapi import APIRouter, HTTPException, Query, Request -from fastapi.responses import JSONResponse +from fastapi import APIRouter, HTTPException, Query, Request, status as http_status from app.models import StrategyStartRequest from app.services.strategy_service import ( start_strategy, @@ -15,6 +14,20 @@ from app.services.tenant import get_request_user_id router = APIRouter(prefix="/api") + +def _raise_strategy_error(payload: dict, *, default_status: int) -> None: + message = payload.get("message") or payload.get("detail") or "Strategy operation failed" + raise HTTPException( + status_code=default_status, + detail={ + "status": payload.get("status", "error"), + "message": message, + "run_id": payload.get("run_id"), + "redirect_url": payload.get("redirect_url"), + "broker": payload.get("broker"), + }, + ) + @router.post("/strategy/start") def start(req: StrategyStartRequest, request: Request): user_id = get_request_user_id(request) @@ -24,35 +37,38 @@ def start(req: StrategyStartRequest, request: Request): def stop(request: Request): try: user_id = get_request_user_id(request) - return stop_strategy(user_id) + result = stop_strategy(user_id) + if result.get("status") not in {"stopped", "already_stopped"}: + _raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT) + return result except HTTPException: raise except Exception as exc: print(f"[STRATEGY] unhandled stop route failure: {exc}", flush=True) - return JSONResponse( - status_code=200, - content={ - "status": "stop_failed", - "message": f"Unable to stop strategy: {exc}", - }, - ) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"status": "stop_failed", "message": f"Unable to stop strategy: {exc}"}, + ) from exc @router.post("/strategy/resume") def resume(request: Request): try: user_id = get_request_user_id(request) - return resume_strategy(user_id) + result = resume_strategy(user_id) + success_statuses = {"resumed", "already_running"} + if result.get("status") == "broker_auth_required": + _raise_strategy_error(result, default_status=http_status.HTTP_401_UNAUTHORIZED) + if result.get("status") not in success_statuses: + _raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT) + return result except HTTPException: raise except Exception as exc: print(f"[STRATEGY] unhandled resume route failure: {exc}", flush=True) - return JSONResponse( - status_code=200, - content={ - "status": "resume_failed", - "message": f"Unable to resume strategy: {exc}", - }, - ) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"status": "resume_failed", "message": f"Unable to resume strategy: {exc}"}, + ) from exc @router.get("/strategy/status") def status(request: Request): diff --git a/backend/app/routers/support_ticket.py b/backend/app/routers/support_ticket.py index 06db908..db31325 100644 --- a/backend/app/routers/support_ticket.py +++ b/backend/app/routers/support_ticket.py @@ -1,6 +1,7 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Header, HTTPException, Request from pydantic import BaseModel +from app.services.support_abuse import SupportGuardRejected, enforce_support_guard from app.services.support_ticket import create_ticket, get_ticket_status @@ -19,7 +20,20 @@ class TicketStatusRequest(BaseModel): @router.post("/ticket") -def submit_ticket(payload: TicketCreate): +def submit_ticket( + payload: TicketCreate, + request: Request, + support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"), +): + try: + enforce_support_guard( + request=request, + endpoint="ticket_create", + email=payload.email.strip(), + captcha_token=support_captcha, + ) + except SupportGuardRejected as exc: + raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc if not payload.subject.strip() or not payload.message.strip(): raise HTTPException(status_code=400, detail="Subject and message are required") ticket = create_ticket( @@ -32,7 +46,22 @@ def submit_ticket(payload: TicketCreate): @router.post("/ticket/status/{ticket_id}") -def ticket_status(ticket_id: str, payload: TicketStatusRequest): +def ticket_status( + ticket_id: str, + payload: TicketStatusRequest, + request: Request, + support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"), +): + try: + enforce_support_guard( + request=request, + endpoint="ticket_status", + email=payload.email.strip(), + ticket_id=ticket_id.strip(), + captcha_token=support_captcha, + ) + except SupportGuardRejected as exc: + raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc status = get_ticket_status(ticket_id.strip(), payload.email.strip()) if not status: raise HTTPException(status_code=404, detail="Ticket not found") diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index ca6a6b5..813cbac 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -1,9 +1,13 @@ import hashlib import os +import re import secrets from datetime import datetime, timedelta, timezone from uuid import uuid4 +from argon2 import PasswordHasher +from argon2.exceptions import InvalidHash, VerifyMismatchError + from app.services.db import db_connection SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) @@ -11,7 +15,11 @@ SESSION_REFRESH_WINDOW_SECONDS = int( os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60)) ) RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10")) -RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret") +PASSWORD_HASHER = PasswordHasher() +LEGACY_SHA256_RE = re.compile(r"^[0-9a-f]{64}$") +RESET_OTP_SECRET = (os.getenv("RESET_OTP_SECRET") or "").strip() +if not RESET_OTP_SECRET: + raise RuntimeError("RESET_OTP_SECRET must be configured") def _now_utc() -> datetime: @@ -23,9 +31,17 @@ def _new_expiry(now: datetime) -> datetime: def _hash_password(password: str) -> str: + return PASSWORD_HASHER.hash(password) + + +def _hash_password_legacy(password: str) -> str: return hashlib.sha256(password.encode("utf-8")).hexdigest() +def _is_legacy_password_hash(password_hash: str | None) -> bool: + return bool(password_hash and LEGACY_SHA256_RE.fullmatch(password_hash)) + + def _hash_otp(email: str, otp: str) -> str: payload = f"{email}:{otp}:{RESET_OTP_SECRET}" return hashlib.sha256(payload.encode("utf-8")).hexdigest() @@ -80,12 +96,47 @@ def create_user(username: str, password: str): return _row_to_user(cur.fetchone()) +def _update_password_hash(user_id: str, password_hash: str): + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + "UPDATE app_user SET password_hash = %s WHERE id = %s", + (password_hash, user_id), + ) + + +def _verify_password(user_id: str, stored_hash: str | None, password: str) -> tuple[bool, str | None]: + if not stored_hash: + return False, None + if _is_legacy_password_hash(stored_hash): + if secrets.compare_digest(stored_hash, _hash_password_legacy(password)): + return True, _hash_password(password) + return False, None + + try: + verified = PASSWORD_HASHER.verify(stored_hash, password) + except (VerifyMismatchError, InvalidHash): + return False, None + + if not verified: + return False, None + + if PASSWORD_HASHER.check_needs_rehash(stored_hash): + return True, _hash_password(password) + return True, None + + def authenticate_user(username: str, password: str): user = get_user_by_username(username) if not user: return None - if user.get("password") != _hash_password(password): + verified, replacement_hash = _verify_password(user["id"], user.get("password"), password) + if not verified: return None + if replacement_hash: + _update_password_hash(user["id"], replacement_hash) + user["password"] = replacement_hash return user @@ -130,13 +181,7 @@ def get_last_session_meta(user_id: str): def update_user_password(user_id: str, new_password: str): password_hash = _hash_password(new_password) - with db_connection() as conn: - with conn: - with conn.cursor() as cur: - cur.execute( - "UPDATE app_user SET password_hash = %s WHERE id = %s", - (password_hash, user_id), - ) + _update_password_hash(user_id, password_hash) def create_password_reset_otp(email: str): diff --git a/backend/app/services/broker_callback_state.py b/backend/app/services/broker_callback_state.py new file mode 100644 index 0000000..9631c20 --- /dev/null +++ b/backend/app/services/broker_callback_state.py @@ -0,0 +1,111 @@ +import hashlib +import secrets +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +from app.services.db import db_connection + +CALLBACK_STATE_TTL_SECONDS = 15 * 60 + + +def _now_utc() -> datetime: + return datetime.now(timezone.utc) + + +def _state_hash(state: str) -> str: + return hashlib.sha256(state.encode("utf-8")).hexdigest() + + +def create_broker_callback_state( + *, + user_id: str, + session_id: str, + broker: str, + flow: str, + ttl_seconds: int = CALLBACK_STATE_TTL_SECONDS, +) -> str: + state = secrets.token_urlsafe(32) + now = _now_utc() + expires_at = now + timedelta(seconds=ttl_seconds) + state_hash = _state_hash(state) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + DELETE FROM broker_callback_state + WHERE expires_at <= %s OR consumed_at IS NOT NULL + """, + (now,), + ) + cur.execute( + """ + INSERT INTO broker_callback_state ( + id, + state_hash, + user_id, + session_id, + broker, + flow, + created_at, + expires_at, + consumed_at + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NULL) + """, + ( + str(uuid4()), + state_hash, + user_id, + session_id, + broker.strip().upper(), + flow.strip().lower(), + now, + expires_at, + ), + ) + return state + + +def consume_broker_callback_state( + *, + state: str, + user_id: str, + session_id: str, + broker: str, + flow: str, +): + if not state: + return None + now = _now_utc() + state_hash = _state_hash(state) + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE broker_callback_state + SET consumed_at = %s + WHERE state_hash = %s + AND user_id = %s + AND session_id = %s + AND broker = %s + AND flow = %s + AND consumed_at IS NULL + AND expires_at > %s + RETURNING id, expires_at + """, + ( + now, + state_hash, + user_id, + session_id, + broker.strip().upper(), + flow.strip().lower(), + now, + ), + ) + row = cur.fetchone() + if not row: + return None + return {"id": row[0], "expires_at": row[1].isoformat() if row[1] else None} diff --git a/backend/app/services/db.py b/backend/app/services/db.py index 97796a3..b793cc7 100644 --- a/backend/app/services/db.py +++ b/backend/app/services/db.py @@ -16,6 +16,7 @@ Base = declarative_base() _ENGINE: Engine | None = None _ENGINE_LOCK = threading.Lock() +NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"} class _ConnectionProxy: @@ -44,16 +45,28 @@ class _ConnectionProxy: def _db_config() -> dict[str, str | int]: + env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower() + is_non_prod = env_name in NON_PROD_ENVIRONMENTS url = os.getenv("DATABASE_URL") if url: return {"url": url} + password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") + if not password and not is_non_prod: + raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments") + + host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None) + dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None) + user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None) + if not is_non_prod and (not host or not dbname or not user): + raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments") + return { - "host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost", + "host": host, "port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"), - "dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db", - "user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader", - "password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass", + "dbname": dbname, + "user": user, + "password": password, "connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")), "schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app", } diff --git a/backend/app/services/run_service.py b/backend/app/services/run_service.py index bf57f54..d507dd5 100644 --- a/backend/app/services/run_service.py +++ b/backend/app/services/run_service.py @@ -172,6 +172,6 @@ def update_run_status(user_id: str, run_id: str, status: str, meta: dict | None """, (status, now, Json(meta or {}), run_id, user_id), ) - return True + return cur.rowcount > 0 return run_with_retry(_op) diff --git a/backend/app/services/strategy_service.py b/backend/app/services/strategy_service.py index 5f3cd22..a5ae7f8 100644 --- a/backend/app/services/strategy_service.py +++ b/backend/app/services/strategy_service.py @@ -4,17 +4,28 @@ import sys import threading from datetime import datetime, timedelta, timezone from pathlib import Path -from zoneinfo import ZoneInfo ENGINE_ROOT = Path(__file__).resolve().parents[3] if str(ENGINE_ROOT) not in sys.path: sys.path.append(str(ENGINE_ROOT)) -from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now -from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine +from indian_paper_trading_strategy.engine.market import ( + align_to_market_open, + market_now, + market_session, + next_market_open_after, +) +from indian_paper_trading_strategy.engine.market_calendar import UnsupportedCalendarYearError +from indian_paper_trading_strategy.engine.runner import RunLeaseNotAcquiredError, start_engine, stop_engine from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state from indian_paper_trading_strategy.engine.broker import PaperBroker -from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta +from indian_paper_trading_strategy.engine.time_utils import ( + UTC, + frequency_to_timedelta, + parse_market_timestamp, + parse_persisted_timestamp, + serialize_timestamp, +) from indian_paper_trading_strategy.engine.db import engine_context from app.broker_store import get_user_broker, set_broker_auth_state @@ -41,7 +52,6 @@ SEQ_LOCK = threading.Lock() SEQ = 0 LAST_WAIT_LOG_TS = {} WAIT_LOG_INTERVAL = timedelta(seconds=60) -IST = ZoneInfo("Asia/Kolkata") def init_log_state(): global SEQ @@ -110,7 +120,7 @@ def emit_event( evt = { "seq": seq, - "ts": now.isoformat().replace("+00:00", "Z"), + "ts": serialize_timestamp(now), "level": level, "category": category, "event": event, @@ -157,14 +167,8 @@ def _maybe_parse_json(value): return value -def _local_tz(): - return IST - - -def _format_local_ts(value: datetime | None): - if value is None: - return None - return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() +def _utc_now(): + return datetime.now(UTC) def _load_config(user_id: str, run_id: str): @@ -192,7 +196,7 @@ def _load_config(user_id: str, run_id: str): "frequency": _maybe_parse_json(row[7]), "frequency_days": row[8], "unit": row[9], - "next_run": _format_local_ts(row[10]), + "next_run": serialize_timestamp(row[10]), } if row[2] is not None or row[3] is not None: cfg["sip_frequency"] = { @@ -217,13 +221,7 @@ def _save_config(cfg, user_id: str, run_id: str): next_run = cfg.get("next_run") next_run_dt = None if isinstance(next_run, str): - try: - parsed = datetime.fromisoformat(next_run) - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=_local_tz()) - next_run_dt = parsed - except ValueError: - next_run_dt = None + next_run_dt = parse_persisted_timestamp(next_run) with db_connection() as conn: with conn: @@ -294,7 +292,7 @@ def reactivate_strategy_config(user_id: str, run_id: str): return cfg def _write_status(user_id: str, run_id: str, status): - now_local = market_now() + now_local = _utc_now() with db_connection() as conn: with conn: with conn.cursor() as cur: @@ -346,6 +344,12 @@ def _effective_running_run_id(user_id: str): ) return None + +def _set_run_status_or_raise(user_id: str, run_id: str, status: str, meta: dict | None = None): + updated = update_run_status(user_id, run_id, status, meta=meta) + if not updated: + raise RuntimeError(f"Run {run_id} for user {user_id} no longer exists") + def validate_frequency(freq: dict, mode: str): if not isinstance(freq, dict): raise ValueError("Frequency payload is required") @@ -436,9 +440,8 @@ def _validate_live_broker_session(user_id: str): def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): if not last_run or not sip_frequency: return None - try: - last_dt = datetime.fromisoformat(last_run) - except ValueError: + last_dt = parse_market_timestamp(last_run) + if last_dt is None: return None try: delta = frequency_to_timedelta(sip_frequency) @@ -446,7 +449,7 @@ def compute_next_eligible(last_run: str | None, sip_frequency: dict | None): return None next_dt = last_dt + delta next_dt = align_to_market_open(next_dt) - return next_dt.isoformat() + return serialize_timestamp(next_dt) def _last_execution_ts(state: dict, mode: str) -> str | None: @@ -473,7 +476,10 @@ def start_strategy(req, user_id: str): return {"status": "already_running", "run_id": running_run_id} engine_config = _build_engine_config(user_id, running_run_id, req) if engine_config: - started = start_engine(engine_config) + try: + started = start_engine(engine_config) + except RunLeaseNotAcquiredError: + return {"status": "already_running", "run_id": running_run_id} if started: _write_status(user_id, running_run_id, "RUNNING") return {"status": "restarted", "run_id": running_run_id} @@ -573,7 +579,10 @@ def start_strategy(req, user_id: str): engine_config["run_id"] = run_id engine_config["user_id"] = user_id engine_config["emit_event"] = emit_event_cb - start_engine(engine_config) + try: + start_engine(engine_config) + except RunLeaseNotAcquiredError: + pass try: user = get_user_by_id(user_id) @@ -655,14 +664,17 @@ def resume_running_runs(): engine_config = _build_engine_config(user_id, run_id, None) if not engine_config: continue - started = start_engine(engine_config) + try: + started = start_engine(engine_config) + except RunLeaseNotAcquiredError: + started = False if started: _write_status(user_id, run_id, "RUNNING") def stop_strategy(user_id: str): run_id = _effective_running_run_id(user_id) if not run_id: - latest_run_id = get_active_run_id(user_id) + latest_run_id = get_running_run_id(user_id) or get_active_run_id(user_id) return {"status": "already_stopped", "run_id": latest_run_id} engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"} @@ -681,7 +693,14 @@ def stop_strategy(user_id: str): print(f"[STRATEGY] engine status update failed during stop for {user_id}/{run_id}: {exc}", flush=True) if not stop_warning: stop_warning = str(exc) - update_run_status(user_id, run_id, "STOPPED", meta={"reason": "user_request"}) + try: + _set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "user_request"}) + except RuntimeError as exc: + return { + "status": "stop_failed", + "run_id": run_id, + "message": str(exc), + } try: user = get_user_by_id(user_id) @@ -704,6 +723,8 @@ def resume_strategy(user_id: str): return {"status": "already_running", "run_id": running_run_id} run_id = get_active_run_id(user_id) + if not run_id: + return {"status": "no_resumable_run"} cfg = _load_config(user_id, run_id) strategy_name = (cfg.get("strategy") or "").strip() mode = (cfg.get("mode") or "").strip().upper() @@ -737,16 +758,26 @@ def resume_strategy(user_id: str): } reactivate_strategy_config(user_id, run_id) - update_run_status(user_id, run_id, "RUNNING", meta={"reason": "user_resume"}) + try: + _set_run_status_or_raise(user_id, run_id, "RUNNING", meta={"reason": "user_resume"}) + except RuntimeError as exc: + deactivate_strategy_config(user_id, run_id) + return { + "status": "resume_failed", + "run_id": run_id, + "message": str(exc), + } _write_status(user_id, run_id, "RUNNING") if not engine_external: try: started = start_engine(engine_config) + except RunLeaseNotAcquiredError: + return {"status": "already_running", "run_id": run_id} except Exception as exc: deactivate_strategy_config(user_id, run_id) _write_status(user_id, run_id, "STOPPED") - update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"}) + _set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"}) return { "status": "resume_failed", "run_id": run_id, @@ -755,7 +786,7 @@ def resume_strategy(user_id: str): if not started: deactivate_strategy_config(user_id, run_id) _write_status(user_id, run_id, "STOPPED") - update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"}) + _set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"}) return { "status": "resume_failed", "run_id": run_id, @@ -797,7 +828,7 @@ def get_strategy_status(user_id: str): else: status = { "status": default_status, - "last_updated": _format_local_ts(engine_row[1]), + "last_updated": serialize_timestamp(engine_row[1]), } status["run_id"] = run_id engine_state = str((engine_row or [None])[0] or "").strip().upper() @@ -832,17 +863,9 @@ def get_strategy_status(user_id: str): status["last_execution_ts"] = last_execution_ts status["next_eligible_ts"] = next_eligible if next_eligible: - try: - parsed_next = datetime.fromisoformat(next_eligible) - now_cmp = ( - datetime.now(parsed_next.tzinfo) - if parsed_next.tzinfo - else market_now().replace(tzinfo=None) - ) - if parsed_next > now_cmp: - status["status"] = "WAITING" - except ValueError: - pass + parsed_next = parse_persisted_timestamp(next_eligible) + if parsed_next and parsed_next > _utc_now(): + status["status"] = "WAITING" status_key = (status.get("status") or "IDLE").upper() resumable = bool(cfg.get("strategy")) and bool(cfg.get("mode")) status["can_resume"] = resumable and status_key in {"STOPPED", "PAUSED_AUTH_EXPIRED"} @@ -876,11 +899,7 @@ def get_engine_status(user_id: str): status["state"] = row[0] last_updated = row[1] if last_updated is not None: - status["last_heartbeat_ts"] = ( - last_updated.astimezone(timezone.utc) - .isoformat() - .replace("+00:00", "Z") - ) + status["last_heartbeat_ts"] = serialize_timestamp(last_updated) cfg = _load_config(user_id, run_id) mode = (cfg.get("mode") or "LIVE").strip().upper() with engine_context(user_id, run_id): @@ -926,10 +945,7 @@ def get_strategy_logs(user_id: str, since_seq: int): events = [] for row in rows: ts = row[1] - if ts is not None: - ts_str = ts.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") - else: - ts_str = None + ts_str = serialize_timestamp(ts) events.append( { "seq": row[0], @@ -980,6 +996,16 @@ def _issue_message(event: str, message: str | None, data: dict | None, meta: dic if event == "ENGINE_ERROR": return message or "Strategy engine hit an error." if event == "EXECUTION_BLOCKED": + if reason_key == "market_holiday": + return "Exchange holiday. Execution will resume next session." + if reason_key == "market_weekend": + return "Weekend closure. Execution will resume next session." + if reason_key == "market_pre_open": + return "Market has not opened yet. Execution will begin after 9:15 AM IST." + if reason_key == "market_post_close": + return "Market is closed for the day. Execution will resume next session." + if reason_key == "market_calendar_unavailable": + return "Market calendar unavailable. Execution paused for safety." if reason_key == "market_closed": return "Market is closed. Execution will resume next session." return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}." @@ -1019,8 +1045,17 @@ def _issue_is_stale_for_current_state( }: return True - if event == "EXECUTION_BLOCKED" and reason_key == "market_closed": - return is_market_open(market_now()) + if event == "EXECUTION_BLOCKED" and reason_key.startswith("market_"): + current_session = market_session(market_now()) + current_reason = str(current_session.get("reason") or "").strip().lower() + current_status = str(current_session.get("status") or "").strip().upper() + if reason_key == "market_holiday": + return current_status != "HOLIDAY" + if reason_key == "market_calendar_unavailable": + return current_reason != "calendar_unavailable" + if reason_key in {"market_weekend", "market_pre_open", "market_post_close", "market_closed"}: + return current_status == "OPEN" + return False if mode != "LIVE": return False @@ -1085,7 +1120,7 @@ def get_strategy_summary(user_id: str): "tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning", "message": _issue_message(event, message, data, meta), "event": event, - "ts": _format_local_ts(ts), + "ts": serialize_timestamp(ts), } ) return summary @@ -1119,7 +1154,17 @@ def get_strategy_summary(user_id: str): def get_market_status(): now = market_now() + session = market_session(now) + status = str(session.get("status") or "CLOSED") + reason = str(session.get("reason") or "") + next_open_at = None + try: + next_open_at = serialize_timestamp(next_market_open_after(now)) + except UnsupportedCalendarYearError: + next_open_at = None return { - "status": "OPEN" if is_market_open(now) else "CLOSED", - "checked_at": now.isoformat(), + "status": status, + "reason": reason, + "checked_at": serialize_timestamp(now), + "next_open_at": next_open_at, } diff --git a/backend/app/services/support_abuse.py b/backend/app/services/support_abuse.py new file mode 100644 index 0000000..391b369 --- /dev/null +++ b/backend/app/services/support_abuse.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import hashlib +import logging +import os +import threading +from collections import deque +from datetime import datetime, timedelta, timezone +from typing import Deque + +from fastapi import Request + +from app.services.db import db_connection + +logger = logging.getLogger(__name__) + +_MEMORY_LOCK = threading.Lock() +_MEMORY_EVENTS: list[dict] = [] + + +class SupportGuardRejected(Exception): + def __init__(self, status_code: int, detail: str): + super().__init__(detail) + self.status_code = status_code + self.detail = detail + + +def _now_utc() -> datetime: + return datetime.now(timezone.utc) + + +def _sha256(value: str | None) -> str | None: + if not value: + return None + return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest() + + +def _backend_mode() -> str: + return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower() + + +def _window() -> timedelta: + return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900"))) + + +def _create_limit() -> int: + return int(os.getenv("SUPPORT_CREATE_LIMIT", "5")) + + +def _status_limit() -> int: + return int(os.getenv("SUPPORT_STATUS_LIMIT", "15")) + + +def _ticket_probe_limit() -> int: + return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10")) + + +def _captcha_secret() -> str | None: + return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None + + +def _request_ip(request: Request) -> str: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + first = forwarded.split(",")[0].strip() + if first: + return first + return request.client.host if request.client else "unknown" + + +def _validate_captcha(captcha_token: str | None) -> None: + secret = _captcha_secret() + if secret and captcha_token != secret: + raise SupportGuardRejected(403, "Support verification failed") + + +def _record_memory_event(record: dict) -> None: + cutoff = _now_utc() - _window() + with _MEMORY_LOCK: + _MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff] + _MEMORY_EVENTS.append(record) + + +def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: + with _MEMORY_LOCK: + ip_count = sum( + 1 + for entry in _MEMORY_EVENTS + if entry["endpoint"] == endpoint + and entry["created_at"] >= cutoff + and entry["ip_hash"] == ip_hash + ) + ticket_count = sum( + 1 + for entry in _MEMORY_EVENTS + if entry["endpoint"] == endpoint + and entry["created_at"] >= cutoff + and entry["ticket_hash"] == ticket_hash + ) if ticket_hash else 0 + return ip_count, ticket_count + + +def _record_db_event(record: dict) -> None: + with db_connection() as conn: + with conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO support_request_audit ( + endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at + ) + VALUES (%s, %s, %s, %s, %s, %s, %s) + """, + ( + record["endpoint"], + record["ip_hash"], + record["email_hash"], + record["ticket_hash"], + record["blocked"], + record["reason"], + record["created_at"], + ), + ) + + +def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: + with db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT COUNT(*) + FROM support_request_audit + WHERE endpoint = %s + AND ip_hash IS NOT DISTINCT FROM %s + AND created_at >= %s + """, + (endpoint, ip_hash, cutoff), + ) + ip_count = cur.fetchone()[0] or 0 + ticket_count = 0 + if ticket_hash: + cur.execute( + """ + SELECT COUNT(*) + FROM support_request_audit + WHERE endpoint = %s + AND ticket_hash = %s + AND created_at >= %s + """, + (endpoint, ticket_hash, cutoff), + ) + ticket_count = cur.fetchone()[0] or 0 + return ip_count, ticket_count + + +def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None: + if endpoint == "ticket_create" and ip_count >= _create_limit(): + return "create_rate_limited" + if endpoint == "ticket_status" and ip_count >= _status_limit(): + return "status_rate_limited" + if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit(): + return "ticket_probe_limited" + return None + + +def _audit_attempt(record: dict) -> None: + if _backend_mode() == "memory": + _record_memory_event(record) + return + _record_db_event(record) + + +def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]: + if _backend_mode() == "memory": + return _memory_count(endpoint, ip_hash, ticket_hash, cutoff) + return _db_count(endpoint, ip_hash, ticket_hash, cutoff) + + +def enforce_support_guard( + *, + request: Request, + endpoint: str, + email: str | None = None, + ticket_id: str | None = None, + captcha_token: str | None = None, +) -> None: + _validate_captcha(captcha_token) + + now = _now_utc() + cutoff = now - _window() + ip_hash = _sha256(_request_ip(request)) + email_hash = _sha256(email) + ticket_hash = _sha256(ticket_id) + + ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff) + reason = _determine_limits(endpoint, ip_count, ticket_count) + + record = { + "endpoint": endpoint, + "ip_hash": ip_hash, + "email_hash": email_hash, + "ticket_hash": ticket_hash, + "blocked": reason is not None, + "reason": reason, + "created_at": now, + } + _audit_attempt(record) + + if reason is not None: + logger.warning( + "Support request blocked", + extra={ + "endpoint": endpoint, + "reason": reason, + "ip_hash": ip_hash, + "ticket_hash": ticket_hash, + }, + ) + raise SupportGuardRejected(429, "Too many support requests. Please try again later.") + + +def reset_memory_support_guard_state() -> None: + with _MEMORY_LOCK: + _MEMORY_EVENTS.clear() diff --git a/backend/app/services/support_ticket.py b/backend/app/services/support_ticket.py index fdbfa1c..0c832d6 100644 --- a/backend/app/services/support_ticket.py +++ b/backend/app/services/support_ticket.py @@ -4,6 +4,7 @@ from uuid import uuid4 from app.services.db import db_connection from app.services.email_service import send_email +from indian_paper_trading_strategy.engine.time_utils import serialize_timestamp def _now(): @@ -41,7 +42,7 @@ def create_ticket(name: str, email: str, subject: str, message: str) -> dict: return { "ticket_id": ticket_id, "status": "NEW", - "created_at": now.isoformat(), + "created_at": serialize_timestamp(now), "email_sent": email_sent, } @@ -65,6 +66,6 @@ def get_ticket_status(ticket_id: str, email: str) -> dict | None: return { "ticket_id": row[0], "status": row[2], - "created_at": row[3].isoformat() if row[3] else None, - "updated_at": row[4].isoformat() if row[4] else None, + "created_at": serialize_timestamp(row[3]) if row[3] else None, + "updated_at": serialize_timestamp(row[4]) if row[4] else None, } diff --git a/backend/app/services/system_service.py b/backend/app/services/system_service.py index 5253042..1f79851 100644 --- a/backend/app/services/system_service.py +++ b/backend/app/services/system_service.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from psycopg2.extras import Json +from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp + from app.broker_store import get_user_broker, set_broker_auth_state from app.services.db import db_connection from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds @@ -59,12 +61,7 @@ def _resolve_sip_frequency(row: dict): def _parse_ts(value: str | None): - if not value: - return None - try: - return datetime.fromisoformat(value) - except ValueError: - return None + return parse_persisted_timestamp(value) def _validate_broker_session(user_id: str): @@ -180,7 +177,7 @@ def arm_system(user_id: str, client_ip: str | None = None): continue sip_frequency = _resolve_sip_frequency(run) - last_run = now.isoformat() + last_run = serialize_timestamp(now) next_run = compute_next_eligible(last_run, sip_frequency) next_run_dt = _parse_ts(next_run) @@ -195,7 +192,7 @@ def arm_system(user_id: str, client_ip: str | None = None): """, ( now, - Json({"armed_at": now.isoformat()}), + Json({"armed_at": serialize_timestamp(now)}), user_id, run["run_id"], ), @@ -339,7 +336,7 @@ def arm_system(user_id: str, client_ip: str | None = None): pass broker_state = get_user_broker(user_id) or {} - next_execution = min(next_runs).isoformat() if next_runs else None + next_execution = serialize_timestamp(min(next_runs)) if next_runs else None return { "ok": True, "armed_runs": armed_runs, @@ -378,7 +375,7 @@ def system_status(user_id: str): "strategy": row[2], "mode": row[3], "broker": row[4], - "next_run": row[5].isoformat() if row[5] else None, + "next_run": serialize_timestamp(row[5]), "active": bool(row[6]) if row[6] is not None else False, "lifecycle": row[1], } diff --git a/backend/app/services/tenant.py b/backend/app/services/tenant.py index 5270cf0..a48f6b3 100644 --- a/backend/app/services/tenant.py +++ b/backend/app/services/tenant.py @@ -1,7 +1,6 @@ from fastapi import HTTPException, Request from app.services.auth_service import get_user_for_session -from app.services.run_service import get_default_user_id SESSION_COOKIE_NAME = "session_id" @@ -13,7 +12,4 @@ def get_request_user_id(request: Request) -> str: if user: return user["id"] - default_user_id = get_default_user_id() - if default_user_id: - return default_user_id raise HTTPException(status_code=401, detail="Not authenticated") diff --git a/backend/app/services/zerodha_service.py b/backend/app/services/zerodha_service.py index ecb87dc..310ccd9 100644 --- a/backend/app/services/zerodha_service.py +++ b/backend/app/services/zerodha_service.py @@ -27,11 +27,13 @@ class KitePermissionError(KiteApiError): pass -def build_login_url(api_key: str, redirect_url: str | None = None) -> str: +def build_login_url(api_key: str, redirect_url: str | None = None, state: str | None = None) -> str: params = {"api_key": api_key, "v": KITE_VERSION} redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip() if redirect_url: params["redirect_url"] = redirect_url + if state: + params["state"] = state query = urllib.parse.urlencode(params) return f"{KITE_LOGIN_URL}?{query}" diff --git a/backend/requirements.txt b/backend/requirements.txt index 9b376a5..9f1303b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -41,3 +41,4 @@ websockets==16.0 yfinance==1.0 alembic==1.13.3 pytest==8.3.5 +argon2-cffi==25.1.0 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..4de919e --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,14 @@ +import os +import sys +from pathlib import Path + + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +os.environ.setdefault("RESET_OTP_SECRET", "test-reset-secret") diff --git a/backend/tests/test_api_semantics_and_utc.py b/backend/tests/test_api_semantics_and_utc.py new file mode 100644 index 0000000..64647ce --- /dev/null +++ b/backend/tests/test_api_semantics_and_utc.py @@ -0,0 +1,111 @@ +import importlib +from datetime import datetime, timezone +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + + +def _build_app(monkeypatch): + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_NAME", "trading_db") + monkeypatch.setenv("DB_USER", "trader") + monkeypatch.setenv("DB_PASSWORD", "test-password") + monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") + + import app.main as app_main + + importlib.reload(app_main) + return app_main.create_app() + + +def test_strategy_stop_failure_returns_non_200(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.routers.strategy as strategy_router + + monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1") + monkeypatch.setattr( + strategy_router, + "stop_strategy", + lambda _user_id: {"status": "stop_failed", "message": "run missing", "run_id": "run-1"}, + ) + + response = client.post("/api/strategy/stop", cookies={"session_id": "session-1"}) + + assert response.status_code == 409 + assert response.json()["detail"]["status"] == "stop_failed" + + +def test_strategy_resume_failure_returns_non_200(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.routers.strategy as strategy_router + + monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1") + monkeypatch.setattr( + strategy_router, + "resume_strategy", + lambda _user_id: {"status": "resume_failed", "message": "engine start failed", "run_id": "run-1"}, + ) + + response = client.post("/api/strategy/resume", cookies={"session_id": "session-1"}) + + assert response.status_code == 409 + assert response.json()["detail"]["status"] == "resume_failed" + + +def test_update_run_status_returns_false_when_no_row_changes(monkeypatch): + import app.services.run_service as run_service + + class FakeCursor: + rowcount = 0 + + def execute(self, sql, params): + return None + + monkeypatch.setattr(run_service, "run_with_retry", lambda op: op(FakeCursor(), None)) + + updated = run_service.update_run_status("user-1", "missing-run", "STOPPED") + + assert updated is False + + +def test_backend_emits_utc_iso_timestamps(): + from app.services.strategy_service import get_market_status + from indian_paper_trading_strategy.engine.time_utils import ( + parse_persisted_timestamp, + serialize_timestamp, + ) + + timestamp = get_market_status()["checked_at"] + parsed = parse_persisted_timestamp(timestamp) + + assert timestamp.endswith("+00:00") + assert parsed.tzinfo is not None + assert serialize_timestamp(parsed) == timestamp + + +def test_bootstrap_schema_contains_migrated_core_columns_and_tables(): + schema_path = Path(__file__).resolve().parents[2] / ".." / "SIP_GoldBees_Database" / "schema.sql" + schema_sql = schema_path.read_text(encoding="utf-8") + + required_snippets = [ + "ALTER TABLE app_session", + "ADD COLUMN IF NOT EXISTS ip TEXT", + "ADD COLUMN IF NOT EXISTS user_agent TEXT", + "ALTER TABLE user_broker", + "ADD COLUMN IF NOT EXISTS api_secret TEXT", + "ADD COLUMN IF NOT EXISTS auth_state TEXT", + "CREATE TABLE IF NOT EXISTS broker_callback_state", + "CREATE TABLE IF NOT EXISTS execution_claim", + "CREATE TABLE IF NOT EXISTS run_leases", + "CREATE TABLE IF NOT EXISTS support_request_audit", + ] + + for snippet in required_snippets: + assert snippet in schema_sql diff --git a/backend/tests/test_auth_isolation_and_cors.py b/backend/tests/test_auth_isolation_and_cors.py new file mode 100644 index 0000000..65eabd5 --- /dev/null +++ b/backend/tests/test_auth_isolation_and_cors.py @@ -0,0 +1,128 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient + + +def _build_app(monkeypatch, *, app_env="test", cors_origins=None): + monkeypatch.setenv("APP_ENV", app_env) + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_NAME", "trading_db") + monkeypatch.setenv("DB_USER", "trader") + monkeypatch.setenv("DB_PASSWORD", "test-password") + if cors_origins is None: + monkeypatch.delenv("CORS_ORIGINS", raising=False) + else: + monkeypatch.setenv("CORS_ORIGINS", cors_origins) + + import app.main as app_main + + importlib.reload(app_main) + return app_main.create_app() + + +def test_strategy_status_requires_auth(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + response = client.get("/api/strategy/status") + + assert response.status_code == 401 + assert response.json() == {"detail": "Not authenticated"} + + +def test_strategy_stop_requires_auth(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + response = client.post("/api/strategy/stop") + + assert response.status_code == 401 + assert response.json() == {"detail": "Not authenticated"} + + +def test_strategy_routes_use_session_identity_only(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.routers.strategy as strategy_router + import app.services.tenant as tenant + + monkeypatch.setattr(tenant, "get_user_for_session", lambda _sid: {"id": "user-a"}) + + seen = {} + + def fake_get_strategy_status(user_id): + seen["status_user_id"] = user_id + return {"user_id": user_id} + + def fake_stop_strategy(user_id): + seen["stop_user_id"] = user_id + return {"status": "stopped", "user_id": user_id} + + monkeypatch.setattr(strategy_router, "get_strategy_status", fake_get_strategy_status) + monkeypatch.setattr(strategy_router, "stop_strategy", fake_stop_strategy) + + status_response = client.get( + "/api/strategy/status?user_id=user-b", + cookies={"session_id": "session-a"}, + ) + stop_response = client.post( + "/api/strategy/stop?user_id=user-b", + cookies={"session_id": "session-a"}, + ) + + assert status_response.status_code == 200 + assert stop_response.status_code == 200 + assert status_response.json()["user_id"] == "user-a" + assert stop_response.json()["user_id"] == "user-a" + assert seen == {"status_user_id": "user-a", "stop_user_id": "user-a"} + + +def test_allowed_origin_preflight_supports_credentials(monkeypatch): + app = _build_app( + monkeypatch, + app_env="production", + cors_origins="https://app.quantfortune.com", + ) + client = TestClient(app) + + response = client.options( + "/api/strategy/status", + headers={ + "Origin": "https://app.quantfortune.com", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "content-type", + }, + ) + + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://app.quantfortune.com" + assert response.headers["access-control-allow-credentials"] == "true" + + +def test_arbitrary_origin_preflight_is_rejected(monkeypatch): + app = _build_app( + monkeypatch, + app_env="production", + cors_origins="https://app.quantfortune.com", + ) + client = TestClient(app) + + response = client.options( + "/api/strategy/status", + headers={ + "Origin": "https://evil.example", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "content-type", + }, + ) + + assert response.status_code == 400 + assert "access-control-allow-origin" not in response.headers + + +def test_production_without_cors_origins_fails_closed(monkeypatch): + with pytest.raises(RuntimeError, match="CORS_ORIGINS must be configured explicitly in production"): + _build_app(monkeypatch, app_env="production", cors_origins=None) diff --git a/backend/tests/test_execution_claims.py b/backend/tests/test_execution_claims.py new file mode 100644 index 0000000..7c90599 --- /dev/null +++ b/backend/tests/test_execution_claims.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import threading +from datetime import datetime, timezone + +from indian_paper_trading_strategy.engine import execution, ledger + + +def test_claim_execution_window_allows_only_one_winner(monkeypatch): + claims: set[tuple[str, str, datetime]] = set() + lock = threading.Lock() + + class FakeCursor: + def __init__(self): + self._result = None + + def execute(self, sql, params): + assert "INSERT INTO execution_claim" in sql + key = (params[1], params[2], params[4]) + with lock: + if key in claims: + self._result = None + else: + claims.add(key) + self._result = (params[0],) + + def fetchone(self): + return self._result + + monkeypatch.setattr(ledger, "get_context", lambda user_id=None, run_id=None: ("user-a", "run-a")) + monkeypatch.setattr(ledger, "run_with_retry", lambda op, retries=None, delay=None: op(FakeCursor(), None)) + + logical_time = datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc) + results: list[bool] = [] + + def attempt(): + results.append(ledger.claim_execution_window(logical_time, mode="LIVE")) + + threads = [threading.Thread(target=attempt) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert sorted(results) == [False, True] + + +def test_try_execute_sip_live_emits_one_order_batch_when_claim_is_lost(monkeypatch): + claim_lock = threading.Lock() + claimed = False + order_calls: list[tuple[str, str, float]] = [] + + class FakeBroker: + external_orders = True + + def get_funds(self): + return {"cash": 10_000.0} + + def place_order(self, symbol, side, qty, price, logical_time=None): + order_calls.append((symbol, side, qty)) + return { + "id": f"{symbol}-{len(order_calls)}", + "symbol": symbol, + "side": side, + "status": "COMPLETE", + "filled_qty": qty, + "average_price": price, + "price": price, + } + + def fake_claim_execution_window(logical_time, *, mode=None, cur=None, user_id=None, run_id=None): + nonlocal claimed + with claim_lock: + if claimed: + return False + claimed = True + return True + + monkeypatch.setattr(execution, "load_state", lambda *args, **kwargs: {"last_sip_ts": None, "last_run": None}) + monkeypatch.setattr( + execution, + "_resolve_timing", + lambda state, now_ts, sip_interval: (True, None, datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc)), + ) + monkeypatch.setattr(execution, "event_exists", lambda *args, **kwargs: False) + monkeypatch.setattr(execution, "claim_execution_window", fake_claim_execution_window) + monkeypatch.setattr(execution, "log_event", lambda *args, **kwargs: None) + monkeypatch.setattr(execution, "run_with_retry", lambda op, retries=None, delay=None: op(object(), None)) + monkeypatch.setattr( + execution, + "_finalize_live_execution", + lambda **kwargs: ({"last_run": kwargs["now_ts"].isoformat()}, bool(kwargs["orders"])), + ) + + results: list[bool] = [] + + def attempt(): + _state, executed = execution._try_execute_sip_live( + now=datetime(2026, 4, 8, 14, 45, tzinfo=timezone.utc), + market_open=True, + sip_interval=120, + sip_amount=1000.0, + sp_price=125.0, + gd_price=250.0, + eq_w=0.5, + gd_w=0.5, + broker=FakeBroker(), + mode="LIVE", + ) + results.append(executed) + + threads = [threading.Thread(target=attempt) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert sorted(results) == [False, True] + assert len(order_calls) == 2 + assert {call[0] for call in order_calls} == {"NIFTYBEES.NS", "GOLDBEES.NS"} diff --git a/backend/tests/test_market_calendar.py b/backend/tests/test_market_calendar.py new file mode 100644 index 0000000..828b4d7 --- /dev/null +++ b/backend/tests/test_market_calendar.py @@ -0,0 +1,102 @@ +from datetime import date, datetime, timezone + +import pytest + + +def test_market_calendar_open_until_close_boundary_exclusive(): + from indian_paper_trading_strategy.engine.market_calendar import ( + get_market_session, + get_market_status, + is_market_open, + ) + + open_now = datetime(2026, 4, 2, 9, 59, 59, tzinfo=timezone.utc) # 15:29:59 IST + close_now = datetime(2026, 4, 2, 10, 0, 0, tzinfo=timezone.utc) # 15:30:00 IST + + assert get_market_status(open_now) == "OPEN" + assert is_market_open(open_now) is True + assert get_market_status(close_now) == "CLOSED" + assert is_market_open(close_now) is False + + close_session = get_market_session(close_now) + assert close_session["reason"] == "POST_CLOSE" + + +def test_market_calendar_holiday_during_trading_hours(): + from indian_paper_trading_strategy.engine.market_calendar import get_market_session, get_market_status + + now_utc = datetime(2026, 4, 3, 4, 0, tzinfo=timezone.utc) # Good Friday, 09:30 IST + + session = get_market_session(now_utc) + assert get_market_status(now_utc) == "HOLIDAY" + assert session["status"] == "HOLIDAY" + assert session["reason"] == "HOLIDAY" + + +def test_market_calendar_utc_and_ist_inputs_match(): + from zoneinfo import ZoneInfo + + from indian_paper_trading_strategy.engine.market_calendar import get_market_session + + utc_now = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST + ist_now = utc_now.astimezone(ZoneInfo("Asia/Kolkata")) + + assert get_market_session(utc_now) == get_market_session(ist_now) + + +def test_market_calendar_reason_codes_cover_session_windows(): + from indian_paper_trading_strategy.engine.market_calendar import get_market_session + + pre_open = datetime(2026, 4, 2, 3, 0, tzinfo=timezone.utc) # 08:30 IST + in_session = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST + post_close = datetime(2026, 4, 2, 10, 1, tzinfo=timezone.utc) # 15:31 IST + weekend = datetime(2026, 4, 4, 4, 0, tzinfo=timezone.utc) # Saturday + + assert get_market_session(pre_open)["reason"] == "PRE_OPEN" + assert get_market_session(in_session)["reason"] == "OPEN" + assert get_market_session(post_close)["reason"] == "POST_CLOSE" + assert get_market_session(weekend)["reason"] == "WEEKEND" + + +def test_supported_year_returns_expected_holidays(): + from indian_paper_trading_strategy.engine.market_calendar import get_nse_holidays + + holidays = get_nse_holidays(2026) + + assert date(2026, 4, 3) in holidays + assert date(2026, 1, 26) in holidays + + +def test_unsupported_year_fails_closed(): + from indian_paper_trading_strategy.engine.market_calendar import ( + UnsupportedCalendarYearError, + get_market_session, + next_market_open, + ) + + now_utc = datetime(2027, 1, 4, 4, 0, tzinfo=timezone.utc) + session = get_market_session(now_utc) + + assert session["status"] == "CLOSED" + assert session["reason"] == "CALENDAR_UNAVAILABLE" + assert session["calendar_supported"] is False + + with pytest.raises(UnsupportedCalendarYearError): + next_market_open(now_utc) + + +def test_strategy_service_market_status_exposes_reason_and_holiday(monkeypatch): + import app.services.strategy_service as strategy_service + from indian_paper_trading_strategy.engine.market_calendar import MARKET_TZ + + monkeypatch.setattr( + strategy_service, + "market_now", + lambda: datetime(2026, 4, 3, 9, 30, tzinfo=MARKET_TZ), + ) + + payload = strategy_service.get_market_status() + + assert payload["status"] == "HOLIDAY" + assert payload["reason"] == "HOLIDAY" + assert payload["next_open_at"] == "2026-04-06T03:45:00+00:00" diff --git a/backend/tests/test_runner_leases.py b/backend/tests/test_runner_leases.py new file mode 100644 index 0000000..35d03d5 --- /dev/null +++ b/backend/tests/test_runner_leases.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import threading +from datetime import datetime, timedelta, timezone + +from indian_paper_trading_strategy.engine import db as engine_db +from indian_paper_trading_strategy.engine import runner + + +class _LeaseCursor: + def __init__(self, leases: dict[str, dict], lock: threading.Lock): + self._leases = leases + self._lock = lock + self._result = None + + def execute(self, sql, params): + sql_text = " ".join(sql.split()) + with self._lock: + if sql_text.startswith("INSERT INTO run_leases"): + run_id, owner_id, leased_at, expires_at, heartbeat_at = params + lease = self._leases.get(run_id) + if lease is None: + self._leases[run_id] = { + "owner_id": owner_id, + "leased_at": leased_at, + "expires_at": expires_at, + "heartbeat_at": heartbeat_at, + } + self._result = (run_id,) + else: + self._result = None + return + + if "SELECT owner_id, expires_at FROM run_leases" in sql_text: + run_id = params[0] + lease = self._leases.get(run_id) + self._result = None if lease is None else (lease["owner_id"], lease["expires_at"]) + return + + if sql_text.startswith("UPDATE run_leases SET leased_at"): + leased_at, expires_at, heartbeat_at, run_id, owner_id = params + lease = self._leases.get(run_id) + if lease and lease["owner_id"] == owner_id: + lease.update( + { + "owner_id": owner_id, + "leased_at": leased_at, + "expires_at": expires_at, + "heartbeat_at": heartbeat_at, + } + ) + self._result = (run_id,) + else: + self._result = None + return + + if sql_text.startswith("UPDATE run_leases SET owner_id"): + owner_id, leased_at, expires_at, heartbeat_at, run_id, current_time = params + lease = self._leases.get(run_id) + if lease and lease["expires_at"] <= current_time: + lease.update( + { + "owner_id": owner_id, + "leased_at": leased_at, + "expires_at": expires_at, + "heartbeat_at": heartbeat_at, + } + ) + self._result = (run_id,) + else: + self._result = None + return + + if sql_text.startswith("UPDATE run_leases SET heartbeat_at"): + heartbeat_at, expires_at, run_id, owner_id, current_time = params + lease = self._leases.get(run_id) + if lease and lease["owner_id"] == owner_id and lease["expires_at"] > current_time: + lease.update({"heartbeat_at": heartbeat_at, "expires_at": expires_at}) + self._result = (run_id, expires_at) + else: + self._result = None + return + + if sql_text.startswith("DELETE FROM run_leases"): + run_id, owner_id = params + lease = self._leases.get(run_id) + if lease and lease["owner_id"] == owner_id: + del self._leases[run_id] + self._result = (run_id,) + else: + self._result = None + return + + raise AssertionError(f"Unexpected SQL: {sql_text}") + + def fetchone(self): + return self._result + + +def _patch_lease_storage(monkeypatch): + leases: dict[str, dict] = {} + lock = threading.Lock() + + def fake_run_with_retry(operation, retries=None, delay=None): + cursor = _LeaseCursor(leases, lock) + return operation(cursor, None) + + monkeypatch.setattr(engine_db, "run_with_retry", fake_run_with_retry) + return leases + + +def test_run_lease_allows_only_one_active_owner(monkeypatch): + _patch_lease_storage(monkeypatch) + now = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc) + results = [] + + def attempt(owner_id: str): + results.append(engine_db.acquire_run_lease("run-1", owner_id, lease_seconds=90, now=now)) + + threads = [ + threading.Thread(target=attempt, args=("owner-a",)), + threading.Thread(target=attempt, args=("owner-b",)), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + acquired = [result for result in results if result["acquired"]] + denied = [result for result in results if not result["acquired"]] + assert len(acquired) == 1 + assert len(denied) == 1 + assert denied[0]["status"] == "DENIED" + + +def test_expired_run_lease_can_be_reacquired(monkeypatch): + _patch_lease_storage(monkeypatch) + start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc) + later = start + timedelta(seconds=91) + + first = engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start) + second = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=later) + + assert first["status"] == "ACQUIRED" + assert second["acquired"] is True + assert second["status"] == "REACQUIRED" + assert second["previous_owner"] == "owner-a" + + +def test_heartbeat_prevents_takeover(monkeypatch): + _patch_lease_storage(monkeypatch) + start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc) + heartbeat_time = start + timedelta(seconds=30) + challenger_time = start + timedelta(seconds=60) + + engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start) + heartbeat = engine_db.heartbeat_run_lease("run-1", "owner-a", lease_seconds=90, now=heartbeat_time) + challenger = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=challenger_time) + + assert heartbeat["active"] is True + assert challenger["acquired"] is False + assert challenger["status"] == "DENIED" + + +def test_runner_exits_cleanly_when_lease_is_lost(monkeypatch): + statuses: list[str] = [] + cleared: list[tuple[str, str]] = [] + released: list[tuple[str, str]] = [] + refreshed = {"count": 0} + + monkeypatch.setattr(runner, "get_context", lambda user_id=None, run_id=None: ("user-1", "run-1")) + monkeypatch.setattr(runner, "set_context", lambda user_id, run_id: None) + monkeypatch.setattr(runner, "log_event", lambda *args, **kwargs: None) + monkeypatch.setattr(runner, "PaperBroker", lambda initial_cash=0: object()) + monkeypatch.setattr(runner, "_set_state", lambda *args, **kwargs: None) + monkeypatch.setattr(runner, "_clear_runner", lambda user_id, run_id: cleared.append((user_id, run_id))) + monkeypatch.setattr(runner, "_update_engine_status", lambda user_id, run_id, status: statuses.append(status)) + monkeypatch.setattr(runner, "release_run_lease", lambda run_id, owner_id: released.append((run_id, owner_id)) or True) + monkeypatch.setattr(runner, "_log_runner_lease_event", lambda *args, **kwargs: None) + + def fake_refresh(user_id, run_id, owner_id): + refreshed["count"] += 1 + return False + + monkeypatch.setattr(runner, "_refresh_run_lease_or_stop", fake_refresh) + + stop_event = threading.Event() + runner._engine_loop( + { + "user_id": "user-1", + "run_id": "run-1", + "strategy": "Golden Nifty", + "sip_amount": 1000, + "sip_frequency": {"value": 2, "unit": "minutes"}, + "mode": "PAPER", + "broker": "paper", + "runner_owner_id": "owner-1", + }, + stop_event, + ) + + assert refreshed["count"] >= 1 + assert statuses == ["RUNNING"] + assert cleared == [("user-1", "run-1")] + assert released == [("run-1", "owner-1")] diff --git a/backend/tests/test_security_hardening.py b/backend/tests/test_security_hardening.py new file mode 100644 index 0000000..02f6774 --- /dev/null +++ b/backend/tests/test_security_hardening.py @@ -0,0 +1,269 @@ +import importlib +from datetime import datetime, timedelta, timezone + +import pytest +from fastapi import Response +from fastapi.testclient import TestClient + + +def test_legacy_password_hash_upgrades_on_successful_login(monkeypatch): + import app.services.auth_service as auth_service + + legacy_hash = auth_service._hash_password_legacy("correct-horse-battery-staple") + user = { + "id": "user-1", + "username": "user@example.com", + "password": legacy_hash, + "role": "USER", + } + updated = {} + + monkeypatch.setattr(auth_service, "get_user_by_username", lambda username: user if username == user["username"] else None) + monkeypatch.setattr( + auth_service, + "_update_password_hash", + lambda user_id, password_hash: updated.update({"user_id": user_id, "password_hash": password_hash}), + ) + + authenticated = auth_service.authenticate_user(user["username"], "correct-horse-battery-staple") + + assert authenticated is user + assert updated["user_id"] == "user-1" + assert updated["password_hash"].startswith("$argon2id$") + assert authenticated["password"] == updated["password_hash"] + + +def test_secure_session_cookie_flags_are_enforced_in_production(monkeypatch): + monkeypatch.setenv("APP_ENV", "production") + monkeypatch.delenv("COOKIE_SECURE", raising=False) + monkeypatch.setenv("COOKIE_SAMESITE", "strict") + + import app.routers.auth as auth_router + + importlib.reload(auth_router) + + response = Response() + auth_router._set_session_cookie(response, "session-123") + set_cookie = response.headers["set-cookie"].lower() + + assert "secure" in set_cookie + assert "httponly" in set_cookie + assert "samesite=strict" in set_cookie + + +def test_missing_reset_otp_secret_fails_auth_service_import(monkeypatch): + monkeypatch.delenv("RESET_OTP_SECRET", raising=False) + + import app.services.auth_service as auth_service + + with pytest.raises(RuntimeError, match="RESET_OTP_SECRET must be configured"): + importlib.reload(auth_service) + + monkeypatch.setenv("RESET_OTP_SECRET", "test-reset-secret") + importlib.reload(auth_service) + + +def test_valid_broker_callback_state_allows_connect_callback(monkeypatch): + import app.main as app_main + import app.routers.broker as broker_router + + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") + importlib.reload(app_main) + app = app_main.create_app() + client = TestClient(app) + + monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) + monkeypatch.setattr( + broker_router, + "consume_broker_callback_state", + lambda **kwargs: {"id": "state-1", "expires_at": datetime.now(timezone.utc).isoformat()}, + ) + monkeypatch.setattr( + broker_router, + "get_pending_broker", + lambda _user_id: {"api_key": "kite-key", "api_secret": "kite-secret"}, + ) + monkeypatch.setattr( + broker_router, + "exchange_request_token", + lambda api_key, api_secret, token: { + "access_token": "access-token", + "request_token": token, + "user_name": "Trader", + "user_id": "Z123", + }, + ) + captured = {} + monkeypatch.setattr( + broker_router, + "set_zerodha_session", + lambda user_id, payload: captured.setdefault("zerodha_session", {"user_id": user_id, **payload}), + ) + monkeypatch.setattr( + broker_router, + "set_connected_broker", + lambda user_id, broker, token, **kwargs: captured.setdefault( + "connected", + {"user_id": user_id, "broker": broker, "token": token, **kwargs}, + ), + ) + + response = client.get( + "/api/broker/zerodha/callback", + params={"request_token": "request-token", "state": "valid-state"}, + cookies={"session_id": "session-1"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "connected": True, + "userName": "Trader", + "brokerUserId": "Z123", + } + assert captured["connected"]["broker"] == "ZERODHA" + assert captured["connected"]["user_id"] == "user-1" + + +def test_missing_broker_callback_state_fails(monkeypatch): + import app.main as app_main + import app.routers.broker as broker_router + + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") + importlib.reload(app_main) + app = app_main.create_app() + client = TestClient(app) + + monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) + + response = client.get( + "/api/broker/zerodha/callback", + params={"request_token": "request-token"}, + cookies={"session_id": "session-1"}, + ) + + assert response.status_code == 400 + assert response.json() == {"detail": "Missing state"} + + +def test_wrong_or_expired_broker_callback_state_fails(monkeypatch): + import app.main as app_main + import app.routers.broker as broker_router + + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") + importlib.reload(app_main) + app = app_main.create_app() + client = TestClient(app) + + monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) + monkeypatch.setattr(broker_router, "consume_broker_callback_state", lambda **kwargs: None) + + response = client.get( + "/api/broker/zerodha/callback", + params={"request_token": "request-token", "state": "wrong-or-expired"}, + cookies={"session_id": "session-1"}, + ) + + assert response.status_code == 401 + assert response.json() == {"detail": "Invalid or expired broker callback state"} + + +def test_broker_callback_state_service_rejects_wrong_user_and_expired_state(monkeypatch): + import app.services.broker_callback_state as callback_state + + now = datetime(2026, 4, 8, 9, 0, tzinfo=timezone.utc) + rows = [ + { + "state_hash": callback_state._state_hash("valid-state"), + "user_id": "user-1", + "session_id": "session-1", + "broker": "ZERODHA", + "flow": "connect", + "expires_at": now + timedelta(minutes=5), + "consumed_at": None, + }, + { + "state_hash": callback_state._state_hash("expired-state"), + "user_id": "user-1", + "session_id": "session-1", + "broker": "ZERODHA", + "flow": "connect", + "expires_at": now - timedelta(seconds=1), + "consumed_at": None, + }, + ] + + class FakeCursor: + def __init__(self): + self._result = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, sql, params): + if "UPDATE broker_callback_state" not in sql: + raise AssertionError("Unexpected SQL") + consumed_at, state_hash, user_id, session_id, broker, flow, now_param = params + self._result = None + for row in rows: + if ( + row["state_hash"] == state_hash + and row["user_id"] == user_id + and row["session_id"] == session_id + and row["broker"] == broker + and row["flow"] == flow + and row["consumed_at"] is None + and row["expires_at"] > now_param + ): + row["consumed_at"] = consumed_at + self._result = ("row-id", row["expires_at"]) + break + + def fetchone(self): + return self._result + + class FakeConnection: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def cursor(self): + return FakeCursor() + + monkeypatch.setattr(callback_state, "_now_utc", lambda: now) + monkeypatch.setattr(callback_state, "db_connection", lambda: FakeConnection()) + + assert callback_state.consume_broker_callback_state( + state="valid-state", + user_id="user-2", + session_id="session-1", + broker="ZERODHA", + flow="connect", + ) is None + assert callback_state.consume_broker_callback_state( + state="expired-state", + user_id="user-1", + session_id="session-1", + broker="ZERODHA", + flow="connect", + ) is None + assert callback_state.consume_broker_callback_state( + state="valid-state", + user_id="user-1", + session_id="session-1", + broker="ZERODHA", + flow="connect", + ) == { + "id": "row-id", + "expires_at": (now + timedelta(minutes=5)).isoformat(), + } diff --git a/backend/tests/test_support_throttling.py b/backend/tests/test_support_throttling.py new file mode 100644 index 0000000..e2138cb --- /dev/null +++ b/backend/tests/test_support_throttling.py @@ -0,0 +1,131 @@ +import importlib +import logging + +from fastapi.testclient import TestClient + + +def _build_app(monkeypatch): + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_NAME", "trading_db") + monkeypatch.setenv("DB_USER", "trader") + monkeypatch.setenv("DB_PASSWORD", "test-password") + monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") + monkeypatch.setenv("SUPPORT_GUARD_BACKEND", "memory") + monkeypatch.setenv("SUPPORT_GUARD_WINDOW_SECONDS", "900") + monkeypatch.setenv("SUPPORT_CREATE_LIMIT", "2") + monkeypatch.setenv("SUPPORT_STATUS_LIMIT", "3") + monkeypatch.setenv("SUPPORT_STATUS_TICKET_LIMIT", "2") + + import app.main as app_main + + importlib.reload(app_main) + return app_main.create_app() + + +def test_support_ticket_creation_is_throttled(monkeypatch, caplog): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.services.support_abuse as support_abuse + import app.routers.support_ticket as support_router + + support_abuse.reset_memory_support_guard_state() + monkeypatch.setattr( + support_router, + "create_ticket", + lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"}, + ) + + payload = { + "name": "Trader", + "email": "trader@example.com", + "subject": "Need help", + "message": "Something happened", + } + + with caplog.at_level(logging.WARNING): + first = client.post("/api/support/ticket", json=payload) + second = client.post("/api/support/ticket", json=payload) + third = client.post("/api/support/ticket", json=payload) + + assert first.status_code == 200 + assert second.status_code == 200 + assert third.status_code == 429 + assert "Support request blocked" in caplog.text + + +def test_invalid_ticket_probing_is_throttled(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.services.support_abuse as support_abuse + import app.routers.support_ticket as support_router + + support_abuse.reset_memory_support_guard_state() + monkeypatch.setattr(support_router, "get_ticket_status", lambda ticket_id, email: None) + + payload = {"email": "trader@example.com"} + + first = client.post("/api/support/ticket/status/unknown-ticket", json=payload) + second = client.post("/api/support/ticket/status/unknown-ticket", json=payload) + third = client.post("/api/support/ticket/status/unknown-ticket", json=payload) + + assert first.status_code == 404 + assert second.status_code == 404 + assert third.status_code == 429 + + +def test_legitimate_status_lookup_still_works(monkeypatch): + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.services.support_abuse as support_abuse + import app.routers.support_ticket as support_router + + support_abuse.reset_memory_support_guard_state() + monkeypatch.setattr( + support_router, + "get_ticket_status", + lambda ticket_id, email: { + "ticket_id": ticket_id, + "status": "NEW", + "created_at": "2026-04-08T00:00:00+00:00", + "updated_at": "2026-04-08T00:00:00+00:00", + }, + ) + + response = client.post("/api/support/ticket/status/ticket-1", json={"email": "trader@example.com"}) + + assert response.status_code == 200 + assert response.json()["ticket_id"] == "ticket-1" + + +def test_support_captcha_hook_blocks_without_matching_header(monkeypatch): + monkeypatch.setenv("SUPPORT_CAPTCHA_SECRET", "expected-captcha") + app = _build_app(monkeypatch) + client = TestClient(app) + + import app.services.support_abuse as support_abuse + import app.routers.support_ticket as support_router + + support_abuse.reset_memory_support_guard_state() + monkeypatch.setattr( + support_router, + "create_ticket", + lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"}, + ) + + response = client.post( + "/api/support/ticket", + json={ + "name": "Trader", + "email": "trader@example.com", + "subject": "Need help", + "message": "Something happened", + }, + ) + + assert response.status_code == 403 + assert response.json() == {"detail": "Support verification failed"} diff --git a/indian_paper_trading_strategy/engine/db.py b/indian_paper_trading_strategy/engine/db.py index fc8b1b7..4b913d8 100644 --- a/indian_paper_trading_strategy/engine/db.py +++ b/indian_paper_trading_strategy/engine/db.py @@ -1,37 +1,48 @@ -import os -import threading -import time -from contextlib import contextmanager -from datetime import datetime, timezone -from contextvars import ContextVar +import os +import threading +import time +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone +from contextvars import ContextVar import psycopg2 from psycopg2 import pool from psycopg2 import OperationalError, InterfaceError from psycopg2.extras import Json -_POOL = None -_POOL_LOCK = threading.Lock() -_DEFAULT_USER_ID = None -_DEFAULT_LOCK = threading.Lock() +_POOL = None +_POOL_LOCK = threading.Lock() +_DEFAULT_USER_ID = None +_DEFAULT_LOCK = threading.Lock() +NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"} _USER_ID = ContextVar("engine_user_id", default=None) _RUN_ID = ContextVar("engine_run_id", default=None) def _db_config(): + env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower() + is_non_prod = env_name in NON_PROD_ENVIRONMENTS url = os.getenv("DATABASE_URL") if url: return {"dsn": url} schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app" + password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") + host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None) + dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None) + user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None) + if not is_non_prod and not password: + raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments") + if not is_non_prod and (not host or not dbname or not user): + raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments") return { - "host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost", + "host": host, "port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"), - "dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db", - "user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader", - "password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass", + "dbname": dbname, + "user": user, + "password": password, "connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")), "options": f"-csearch_path={schema},public" if schema else None, } @@ -295,7 +306,7 @@ def get_running_runs(user_id: str | None = None): return run_with_retry(_op) -def insert_engine_event( +def insert_engine_event( cur, event: str, data=None, @@ -307,10 +318,10 @@ def insert_engine_event( ): when = ts or _utc_now() scope_user, scope_run = _resolve_context(user_id, run_id) - cur.execute( - """ - INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta) - VALUES (%s, %s, %s, %s, %s, %s, %s) + cur.execute( + """ + INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta) + VALUES (%s, %s, %s, %s, %s, %s, %s) """, ( scope_user, @@ -320,5 +331,152 @@ def insert_engine_event( Json(data) if data is not None else None, message, Json(meta) if meta is not None else None, - ), - ) + ), + ) + + +def acquire_run_lease( + run_id: str, + owner_id: str, + *, + lease_seconds: int = 90, + now: datetime | None = None, +): + current_time = now or _utc_now() + expires_at = current_time + timedelta(seconds=lease_seconds) + + def _op(cur, _conn): + cur.execute( + """ + INSERT INTO run_leases (run_id, owner_id, leased_at, expires_at, heartbeat_at) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (run_id) DO NOTHING + RETURNING run_id + """, + (run_id, owner_id, current_time, expires_at, current_time), + ) + inserted = cur.fetchone() + if inserted: + return { + "acquired": True, + "status": "ACQUIRED", + "owner_id": owner_id, + "expires_at": expires_at, + } + + cur.execute( + """ + SELECT owner_id, expires_at + FROM run_leases + WHERE run_id = %s + FOR UPDATE + """, + (run_id,), + ) + row = cur.fetchone() + if not row: + return { + "acquired": False, + "status": "DENIED", + "owner_id": None, + "expires_at": None, + } + + current_owner, current_expiry = row + if current_owner == owner_id: + cur.execute( + """ + UPDATE run_leases + SET leased_at = %s, + expires_at = %s, + heartbeat_at = %s + WHERE run_id = %s AND owner_id = %s + RETURNING run_id + """, + (current_time, expires_at, current_time, run_id, owner_id), + ) + cur.fetchone() + return { + "acquired": True, + "status": "REFRESHED", + "owner_id": owner_id, + "expires_at": expires_at, + } + + if current_expiry <= current_time: + cur.execute( + """ + UPDATE run_leases + SET owner_id = %s, + leased_at = %s, + expires_at = %s, + heartbeat_at = %s + WHERE run_id = %s AND expires_at <= %s + RETURNING run_id + """, + (owner_id, current_time, expires_at, current_time, run_id, current_time), + ) + replaced = cur.fetchone() + if replaced: + return { + "acquired": True, + "status": "REACQUIRED", + "owner_id": owner_id, + "previous_owner": current_owner, + "expires_at": expires_at, + } + + return { + "acquired": False, + "status": "DENIED", + "owner_id": current_owner, + "expires_at": current_expiry, + } + + return run_with_retry(_op) + + +def heartbeat_run_lease( + run_id: str, + owner_id: str, + *, + lease_seconds: int = 90, + now: datetime | None = None, +): + current_time = now or _utc_now() + expires_at = current_time + timedelta(seconds=lease_seconds) + + def _op(cur, _conn): + cur.execute( + """ + UPDATE run_leases + SET heartbeat_at = %s, + expires_at = %s + WHERE run_id = %s + AND owner_id = %s + AND expires_at > %s + RETURNING run_id, expires_at + """, + (current_time, expires_at, run_id, owner_id, current_time), + ) + row = cur.fetchone() + if not row: + return {"active": False, "expires_at": None} + return {"active": True, "expires_at": row[1]} + + return run_with_retry(_op) + + +def release_run_lease(run_id: str, owner_id: str): + def _op(cur, _conn): + cur.execute( + """ + DELETE FROM run_leases + WHERE run_id = %s AND owner_id = %s + RETURNING run_id + """, + (run_id, owner_id), + ) + return cur.fetchone() is not None + + return run_with_retry(_op) diff --git a/indian_paper_trading_strategy/engine/execution.py b/indian_paper_trading_strategy/engine/execution.py index f179cd5..5068a5a 100644 --- a/indian_paper_trading_strategy/engine/execution.py +++ b/indian_paper_trading_strategy/engine/execution.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from indian_paper_trading_strategy.engine.state import load_state, save_state from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired -from indian_paper_trading_strategy.engine.ledger import log_event, event_exists +from indian_paper_trading_strategy.engine.ledger import claim_execution_window, log_event, event_exists from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry from indian_paper_trading_strategy.engine.market import market_now from indian_paper_trading_strategy.engine.time_utils import compute_logical_time @@ -237,7 +237,7 @@ def _prepare_live_execution(now_ts, sip_interval, sip_amount_val, sp_price_val, return {"ready": False, "state": state} if event_exists("SIP_EXECUTED", logical_time, cur=cur): return {"ready": False, "state": state} - if event_exists("SIP_ORDER_ATTEMPTED", logical_time, cur=cur): + if not claim_execution_window(logical_time, mode=mode, cur=cur): return {"ready": False, "state": state} log_event( diff --git a/indian_paper_trading_strategy/engine/ledger.py b/indian_paper_trading_strategy/engine/ledger.py index ebd2caf..240b618 100644 --- a/indian_paper_trading_strategy/engine/ledger.py +++ b/indian_paper_trading_strategy/engine/ledger.py @@ -1,8 +1,9 @@ # engine/ledger.py -from datetime import datetime, timezone - -from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context -from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time +import uuid +from datetime import datetime, timezone + +from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context +from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time def _event_exists_in_tx(cur, event, logical_time, user_id: str | None = None, run_id: str | None = None): @@ -20,14 +21,80 @@ def _event_exists_in_tx(cur, event, logical_time, user_id: str | None = None, ru return cur.fetchone() is not None -def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, run_id: str | None = None): +def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, run_id: str | None = None): if cur is not None: return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id) def _op(cur, _conn): return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id) - return run_with_retry(_op) + return run_with_retry(_op) + + +def _claim_execution_window_in_tx( + cur, + logical_time, + *, + mode: str | None = None, + user_id: str | None = None, + run_id: str | None = None, +): + scope_user, scope_run = get_context(user_id, run_id) + logical_ts = normalize_logical_time(logical_time) + claim_id = str(uuid.uuid4()) + cur.execute( + """ + INSERT INTO execution_claim ( + id, + user_id, + run_id, + mode, + logical_time, + claimed_at + ) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (user_id, run_id, logical_time) DO NOTHING + RETURNING id + """, + ( + claim_id, + scope_user, + scope_run, + (mode or "LIVE").strip().upper(), + logical_ts, + datetime.now(timezone.utc), + ), + ) + return cur.fetchone() is not None + + +def claim_execution_window( + logical_time, + *, + mode: str | None = None, + cur=None, + user_id: str | None = None, + run_id: str | None = None, +): + if cur is not None: + return _claim_execution_window_in_tx( + cur, + logical_time, + mode=mode, + user_id=user_id, + run_id=run_id, + ) + + def _op(cur, _conn): + return _claim_execution_window_in_tx( + cur, + logical_time, + mode=mode, + user_id=user_id, + run_id=run_id, + ) + + return run_with_retry(_op) def _log_event_in_tx( diff --git a/indian_paper_trading_strategy/engine/market.py b/indian_paper_trading_strategy/engine/market.py index a999420..4108e0a 100644 --- a/indian_paper_trading_strategy/engine/market.py +++ b/indian_paper_trading_strategy/engine/market.py @@ -1,46 +1,51 @@ -# engine/market.py -from datetime import datetime, time as dtime, timedelta -import pytz +from __future__ import annotations -_MARKET_TZ = pytz.timezone("Asia/Kolkata") -_OPEN_T = dtime(9, 15) -_CLOSE_T = dtime(15, 30) +from datetime import datetime + +from indian_paper_trading_strategy.engine.market_calendar import ( + MARKET_TZ, + get_market_session as _get_market_session, + get_market_status as _get_market_status, + is_market_open as _is_market_open, + market_now_utc, + next_market_open as _next_market_open, +) def market_now() -> datetime: - return datetime.now(_MARKET_TZ) + return market_now_utc().astimezone(MARKET_TZ) -def _as_market_tz(value: datetime) -> datetime: + +def _to_utc(value: datetime) -> datetime: if value.tzinfo is None: - return _MARKET_TZ.localize(value) - return value.astimezone(_MARKET_TZ) - -def is_market_open(now: datetime) -> bool: - now = _as_market_tz(now) - return now.weekday() < 5 and _OPEN_T <= now.time() <= _CLOSE_T - + return value.replace(tzinfo=MARKET_TZ).astimezone(market_now_utc().tzinfo) + return value.astimezone(market_now_utc().tzinfo) + + +def is_market_open(now: datetime) -> bool: + return _is_market_open(_to_utc(now)) + + +def market_session(now: datetime) -> dict[str, object]: + return _get_market_session(_to_utc(now)) + + +def market_status(now: datetime) -> str: + return _get_market_status(_to_utc(now)) + + def india_market_status(): now = market_now() - return is_market_open(now), now - -def next_market_open_after(value: datetime) -> datetime: - current = _as_market_tz(value) - while current.weekday() >= 5: - current = current + timedelta(days=1) - current = current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) - if current.time() < _OPEN_T: - return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) - if current.time() > _CLOSE_T: - current = current + timedelta(days=1) - while current.weekday() >= 5: - current = current + timedelta(days=1) - return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0) - return current - -def align_to_market_open(value: datetime) -> datetime: - current = _as_market_tz(value) - aligned = current if is_market_open(current) else next_market_open_after(current) - if value.tzinfo is None: - return aligned.replace(tzinfo=None) - return aligned + + +def next_market_open_after(value: datetime) -> datetime: + aligned_utc = _next_market_open(_to_utc(value)) + aligned_ist = aligned_utc.astimezone(MARKET_TZ) + if value.tzinfo is None: + return aligned_ist.replace(tzinfo=None) + return aligned_ist + + +def align_to_market_open(value: datetime) -> datetime: + return next_market_open_after(value) diff --git a/indian_paper_trading_strategy/engine/market_calendar.py b/indian_paper_trading_strategy/engine/market_calendar.py new file mode 100644 index 0000000..0a5a767 --- /dev/null +++ b/indian_paper_trading_strategy/engine/market_calendar.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import logging +import os +from datetime import date, datetime, time as dtime, timedelta, timezone +from functools import lru_cache +from zoneinfo import ZoneInfo + + +logger = logging.getLogger(__name__) + +UTC = timezone.utc +MARKET_TZ = ZoneInfo("Asia/Kolkata") +MARKET_OPEN = dtime(9, 15) +MARKET_CLOSE = dtime(15, 30) + +STATUS_OPEN = "OPEN" +STATUS_CLOSED = "CLOSED" +STATUS_HOLIDAY = "HOLIDAY" + +REASON_OPEN = "OPEN" +REASON_HOLIDAY = "HOLIDAY" +REASON_WEEKEND = "WEEKEND" +REASON_PRE_OPEN = "PRE_OPEN" +REASON_POST_CLOSE = "POST_CLOSE" +REASON_CALENDAR_UNAVAILABLE = "CALENDAR_UNAVAILABLE" + +# Capital market trading holidays for NSE calendar year 2026. +# Source: NSE circular dated December 12, 2025. +DEFAULT_NSE_HOLIDAYS_BY_YEAR = { + 2026: frozenset( + { + date(2026, 1, 26), # Republic Day + date(2026, 3, 3), # Holi + date(2026, 3, 26), # Shri Ram Navami + date(2026, 3, 31), # Shri Mahavir Jayanti + date(2026, 4, 3), # Good Friday + date(2026, 4, 14), # Dr. Baba Saheb Ambedkar Jayanti + date(2026, 5, 1), # Maharashtra Day + date(2026, 5, 28), # Bakri Id + date(2026, 6, 26), # Muharram + date(2026, 9, 14), # Ganesh Chaturthi + date(2026, 10, 2), # Mahatma Gandhi Jayanti + date(2026, 10, 20), # Dussehra + date(2026, 11, 10), # Diwali-Balipratipada + date(2026, 11, 24), # Prakash Gurpurb Sri Guru Nanak Dev + date(2026, 12, 25), # Christmas + } + ), +} + + +class UnsupportedCalendarYearError(RuntimeError): + def __init__(self, year: int): + super().__init__(f"NSE holiday calendar is not configured for {year}") + self.year = year + + +def market_now_utc() -> datetime: + return datetime.now(UTC) + + +def _as_utc(value: datetime | None) -> datetime: + if value is None: + return market_now_utc() + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value.astimezone(UTC) + + +def _as_ist(value: datetime | None) -> datetime: + return _as_utc(value).astimezone(MARKET_TZ) + + +def _parse_holiday_token(token: str) -> date: + text = token.strip() + if not text: + raise ValueError("Empty holiday token is not allowed") + return date.fromisoformat(text) + + +@lru_cache(maxsize=16) +def get_nse_holidays(year: int) -> frozenset[date]: + configured = (os.getenv(f"NSE_HOLIDAYS_{year}") or "").strip() + if configured: + return frozenset(_parse_holiday_token(token) for token in configured.split(",")) + + if year in DEFAULT_NSE_HOLIDAYS_BY_YEAR: + return DEFAULT_NSE_HOLIDAYS_BY_YEAR[year] + + raise UnsupportedCalendarYearError(year) + + +def is_trading_holiday(now_utc: datetime | None = None) -> bool: + current_ist = _as_ist(now_utc) + holidays = get_nse_holidays(current_ist.year) + return current_ist.date() in holidays + + +def _session_from_utc(now_utc: datetime | None = None) -> dict[str, object]: + current_utc = _as_utc(now_utc) + current_ist = current_utc.astimezone(MARKET_TZ) + try: + holidays = get_nse_holidays(current_ist.year) + except UnsupportedCalendarYearError as exc: + logger.error("NSE holiday calendar unavailable for year %s", exc.year) + return { + "status": STATUS_CLOSED, + "reason": REASON_CALENDAR_UNAVAILABLE, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": False, + } + + current_time = current_ist.timetz().replace(tzinfo=None) + if current_ist.date() in holidays: + return { + "status": STATUS_HOLIDAY, + "reason": REASON_HOLIDAY, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": True, + } + if current_ist.weekday() >= 5: + return { + "status": STATUS_CLOSED, + "reason": REASON_WEEKEND, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": True, + } + if current_time < MARKET_OPEN: + return { + "status": STATUS_CLOSED, + "reason": REASON_PRE_OPEN, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": True, + } + if MARKET_OPEN <= current_time < MARKET_CLOSE: + return { + "status": STATUS_OPEN, + "reason": REASON_OPEN, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": True, + } + return { + "status": STATUS_CLOSED, + "reason": REASON_POST_CLOSE, + "checked_at": current_utc, + "checked_at_ist": current_ist, + "calendar_supported": True, + } + + +def get_market_session(now_utc: datetime | None = None) -> dict[str, object]: + return dict(_session_from_utc(now_utc)) + + +def get_market_status(now_utc: datetime | None = None) -> str: + return str(_session_from_utc(now_utc)["status"]) + + +def is_market_open(now_utc: datetime | None = None) -> bool: + return get_market_status(now_utc) == STATUS_OPEN + + +def next_market_open(now_utc: datetime | None = None) -> datetime: + current_utc = _as_utc(now_utc) + session = _session_from_utc(current_utc) + if session["reason"] == REASON_CALENDAR_UNAVAILABLE: + checked_at_ist = session["checked_at_ist"] + raise UnsupportedCalendarYearError(int(checked_at_ist.year)) + + if session["status"] == STATUS_OPEN: + return current_utc + + current_ist = current_utc.astimezone(MARKET_TZ) + candidate_date = current_ist.date() + holidays = get_nse_holidays(candidate_date.year) + current_time = current_ist.timetz().replace(tzinfo=None) + + if ( + session["reason"] == REASON_PRE_OPEN + and candidate_date not in holidays + and candidate_date.weekday() < 5 + and current_time < MARKET_OPEN + ): + candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ) + return candidate.astimezone(UTC) + + candidate_date = candidate_date + timedelta(days=1) + while True: + holidays = get_nse_holidays(candidate_date.year) + if candidate_date.weekday() < 5 and candidate_date not in holidays: + break + candidate_date = candidate_date + timedelta(days=1) + + candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ) + return candidate.astimezone(UTC) diff --git a/indian_paper_trading_strategy/engine/runner.py b/indian_paper_trading_strategy/engine/runner.py index 74e1d3a..e2ec579 100644 --- a/indian_paper_trading_strategy/engine/runner.py +++ b/indian_paper_trading_strategy/engine/runner.py @@ -1,11 +1,17 @@ import os +import socket import threading import time +import uuid from datetime import datetime, timedelta, timezone from psycopg2.extras import Json -from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now +from indian_paper_trading_strategy.engine.market import ( + align_to_market_open, + market_now, + market_session, +) from indian_paper_trading_strategy.engine.execution import try_execute_sip from indian_paper_trading_strategy.engine.broker import ( BrokerAuthExpired, @@ -21,7 +27,16 @@ from indian_paper_trading_strategy.engine.strategy import allocation from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time from app.services.zerodha_service import KiteTokenError -from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context +from indian_paper_trading_strategy.engine.db import ( + acquire_run_lease, + db_transaction, + heartbeat_run_lease, + insert_engine_event, + release_run_lease, + run_with_retry, + get_context, + set_context, +) def _update_engine_status(user_id: str, run_id: str, status: str): @@ -58,7 +73,18 @@ _ENGINE_STATES_LOCK = threading.Lock() _RUNNERS = {} _RUNNERS_LOCK = threading.Lock() -engine_state = _ENGINE_STATES +engine_state = _ENGINE_STATES + +RUNNER_OWNER_ID = os.getenv("RUNNER_OWNER_ID") or f"{socket.gethostname()}:{os.getpid()}:{uuid.uuid4().hex}" +RUN_LEASE_SECONDS = int(os.getenv("RUN_LEASE_SECONDS", "90")) + + +class RunLeaseNotAcquiredError(RuntimeError): + def __init__(self, run_id: str, owner_id: str, details: dict | None = None): + super().__init__(f"Run lease not acquired for run {run_id}") + self.run_id = run_id + self.owner_id = owner_id + self.details = details or {} def _state_key(user_id: str, run_id: str): @@ -93,7 +119,7 @@ def get_engine_state(user_id: str, run_id: str): state = _get_state(user_id, run_id) return dict(state) -def log_event( +def log_event( event: str, data: dict | None = None, message: str | None = None, @@ -121,22 +147,82 @@ def log_event( meta=meta, ts=event_ts, ) - - run_with_retry(_op) - + + run_with_retry(_op) + + +def _log_runner_lease_event( + user_id: str, + run_id: str, + event: str, + message: str, + meta: dict | None = None, +): + details = meta or {} + print(f"[ENGINE] {event} {message} {details}", flush=True) + + def _op(cur, _conn): + insert_engine_event( + cur, + event, + data=details, + message=message, + ts=datetime.utcnow().replace(tzinfo=timezone.utc), + user_id=user_id, + run_id=run_id, + ) + + try: + run_with_retry(_op) + except Exception: + pass + + +def _refresh_run_lease_or_stop( + user_id: str, + run_id: str, + owner_id: str, +): + lease = heartbeat_run_lease( + run_id, + owner_id, + lease_seconds=RUN_LEASE_SECONDS, + ) + if lease.get("active"): + print( + f"[ENGINE] RUNNER_LEASE_HEARTBEAT lease heartbeat refreshed " + f"{{'run_id': '{run_id}', 'owner_id': '{owner_id}', 'expires_at': '{lease.get('expires_at')}'}}", + flush=True, + ) + return True + + _log_runner_lease_event( + user_id, + run_id, + "RUNNER_LEASE_LOST", + "Runner exiting due to lost lease", + {"owner_id": owner_id}, + ) + return False + def sleep_with_heartbeat( total_seconds: int, stop_event: threading.Event, user_id: str, run_id: str, + owner_id: str, step_seconds: int = 5, ): remaining = total_seconds while remaining > 0 and not stop_event.is_set(): - time.sleep(min(step_seconds, remaining)) + chunk = min(step_seconds, remaining) + time.sleep(chunk) _set_state(user_id, run_id, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") _update_engine_status(user_id, run_id, "RUNNING") - remaining -= step_seconds + if not _refresh_run_lease_or_stop(user_id, run_id, owner_id): + return False + remaining -= chunk + return True def _clear_runner(user_id: str, run_id: str): key = _state_key(user_id, run_id) @@ -144,7 +230,20 @@ def _clear_runner(user_id: str, run_id: str): _RUNNERS.pop(key, None) def can_execute(now: datetime) -> tuple[bool, str]: - if not is_market_open(now): + session = market_session(now) + status = str(session.get("status") or "CLOSED").upper() + reason = str(session.get("reason") or "").upper() + if status == "HOLIDAY": + return False, "MARKET_HOLIDAY" + if status != "OPEN": + if reason == "WEEKEND": + return False, "MARKET_WEEKEND" + if reason == "PRE_OPEN": + return False, "MARKET_PRE_OPEN" + if reason == "POST_CLOSE": + return False, "MARKET_POST_CLOSE" + if reason == "CALENDAR_UNAVAILABLE": + return False, "MARKET_CALENDAR_UNAVAILABLE" return False, "MARKET_CLOSED" return True, "OK" @@ -225,12 +324,13 @@ def _pause_for_auth_expiry( def _engine_loop(config, stop_event: threading.Event): - print("Strategy engine started with config:", config) - - user_id = config.get("user_id") - run_id = config.get("run_id") - scope_user, scope_run = get_context(user_id, run_id) - set_context(scope_user, scope_run) + print("Strategy engine started with config:", config) + + user_id = config.get("user_id") + run_id = config.get("run_id") + owner_id = config.get("runner_owner_id") or RUNNER_OWNER_ID + scope_user, scope_run = get_context(user_id, run_id) + set_context(scope_user, scope_run) strategy_name = config.get("strategy_name") or config.get("strategy") or "golden_nifty" sip_amount = config["sip_amount"] @@ -303,10 +403,14 @@ def _engine_loop(config, stop_event: threading.Event): state="RUNNING", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", ) - _update_engine_status(scope_user, scope_run, "RUNNING") - - try: - while not stop_event.is_set(): + _update_engine_status(scope_user, scope_run, "RUNNING") + exit_reason = "STOPPED" + + try: + while not stop_event.is_set(): + if not _refresh_run_lease_or_stop(scope_user, scope_run, owner_id): + exit_reason = "LEASE_LOST" + break _set_state(scope_user, scope_run, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") _update_engine_status(scope_user, scope_run, "RUNNING") @@ -357,10 +461,10 @@ def _engine_loop(config, stop_event: threading.Event): "frequency": frequency_label, }, ) - if emit_event_cb: - emit_event_cb( - event="SIP_WAITING", - message="Waiting for next SIP window", + if emit_event_cb: + emit_event_cb( + event="SIP_WAITING", + message="Waiting for next SIP window", meta={ "last_run": last_run, "next_eligible": next_run.isoformat(), @@ -368,7 +472,9 @@ def _engine_loop(config, stop_event: threading.Event): "frequency": frequency_label, }, ) - sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run) + if not sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run, owner_id): + exit_reason = "LEASE_LOST" + break continue try: @@ -395,7 +501,9 @@ def _engine_loop(config, stop_event: threading.Event): break except Exception as exc: debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)}) - sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id): + exit_reason = "LEASE_LOST" + break continue try: @@ -416,7 +524,9 @@ def _engine_loop(config, stop_event: threading.Event): break except Exception as exc: debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)}) - sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id): + exit_reason = "LEASE_LOST" + break continue nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1] @@ -565,26 +675,49 @@ def _engine_loop(config, stop_event: threading.Event): logical_time=logical_time, ) - sleep_with_heartbeat(30, stop_event, scope_user, scope_run) + if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id): + exit_reason = "LEASE_LOST" + break except BrokerAuthExpired as exc: + exit_reason = "AUTH_EXPIRED" _pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb) print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True) except Exception as e: + exit_reason = "ERROR" _set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z") _update_engine_status(scope_user, scope_run, "ERROR") log_event("ENGINE_ERROR", {"error": str(e)}) raise + finally: + try: + released = release_run_lease(scope_run, owner_id) + if released: + print( + f"[ENGINE] RUNNER_LEASE_RELEASED released run lease " + f"{{'run_id': '{scope_run}', 'owner_id': '{owner_id}'}}", + flush=True, + ) + except Exception: + pass - log_event("ENGINE_STOP") - _set_state( - scope_user, - scope_run, - state="STOPPED", - last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", - ) - _update_engine_status(scope_user, scope_run, "STOPPED") - print("Strategy engine stopped") - _clear_runner(scope_user, scope_run) + if exit_reason not in {"ERROR", "LEASE_LOST", "AUTH_EXPIRED"}: + log_event("ENGINE_STOP") + _set_state( + scope_user, + scope_run, + state="STOPPED", + last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", + ) + _update_engine_status(scope_user, scope_run, "STOPPED") + print("Strategy engine stopped") + elif exit_reason == "LEASE_LOST": + _set_state( + scope_user, + scope_run, + state="STOPPED", + last_heartbeat_ts=datetime.utcnow().isoformat() + "Z", + ) + _clear_runner(scope_user, scope_run) def start_engine(config): user_id = config.get("user_id") @@ -600,14 +733,53 @@ def start_engine(config): if runner and runner["thread"].is_alive(): return False + lease = acquire_run_lease( + run_id, + RUNNER_OWNER_ID, + lease_seconds=RUN_LEASE_SECONDS, + ) + if not lease.get("acquired"): + _log_runner_lease_event( + user_id, + run_id, + "RUNNER_LEASE_DENIED", + "Run lease denied", + { + "owner_id": RUNNER_OWNER_ID, + "current_owner": lease.get("owner_id"), + "expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None, + }, + ) + raise RunLeaseNotAcquiredError(run_id, RUNNER_OWNER_ID, lease) + + lease_status = str(lease.get("status") or "ACQUIRED").upper() + event_name = "RUNNER_LEASE_REACQUIRED" if lease_status == "REACQUIRED" else "RUNNER_LEASE_ACQUIRED" + _log_runner_lease_event( + user_id, + run_id, + event_name, + "Run lease acquired" if lease_status != "REACQUIRED" else "Expired run lease reacquired", + { + "owner_id": RUNNER_OWNER_ID, + "expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None, + }, + ) + stop_event = threading.Event() - thread = threading.Thread( - target=_engine_loop, - args=(config, stop_event), + thread_config = dict(config) + thread_config["runner_owner_id"] = RUNNER_OWNER_ID + thread = threading.Thread( + target=_engine_loop, + args=(thread_config, stop_event), daemon=True, ) _RUNNERS[key] = {"thread": thread, "stop_event": stop_event} - thread.start() + try: + thread.start() + except Exception: + _RUNNERS.pop(key, None) + release_run_lease(run_id, RUNNER_OWNER_ID) + raise return True def stop_engine(user_id: str, run_id: str | None = None, timeout: float | None = 10.0): diff --git a/indian_paper_trading_strategy/engine/state.py b/indian_paper_trading_strategy/engine/state.py index d14998c..b1ad5c7 100644 --- a/indian_paper_trading_strategy/engine/state.py +++ b/indian_paper_trading_strategy/engine/state.py @@ -1,8 +1,8 @@ -# engine/state.py +# engine/state.py from datetime import datetime, timezone from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context -from indian_paper_trading_strategy.engine.market import market_now +from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp DEFAULT_STATE = { "initial_cash": 0.0, @@ -31,33 +31,8 @@ def _default_state(mode: str | None): return DEFAULT_PAPER_STATE.copy() return DEFAULT_STATE.copy() -def _local_tz(): - return market_now().tzinfo - -def _format_local_ts(value: datetime | None): - if value is None: - return None - return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat() - -def _parse_ts(value): - if value is None: - return None - if isinstance(value, datetime): - if value.tzinfo is None: - return value.replace(tzinfo=_local_tz()) - return value - if isinstance(value, str): - text = value.strip() - if not text: - return None - try: - parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) - except ValueError: - return None - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=_local_tz()) - return parsed - return None +def _parse_ts(value): + return parse_persisted_timestamp(value) def _resolve_scope(user_id: str | None, run_id: str | None): return get_context(user_id, run_id) @@ -101,15 +76,15 @@ def load_state( merged = _default_state(mode) merged.update( { - "initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"], - "cash": float(row[1]) if row[1] is not None else merged["cash"], - "total_invested": float(row[2]) if row[2] is not None else merged["total_invested"], - "nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"], - "gold_units": float(row[4]) if row[4] is not None else merged["gold_units"], - "last_sip_ts": _format_local_ts(row[5]), - "last_run": _format_local_ts(row[6]), - } - ) + "initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"], + "cash": float(row[1]) if row[1] is not None else merged["cash"], + "total_invested": float(row[2]) if row[2] is not None else merged["total_invested"], + "nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"], + "gold_units": float(row[4]) if row[4] is not None else merged["gold_units"], + "last_sip_ts": serialize_timestamp(row[5]), + "last_run": serialize_timestamp(row[6]), + } + ) if row[7] is not None or row[8] is not None: merged["sip_frequency"] = {"value": row[7], "unit": row[8]} return merged @@ -140,13 +115,13 @@ def load_state( merged = _default_state(mode) merged.update( { - "total_invested": float(row[0]) if row[0] is not None else merged["total_invested"], - "nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"], - "gold_units": float(row[2]) if row[2] is not None else merged["gold_units"], - "last_sip_ts": _format_local_ts(row[3]), - "last_run": _format_local_ts(row[4]), - } - ) + "total_invested": float(row[0]) if row[0] is not None else merged["total_invested"], + "nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"], + "gold_units": float(row[2]) if row[2] is not None else merged["gold_units"], + "last_sip_ts": serialize_timestamp(row[3]), + "last_run": serialize_timestamp(row[4]), + } + ) return merged def init_paper_state( diff --git a/indian_paper_trading_strategy/engine/time_utils.py b/indian_paper_trading_strategy/engine/time_utils.py index 9e4943d..b4f6a9f 100644 --- a/indian_paper_trading_strategy/engine/time_utils.py +++ b/indian_paper_trading_strategy/engine/time_utils.py @@ -1,7 +1,12 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone +from zoneinfo import ZoneInfo + + +UTC = timezone.utc +MARKET_TZ = ZoneInfo("Asia/Kolkata") -def frequency_to_timedelta(freq: dict) -> timedelta: +def frequency_to_timedelta(freq: dict) -> timedelta: value = int(freq.get("value", 0)) unit = freq.get("unit") @@ -15,27 +20,64 @@ def frequency_to_timedelta(freq: dict) -> timedelta: raise ValueError(f"Unsupported frequency unit: {unit}") -def normalize_logical_time(ts: datetime) -> datetime: - return ts.replace(microsecond=0) +def normalize_logical_time(ts: datetime) -> datetime: + return ts.replace(microsecond=0) + + +def ensure_aware(ts: datetime, *, default_tz=UTC) -> datetime: + if ts.tzinfo is None: + return ts.replace(tzinfo=default_tz) + return ts + + +def to_utc(ts: datetime, *, default_tz=UTC) -> datetime: + return ensure_aware(ts, default_tz=default_tz).astimezone(UTC) + + +def serialize_timestamp(ts: datetime | None, *, default_tz=UTC) -> str | None: + if ts is None: + return None + return to_utc(ts, default_tz=default_tz).isoformat() + + +def parse_persisted_timestamp(value: datetime | str | None, *, default_tz=UTC) -> datetime | None: + if value is None: + return None + if isinstance(value, datetime): + return to_utc(value, default_tz=default_tz) + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + return None + return to_utc(parsed, default_tz=default_tz) + return None + + +def parse_market_timestamp(value: datetime | str | None) -> datetime | None: + parsed = parse_persisted_timestamp(value) + if parsed is None: + return None + return parsed.astimezone(MARKET_TZ) -def compute_logical_time( - now: datetime, - last_run: str | None, - interval_seconds: float | None, -) -> datetime: - base = now - if last_run and interval_seconds: - try: - parsed = datetime.fromisoformat(last_run.replace("Z", "+00:00")) - except ValueError: - parsed = None - if parsed is not None: - if now.tzinfo and parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=now.tzinfo) - elif now.tzinfo is None and parsed.tzinfo: - parsed = parsed.replace(tzinfo=None) - candidate = parsed + timedelta(seconds=interval_seconds) - if now >= candidate: - base = candidate - return normalize_logical_time(base) +def compute_logical_time( + now: datetime, + last_run: str | None, + interval_seconds: float | None, +) -> datetime: + base = now + if last_run and interval_seconds: + parsed = parse_persisted_timestamp(last_run, default_tz=now.tzinfo or UTC) + if parsed is not None: + if now.tzinfo is None: + parsed = parsed.replace(tzinfo=None) + else: + parsed = parsed.astimezone(now.tzinfo) + candidate = parsed + timedelta(seconds=interval_seconds) + if now >= candidate: + base = candidate + return normalize_logical_time(base)