Harden backend auth, execution safety, and market session logic

This commit is contained in:
Thigazhezhilan J 2026-04-08 22:02:24 +05:30
parent 99e48144aa
commit 519addd78f
33 changed files with 2753 additions and 360 deletions

View File

@ -47,6 +47,8 @@ class AppSession(Base):
created_at = Column(DateTime(timezone=True), nullable=False)
last_seen_at = Column(DateTime(timezone=True))
expires_at = Column(DateTime(timezone=True), nullable=False)
ip = Column(Text)
user_agent = Column(Text)
__table_args__ = (
Index("idx_app_session_user_id", "user_id"),
@ -63,6 +65,8 @@ class UserBroker(Base):
access_token = Column(Text)
connected_at = Column(DateTime(timezone=True))
api_key = Column(Text)
api_secret = Column(Text)
auth_state = Column(Text)
user_name = Column(Text)
broker_user_id = Column(Text)
pending_broker = Column(Text)
@ -117,7 +121,10 @@ class StrategyRun(Base):
__table_args__ = (
UniqueConstraint("user_id", "run_id", name="uq_strategy_run_user_run"),
CheckConstraint("status IN ('RUNNING','STOPPED','ERROR')", name="chk_strategy_run_status"),
CheckConstraint(
"status IN ('RUNNING','STOPPED','ERROR','PAUSED_AUTH_EXPIRED')",
name="chk_strategy_run_status",
),
Index("idx_strategy_run_user_status", "user_id", "status"),
Index("idx_strategy_run_user_created", "user_id", "created_at"),
Index(
@ -129,6 +136,22 @@ class StrategyRun(Base):
)
class PasswordResetOtp(Base):
__tablename__ = "password_reset_otp"
id = Column(String, primary_key=True)
email = Column(Text, nullable=False)
otp_hash = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
expires_at = Column(DateTime(timezone=True), nullable=False)
used_at = Column(DateTime(timezone=True))
__table_args__ = (
Index("idx_password_reset_otp_email", "email"),
Index("idx_password_reset_otp_expires_at", "expires_at"),
)
class StrategyConfig(Base):
__tablename__ = "strategy_config"
@ -420,6 +443,97 @@ class LiveEquitySnapshot(Base):
)
class SupportTicket(Base):
__tablename__ = "support_ticket"
id = Column(String, primary_key=True)
name = Column(Text, nullable=False)
email = Column(Text, nullable=False)
subject = Column(Text, nullable=False)
message = Column(Text, nullable=False)
status = Column(Text, nullable=False, server_default=text("'NEW'"))
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
__table_args__ = (
Index("idx_support_ticket_email", "email"),
Index("idx_support_ticket_created_at", "created_at"),
)
class SupportRequestAudit(Base):
__tablename__ = "support_request_audit"
id = Column(BigInteger, primary_key=True, autoincrement=True)
endpoint = Column(Text, nullable=False)
ip_hash = Column(Text)
email_hash = Column(Text)
ticket_hash = Column(Text)
blocked = Column(Boolean, nullable=False, server_default=text("false"))
reason = Column(Text)
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
__table_args__ = (
Index("idx_support_request_audit_endpoint_ip_created", "endpoint", "ip_hash", "created_at"),
Index("idx_support_request_audit_ticket_created", "ticket_hash", "created_at"),
)
class BrokerCallbackState(Base):
__tablename__ = "broker_callback_state"
id = Column(String, primary_key=True)
state_hash = Column(Text, nullable=False, unique=True)
user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False)
session_id = Column(String, ForeignKey("app_session.id", ondelete="CASCADE"), nullable=False)
broker = Column(Text, nullable=False)
flow = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False)
consumed_at = Column(DateTime(timezone=True))
__table_args__ = (
Index(
"idx_broker_callback_state_lookup",
"user_id",
"session_id",
"broker",
"flow",
"expires_at",
),
)
class ExecutionClaim(Base):
__tablename__ = "execution_claim"
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False)
run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False)
mode = Column(Text, nullable=False)
logical_time = Column(DateTime(timezone=True), nullable=False)
claimed_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
__table_args__ = (
UniqueConstraint("user_id", "run_id", "logical_time", name="uq_execution_claim_scope"),
Index("idx_execution_claim_run_claimed", "run_id", "claimed_at"),
)
class RunLease(Base):
__tablename__ = "run_leases"
run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), primary_key=True)
owner_id = Column(Text, nullable=False)
leased_at = Column(DateTime(timezone=True), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False)
heartbeat_at = Column(DateTime(timezone=True))
__table_args__ = (
Index("idx_run_leases_owner_expires", "owner_id", "expires_at"),
)
class MTMLedger(Base):
__tablename__ = "mtm_ledger"

View File

@ -1,82 +1,126 @@
import os
from urllib.parse import urlparse
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.admin_role_service import bootstrap_super_admin
from app.admin_router import router as admin_router
from app.routers.auth import router as auth_router
from app.routers.broker import router as broker_router
from app.routers.health import router as health_router
from app.routers.paper import router as paper_router
from app.routers.password_reset import router as password_reset_router
from app.routers.strategy import router as strategy_router
from app.routers.support_ticket import router as support_ticket_router
from app.routers.system import router as system_router
from app.routers.strategy import router as strategy_router
from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router
from app.routers.paper import router as paper_router
from market import router as market_router
from paper_mtm import router as paper_mtm_router
from app.services.db import _db_config as _validate_db_config
from app.services.live_equity_service import start_live_equity_snapshot_daemon
from app.services.strategy_service import init_log_state, resume_running_runs
from app.admin_router import router as admin_router
from app.admin_role_service import bootstrap_super_admin
from market import router as market_router
from paper_mtm import router as paper_mtm_router
app = FastAPI(
title="QuantFortune Backend",
version="1.0"
)
DEFAULT_PRODUCTION_ORIGINS = {"https://app.quantfortune.com"}
DEFAULT_DEV_ORIGINS = {
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:5173",
"http://127.0.0.1:5173",
}
PRODUCTION_ENV_NAMES = {"prod", "production"}
cors_origins = [
origin.strip()
for origin in os.getenv("CORS_ORIGINS", "").split(",")
if origin.strip()
]
if not cors_origins:
cors_origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
cors_origin_regex = os.getenv("CORS_ORIGIN_REGEX", "").strip()
if not cors_origin_regex:
cors_origin_regex = (
r"https://.*\\.ngrok-free\\.dev"
r"|https://.*\\.ngrok-free\\.app"
r"|https://.*\\.ngrok\\.io"
def _environment_name() -> str:
return (
os.getenv("APP_ENV")
or os.getenv("ENVIRONMENT")
or os.getenv("FASTAPI_ENV")
or "development"
).strip().lower()
def _normalize_origin(origin: str) -> str:
return origin.strip().rstrip("/")
def _is_dev_origin(origin: str) -> bool:
parsed = urlparse(origin)
return parsed.scheme == "http" and parsed.hostname in {"localhost", "127.0.0.1"}
def _validate_cors_origin(origin: str) -> str:
normalized = _normalize_origin(origin)
if not normalized:
raise RuntimeError("Empty CORS origin is not allowed")
if normalized in DEFAULT_PRODUCTION_ORIGINS or _is_dev_origin(normalized):
return normalized
raise RuntimeError(
f"Unsupported CORS origin '{normalized}'. Only app.quantfortune.com and localhost dev origins are allowed."
)
# app.add_middleware(
# CORSMiddleware,
# allow_origins=cors_origins,
# allow_origin_regex=cors_origin_regex or None,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
app.add_middleware(
CORSMiddleware,
allow_origins=[], # must be empty when using regex
allow_origin_regex=".*", # allow ANY origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _build_cors_origins() -> list[str]:
configured = [
_normalize_origin(origin)
for origin in os.getenv("CORS_ORIGINS", "").split(",")
if origin.strip()
]
env_name = _environment_name()
app.include_router(strategy_router)
app.include_router(auth_router)
app.include_router(broker_router)
app.include_router(zerodha_router)
app.include_router(zerodha_public_router)
app.include_router(paper_router)
app.include_router(market_router)
app.include_router(paper_mtm_router)
app.include_router(health_router)
app.include_router(system_router)
app.include_router(admin_router)
app.include_router(support_ticket_router)
app.include_router(password_reset_router)
if env_name in PRODUCTION_ENV_NAMES:
if not configured:
raise RuntimeError("CORS_ORIGINS must be configured explicitly in production")
origins = configured
else:
origins = configured or sorted(DEFAULT_DEV_ORIGINS)
@app.on_event("startup")
def init_app_state():
init_log_state()
bootstrap_super_admin()
resume_running_runs()
start_live_equity_snapshot_daemon()
deduped: list[str] = []
seen: set[str] = set()
for origin in origins:
validated = _validate_cors_origin(origin)
if validated not in seen:
seen.add(validated)
deduped.append(validated)
return deduped
def create_app() -> FastAPI:
_validate_db_config()
app = FastAPI(title="QuantFortune Backend", version="1.0")
cors_origins = _build_cors_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(strategy_router)
app.include_router(auth_router)
app.include_router(broker_router)
app.include_router(zerodha_router)
app.include_router(zerodha_public_router)
app.include_router(paper_router)
app.include_router(market_router)
app.include_router(paper_mtm_router)
app.include_router(health_router)
app.include_router(system_router)
app.include_router(admin_router)
app.include_router(support_ticket_router)
app.include_router(password_reset_router)
@app.on_event("startup")
def init_app_state():
if os.getenv("DISABLE_STARTUP_TASKS", "0") == "1":
return
init_log_state()
bootstrap_super_admin()
resume_running_runs()
start_live_equity_snapshot_daemon()
return app
app = create_app()

View File

@ -15,8 +15,12 @@ from app.services.email_service import send_email_async
router = APIRouter(prefix="/api")
SESSION_COOKIE_NAME = "session_id"
COOKIE_SECURE = os.getenv("COOKIE_SECURE", "0") == "1"
APP_ENV = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
IS_PRODUCTION = APP_ENV in {"prod", "production"}
COOKIE_SECURE = True if IS_PRODUCTION else os.getenv("COOKIE_SECURE", "0") == "1"
COOKIE_SAMESITE = (os.getenv("COOKIE_SAMESITE") or "lax").lower()
if IS_PRODUCTION and not COOKIE_SECURE:
raise RuntimeError("Secure session cookies are mandatory in production")
def _set_session_cookie(response: Response, session_id: str):

View File

@ -15,6 +15,10 @@ from app.broker_store import (
set_pending_broker,
)
from app.services.auth_service import get_user_for_session
from app.services.broker_callback_state import (
consume_broker_callback_state,
create_broker_callback_state,
)
from app.services.email_service import send_email_async
from app.services.groww_service import (
GrowwApiError,
@ -60,6 +64,13 @@ def _require_user(request: Request):
return user
def _require_session_id(request: Request) -> str:
session_id = request.cookies.get("session_id")
if not session_id:
raise HTTPException(status_code=401, detail="Not authenticated")
return session_id
def _first_number(*values, default: float = 0.0) -> float:
for value in values:
try:
@ -317,6 +328,7 @@ def _normalize_groww_funds(data: dict | None) -> dict:
def _build_saved_broker_login_url(
request: Request,
user_id: str,
session_id: str,
redirect_url_override: str | None = None,
) -> str:
entry = get_user_broker(user_id) or {}
@ -332,7 +344,13 @@ def _build_saved_broker_login_url(
if not redirect_url:
base = str(request.base_url).rstrip("/")
redirect_url = f"{base}/api/broker/callback"
return build_login_url(creds["api_key"], redirect_url=redirect_url)
state = create_broker_callback_state(
user_id=user_id,
session_id=session_id,
broker="ZERODHA",
flow="reconnect",
)
return build_login_url(creds["api_key"], redirect_url=redirect_url, state=state)
def _notify_broker_connected(username: str, broker: str, broker_user_id: str | None):
@ -401,6 +419,7 @@ async def disconnect_broker(request: Request):
@router.post("/zerodha/login")
async def zerodha_login(payload: dict, request: Request):
user = _require_user(request)
session_id = _require_session_id(request)
api_key = (payload.get("apiKey") or "").strip()
api_secret = (payload.get("apiSecret") or "").strip()
redirect_url = (payload.get("redirectUrl") or "").strip()
@ -408,7 +427,13 @@ async def zerodha_login(payload: dict, request: Request):
raise HTTPException(status_code=400, detail="API key and secret are required")
set_pending_broker(user["id"], "ZERODHA", api_key, api_secret)
return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None)}
state = create_broker_callback_state(
user_id=user["id"],
session_id=session_id,
broker="ZERODHA",
flow="connect",
)
return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None, state=state)}
@router.post("/groww/connect")
@ -490,11 +515,23 @@ async def groww_reconnect(request: Request):
@router.get("/zerodha/callback")
async def zerodha_callback(request: Request, request_token: str = ""):
async def zerodha_callback(request: Request, request_token: str = "", state: str = ""):
user = _require_user(request)
session_id = _require_session_id(request)
token = request_token.strip()
callback_state = state.strip()
if not token:
raise HTTPException(status_code=400, detail="Missing request_token")
if not callback_state:
raise HTTPException(status_code=400, detail="Missing state")
if not consume_broker_callback_state(
state=callback_state,
user_id=user["id"],
session_id=session_id,
broker="ZERODHA",
flow="connect",
):
raise HTTPException(status_code=401, detail="Invalid or expired broker callback state")
pending = get_pending_broker(user["id"]) or {}
api_key = (pending.get("api_key") or "").strip()
@ -541,32 +578,46 @@ async def zerodha_callback(request: Request, request_token: str = ""):
@router.get("/login")
async def broker_login(request: Request):
user = _require_user(request)
session_id = _require_session_id(request)
redirect_url = (
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
.strip()
or None
)
login_url = _build_saved_broker_login_url(request, user["id"], redirect_url)
login_url = _build_saved_broker_login_url(request, user["id"], session_id, redirect_url)
return RedirectResponse(login_url)
@router.get("/login-url")
async def broker_login_url(request: Request):
user = _require_user(request)
session_id = _require_session_id(request)
redirect_url = (
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
.strip()
or None
)
return {"loginUrl": _build_saved_broker_login_url(request, user["id"], redirect_url)}
return {"loginUrl": _build_saved_broker_login_url(request, user["id"], session_id, redirect_url)}
@router.get("/callback")
async def broker_callback(request: Request, request_token: str = ""):
async def broker_callback(request: Request, request_token: str = "", state: str = ""):
user = _require_user(request)
session_id = _require_session_id(request)
token = request_token.strip()
callback_state = state.strip()
if not token:
raise HTTPException(status_code=400, detail="Missing request_token")
if not callback_state:
raise HTTPException(status_code=400, detail="Missing state")
if not consume_broker_callback_state(
state=callback_state,
user_id=user["id"],
session_id=session_id,
broker="ZERODHA",
flow="reconnect",
):
raise HTTPException(status_code=401, detail="Invalid or expired broker callback state")
creds = get_broker_credentials(user["id"])
if not creds:
raise HTTPException(status_code=400, detail="Broker credentials not configured")

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from fastapi import APIRouter, HTTPException, Query, Request, status as http_status
from app.models import StrategyStartRequest
from app.services.strategy_service import (
start_strategy,
@ -15,6 +14,20 @@ from app.services.tenant import get_request_user_id
router = APIRouter(prefix="/api")
def _raise_strategy_error(payload: dict, *, default_status: int) -> None:
message = payload.get("message") or payload.get("detail") or "Strategy operation failed"
raise HTTPException(
status_code=default_status,
detail={
"status": payload.get("status", "error"),
"message": message,
"run_id": payload.get("run_id"),
"redirect_url": payload.get("redirect_url"),
"broker": payload.get("broker"),
},
)
@router.post("/strategy/start")
def start(req: StrategyStartRequest, request: Request):
user_id = get_request_user_id(request)
@ -24,35 +37,38 @@ def start(req: StrategyStartRequest, request: Request):
def stop(request: Request):
try:
user_id = get_request_user_id(request)
return stop_strategy(user_id)
result = stop_strategy(user_id)
if result.get("status") not in {"stopped", "already_stopped"}:
_raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT)
return result
except HTTPException:
raise
except Exception as exc:
print(f"[STRATEGY] unhandled stop route failure: {exc}", flush=True)
return JSONResponse(
status_code=200,
content={
"status": "stop_failed",
"message": f"Unable to stop strategy: {exc}",
},
)
raise HTTPException(
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"status": "stop_failed", "message": f"Unable to stop strategy: {exc}"},
) from exc
@router.post("/strategy/resume")
def resume(request: Request):
try:
user_id = get_request_user_id(request)
return resume_strategy(user_id)
result = resume_strategy(user_id)
success_statuses = {"resumed", "already_running"}
if result.get("status") == "broker_auth_required":
_raise_strategy_error(result, default_status=http_status.HTTP_401_UNAUTHORIZED)
if result.get("status") not in success_statuses:
_raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT)
return result
except HTTPException:
raise
except Exception as exc:
print(f"[STRATEGY] unhandled resume route failure: {exc}", flush=True)
return JSONResponse(
status_code=200,
content={
"status": "resume_failed",
"message": f"Unable to resume strategy: {exc}",
},
)
raise HTTPException(
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"status": "resume_failed", "message": f"Unable to resume strategy: {exc}"},
) from exc
@router.get("/strategy/status")
def status(request: Request):

View File

@ -1,6 +1,7 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Header, HTTPException, Request
from pydantic import BaseModel
from app.services.support_abuse import SupportGuardRejected, enforce_support_guard
from app.services.support_ticket import create_ticket, get_ticket_status
@ -19,7 +20,20 @@ class TicketStatusRequest(BaseModel):
@router.post("/ticket")
def submit_ticket(payload: TicketCreate):
def submit_ticket(
payload: TicketCreate,
request: Request,
support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"),
):
try:
enforce_support_guard(
request=request,
endpoint="ticket_create",
email=payload.email.strip(),
captcha_token=support_captcha,
)
except SupportGuardRejected as exc:
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
if not payload.subject.strip() or not payload.message.strip():
raise HTTPException(status_code=400, detail="Subject and message are required")
ticket = create_ticket(
@ -32,7 +46,22 @@ def submit_ticket(payload: TicketCreate):
@router.post("/ticket/status/{ticket_id}")
def ticket_status(ticket_id: str, payload: TicketStatusRequest):
def ticket_status(
ticket_id: str,
payload: TicketStatusRequest,
request: Request,
support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"),
):
try:
enforce_support_guard(
request=request,
endpoint="ticket_status",
email=payload.email.strip(),
ticket_id=ticket_id.strip(),
captcha_token=support_captcha,
)
except SupportGuardRejected as exc:
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
status = get_ticket_status(ticket_id.strip(), payload.email.strip())
if not status:
raise HTTPException(status_code=404, detail="Ticket not found")

View File

@ -1,9 +1,13 @@
import hashlib
import os
import re
import secrets
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from argon2 import PasswordHasher
from argon2.exceptions import InvalidHash, VerifyMismatchError
from app.services.db import db_connection
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
@ -11,7 +15,11 @@ SESSION_REFRESH_WINDOW_SECONDS = int(
os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60))
)
RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10"))
RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret")
PASSWORD_HASHER = PasswordHasher()
LEGACY_SHA256_RE = re.compile(r"^[0-9a-f]{64}$")
RESET_OTP_SECRET = (os.getenv("RESET_OTP_SECRET") or "").strip()
if not RESET_OTP_SECRET:
raise RuntimeError("RESET_OTP_SECRET must be configured")
def _now_utc() -> datetime:
@ -23,9 +31,17 @@ def _new_expiry(now: datetime) -> datetime:
def _hash_password(password: str) -> str:
return PASSWORD_HASHER.hash(password)
def _hash_password_legacy(password: str) -> str:
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def _is_legacy_password_hash(password_hash: str | None) -> bool:
return bool(password_hash and LEGACY_SHA256_RE.fullmatch(password_hash))
def _hash_otp(email: str, otp: str) -> str:
payload = f"{email}:{otp}:{RESET_OTP_SECRET}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
@ -80,12 +96,47 @@ def create_user(username: str, password: str):
return _row_to_user(cur.fetchone())
def _update_password_hash(user_id: str, password_hash: str):
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE app_user SET password_hash = %s WHERE id = %s",
(password_hash, user_id),
)
def _verify_password(user_id: str, stored_hash: str | None, password: str) -> tuple[bool, str | None]:
if not stored_hash:
return False, None
if _is_legacy_password_hash(stored_hash):
if secrets.compare_digest(stored_hash, _hash_password_legacy(password)):
return True, _hash_password(password)
return False, None
try:
verified = PASSWORD_HASHER.verify(stored_hash, password)
except (VerifyMismatchError, InvalidHash):
return False, None
if not verified:
return False, None
if PASSWORD_HASHER.check_needs_rehash(stored_hash):
return True, _hash_password(password)
return True, None
def authenticate_user(username: str, password: str):
user = get_user_by_username(username)
if not user:
return None
if user.get("password") != _hash_password(password):
verified, replacement_hash = _verify_password(user["id"], user.get("password"), password)
if not verified:
return None
if replacement_hash:
_update_password_hash(user["id"], replacement_hash)
user["password"] = replacement_hash
return user
@ -130,13 +181,7 @@ def get_last_session_meta(user_id: str):
def update_user_password(user_id: str, new_password: str):
password_hash = _hash_password(new_password)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE app_user SET password_hash = %s WHERE id = %s",
(password_hash, user_id),
)
_update_password_hash(user_id, password_hash)
def create_password_reset_otp(email: str):

View File

@ -0,0 +1,111 @@
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from app.services.db import db_connection
CALLBACK_STATE_TTL_SECONDS = 15 * 60
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _state_hash(state: str) -> str:
return hashlib.sha256(state.encode("utf-8")).hexdigest()
def create_broker_callback_state(
*,
user_id: str,
session_id: str,
broker: str,
flow: str,
ttl_seconds: int = CALLBACK_STATE_TTL_SECONDS,
) -> str:
state = secrets.token_urlsafe(32)
now = _now_utc()
expires_at = now + timedelta(seconds=ttl_seconds)
state_hash = _state_hash(state)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM broker_callback_state
WHERE expires_at <= %s OR consumed_at IS NOT NULL
""",
(now,),
)
cur.execute(
"""
INSERT INTO broker_callback_state (
id,
state_hash,
user_id,
session_id,
broker,
flow,
created_at,
expires_at,
consumed_at
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NULL)
""",
(
str(uuid4()),
state_hash,
user_id,
session_id,
broker.strip().upper(),
flow.strip().lower(),
now,
expires_at,
),
)
return state
def consume_broker_callback_state(
*,
state: str,
user_id: str,
session_id: str,
broker: str,
flow: str,
):
if not state:
return None
now = _now_utc()
state_hash = _state_hash(state)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE broker_callback_state
SET consumed_at = %s
WHERE state_hash = %s
AND user_id = %s
AND session_id = %s
AND broker = %s
AND flow = %s
AND consumed_at IS NULL
AND expires_at > %s
RETURNING id, expires_at
""",
(
now,
state_hash,
user_id,
session_id,
broker.strip().upper(),
flow.strip().lower(),
now,
),
)
row = cur.fetchone()
if not row:
return None
return {"id": row[0], "expires_at": row[1].isoformat() if row[1] else None}

View File

@ -16,6 +16,7 @@ Base = declarative_base()
_ENGINE: Engine | None = None
_ENGINE_LOCK = threading.Lock()
NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"}
class _ConnectionProxy:
@ -44,16 +45,28 @@ class _ConnectionProxy:
def _db_config() -> dict[str, str | int]:
env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
is_non_prod = env_name in NON_PROD_ENVIRONMENTS
url = os.getenv("DATABASE_URL")
if url:
return {"url": url}
password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD")
if not password and not is_non_prod:
raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments")
host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None)
dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None)
user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None)
if not is_non_prod and (not host or not dbname or not user):
raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments")
return {
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
"host": host,
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
"dbname": dbname,
"user": user,
"password": password,
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
"schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app",
}

View File

@ -172,6 +172,6 @@ def update_run_status(user_id: str, run_id: str, status: str, meta: dict | None
""",
(status, now, Json(meta or {}), run_id, user_id),
)
return True
return cur.rowcount > 0
return run_with_retry(_op)

View File

@ -4,17 +4,28 @@ import sys
import threading
from datetime import datetime, timedelta, timezone
from pathlib import Path
from zoneinfo import ZoneInfo
ENGINE_ROOT = Path(__file__).resolve().parents[3]
if str(ENGINE_ROOT) not in sys.path:
sys.path.append(str(ENGINE_ROOT))
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now
from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine
from indian_paper_trading_strategy.engine.market import (
align_to_market_open,
market_now,
market_session,
next_market_open_after,
)
from indian_paper_trading_strategy.engine.market_calendar import UnsupportedCalendarYearError
from indian_paper_trading_strategy.engine.runner import RunLeaseNotAcquiredError, start_engine, stop_engine
from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state
from indian_paper_trading_strategy.engine.broker import PaperBroker
from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta
from indian_paper_trading_strategy.engine.time_utils import (
UTC,
frequency_to_timedelta,
parse_market_timestamp,
parse_persisted_timestamp,
serialize_timestamp,
)
from indian_paper_trading_strategy.engine.db import engine_context
from app.broker_store import get_user_broker, set_broker_auth_state
@ -41,7 +52,6 @@ SEQ_LOCK = threading.Lock()
SEQ = 0
LAST_WAIT_LOG_TS = {}
WAIT_LOG_INTERVAL = timedelta(seconds=60)
IST = ZoneInfo("Asia/Kolkata")
def init_log_state():
global SEQ
@ -110,7 +120,7 @@ def emit_event(
evt = {
"seq": seq,
"ts": now.isoformat().replace("+00:00", "Z"),
"ts": serialize_timestamp(now),
"level": level,
"category": category,
"event": event,
@ -157,14 +167,8 @@ def _maybe_parse_json(value):
return value
def _local_tz():
return IST
def _format_local_ts(value: datetime | None):
if value is None:
return None
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
def _utc_now():
return datetime.now(UTC)
def _load_config(user_id: str, run_id: str):
@ -192,7 +196,7 @@ def _load_config(user_id: str, run_id: str):
"frequency": _maybe_parse_json(row[7]),
"frequency_days": row[8],
"unit": row[9],
"next_run": _format_local_ts(row[10]),
"next_run": serialize_timestamp(row[10]),
}
if row[2] is not None or row[3] is not None:
cfg["sip_frequency"] = {
@ -217,13 +221,7 @@ def _save_config(cfg, user_id: str, run_id: str):
next_run = cfg.get("next_run")
next_run_dt = None
if isinstance(next_run, str):
try:
parsed = datetime.fromisoformat(next_run)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=_local_tz())
next_run_dt = parsed
except ValueError:
next_run_dt = None
next_run_dt = parse_persisted_timestamp(next_run)
with db_connection() as conn:
with conn:
@ -294,7 +292,7 @@ def reactivate_strategy_config(user_id: str, run_id: str):
return cfg
def _write_status(user_id: str, run_id: str, status):
now_local = market_now()
now_local = _utc_now()
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
@ -346,6 +344,12 @@ def _effective_running_run_id(user_id: str):
)
return None
def _set_run_status_or_raise(user_id: str, run_id: str, status: str, meta: dict | None = None):
updated = update_run_status(user_id, run_id, status, meta=meta)
if not updated:
raise RuntimeError(f"Run {run_id} for user {user_id} no longer exists")
def validate_frequency(freq: dict, mode: str):
if not isinstance(freq, dict):
raise ValueError("Frequency payload is required")
@ -436,9 +440,8 @@ def _validate_live_broker_session(user_id: str):
def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
if not last_run or not sip_frequency:
return None
try:
last_dt = datetime.fromisoformat(last_run)
except ValueError:
last_dt = parse_market_timestamp(last_run)
if last_dt is None:
return None
try:
delta = frequency_to_timedelta(sip_frequency)
@ -446,7 +449,7 @@ def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
return None
next_dt = last_dt + delta
next_dt = align_to_market_open(next_dt)
return next_dt.isoformat()
return serialize_timestamp(next_dt)
def _last_execution_ts(state: dict, mode: str) -> str | None:
@ -473,7 +476,10 @@ def start_strategy(req, user_id: str):
return {"status": "already_running", "run_id": running_run_id}
engine_config = _build_engine_config(user_id, running_run_id, req)
if engine_config:
started = start_engine(engine_config)
try:
started = start_engine(engine_config)
except RunLeaseNotAcquiredError:
return {"status": "already_running", "run_id": running_run_id}
if started:
_write_status(user_id, running_run_id, "RUNNING")
return {"status": "restarted", "run_id": running_run_id}
@ -573,7 +579,10 @@ def start_strategy(req, user_id: str):
engine_config["run_id"] = run_id
engine_config["user_id"] = user_id
engine_config["emit_event"] = emit_event_cb
start_engine(engine_config)
try:
start_engine(engine_config)
except RunLeaseNotAcquiredError:
pass
try:
user = get_user_by_id(user_id)
@ -655,14 +664,17 @@ def resume_running_runs():
engine_config = _build_engine_config(user_id, run_id, None)
if not engine_config:
continue
started = start_engine(engine_config)
try:
started = start_engine(engine_config)
except RunLeaseNotAcquiredError:
started = False
if started:
_write_status(user_id, run_id, "RUNNING")
def stop_strategy(user_id: str):
run_id = _effective_running_run_id(user_id)
if not run_id:
latest_run_id = get_active_run_id(user_id)
latest_run_id = get_running_run_id(user_id) or get_active_run_id(user_id)
return {"status": "already_stopped", "run_id": latest_run_id}
engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"}
@ -681,7 +693,14 @@ def stop_strategy(user_id: str):
print(f"[STRATEGY] engine status update failed during stop for {user_id}/{run_id}: {exc}", flush=True)
if not stop_warning:
stop_warning = str(exc)
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "user_request"})
try:
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "user_request"})
except RuntimeError as exc:
return {
"status": "stop_failed",
"run_id": run_id,
"message": str(exc),
}
try:
user = get_user_by_id(user_id)
@ -704,6 +723,8 @@ def resume_strategy(user_id: str):
return {"status": "already_running", "run_id": running_run_id}
run_id = get_active_run_id(user_id)
if not run_id:
return {"status": "no_resumable_run"}
cfg = _load_config(user_id, run_id)
strategy_name = (cfg.get("strategy") or "").strip()
mode = (cfg.get("mode") or "").strip().upper()
@ -737,16 +758,26 @@ def resume_strategy(user_id: str):
}
reactivate_strategy_config(user_id, run_id)
update_run_status(user_id, run_id, "RUNNING", meta={"reason": "user_resume"})
try:
_set_run_status_or_raise(user_id, run_id, "RUNNING", meta={"reason": "user_resume"})
except RuntimeError as exc:
deactivate_strategy_config(user_id, run_id)
return {
"status": "resume_failed",
"run_id": run_id,
"message": str(exc),
}
_write_status(user_id, run_id, "RUNNING")
if not engine_external:
try:
started = start_engine(engine_config)
except RunLeaseNotAcquiredError:
return {"status": "already_running", "run_id": run_id}
except Exception as exc:
deactivate_strategy_config(user_id, run_id)
_write_status(user_id, run_id, "STOPPED")
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
return {
"status": "resume_failed",
"run_id": run_id,
@ -755,7 +786,7 @@ def resume_strategy(user_id: str):
if not started:
deactivate_strategy_config(user_id, run_id)
_write_status(user_id, run_id, "STOPPED")
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
return {
"status": "resume_failed",
"run_id": run_id,
@ -797,7 +828,7 @@ def get_strategy_status(user_id: str):
else:
status = {
"status": default_status,
"last_updated": _format_local_ts(engine_row[1]),
"last_updated": serialize_timestamp(engine_row[1]),
}
status["run_id"] = run_id
engine_state = str((engine_row or [None])[0] or "").strip().upper()
@ -832,17 +863,9 @@ def get_strategy_status(user_id: str):
status["last_execution_ts"] = last_execution_ts
status["next_eligible_ts"] = next_eligible
if next_eligible:
try:
parsed_next = datetime.fromisoformat(next_eligible)
now_cmp = (
datetime.now(parsed_next.tzinfo)
if parsed_next.tzinfo
else market_now().replace(tzinfo=None)
)
if parsed_next > now_cmp:
status["status"] = "WAITING"
except ValueError:
pass
parsed_next = parse_persisted_timestamp(next_eligible)
if parsed_next and parsed_next > _utc_now():
status["status"] = "WAITING"
status_key = (status.get("status") or "IDLE").upper()
resumable = bool(cfg.get("strategy")) and bool(cfg.get("mode"))
status["can_resume"] = resumable and status_key in {"STOPPED", "PAUSED_AUTH_EXPIRED"}
@ -876,11 +899,7 @@ def get_engine_status(user_id: str):
status["state"] = row[0]
last_updated = row[1]
if last_updated is not None:
status["last_heartbeat_ts"] = (
last_updated.astimezone(timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)
status["last_heartbeat_ts"] = serialize_timestamp(last_updated)
cfg = _load_config(user_id, run_id)
mode = (cfg.get("mode") or "LIVE").strip().upper()
with engine_context(user_id, run_id):
@ -926,10 +945,7 @@ def get_strategy_logs(user_id: str, since_seq: int):
events = []
for row in rows:
ts = row[1]
if ts is not None:
ts_str = ts.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
else:
ts_str = None
ts_str = serialize_timestamp(ts)
events.append(
{
"seq": row[0],
@ -980,6 +996,16 @@ def _issue_message(event: str, message: str | None, data: dict | None, meta: dic
if event == "ENGINE_ERROR":
return message or "Strategy engine hit an error."
if event == "EXECUTION_BLOCKED":
if reason_key == "market_holiday":
return "Exchange holiday. Execution will resume next session."
if reason_key == "market_weekend":
return "Weekend closure. Execution will resume next session."
if reason_key == "market_pre_open":
return "Market has not opened yet. Execution will begin after 9:15 AM IST."
if reason_key == "market_post_close":
return "Market is closed for the day. Execution will resume next session."
if reason_key == "market_calendar_unavailable":
return "Market calendar unavailable. Execution paused for safety."
if reason_key == "market_closed":
return "Market is closed. Execution will resume next session."
return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}."
@ -1019,8 +1045,17 @@ def _issue_is_stale_for_current_state(
}:
return True
if event == "EXECUTION_BLOCKED" and reason_key == "market_closed":
return is_market_open(market_now())
if event == "EXECUTION_BLOCKED" and reason_key.startswith("market_"):
current_session = market_session(market_now())
current_reason = str(current_session.get("reason") or "").strip().lower()
current_status = str(current_session.get("status") or "").strip().upper()
if reason_key == "market_holiday":
return current_status != "HOLIDAY"
if reason_key == "market_calendar_unavailable":
return current_reason != "calendar_unavailable"
if reason_key in {"market_weekend", "market_pre_open", "market_post_close", "market_closed"}:
return current_status == "OPEN"
return False
if mode != "LIVE":
return False
@ -1085,7 +1120,7 @@ def get_strategy_summary(user_id: str):
"tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning",
"message": _issue_message(event, message, data, meta),
"event": event,
"ts": _format_local_ts(ts),
"ts": serialize_timestamp(ts),
}
)
return summary
@ -1119,7 +1154,17 @@ def get_strategy_summary(user_id: str):
def get_market_status():
now = market_now()
session = market_session(now)
status = str(session.get("status") or "CLOSED")
reason = str(session.get("reason") or "")
next_open_at = None
try:
next_open_at = serialize_timestamp(next_market_open_after(now))
except UnsupportedCalendarYearError:
next_open_at = None
return {
"status": "OPEN" if is_market_open(now) else "CLOSED",
"checked_at": now.isoformat(),
"status": status,
"reason": reason,
"checked_at": serialize_timestamp(now),
"next_open_at": next_open_at,
}

View File

@ -0,0 +1,224 @@
from __future__ import annotations
import hashlib
import logging
import os
import threading
from collections import deque
from datetime import datetime, timedelta, timezone
from typing import Deque
from fastapi import Request
from app.services.db import db_connection
logger = logging.getLogger(__name__)
_MEMORY_LOCK = threading.Lock()
_MEMORY_EVENTS: list[dict] = []
class SupportGuardRejected(Exception):
def __init__(self, status_code: int, detail: str):
super().__init__(detail)
self.status_code = status_code
self.detail = detail
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _sha256(value: str | None) -> str | None:
if not value:
return None
return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest()
def _backend_mode() -> str:
return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower()
def _window() -> timedelta:
return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")))
def _create_limit() -> int:
return int(os.getenv("SUPPORT_CREATE_LIMIT", "5"))
def _status_limit() -> int:
return int(os.getenv("SUPPORT_STATUS_LIMIT", "15"))
def _ticket_probe_limit() -> int:
return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10"))
def _captcha_secret() -> str | None:
return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None
def _request_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
first = forwarded.split(",")[0].strip()
if first:
return first
return request.client.host if request.client else "unknown"
def _validate_captcha(captcha_token: str | None) -> None:
secret = _captcha_secret()
if secret and captcha_token != secret:
raise SupportGuardRejected(403, "Support verification failed")
def _record_memory_event(record: dict) -> None:
cutoff = _now_utc() - _window()
with _MEMORY_LOCK:
_MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff]
_MEMORY_EVENTS.append(record)
def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
with _MEMORY_LOCK:
ip_count = sum(
1
for entry in _MEMORY_EVENTS
if entry["endpoint"] == endpoint
and entry["created_at"] >= cutoff
and entry["ip_hash"] == ip_hash
)
ticket_count = sum(
1
for entry in _MEMORY_EVENTS
if entry["endpoint"] == endpoint
and entry["created_at"] >= cutoff
and entry["ticket_hash"] == ticket_hash
) if ticket_hash else 0
return ip_count, ticket_count
def _record_db_event(record: dict) -> None:
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO support_request_audit (
endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at
)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
record["endpoint"],
record["ip_hash"],
record["email_hash"],
record["ticket_hash"],
record["blocked"],
record["reason"],
record["created_at"],
),
)
def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT COUNT(*)
FROM support_request_audit
WHERE endpoint = %s
AND ip_hash IS NOT DISTINCT FROM %s
AND created_at >= %s
""",
(endpoint, ip_hash, cutoff),
)
ip_count = cur.fetchone()[0] or 0
ticket_count = 0
if ticket_hash:
cur.execute(
"""
SELECT COUNT(*)
FROM support_request_audit
WHERE endpoint = %s
AND ticket_hash = %s
AND created_at >= %s
""",
(endpoint, ticket_hash, cutoff),
)
ticket_count = cur.fetchone()[0] or 0
return ip_count, ticket_count
def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None:
if endpoint == "ticket_create" and ip_count >= _create_limit():
return "create_rate_limited"
if endpoint == "ticket_status" and ip_count >= _status_limit():
return "status_rate_limited"
if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit():
return "ticket_probe_limited"
return None
def _audit_attempt(record: dict) -> None:
if _backend_mode() == "memory":
_record_memory_event(record)
return
_record_db_event(record)
def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
if _backend_mode() == "memory":
return _memory_count(endpoint, ip_hash, ticket_hash, cutoff)
return _db_count(endpoint, ip_hash, ticket_hash, cutoff)
def enforce_support_guard(
*,
request: Request,
endpoint: str,
email: str | None = None,
ticket_id: str | None = None,
captcha_token: str | None = None,
) -> None:
_validate_captcha(captcha_token)
now = _now_utc()
cutoff = now - _window()
ip_hash = _sha256(_request_ip(request))
email_hash = _sha256(email)
ticket_hash = _sha256(ticket_id)
ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff)
reason = _determine_limits(endpoint, ip_count, ticket_count)
record = {
"endpoint": endpoint,
"ip_hash": ip_hash,
"email_hash": email_hash,
"ticket_hash": ticket_hash,
"blocked": reason is not None,
"reason": reason,
"created_at": now,
}
_audit_attempt(record)
if reason is not None:
logger.warning(
"Support request blocked",
extra={
"endpoint": endpoint,
"reason": reason,
"ip_hash": ip_hash,
"ticket_hash": ticket_hash,
},
)
raise SupportGuardRejected(429, "Too many support requests. Please try again later.")
def reset_memory_support_guard_state() -> None:
with _MEMORY_LOCK:
_MEMORY_EVENTS.clear()

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from app.services.db import db_connection
from app.services.email_service import send_email
from indian_paper_trading_strategy.engine.time_utils import serialize_timestamp
def _now():
@ -41,7 +42,7 @@ def create_ticket(name: str, email: str, subject: str, message: str) -> dict:
return {
"ticket_id": ticket_id,
"status": "NEW",
"created_at": now.isoformat(),
"created_at": serialize_timestamp(now),
"email_sent": email_sent,
}
@ -65,6 +66,6 @@ def get_ticket_status(ticket_id: str, email: str) -> dict | None:
return {
"ticket_id": row[0],
"status": row[2],
"created_at": row[3].isoformat() if row[3] else None,
"updated_at": row[4].isoformat() if row[4] else None,
"created_at": serialize_timestamp(row[3]) if row[3] else None,
"updated_at": serialize_timestamp(row[4]) if row[4] else None,
}

View File

@ -5,6 +5,8 @@ from datetime import datetime, timezone
from psycopg2.extras import Json
from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp
from app.broker_store import get_user_broker, set_broker_auth_state
from app.services.db import db_connection
from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds
@ -59,12 +61,7 @@ def _resolve_sip_frequency(row: dict):
def _parse_ts(value: str | None):
if not value:
return None
try:
return datetime.fromisoformat(value)
except ValueError:
return None
return parse_persisted_timestamp(value)
def _validate_broker_session(user_id: str):
@ -180,7 +177,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
continue
sip_frequency = _resolve_sip_frequency(run)
last_run = now.isoformat()
last_run = serialize_timestamp(now)
next_run = compute_next_eligible(last_run, sip_frequency)
next_run_dt = _parse_ts(next_run)
@ -195,7 +192,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
""",
(
now,
Json({"armed_at": now.isoformat()}),
Json({"armed_at": serialize_timestamp(now)}),
user_id,
run["run_id"],
),
@ -339,7 +336,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
pass
broker_state = get_user_broker(user_id) or {}
next_execution = min(next_runs).isoformat() if next_runs else None
next_execution = serialize_timestamp(min(next_runs)) if next_runs else None
return {
"ok": True,
"armed_runs": armed_runs,
@ -378,7 +375,7 @@ def system_status(user_id: str):
"strategy": row[2],
"mode": row[3],
"broker": row[4],
"next_run": row[5].isoformat() if row[5] else None,
"next_run": serialize_timestamp(row[5]),
"active": bool(row[6]) if row[6] is not None else False,
"lifecycle": row[1],
}

View File

@ -1,7 +1,6 @@
from fastapi import HTTPException, Request
from app.services.auth_service import get_user_for_session
from app.services.run_service import get_default_user_id
SESSION_COOKIE_NAME = "session_id"
@ -13,7 +12,4 @@ def get_request_user_id(request: Request) -> str:
if user:
return user["id"]
default_user_id = get_default_user_id()
if default_user_id:
return default_user_id
raise HTTPException(status_code=401, detail="Not authenticated")

View File

@ -27,11 +27,13 @@ class KitePermissionError(KiteApiError):
pass
def build_login_url(api_key: str, redirect_url: str | None = None) -> str:
def build_login_url(api_key: str, redirect_url: str | None = None, state: str | None = None) -> str:
params = {"api_key": api_key, "v": KITE_VERSION}
redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
if redirect_url:
params["redirect_url"] = redirect_url
if state:
params["state"] = state
query = urllib.parse.urlencode(params)
return f"{KITE_LOGIN_URL}?{query}"

View File

@ -41,3 +41,4 @@ websockets==16.0
yfinance==1.0
alembic==1.13.3
pytest==8.3.5
argon2-cffi==25.1.0

14
backend/tests/conftest.py Normal file
View File

@ -0,0 +1,14 @@
import os
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("RESET_OTP_SECRET", "test-reset-secret")

View File

@ -0,0 +1,111 @@
import importlib
from datetime import datetime, timezone
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
def _build_app(monkeypatch):
monkeypatch.setenv("APP_ENV", "test")
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "trading_db")
monkeypatch.setenv("DB_USER", "trader")
monkeypatch.setenv("DB_PASSWORD", "test-password")
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
import app.main as app_main
importlib.reload(app_main)
return app_main.create_app()
def test_strategy_stop_failure_returns_non_200(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.routers.strategy as strategy_router
monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1")
monkeypatch.setattr(
strategy_router,
"stop_strategy",
lambda _user_id: {"status": "stop_failed", "message": "run missing", "run_id": "run-1"},
)
response = client.post("/api/strategy/stop", cookies={"session_id": "session-1"})
assert response.status_code == 409
assert response.json()["detail"]["status"] == "stop_failed"
def test_strategy_resume_failure_returns_non_200(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.routers.strategy as strategy_router
monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1")
monkeypatch.setattr(
strategy_router,
"resume_strategy",
lambda _user_id: {"status": "resume_failed", "message": "engine start failed", "run_id": "run-1"},
)
response = client.post("/api/strategy/resume", cookies={"session_id": "session-1"})
assert response.status_code == 409
assert response.json()["detail"]["status"] == "resume_failed"
def test_update_run_status_returns_false_when_no_row_changes(monkeypatch):
import app.services.run_service as run_service
class FakeCursor:
rowcount = 0
def execute(self, sql, params):
return None
monkeypatch.setattr(run_service, "run_with_retry", lambda op: op(FakeCursor(), None))
updated = run_service.update_run_status("user-1", "missing-run", "STOPPED")
assert updated is False
def test_backend_emits_utc_iso_timestamps():
from app.services.strategy_service import get_market_status
from indian_paper_trading_strategy.engine.time_utils import (
parse_persisted_timestamp,
serialize_timestamp,
)
timestamp = get_market_status()["checked_at"]
parsed = parse_persisted_timestamp(timestamp)
assert timestamp.endswith("+00:00")
assert parsed.tzinfo is not None
assert serialize_timestamp(parsed) == timestamp
def test_bootstrap_schema_contains_migrated_core_columns_and_tables():
schema_path = Path(__file__).resolve().parents[2] / ".." / "SIP_GoldBees_Database" / "schema.sql"
schema_sql = schema_path.read_text(encoding="utf-8")
required_snippets = [
"ALTER TABLE app_session",
"ADD COLUMN IF NOT EXISTS ip TEXT",
"ADD COLUMN IF NOT EXISTS user_agent TEXT",
"ALTER TABLE user_broker",
"ADD COLUMN IF NOT EXISTS api_secret TEXT",
"ADD COLUMN IF NOT EXISTS auth_state TEXT",
"CREATE TABLE IF NOT EXISTS broker_callback_state",
"CREATE TABLE IF NOT EXISTS execution_claim",
"CREATE TABLE IF NOT EXISTS run_leases",
"CREATE TABLE IF NOT EXISTS support_request_audit",
]
for snippet in required_snippets:
assert snippet in schema_sql

View File

@ -0,0 +1,128 @@
import importlib
import pytest
from fastapi.testclient import TestClient
def _build_app(monkeypatch, *, app_env="test", cors_origins=None):
monkeypatch.setenv("APP_ENV", app_env)
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "trading_db")
monkeypatch.setenv("DB_USER", "trader")
monkeypatch.setenv("DB_PASSWORD", "test-password")
if cors_origins is None:
monkeypatch.delenv("CORS_ORIGINS", raising=False)
else:
monkeypatch.setenv("CORS_ORIGINS", cors_origins)
import app.main as app_main
importlib.reload(app_main)
return app_main.create_app()
def test_strategy_status_requires_auth(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
response = client.get("/api/strategy/status")
assert response.status_code == 401
assert response.json() == {"detail": "Not authenticated"}
def test_strategy_stop_requires_auth(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
response = client.post("/api/strategy/stop")
assert response.status_code == 401
assert response.json() == {"detail": "Not authenticated"}
def test_strategy_routes_use_session_identity_only(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.routers.strategy as strategy_router
import app.services.tenant as tenant
monkeypatch.setattr(tenant, "get_user_for_session", lambda _sid: {"id": "user-a"})
seen = {}
def fake_get_strategy_status(user_id):
seen["status_user_id"] = user_id
return {"user_id": user_id}
def fake_stop_strategy(user_id):
seen["stop_user_id"] = user_id
return {"status": "stopped", "user_id": user_id}
monkeypatch.setattr(strategy_router, "get_strategy_status", fake_get_strategy_status)
monkeypatch.setattr(strategy_router, "stop_strategy", fake_stop_strategy)
status_response = client.get(
"/api/strategy/status?user_id=user-b",
cookies={"session_id": "session-a"},
)
stop_response = client.post(
"/api/strategy/stop?user_id=user-b",
cookies={"session_id": "session-a"},
)
assert status_response.status_code == 200
assert stop_response.status_code == 200
assert status_response.json()["user_id"] == "user-a"
assert stop_response.json()["user_id"] == "user-a"
assert seen == {"status_user_id": "user-a", "stop_user_id": "user-a"}
def test_allowed_origin_preflight_supports_credentials(monkeypatch):
app = _build_app(
monkeypatch,
app_env="production",
cors_origins="https://app.quantfortune.com",
)
client = TestClient(app)
response = client.options(
"/api/strategy/status",
headers={
"Origin": "https://app.quantfortune.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "content-type",
},
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://app.quantfortune.com"
assert response.headers["access-control-allow-credentials"] == "true"
def test_arbitrary_origin_preflight_is_rejected(monkeypatch):
app = _build_app(
monkeypatch,
app_env="production",
cors_origins="https://app.quantfortune.com",
)
client = TestClient(app)
response = client.options(
"/api/strategy/status",
headers={
"Origin": "https://evil.example",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "content-type",
},
)
assert response.status_code == 400
assert "access-control-allow-origin" not in response.headers
def test_production_without_cors_origins_fails_closed(monkeypatch):
with pytest.raises(RuntimeError, match="CORS_ORIGINS must be configured explicitly in production"):
_build_app(monkeypatch, app_env="production", cors_origins=None)

View File

@ -0,0 +1,120 @@
from __future__ import annotations
import threading
from datetime import datetime, timezone
from indian_paper_trading_strategy.engine import execution, ledger
def test_claim_execution_window_allows_only_one_winner(monkeypatch):
claims: set[tuple[str, str, datetime]] = set()
lock = threading.Lock()
class FakeCursor:
def __init__(self):
self._result = None
def execute(self, sql, params):
assert "INSERT INTO execution_claim" in sql
key = (params[1], params[2], params[4])
with lock:
if key in claims:
self._result = None
else:
claims.add(key)
self._result = (params[0],)
def fetchone(self):
return self._result
monkeypatch.setattr(ledger, "get_context", lambda user_id=None, run_id=None: ("user-a", "run-a"))
monkeypatch.setattr(ledger, "run_with_retry", lambda op, retries=None, delay=None: op(FakeCursor(), None))
logical_time = datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc)
results: list[bool] = []
def attempt():
results.append(ledger.claim_execution_window(logical_time, mode="LIVE"))
threads = [threading.Thread(target=attempt) for _ in range(2)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert sorted(results) == [False, True]
def test_try_execute_sip_live_emits_one_order_batch_when_claim_is_lost(monkeypatch):
claim_lock = threading.Lock()
claimed = False
order_calls: list[tuple[str, str, float]] = []
class FakeBroker:
external_orders = True
def get_funds(self):
return {"cash": 10_000.0}
def place_order(self, symbol, side, qty, price, logical_time=None):
order_calls.append((symbol, side, qty))
return {
"id": f"{symbol}-{len(order_calls)}",
"symbol": symbol,
"side": side,
"status": "COMPLETE",
"filled_qty": qty,
"average_price": price,
"price": price,
}
def fake_claim_execution_window(logical_time, *, mode=None, cur=None, user_id=None, run_id=None):
nonlocal claimed
with claim_lock:
if claimed:
return False
claimed = True
return True
monkeypatch.setattr(execution, "load_state", lambda *args, **kwargs: {"last_sip_ts": None, "last_run": None})
monkeypatch.setattr(
execution,
"_resolve_timing",
lambda state, now_ts, sip_interval: (True, None, datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc)),
)
monkeypatch.setattr(execution, "event_exists", lambda *args, **kwargs: False)
monkeypatch.setattr(execution, "claim_execution_window", fake_claim_execution_window)
monkeypatch.setattr(execution, "log_event", lambda *args, **kwargs: None)
monkeypatch.setattr(execution, "run_with_retry", lambda op, retries=None, delay=None: op(object(), None))
monkeypatch.setattr(
execution,
"_finalize_live_execution",
lambda **kwargs: ({"last_run": kwargs["now_ts"].isoformat()}, bool(kwargs["orders"])),
)
results: list[bool] = []
def attempt():
_state, executed = execution._try_execute_sip_live(
now=datetime(2026, 4, 8, 14, 45, tzinfo=timezone.utc),
market_open=True,
sip_interval=120,
sip_amount=1000.0,
sp_price=125.0,
gd_price=250.0,
eq_w=0.5,
gd_w=0.5,
broker=FakeBroker(),
mode="LIVE",
)
results.append(executed)
threads = [threading.Thread(target=attempt) for _ in range(2)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert sorted(results) == [False, True]
assert len(order_calls) == 2
assert {call[0] for call in order_calls} == {"NIFTYBEES.NS", "GOLDBEES.NS"}

View File

@ -0,0 +1,102 @@
from datetime import date, datetime, timezone
import pytest
def test_market_calendar_open_until_close_boundary_exclusive():
from indian_paper_trading_strategy.engine.market_calendar import (
get_market_session,
get_market_status,
is_market_open,
)
open_now = datetime(2026, 4, 2, 9, 59, 59, tzinfo=timezone.utc) # 15:29:59 IST
close_now = datetime(2026, 4, 2, 10, 0, 0, tzinfo=timezone.utc) # 15:30:00 IST
assert get_market_status(open_now) == "OPEN"
assert is_market_open(open_now) is True
assert get_market_status(close_now) == "CLOSED"
assert is_market_open(close_now) is False
close_session = get_market_session(close_now)
assert close_session["reason"] == "POST_CLOSE"
def test_market_calendar_holiday_during_trading_hours():
from indian_paper_trading_strategy.engine.market_calendar import get_market_session, get_market_status
now_utc = datetime(2026, 4, 3, 4, 0, tzinfo=timezone.utc) # Good Friday, 09:30 IST
session = get_market_session(now_utc)
assert get_market_status(now_utc) == "HOLIDAY"
assert session["status"] == "HOLIDAY"
assert session["reason"] == "HOLIDAY"
def test_market_calendar_utc_and_ist_inputs_match():
from zoneinfo import ZoneInfo
from indian_paper_trading_strategy.engine.market_calendar import get_market_session
utc_now = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST
ist_now = utc_now.astimezone(ZoneInfo("Asia/Kolkata"))
assert get_market_session(utc_now) == get_market_session(ist_now)
def test_market_calendar_reason_codes_cover_session_windows():
from indian_paper_trading_strategy.engine.market_calendar import get_market_session
pre_open = datetime(2026, 4, 2, 3, 0, tzinfo=timezone.utc) # 08:30 IST
in_session = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST
post_close = datetime(2026, 4, 2, 10, 1, tzinfo=timezone.utc) # 15:31 IST
weekend = datetime(2026, 4, 4, 4, 0, tzinfo=timezone.utc) # Saturday
assert get_market_session(pre_open)["reason"] == "PRE_OPEN"
assert get_market_session(in_session)["reason"] == "OPEN"
assert get_market_session(post_close)["reason"] == "POST_CLOSE"
assert get_market_session(weekend)["reason"] == "WEEKEND"
def test_supported_year_returns_expected_holidays():
from indian_paper_trading_strategy.engine.market_calendar import get_nse_holidays
holidays = get_nse_holidays(2026)
assert date(2026, 4, 3) in holidays
assert date(2026, 1, 26) in holidays
def test_unsupported_year_fails_closed():
from indian_paper_trading_strategy.engine.market_calendar import (
UnsupportedCalendarYearError,
get_market_session,
next_market_open,
)
now_utc = datetime(2027, 1, 4, 4, 0, tzinfo=timezone.utc)
session = get_market_session(now_utc)
assert session["status"] == "CLOSED"
assert session["reason"] == "CALENDAR_UNAVAILABLE"
assert session["calendar_supported"] is False
with pytest.raises(UnsupportedCalendarYearError):
next_market_open(now_utc)
def test_strategy_service_market_status_exposes_reason_and_holiday(monkeypatch):
import app.services.strategy_service as strategy_service
from indian_paper_trading_strategy.engine.market_calendar import MARKET_TZ
monkeypatch.setattr(
strategy_service,
"market_now",
lambda: datetime(2026, 4, 3, 9, 30, tzinfo=MARKET_TZ),
)
payload = strategy_service.get_market_status()
assert payload["status"] == "HOLIDAY"
assert payload["reason"] == "HOLIDAY"
assert payload["next_open_at"] == "2026-04-06T03:45:00+00:00"

View File

@ -0,0 +1,205 @@
from __future__ import annotations
import threading
from datetime import datetime, timedelta, timezone
from indian_paper_trading_strategy.engine import db as engine_db
from indian_paper_trading_strategy.engine import runner
class _LeaseCursor:
def __init__(self, leases: dict[str, dict], lock: threading.Lock):
self._leases = leases
self._lock = lock
self._result = None
def execute(self, sql, params):
sql_text = " ".join(sql.split())
with self._lock:
if sql_text.startswith("INSERT INTO run_leases"):
run_id, owner_id, leased_at, expires_at, heartbeat_at = params
lease = self._leases.get(run_id)
if lease is None:
self._leases[run_id] = {
"owner_id": owner_id,
"leased_at": leased_at,
"expires_at": expires_at,
"heartbeat_at": heartbeat_at,
}
self._result = (run_id,)
else:
self._result = None
return
if "SELECT owner_id, expires_at FROM run_leases" in sql_text:
run_id = params[0]
lease = self._leases.get(run_id)
self._result = None if lease is None else (lease["owner_id"], lease["expires_at"])
return
if sql_text.startswith("UPDATE run_leases SET leased_at"):
leased_at, expires_at, heartbeat_at, run_id, owner_id = params
lease = self._leases.get(run_id)
if lease and lease["owner_id"] == owner_id:
lease.update(
{
"owner_id": owner_id,
"leased_at": leased_at,
"expires_at": expires_at,
"heartbeat_at": heartbeat_at,
}
)
self._result = (run_id,)
else:
self._result = None
return
if sql_text.startswith("UPDATE run_leases SET owner_id"):
owner_id, leased_at, expires_at, heartbeat_at, run_id, current_time = params
lease = self._leases.get(run_id)
if lease and lease["expires_at"] <= current_time:
lease.update(
{
"owner_id": owner_id,
"leased_at": leased_at,
"expires_at": expires_at,
"heartbeat_at": heartbeat_at,
}
)
self._result = (run_id,)
else:
self._result = None
return
if sql_text.startswith("UPDATE run_leases SET heartbeat_at"):
heartbeat_at, expires_at, run_id, owner_id, current_time = params
lease = self._leases.get(run_id)
if lease and lease["owner_id"] == owner_id and lease["expires_at"] > current_time:
lease.update({"heartbeat_at": heartbeat_at, "expires_at": expires_at})
self._result = (run_id, expires_at)
else:
self._result = None
return
if sql_text.startswith("DELETE FROM run_leases"):
run_id, owner_id = params
lease = self._leases.get(run_id)
if lease and lease["owner_id"] == owner_id:
del self._leases[run_id]
self._result = (run_id,)
else:
self._result = None
return
raise AssertionError(f"Unexpected SQL: {sql_text}")
def fetchone(self):
return self._result
def _patch_lease_storage(monkeypatch):
leases: dict[str, dict] = {}
lock = threading.Lock()
def fake_run_with_retry(operation, retries=None, delay=None):
cursor = _LeaseCursor(leases, lock)
return operation(cursor, None)
monkeypatch.setattr(engine_db, "run_with_retry", fake_run_with_retry)
return leases
def test_run_lease_allows_only_one_active_owner(monkeypatch):
_patch_lease_storage(monkeypatch)
now = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
results = []
def attempt(owner_id: str):
results.append(engine_db.acquire_run_lease("run-1", owner_id, lease_seconds=90, now=now))
threads = [
threading.Thread(target=attempt, args=("owner-a",)),
threading.Thread(target=attempt, args=("owner-b",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
acquired = [result for result in results if result["acquired"]]
denied = [result for result in results if not result["acquired"]]
assert len(acquired) == 1
assert len(denied) == 1
assert denied[0]["status"] == "DENIED"
def test_expired_run_lease_can_be_reacquired(monkeypatch):
_patch_lease_storage(monkeypatch)
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
later = start + timedelta(seconds=91)
first = engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
second = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=later)
assert first["status"] == "ACQUIRED"
assert second["acquired"] is True
assert second["status"] == "REACQUIRED"
assert second["previous_owner"] == "owner-a"
def test_heartbeat_prevents_takeover(monkeypatch):
_patch_lease_storage(monkeypatch)
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
heartbeat_time = start + timedelta(seconds=30)
challenger_time = start + timedelta(seconds=60)
engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
heartbeat = engine_db.heartbeat_run_lease("run-1", "owner-a", lease_seconds=90, now=heartbeat_time)
challenger = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=challenger_time)
assert heartbeat["active"] is True
assert challenger["acquired"] is False
assert challenger["status"] == "DENIED"
def test_runner_exits_cleanly_when_lease_is_lost(monkeypatch):
statuses: list[str] = []
cleared: list[tuple[str, str]] = []
released: list[tuple[str, str]] = []
refreshed = {"count": 0}
monkeypatch.setattr(runner, "get_context", lambda user_id=None, run_id=None: ("user-1", "run-1"))
monkeypatch.setattr(runner, "set_context", lambda user_id, run_id: None)
monkeypatch.setattr(runner, "log_event", lambda *args, **kwargs: None)
monkeypatch.setattr(runner, "PaperBroker", lambda initial_cash=0: object())
monkeypatch.setattr(runner, "_set_state", lambda *args, **kwargs: None)
monkeypatch.setattr(runner, "_clear_runner", lambda user_id, run_id: cleared.append((user_id, run_id)))
monkeypatch.setattr(runner, "_update_engine_status", lambda user_id, run_id, status: statuses.append(status))
monkeypatch.setattr(runner, "release_run_lease", lambda run_id, owner_id: released.append((run_id, owner_id)) or True)
monkeypatch.setattr(runner, "_log_runner_lease_event", lambda *args, **kwargs: None)
def fake_refresh(user_id, run_id, owner_id):
refreshed["count"] += 1
return False
monkeypatch.setattr(runner, "_refresh_run_lease_or_stop", fake_refresh)
stop_event = threading.Event()
runner._engine_loop(
{
"user_id": "user-1",
"run_id": "run-1",
"strategy": "Golden Nifty",
"sip_amount": 1000,
"sip_frequency": {"value": 2, "unit": "minutes"},
"mode": "PAPER",
"broker": "paper",
"runner_owner_id": "owner-1",
},
stop_event,
)
assert refreshed["count"] >= 1
assert statuses == ["RUNNING"]
assert cleared == [("user-1", "run-1")]
assert released == [("run-1", "owner-1")]

View File

@ -0,0 +1,269 @@
import importlib
from datetime import datetime, timedelta, timezone
import pytest
from fastapi import Response
from fastapi.testclient import TestClient
def test_legacy_password_hash_upgrades_on_successful_login(monkeypatch):
import app.services.auth_service as auth_service
legacy_hash = auth_service._hash_password_legacy("correct-horse-battery-staple")
user = {
"id": "user-1",
"username": "user@example.com",
"password": legacy_hash,
"role": "USER",
}
updated = {}
monkeypatch.setattr(auth_service, "get_user_by_username", lambda username: user if username == user["username"] else None)
monkeypatch.setattr(
auth_service,
"_update_password_hash",
lambda user_id, password_hash: updated.update({"user_id": user_id, "password_hash": password_hash}),
)
authenticated = auth_service.authenticate_user(user["username"], "correct-horse-battery-staple")
assert authenticated is user
assert updated["user_id"] == "user-1"
assert updated["password_hash"].startswith("$argon2id$")
assert authenticated["password"] == updated["password_hash"]
def test_secure_session_cookie_flags_are_enforced_in_production(monkeypatch):
monkeypatch.setenv("APP_ENV", "production")
monkeypatch.delenv("COOKIE_SECURE", raising=False)
monkeypatch.setenv("COOKIE_SAMESITE", "strict")
import app.routers.auth as auth_router
importlib.reload(auth_router)
response = Response()
auth_router._set_session_cookie(response, "session-123")
set_cookie = response.headers["set-cookie"].lower()
assert "secure" in set_cookie
assert "httponly" in set_cookie
assert "samesite=strict" in set_cookie
def test_missing_reset_otp_secret_fails_auth_service_import(monkeypatch):
monkeypatch.delenv("RESET_OTP_SECRET", raising=False)
import app.services.auth_service as auth_service
with pytest.raises(RuntimeError, match="RESET_OTP_SECRET must be configured"):
importlib.reload(auth_service)
monkeypatch.setenv("RESET_OTP_SECRET", "test-reset-secret")
importlib.reload(auth_service)
def test_valid_broker_callback_state_allows_connect_callback(monkeypatch):
import app.main as app_main
import app.routers.broker as broker_router
monkeypatch.setenv("APP_ENV", "test")
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
importlib.reload(app_main)
app = app_main.create_app()
client = TestClient(app)
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
monkeypatch.setattr(
broker_router,
"consume_broker_callback_state",
lambda **kwargs: {"id": "state-1", "expires_at": datetime.now(timezone.utc).isoformat()},
)
monkeypatch.setattr(
broker_router,
"get_pending_broker",
lambda _user_id: {"api_key": "kite-key", "api_secret": "kite-secret"},
)
monkeypatch.setattr(
broker_router,
"exchange_request_token",
lambda api_key, api_secret, token: {
"access_token": "access-token",
"request_token": token,
"user_name": "Trader",
"user_id": "Z123",
},
)
captured = {}
monkeypatch.setattr(
broker_router,
"set_zerodha_session",
lambda user_id, payload: captured.setdefault("zerodha_session", {"user_id": user_id, **payload}),
)
monkeypatch.setattr(
broker_router,
"set_connected_broker",
lambda user_id, broker, token, **kwargs: captured.setdefault(
"connected",
{"user_id": user_id, "broker": broker, "token": token, **kwargs},
),
)
response = client.get(
"/api/broker/zerodha/callback",
params={"request_token": "request-token", "state": "valid-state"},
cookies={"session_id": "session-1"},
)
assert response.status_code == 200
assert response.json() == {
"connected": True,
"userName": "Trader",
"brokerUserId": "Z123",
}
assert captured["connected"]["broker"] == "ZERODHA"
assert captured["connected"]["user_id"] == "user-1"
def test_missing_broker_callback_state_fails(monkeypatch):
import app.main as app_main
import app.routers.broker as broker_router
monkeypatch.setenv("APP_ENV", "test")
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
importlib.reload(app_main)
app = app_main.create_app()
client = TestClient(app)
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
response = client.get(
"/api/broker/zerodha/callback",
params={"request_token": "request-token"},
cookies={"session_id": "session-1"},
)
assert response.status_code == 400
assert response.json() == {"detail": "Missing state"}
def test_wrong_or_expired_broker_callback_state_fails(monkeypatch):
import app.main as app_main
import app.routers.broker as broker_router
monkeypatch.setenv("APP_ENV", "test")
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
importlib.reload(app_main)
app = app_main.create_app()
client = TestClient(app)
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
monkeypatch.setattr(broker_router, "consume_broker_callback_state", lambda **kwargs: None)
response = client.get(
"/api/broker/zerodha/callback",
params={"request_token": "request-token", "state": "wrong-or-expired"},
cookies={"session_id": "session-1"},
)
assert response.status_code == 401
assert response.json() == {"detail": "Invalid or expired broker callback state"}
def test_broker_callback_state_service_rejects_wrong_user_and_expired_state(monkeypatch):
import app.services.broker_callback_state as callback_state
now = datetime(2026, 4, 8, 9, 0, tzinfo=timezone.utc)
rows = [
{
"state_hash": callback_state._state_hash("valid-state"),
"user_id": "user-1",
"session_id": "session-1",
"broker": "ZERODHA",
"flow": "connect",
"expires_at": now + timedelta(minutes=5),
"consumed_at": None,
},
{
"state_hash": callback_state._state_hash("expired-state"),
"user_id": "user-1",
"session_id": "session-1",
"broker": "ZERODHA",
"flow": "connect",
"expires_at": now - timedelta(seconds=1),
"consumed_at": None,
},
]
class FakeCursor:
def __init__(self):
self._result = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params):
if "UPDATE broker_callback_state" not in sql:
raise AssertionError("Unexpected SQL")
consumed_at, state_hash, user_id, session_id, broker, flow, now_param = params
self._result = None
for row in rows:
if (
row["state_hash"] == state_hash
and row["user_id"] == user_id
and row["session_id"] == session_id
and row["broker"] == broker
and row["flow"] == flow
and row["consumed_at"] is None
and row["expires_at"] > now_param
):
row["consumed_at"] = consumed_at
self._result = ("row-id", row["expires_at"])
break
def fetchone(self):
return self._result
class FakeConnection:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def cursor(self):
return FakeCursor()
monkeypatch.setattr(callback_state, "_now_utc", lambda: now)
monkeypatch.setattr(callback_state, "db_connection", lambda: FakeConnection())
assert callback_state.consume_broker_callback_state(
state="valid-state",
user_id="user-2",
session_id="session-1",
broker="ZERODHA",
flow="connect",
) is None
assert callback_state.consume_broker_callback_state(
state="expired-state",
user_id="user-1",
session_id="session-1",
broker="ZERODHA",
flow="connect",
) is None
assert callback_state.consume_broker_callback_state(
state="valid-state",
user_id="user-1",
session_id="session-1",
broker="ZERODHA",
flow="connect",
) == {
"id": "row-id",
"expires_at": (now + timedelta(minutes=5)).isoformat(),
}

View File

@ -0,0 +1,131 @@
import importlib
import logging
from fastapi.testclient import TestClient
def _build_app(monkeypatch):
monkeypatch.setenv("APP_ENV", "test")
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "trading_db")
monkeypatch.setenv("DB_USER", "trader")
monkeypatch.setenv("DB_PASSWORD", "test-password")
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
monkeypatch.setenv("SUPPORT_GUARD_BACKEND", "memory")
monkeypatch.setenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")
monkeypatch.setenv("SUPPORT_CREATE_LIMIT", "2")
monkeypatch.setenv("SUPPORT_STATUS_LIMIT", "3")
monkeypatch.setenv("SUPPORT_STATUS_TICKET_LIMIT", "2")
import app.main as app_main
importlib.reload(app_main)
return app_main.create_app()
def test_support_ticket_creation_is_throttled(monkeypatch, caplog):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.services.support_abuse as support_abuse
import app.routers.support_ticket as support_router
support_abuse.reset_memory_support_guard_state()
monkeypatch.setattr(
support_router,
"create_ticket",
lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"},
)
payload = {
"name": "Trader",
"email": "trader@example.com",
"subject": "Need help",
"message": "Something happened",
}
with caplog.at_level(logging.WARNING):
first = client.post("/api/support/ticket", json=payload)
second = client.post("/api/support/ticket", json=payload)
third = client.post("/api/support/ticket", json=payload)
assert first.status_code == 200
assert second.status_code == 200
assert third.status_code == 429
assert "Support request blocked" in caplog.text
def test_invalid_ticket_probing_is_throttled(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.services.support_abuse as support_abuse
import app.routers.support_ticket as support_router
support_abuse.reset_memory_support_guard_state()
monkeypatch.setattr(support_router, "get_ticket_status", lambda ticket_id, email: None)
payload = {"email": "trader@example.com"}
first = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
second = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
third = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
assert first.status_code == 404
assert second.status_code == 404
assert third.status_code == 429
def test_legitimate_status_lookup_still_works(monkeypatch):
app = _build_app(monkeypatch)
client = TestClient(app)
import app.services.support_abuse as support_abuse
import app.routers.support_ticket as support_router
support_abuse.reset_memory_support_guard_state()
monkeypatch.setattr(
support_router,
"get_ticket_status",
lambda ticket_id, email: {
"ticket_id": ticket_id,
"status": "NEW",
"created_at": "2026-04-08T00:00:00+00:00",
"updated_at": "2026-04-08T00:00:00+00:00",
},
)
response = client.post("/api/support/ticket/status/ticket-1", json={"email": "trader@example.com"})
assert response.status_code == 200
assert response.json()["ticket_id"] == "ticket-1"
def test_support_captcha_hook_blocks_without_matching_header(monkeypatch):
monkeypatch.setenv("SUPPORT_CAPTCHA_SECRET", "expected-captcha")
app = _build_app(monkeypatch)
client = TestClient(app)
import app.services.support_abuse as support_abuse
import app.routers.support_ticket as support_router
support_abuse.reset_memory_support_guard_state()
monkeypatch.setattr(
support_router,
"create_ticket",
lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"},
)
response = client.post(
"/api/support/ticket",
json={
"name": "Trader",
"email": "trader@example.com",
"subject": "Need help",
"message": "Something happened",
},
)
assert response.status_code == 403
assert response.json() == {"detail": "Support verification failed"}

View File

@ -1,37 +1,48 @@
import os
import threading
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from contextvars import ContextVar
import os
import threading
import time
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from contextvars import ContextVar
import psycopg2
from psycopg2 import pool
from psycopg2 import OperationalError, InterfaceError
from psycopg2.extras import Json
_POOL = None
_POOL_LOCK = threading.Lock()
_DEFAULT_USER_ID = None
_DEFAULT_LOCK = threading.Lock()
_POOL = None
_POOL_LOCK = threading.Lock()
_DEFAULT_USER_ID = None
_DEFAULT_LOCK = threading.Lock()
NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"}
_USER_ID = ContextVar("engine_user_id", default=None)
_RUN_ID = ContextVar("engine_run_id", default=None)
def _db_config():
env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
is_non_prod = env_name in NON_PROD_ENVIRONMENTS
url = os.getenv("DATABASE_URL")
if url:
return {"dsn": url}
schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app"
password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD")
host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None)
dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None)
user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None)
if not is_non_prod and not password:
raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments")
if not is_non_prod and (not host or not dbname or not user):
raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments")
return {
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
"host": host,
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
"dbname": dbname,
"user": user,
"password": password,
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
"options": f"-csearch_path={schema},public" if schema else None,
}
@ -295,7 +306,7 @@ def get_running_runs(user_id: str | None = None):
return run_with_retry(_op)
def insert_engine_event(
def insert_engine_event(
cur,
event: str,
data=None,
@ -307,10 +318,10 @@ def insert_engine_event(
):
when = ts or _utc_now()
scope_user, scope_run = _resolve_context(user_id, run_id)
cur.execute(
"""
INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta)
VALUES (%s, %s, %s, %s, %s, %s, %s)
cur.execute(
"""
INSERT INTO engine_event (user_id, run_id, ts, event, data, message, meta)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
scope_user,
@ -320,5 +331,152 @@ def insert_engine_event(
Json(data) if data is not None else None,
message,
Json(meta) if meta is not None else None,
),
)
),
)
def acquire_run_lease(
run_id: str,
owner_id: str,
*,
lease_seconds: int = 90,
now: datetime | None = None,
):
current_time = now or _utc_now()
expires_at = current_time + timedelta(seconds=lease_seconds)
def _op(cur, _conn):
cur.execute(
"""
INSERT INTO run_leases (run_id, owner_id, leased_at, expires_at, heartbeat_at)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (run_id) DO NOTHING
RETURNING run_id
""",
(run_id, owner_id, current_time, expires_at, current_time),
)
inserted = cur.fetchone()
if inserted:
return {
"acquired": True,
"status": "ACQUIRED",
"owner_id": owner_id,
"expires_at": expires_at,
}
cur.execute(
"""
SELECT owner_id, expires_at
FROM run_leases
WHERE run_id = %s
FOR UPDATE
""",
(run_id,),
)
row = cur.fetchone()
if not row:
return {
"acquired": False,
"status": "DENIED",
"owner_id": None,
"expires_at": None,
}
current_owner, current_expiry = row
if current_owner == owner_id:
cur.execute(
"""
UPDATE run_leases
SET leased_at = %s,
expires_at = %s,
heartbeat_at = %s
WHERE run_id = %s AND owner_id = %s
RETURNING run_id
""",
(current_time, expires_at, current_time, run_id, owner_id),
)
cur.fetchone()
return {
"acquired": True,
"status": "REFRESHED",
"owner_id": owner_id,
"expires_at": expires_at,
}
if current_expiry <= current_time:
cur.execute(
"""
UPDATE run_leases
SET owner_id = %s,
leased_at = %s,
expires_at = %s,
heartbeat_at = %s
WHERE run_id = %s AND expires_at <= %s
RETURNING run_id
""",
(owner_id, current_time, expires_at, current_time, run_id, current_time),
)
replaced = cur.fetchone()
if replaced:
return {
"acquired": True,
"status": "REACQUIRED",
"owner_id": owner_id,
"previous_owner": current_owner,
"expires_at": expires_at,
}
return {
"acquired": False,
"status": "DENIED",
"owner_id": current_owner,
"expires_at": current_expiry,
}
return run_with_retry(_op)
def heartbeat_run_lease(
run_id: str,
owner_id: str,
*,
lease_seconds: int = 90,
now: datetime | None = None,
):
current_time = now or _utc_now()
expires_at = current_time + timedelta(seconds=lease_seconds)
def _op(cur, _conn):
cur.execute(
"""
UPDATE run_leases
SET heartbeat_at = %s,
expires_at = %s
WHERE run_id = %s
AND owner_id = %s
AND expires_at > %s
RETURNING run_id, expires_at
""",
(current_time, expires_at, run_id, owner_id, current_time),
)
row = cur.fetchone()
if not row:
return {"active": False, "expires_at": None}
return {"active": True, "expires_at": row[1]}
return run_with_retry(_op)
def release_run_lease(run_id: str, owner_id: str):
def _op(cur, _conn):
cur.execute(
"""
DELETE FROM run_leases
WHERE run_id = %s AND owner_id = %s
RETURNING run_id
""",
(run_id, owner_id),
)
return cur.fetchone() is not None
return run_with_retry(_op)

View File

@ -3,7 +3,7 @@ from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.state import load_state, save_state
from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired
from indian_paper_trading_strategy.engine.ledger import log_event, event_exists
from indian_paper_trading_strategy.engine.ledger import claim_execution_window, log_event, event_exists
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry
from indian_paper_trading_strategy.engine.market import market_now
from indian_paper_trading_strategy.engine.time_utils import compute_logical_time
@ -237,7 +237,7 @@ def _prepare_live_execution(now_ts, sip_interval, sip_amount_val, sp_price_val,
return {"ready": False, "state": state}
if event_exists("SIP_EXECUTED", logical_time, cur=cur):
return {"ready": False, "state": state}
if event_exists("SIP_ORDER_ATTEMPTED", logical_time, cur=cur):
if not claim_execution_window(logical_time, mode=mode, cur=cur):
return {"ready": False, "state": state}
log_event(

View File

@ -1,8 +1,9 @@
# engine/ledger.py
from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
import uuid
from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
def _event_exists_in_tx(cur, event, logical_time, user_id: str | None = None, run_id: str | None = None):
@ -20,14 +21,80 @@ def _event_exists_in_tx(cur, event, logical_time, user_id: str | None = None, ru
return cur.fetchone() is not None
def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, run_id: str | None = None):
def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, run_id: str | None = None):
if cur is not None:
return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id)
def _op(cur, _conn):
return _event_exists_in_tx(cur, event, logical_time, user_id=user_id, run_id=run_id)
return run_with_retry(_op)
return run_with_retry(_op)
def _claim_execution_window_in_tx(
cur,
logical_time,
*,
mode: str | None = None,
user_id: str | None = None,
run_id: str | None = None,
):
scope_user, scope_run = get_context(user_id, run_id)
logical_ts = normalize_logical_time(logical_time)
claim_id = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO execution_claim (
id,
user_id,
run_id,
mode,
logical_time,
claimed_at
)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (user_id, run_id, logical_time) DO NOTHING
RETURNING id
""",
(
claim_id,
scope_user,
scope_run,
(mode or "LIVE").strip().upper(),
logical_ts,
datetime.now(timezone.utc),
),
)
return cur.fetchone() is not None
def claim_execution_window(
logical_time,
*,
mode: str | None = None,
cur=None,
user_id: str | None = None,
run_id: str | None = None,
):
if cur is not None:
return _claim_execution_window_in_tx(
cur,
logical_time,
mode=mode,
user_id=user_id,
run_id=run_id,
)
def _op(cur, _conn):
return _claim_execution_window_in_tx(
cur,
logical_time,
mode=mode,
user_id=user_id,
run_id=run_id,
)
return run_with_retry(_op)
def _log_event_in_tx(

View File

@ -1,46 +1,51 @@
# engine/market.py
from datetime import datetime, time as dtime, timedelta
import pytz
from __future__ import annotations
_MARKET_TZ = pytz.timezone("Asia/Kolkata")
_OPEN_T = dtime(9, 15)
_CLOSE_T = dtime(15, 30)
from datetime import datetime
from indian_paper_trading_strategy.engine.market_calendar import (
MARKET_TZ,
get_market_session as _get_market_session,
get_market_status as _get_market_status,
is_market_open as _is_market_open,
market_now_utc,
next_market_open as _next_market_open,
)
def market_now() -> datetime:
return datetime.now(_MARKET_TZ)
return market_now_utc().astimezone(MARKET_TZ)
def _as_market_tz(value: datetime) -> datetime:
def _to_utc(value: datetime) -> datetime:
if value.tzinfo is None:
return _MARKET_TZ.localize(value)
return value.astimezone(_MARKET_TZ)
def is_market_open(now: datetime) -> bool:
now = _as_market_tz(now)
return now.weekday() < 5 and _OPEN_T <= now.time() <= _CLOSE_T
return value.replace(tzinfo=MARKET_TZ).astimezone(market_now_utc().tzinfo)
return value.astimezone(market_now_utc().tzinfo)
def is_market_open(now: datetime) -> bool:
return _is_market_open(_to_utc(now))
def market_session(now: datetime) -> dict[str, object]:
return _get_market_session(_to_utc(now))
def market_status(now: datetime) -> str:
return _get_market_status(_to_utc(now))
def india_market_status():
now = market_now()
return is_market_open(now), now
def next_market_open_after(value: datetime) -> datetime:
current = _as_market_tz(value)
while current.weekday() >= 5:
current = current + timedelta(days=1)
current = current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
if current.time() < _OPEN_T:
return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
if current.time() > _CLOSE_T:
current = current + timedelta(days=1)
while current.weekday() >= 5:
current = current + timedelta(days=1)
return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
return current
def align_to_market_open(value: datetime) -> datetime:
current = _as_market_tz(value)
aligned = current if is_market_open(current) else next_market_open_after(current)
if value.tzinfo is None:
return aligned.replace(tzinfo=None)
return aligned
def next_market_open_after(value: datetime) -> datetime:
aligned_utc = _next_market_open(_to_utc(value))
aligned_ist = aligned_utc.astimezone(MARKET_TZ)
if value.tzinfo is None:
return aligned_ist.replace(tzinfo=None)
return aligned_ist
def align_to_market_open(value: datetime) -> datetime:
return next_market_open_after(value)

View File

@ -0,0 +1,201 @@
from __future__ import annotations
import logging
import os
from datetime import date, datetime, time as dtime, timedelta, timezone
from functools import lru_cache
from zoneinfo import ZoneInfo
logger = logging.getLogger(__name__)
UTC = timezone.utc
MARKET_TZ = ZoneInfo("Asia/Kolkata")
MARKET_OPEN = dtime(9, 15)
MARKET_CLOSE = dtime(15, 30)
STATUS_OPEN = "OPEN"
STATUS_CLOSED = "CLOSED"
STATUS_HOLIDAY = "HOLIDAY"
REASON_OPEN = "OPEN"
REASON_HOLIDAY = "HOLIDAY"
REASON_WEEKEND = "WEEKEND"
REASON_PRE_OPEN = "PRE_OPEN"
REASON_POST_CLOSE = "POST_CLOSE"
REASON_CALENDAR_UNAVAILABLE = "CALENDAR_UNAVAILABLE"
# Capital market trading holidays for NSE calendar year 2026.
# Source: NSE circular dated December 12, 2025.
DEFAULT_NSE_HOLIDAYS_BY_YEAR = {
2026: frozenset(
{
date(2026, 1, 26), # Republic Day
date(2026, 3, 3), # Holi
date(2026, 3, 26), # Shri Ram Navami
date(2026, 3, 31), # Shri Mahavir Jayanti
date(2026, 4, 3), # Good Friday
date(2026, 4, 14), # Dr. Baba Saheb Ambedkar Jayanti
date(2026, 5, 1), # Maharashtra Day
date(2026, 5, 28), # Bakri Id
date(2026, 6, 26), # Muharram
date(2026, 9, 14), # Ganesh Chaturthi
date(2026, 10, 2), # Mahatma Gandhi Jayanti
date(2026, 10, 20), # Dussehra
date(2026, 11, 10), # Diwali-Balipratipada
date(2026, 11, 24), # Prakash Gurpurb Sri Guru Nanak Dev
date(2026, 12, 25), # Christmas
}
),
}
class UnsupportedCalendarYearError(RuntimeError):
def __init__(self, year: int):
super().__init__(f"NSE holiday calendar is not configured for {year}")
self.year = year
def market_now_utc() -> datetime:
return datetime.now(UTC)
def _as_utc(value: datetime | None) -> datetime:
if value is None:
return market_now_utc()
if value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value.astimezone(UTC)
def _as_ist(value: datetime | None) -> datetime:
return _as_utc(value).astimezone(MARKET_TZ)
def _parse_holiday_token(token: str) -> date:
text = token.strip()
if not text:
raise ValueError("Empty holiday token is not allowed")
return date.fromisoformat(text)
@lru_cache(maxsize=16)
def get_nse_holidays(year: int) -> frozenset[date]:
configured = (os.getenv(f"NSE_HOLIDAYS_{year}") or "").strip()
if configured:
return frozenset(_parse_holiday_token(token) for token in configured.split(","))
if year in DEFAULT_NSE_HOLIDAYS_BY_YEAR:
return DEFAULT_NSE_HOLIDAYS_BY_YEAR[year]
raise UnsupportedCalendarYearError(year)
def is_trading_holiday(now_utc: datetime | None = None) -> bool:
current_ist = _as_ist(now_utc)
holidays = get_nse_holidays(current_ist.year)
return current_ist.date() in holidays
def _session_from_utc(now_utc: datetime | None = None) -> dict[str, object]:
current_utc = _as_utc(now_utc)
current_ist = current_utc.astimezone(MARKET_TZ)
try:
holidays = get_nse_holidays(current_ist.year)
except UnsupportedCalendarYearError as exc:
logger.error("NSE holiday calendar unavailable for year %s", exc.year)
return {
"status": STATUS_CLOSED,
"reason": REASON_CALENDAR_UNAVAILABLE,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": False,
}
current_time = current_ist.timetz().replace(tzinfo=None)
if current_ist.date() in holidays:
return {
"status": STATUS_HOLIDAY,
"reason": REASON_HOLIDAY,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": True,
}
if current_ist.weekday() >= 5:
return {
"status": STATUS_CLOSED,
"reason": REASON_WEEKEND,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": True,
}
if current_time < MARKET_OPEN:
return {
"status": STATUS_CLOSED,
"reason": REASON_PRE_OPEN,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": True,
}
if MARKET_OPEN <= current_time < MARKET_CLOSE:
return {
"status": STATUS_OPEN,
"reason": REASON_OPEN,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": True,
}
return {
"status": STATUS_CLOSED,
"reason": REASON_POST_CLOSE,
"checked_at": current_utc,
"checked_at_ist": current_ist,
"calendar_supported": True,
}
def get_market_session(now_utc: datetime | None = None) -> dict[str, object]:
return dict(_session_from_utc(now_utc))
def get_market_status(now_utc: datetime | None = None) -> str:
return str(_session_from_utc(now_utc)["status"])
def is_market_open(now_utc: datetime | None = None) -> bool:
return get_market_status(now_utc) == STATUS_OPEN
def next_market_open(now_utc: datetime | None = None) -> datetime:
current_utc = _as_utc(now_utc)
session = _session_from_utc(current_utc)
if session["reason"] == REASON_CALENDAR_UNAVAILABLE:
checked_at_ist = session["checked_at_ist"]
raise UnsupportedCalendarYearError(int(checked_at_ist.year))
if session["status"] == STATUS_OPEN:
return current_utc
current_ist = current_utc.astimezone(MARKET_TZ)
candidate_date = current_ist.date()
holidays = get_nse_holidays(candidate_date.year)
current_time = current_ist.timetz().replace(tzinfo=None)
if (
session["reason"] == REASON_PRE_OPEN
and candidate_date not in holidays
and candidate_date.weekday() < 5
and current_time < MARKET_OPEN
):
candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ)
return candidate.astimezone(UTC)
candidate_date = candidate_date + timedelta(days=1)
while True:
holidays = get_nse_holidays(candidate_date.year)
if candidate_date.weekday() < 5 and candidate_date not in holidays:
break
candidate_date = candidate_date + timedelta(days=1)
candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ)
return candidate.astimezone(UTC)

View File

@ -1,11 +1,17 @@
import os
import socket
import threading
import time
import uuid
from datetime import datetime, timedelta, timezone
from psycopg2.extras import Json
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now
from indian_paper_trading_strategy.engine.market import (
align_to_market_open,
market_now,
market_session,
)
from indian_paper_trading_strategy.engine.execution import try_execute_sip
from indian_paper_trading_strategy.engine.broker import (
BrokerAuthExpired,
@ -21,7 +27,16 @@ from indian_paper_trading_strategy.engine.strategy import allocation
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
from app.services.zerodha_service import KiteTokenError
from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context
from indian_paper_trading_strategy.engine.db import (
acquire_run_lease,
db_transaction,
heartbeat_run_lease,
insert_engine_event,
release_run_lease,
run_with_retry,
get_context,
set_context,
)
def _update_engine_status(user_id: str, run_id: str, status: str):
@ -58,7 +73,18 @@ _ENGINE_STATES_LOCK = threading.Lock()
_RUNNERS = {}
_RUNNERS_LOCK = threading.Lock()
engine_state = _ENGINE_STATES
engine_state = _ENGINE_STATES
RUNNER_OWNER_ID = os.getenv("RUNNER_OWNER_ID") or f"{socket.gethostname()}:{os.getpid()}:{uuid.uuid4().hex}"
RUN_LEASE_SECONDS = int(os.getenv("RUN_LEASE_SECONDS", "90"))
class RunLeaseNotAcquiredError(RuntimeError):
def __init__(self, run_id: str, owner_id: str, details: dict | None = None):
super().__init__(f"Run lease not acquired for run {run_id}")
self.run_id = run_id
self.owner_id = owner_id
self.details = details or {}
def _state_key(user_id: str, run_id: str):
@ -93,7 +119,7 @@ def get_engine_state(user_id: str, run_id: str):
state = _get_state(user_id, run_id)
return dict(state)
def log_event(
def log_event(
event: str,
data: dict | None = None,
message: str | None = None,
@ -121,22 +147,82 @@ def log_event(
meta=meta,
ts=event_ts,
)
run_with_retry(_op)
run_with_retry(_op)
def _log_runner_lease_event(
user_id: str,
run_id: str,
event: str,
message: str,
meta: dict | None = None,
):
details = meta or {}
print(f"[ENGINE] {event} {message} {details}", flush=True)
def _op(cur, _conn):
insert_engine_event(
cur,
event,
data=details,
message=message,
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
user_id=user_id,
run_id=run_id,
)
try:
run_with_retry(_op)
except Exception:
pass
def _refresh_run_lease_or_stop(
user_id: str,
run_id: str,
owner_id: str,
):
lease = heartbeat_run_lease(
run_id,
owner_id,
lease_seconds=RUN_LEASE_SECONDS,
)
if lease.get("active"):
print(
f"[ENGINE] RUNNER_LEASE_HEARTBEAT lease heartbeat refreshed "
f"{{'run_id': '{run_id}', 'owner_id': '{owner_id}', 'expires_at': '{lease.get('expires_at')}'}}",
flush=True,
)
return True
_log_runner_lease_event(
user_id,
run_id,
"RUNNER_LEASE_LOST",
"Runner exiting due to lost lease",
{"owner_id": owner_id},
)
return False
def sleep_with_heartbeat(
total_seconds: int,
stop_event: threading.Event,
user_id: str,
run_id: str,
owner_id: str,
step_seconds: int = 5,
):
remaining = total_seconds
while remaining > 0 and not stop_event.is_set():
time.sleep(min(step_seconds, remaining))
chunk = min(step_seconds, remaining)
time.sleep(chunk)
_set_state(user_id, run_id, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
_update_engine_status(user_id, run_id, "RUNNING")
remaining -= step_seconds
if not _refresh_run_lease_or_stop(user_id, run_id, owner_id):
return False
remaining -= chunk
return True
def _clear_runner(user_id: str, run_id: str):
key = _state_key(user_id, run_id)
@ -144,7 +230,20 @@ def _clear_runner(user_id: str, run_id: str):
_RUNNERS.pop(key, None)
def can_execute(now: datetime) -> tuple[bool, str]:
if not is_market_open(now):
session = market_session(now)
status = str(session.get("status") or "CLOSED").upper()
reason = str(session.get("reason") or "").upper()
if status == "HOLIDAY":
return False, "MARKET_HOLIDAY"
if status != "OPEN":
if reason == "WEEKEND":
return False, "MARKET_WEEKEND"
if reason == "PRE_OPEN":
return False, "MARKET_PRE_OPEN"
if reason == "POST_CLOSE":
return False, "MARKET_POST_CLOSE"
if reason == "CALENDAR_UNAVAILABLE":
return False, "MARKET_CALENDAR_UNAVAILABLE"
return False, "MARKET_CLOSED"
return True, "OK"
@ -225,12 +324,13 @@ def _pause_for_auth_expiry(
def _engine_loop(config, stop_event: threading.Event):
print("Strategy engine started with config:", config)
user_id = config.get("user_id")
run_id = config.get("run_id")
scope_user, scope_run = get_context(user_id, run_id)
set_context(scope_user, scope_run)
print("Strategy engine started with config:", config)
user_id = config.get("user_id")
run_id = config.get("run_id")
owner_id = config.get("runner_owner_id") or RUNNER_OWNER_ID
scope_user, scope_run = get_context(user_id, run_id)
set_context(scope_user, scope_run)
strategy_name = config.get("strategy_name") or config.get("strategy") or "golden_nifty"
sip_amount = config["sip_amount"]
@ -303,10 +403,14 @@ def _engine_loop(config, stop_event: threading.Event):
state="RUNNING",
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
)
_update_engine_status(scope_user, scope_run, "RUNNING")
try:
while not stop_event.is_set():
_update_engine_status(scope_user, scope_run, "RUNNING")
exit_reason = "STOPPED"
try:
while not stop_event.is_set():
if not _refresh_run_lease_or_stop(scope_user, scope_run, owner_id):
exit_reason = "LEASE_LOST"
break
_set_state(scope_user, scope_run, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
_update_engine_status(scope_user, scope_run, "RUNNING")
@ -357,10 +461,10 @@ def _engine_loop(config, stop_event: threading.Event):
"frequency": frequency_label,
},
)
if emit_event_cb:
emit_event_cb(
event="SIP_WAITING",
message="Waiting for next SIP window",
if emit_event_cb:
emit_event_cb(
event="SIP_WAITING",
message="Waiting for next SIP window",
meta={
"last_run": last_run,
"next_eligible": next_run.isoformat(),
@ -368,7 +472,9 @@ def _engine_loop(config, stop_event: threading.Event):
"frequency": frequency_label,
},
)
sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run)
if not sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run, owner_id):
exit_reason = "LEASE_LOST"
break
continue
try:
@ -395,7 +501,9 @@ def _engine_loop(config, stop_event: threading.Event):
break
except Exception as exc:
debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)})
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
exit_reason = "LEASE_LOST"
break
continue
try:
@ -416,7 +524,9 @@ def _engine_loop(config, stop_event: threading.Event):
break
except Exception as exc:
debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)})
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
exit_reason = "LEASE_LOST"
break
continue
nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1]
@ -565,26 +675,49 @@ def _engine_loop(config, stop_event: threading.Event):
logical_time=logical_time,
)
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
exit_reason = "LEASE_LOST"
break
except BrokerAuthExpired as exc:
exit_reason = "AUTH_EXPIRED"
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True)
except Exception as e:
exit_reason = "ERROR"
_set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
_update_engine_status(scope_user, scope_run, "ERROR")
log_event("ENGINE_ERROR", {"error": str(e)})
raise
finally:
try:
released = release_run_lease(scope_run, owner_id)
if released:
print(
f"[ENGINE] RUNNER_LEASE_RELEASED released run lease "
f"{{'run_id': '{scope_run}', 'owner_id': '{owner_id}'}}",
flush=True,
)
except Exception:
pass
log_event("ENGINE_STOP")
_set_state(
scope_user,
scope_run,
state="STOPPED",
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
)
_update_engine_status(scope_user, scope_run, "STOPPED")
print("Strategy engine stopped")
_clear_runner(scope_user, scope_run)
if exit_reason not in {"ERROR", "LEASE_LOST", "AUTH_EXPIRED"}:
log_event("ENGINE_STOP")
_set_state(
scope_user,
scope_run,
state="STOPPED",
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
)
_update_engine_status(scope_user, scope_run, "STOPPED")
print("Strategy engine stopped")
elif exit_reason == "LEASE_LOST":
_set_state(
scope_user,
scope_run,
state="STOPPED",
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
)
_clear_runner(scope_user, scope_run)
def start_engine(config):
user_id = config.get("user_id")
@ -600,14 +733,53 @@ def start_engine(config):
if runner and runner["thread"].is_alive():
return False
lease = acquire_run_lease(
run_id,
RUNNER_OWNER_ID,
lease_seconds=RUN_LEASE_SECONDS,
)
if not lease.get("acquired"):
_log_runner_lease_event(
user_id,
run_id,
"RUNNER_LEASE_DENIED",
"Run lease denied",
{
"owner_id": RUNNER_OWNER_ID,
"current_owner": lease.get("owner_id"),
"expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None,
},
)
raise RunLeaseNotAcquiredError(run_id, RUNNER_OWNER_ID, lease)
lease_status = str(lease.get("status") or "ACQUIRED").upper()
event_name = "RUNNER_LEASE_REACQUIRED" if lease_status == "REACQUIRED" else "RUNNER_LEASE_ACQUIRED"
_log_runner_lease_event(
user_id,
run_id,
event_name,
"Run lease acquired" if lease_status != "REACQUIRED" else "Expired run lease reacquired",
{
"owner_id": RUNNER_OWNER_ID,
"expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None,
},
)
stop_event = threading.Event()
thread = threading.Thread(
target=_engine_loop,
args=(config, stop_event),
thread_config = dict(config)
thread_config["runner_owner_id"] = RUNNER_OWNER_ID
thread = threading.Thread(
target=_engine_loop,
args=(thread_config, stop_event),
daemon=True,
)
_RUNNERS[key] = {"thread": thread, "stop_event": stop_event}
thread.start()
try:
thread.start()
except Exception:
_RUNNERS.pop(key, None)
release_run_lease(run_id, RUNNER_OWNER_ID)
raise
return True
def stop_engine(user_id: str, run_id: str | None = None, timeout: float | None = 10.0):

View File

@ -1,8 +1,8 @@
# engine/state.py
# engine/state.py
from datetime import datetime, timezone
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
from indian_paper_trading_strategy.engine.market import market_now
from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp
DEFAULT_STATE = {
"initial_cash": 0.0,
@ -31,33 +31,8 @@ def _default_state(mode: str | None):
return DEFAULT_PAPER_STATE.copy()
return DEFAULT_STATE.copy()
def _local_tz():
return market_now().tzinfo
def _format_local_ts(value: datetime | None):
if value is None:
return None
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
def _parse_ts(value):
if value is None:
return None
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=_local_tz())
return value
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=_local_tz())
return parsed
return None
def _parse_ts(value):
return parse_persisted_timestamp(value)
def _resolve_scope(user_id: str | None, run_id: str | None):
return get_context(user_id, run_id)
@ -101,15 +76,15 @@ def load_state(
merged = _default_state(mode)
merged.update(
{
"initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"],
"cash": float(row[1]) if row[1] is not None else merged["cash"],
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
"last_sip_ts": _format_local_ts(row[5]),
"last_run": _format_local_ts(row[6]),
}
)
"initial_cash": float(row[0]) if row[0] is not None else merged["initial_cash"],
"cash": float(row[1]) if row[1] is not None else merged["cash"],
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
"last_sip_ts": serialize_timestamp(row[5]),
"last_run": serialize_timestamp(row[6]),
}
)
if row[7] is not None or row[8] is not None:
merged["sip_frequency"] = {"value": row[7], "unit": row[8]}
return merged
@ -140,13 +115,13 @@ def load_state(
merged = _default_state(mode)
merged.update(
{
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
"last_sip_ts": _format_local_ts(row[3]),
"last_run": _format_local_ts(row[4]),
}
)
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
"last_sip_ts": serialize_timestamp(row[3]),
"last_run": serialize_timestamp(row[4]),
}
)
return merged
def init_paper_state(

View File

@ -1,7 +1,12 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
UTC = timezone.utc
MARKET_TZ = ZoneInfo("Asia/Kolkata")
def frequency_to_timedelta(freq: dict) -> timedelta:
def frequency_to_timedelta(freq: dict) -> timedelta:
value = int(freq.get("value", 0))
unit = freq.get("unit")
@ -15,27 +20,64 @@ def frequency_to_timedelta(freq: dict) -> timedelta:
raise ValueError(f"Unsupported frequency unit: {unit}")
def normalize_logical_time(ts: datetime) -> datetime:
return ts.replace(microsecond=0)
def normalize_logical_time(ts: datetime) -> datetime:
return ts.replace(microsecond=0)
def ensure_aware(ts: datetime, *, default_tz=UTC) -> datetime:
if ts.tzinfo is None:
return ts.replace(tzinfo=default_tz)
return ts
def to_utc(ts: datetime, *, default_tz=UTC) -> datetime:
return ensure_aware(ts, default_tz=default_tz).astimezone(UTC)
def serialize_timestamp(ts: datetime | None, *, default_tz=UTC) -> str | None:
if ts is None:
return None
return to_utc(ts, default_tz=default_tz).isoformat()
def parse_persisted_timestamp(value: datetime | str | None, *, default_tz=UTC) -> datetime | None:
if value is None:
return None
if isinstance(value, datetime):
return to_utc(value, default_tz=default_tz)
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
except ValueError:
return None
return to_utc(parsed, default_tz=default_tz)
return None
def parse_market_timestamp(value: datetime | str | None) -> datetime | None:
parsed = parse_persisted_timestamp(value)
if parsed is None:
return None
return parsed.astimezone(MARKET_TZ)
def compute_logical_time(
now: datetime,
last_run: str | None,
interval_seconds: float | None,
) -> datetime:
base = now
if last_run and interval_seconds:
try:
parsed = datetime.fromisoformat(last_run.replace("Z", "+00:00"))
except ValueError:
parsed = None
if parsed is not None:
if now.tzinfo and parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=now.tzinfo)
elif now.tzinfo is None and parsed.tzinfo:
parsed = parsed.replace(tzinfo=None)
candidate = parsed + timedelta(seconds=interval_seconds)
if now >= candidate:
base = candidate
return normalize_logical_time(base)
def compute_logical_time(
now: datetime,
last_run: str | None,
interval_seconds: float | None,
) -> datetime:
base = now
if last_run and interval_seconds:
parsed = parse_persisted_timestamp(last_run, default_tz=now.tzinfo or UTC)
if parsed is not None:
if now.tzinfo is None:
parsed = parsed.replace(tzinfo=None)
else:
parsed = parsed.astimezone(now.tzinfo)
candidate = parsed + timedelta(seconds=interval_seconds)
if now >= candidate:
base = candidate
return normalize_logical_time(base)