Harden backend auth, execution safety, and market session logic
This commit is contained in:
parent
99e48144aa
commit
519addd78f
@ -47,6 +47,8 @@ class AppSession(Base):
|
|||||||
created_at = Column(DateTime(timezone=True), nullable=False)
|
created_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
last_seen_at = Column(DateTime(timezone=True))
|
last_seen_at = Column(DateTime(timezone=True))
|
||||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
ip = Column(Text)
|
||||||
|
user_agent = Column(Text)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_app_session_user_id", "user_id"),
|
Index("idx_app_session_user_id", "user_id"),
|
||||||
@ -63,6 +65,8 @@ class UserBroker(Base):
|
|||||||
access_token = Column(Text)
|
access_token = Column(Text)
|
||||||
connected_at = Column(DateTime(timezone=True))
|
connected_at = Column(DateTime(timezone=True))
|
||||||
api_key = Column(Text)
|
api_key = Column(Text)
|
||||||
|
api_secret = Column(Text)
|
||||||
|
auth_state = Column(Text)
|
||||||
user_name = Column(Text)
|
user_name = Column(Text)
|
||||||
broker_user_id = Column(Text)
|
broker_user_id = Column(Text)
|
||||||
pending_broker = Column(Text)
|
pending_broker = Column(Text)
|
||||||
@ -117,7 +121,10 @@ class StrategyRun(Base):
|
|||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("user_id", "run_id", name="uq_strategy_run_user_run"),
|
UniqueConstraint("user_id", "run_id", name="uq_strategy_run_user_run"),
|
||||||
CheckConstraint("status IN ('RUNNING','STOPPED','ERROR')", name="chk_strategy_run_status"),
|
CheckConstraint(
|
||||||
|
"status IN ('RUNNING','STOPPED','ERROR','PAUSED_AUTH_EXPIRED')",
|
||||||
|
name="chk_strategy_run_status",
|
||||||
|
),
|
||||||
Index("idx_strategy_run_user_status", "user_id", "status"),
|
Index("idx_strategy_run_user_status", "user_id", "status"),
|
||||||
Index("idx_strategy_run_user_created", "user_id", "created_at"),
|
Index("idx_strategy_run_user_created", "user_id", "created_at"),
|
||||||
Index(
|
Index(
|
||||||
@ -129,6 +136,22 @@ class StrategyRun(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordResetOtp(Base):
|
||||||
|
__tablename__ = "password_reset_otp"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
email = Column(Text, nullable=False)
|
||||||
|
otp_hash = Column(Text, nullable=False)
|
||||||
|
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
used_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_password_reset_otp_email", "email"),
|
||||||
|
Index("idx_password_reset_otp_expires_at", "expires_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StrategyConfig(Base):
|
class StrategyConfig(Base):
|
||||||
__tablename__ = "strategy_config"
|
__tablename__ = "strategy_config"
|
||||||
|
|
||||||
@ -420,6 +443,97 @@ class LiveEquitySnapshot(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SupportTicket(Base):
|
||||||
|
__tablename__ = "support_ticket"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
email = Column(Text, nullable=False)
|
||||||
|
subject = Column(Text, nullable=False)
|
||||||
|
message = Column(Text, nullable=False)
|
||||||
|
status = Column(Text, nullable=False, server_default=text("'NEW'"))
|
||||||
|
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_support_ticket_email", "email"),
|
||||||
|
Index("idx_support_ticket_created_at", "created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SupportRequestAudit(Base):
|
||||||
|
__tablename__ = "support_request_audit"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||||
|
endpoint = Column(Text, nullable=False)
|
||||||
|
ip_hash = Column(Text)
|
||||||
|
email_hash = Column(Text)
|
||||||
|
ticket_hash = Column(Text)
|
||||||
|
blocked = Column(Boolean, nullable=False, server_default=text("false"))
|
||||||
|
reason = Column(Text)
|
||||||
|
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_support_request_audit_endpoint_ip_created", "endpoint", "ip_hash", "created_at"),
|
||||||
|
Index("idx_support_request_audit_ticket_created", "ticket_hash", "created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrokerCallbackState(Base):
|
||||||
|
__tablename__ = "broker_callback_state"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
state_hash = Column(Text, nullable=False, unique=True)
|
||||||
|
user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
session_id = Column(String, ForeignKey("app_session.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
broker = Column(Text, nullable=False)
|
||||||
|
flow = Column(Text, nullable=False)
|
||||||
|
created_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
consumed_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"idx_broker_callback_state_lookup",
|
||||||
|
"user_id",
|
||||||
|
"session_id",
|
||||||
|
"broker",
|
||||||
|
"flow",
|
||||||
|
"expires_at",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionClaim(Base):
|
||||||
|
__tablename__ = "execution_claim"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
user_id = Column(String, ForeignKey("app_user.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), nullable=False)
|
||||||
|
mode = Column(Text, nullable=False)
|
||||||
|
logical_time = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
claimed_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("user_id", "run_id", "logical_time", name="uq_execution_claim_scope"),
|
||||||
|
Index("idx_execution_claim_run_claimed", "run_id", "claimed_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunLease(Base):
|
||||||
|
__tablename__ = "run_leases"
|
||||||
|
|
||||||
|
run_id = Column(String, ForeignKey("strategy_run.run_id", ondelete="CASCADE"), primary_key=True)
|
||||||
|
owner_id = Column(Text, nullable=False)
|
||||||
|
leased_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
heartbeat_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_run_leases_owner_expires", "owner_id", "expires_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MTMLedger(Base):
|
class MTMLedger(Base):
|
||||||
__tablename__ = "mtm_ledger"
|
__tablename__ = "mtm_ledger"
|
||||||
|
|
||||||
|
|||||||
@ -1,82 +1,126 @@
|
|||||||
import os
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.admin_role_service import bootstrap_super_admin
|
||||||
|
from app.admin_router import router as admin_router
|
||||||
from app.routers.auth import router as auth_router
|
from app.routers.auth import router as auth_router
|
||||||
from app.routers.broker import router as broker_router
|
from app.routers.broker import router as broker_router
|
||||||
from app.routers.health import router as health_router
|
from app.routers.health import router as health_router
|
||||||
|
from app.routers.paper import router as paper_router
|
||||||
from app.routers.password_reset import router as password_reset_router
|
from app.routers.password_reset import router as password_reset_router
|
||||||
|
from app.routers.strategy import router as strategy_router
|
||||||
from app.routers.support_ticket import router as support_ticket_router
|
from app.routers.support_ticket import router as support_ticket_router
|
||||||
from app.routers.system import router as system_router
|
from app.routers.system import router as system_router
|
||||||
from app.routers.strategy import router as strategy_router
|
|
||||||
from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router
|
from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router
|
||||||
from app.routers.paper import router as paper_router
|
from app.services.db import _db_config as _validate_db_config
|
||||||
from market import router as market_router
|
|
||||||
from paper_mtm import router as paper_mtm_router
|
|
||||||
from app.services.live_equity_service import start_live_equity_snapshot_daemon
|
from app.services.live_equity_service import start_live_equity_snapshot_daemon
|
||||||
from app.services.strategy_service import init_log_state, resume_running_runs
|
from app.services.strategy_service import init_log_state, resume_running_runs
|
||||||
from app.admin_router import router as admin_router
|
from market import router as market_router
|
||||||
from app.admin_role_service import bootstrap_super_admin
|
from paper_mtm import router as paper_mtm_router
|
||||||
|
|
||||||
app = FastAPI(
|
DEFAULT_PRODUCTION_ORIGINS = {"https://app.quantfortune.com"}
|
||||||
title="QuantFortune Backend",
|
DEFAULT_DEV_ORIGINS = {
|
||||||
version="1.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
cors_origins = [
|
|
||||||
origin.strip()
|
|
||||||
for origin in os.getenv("CORS_ORIGINS", "").split(",")
|
|
||||||
if origin.strip()
|
|
||||||
]
|
|
||||||
if not cors_origins:
|
|
||||||
cors_origins = [
|
|
||||||
"http://localhost:3000",
|
"http://localhost:3000",
|
||||||
"http://127.0.0.1:3000",
|
"http://127.0.0.1:3000",
|
||||||
]
|
"http://localhost:5173",
|
||||||
|
"http://127.0.0.1:5173",
|
||||||
|
}
|
||||||
|
PRODUCTION_ENV_NAMES = {"prod", "production"}
|
||||||
|
|
||||||
cors_origin_regex = os.getenv("CORS_ORIGIN_REGEX", "").strip()
|
|
||||||
if not cors_origin_regex:
|
def _environment_name() -> str:
|
||||||
cors_origin_regex = (
|
return (
|
||||||
r"https://.*\\.ngrok-free\\.dev"
|
os.getenv("APP_ENV")
|
||||||
r"|https://.*\\.ngrok-free\\.app"
|
or os.getenv("ENVIRONMENT")
|
||||||
r"|https://.*\\.ngrok\\.io"
|
or os.getenv("FASTAPI_ENV")
|
||||||
|
or "development"
|
||||||
|
).strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_origin(origin: str) -> str:
|
||||||
|
return origin.strip().rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_dev_origin(origin: str) -> bool:
|
||||||
|
parsed = urlparse(origin)
|
||||||
|
return parsed.scheme == "http" and parsed.hostname in {"localhost", "127.0.0.1"}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_cors_origin(origin: str) -> str:
|
||||||
|
normalized = _normalize_origin(origin)
|
||||||
|
if not normalized:
|
||||||
|
raise RuntimeError("Empty CORS origin is not allowed")
|
||||||
|
if normalized in DEFAULT_PRODUCTION_ORIGINS or _is_dev_origin(normalized):
|
||||||
|
return normalized
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unsupported CORS origin '{normalized}'. Only app.quantfortune.com and localhost dev origins are allowed."
|
||||||
)
|
)
|
||||||
|
|
||||||
# app.add_middleware(
|
|
||||||
# CORSMiddleware,
|
|
||||||
# allow_origins=cors_origins,
|
|
||||||
# allow_origin_regex=cors_origin_regex or None,
|
|
||||||
# allow_credentials=True,
|
|
||||||
# allow_methods=["*"],
|
|
||||||
# allow_headers=["*"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
app.add_middleware(
|
def _build_cors_origins() -> list[str]:
|
||||||
|
configured = [
|
||||||
|
_normalize_origin(origin)
|
||||||
|
for origin in os.getenv("CORS_ORIGINS", "").split(",")
|
||||||
|
if origin.strip()
|
||||||
|
]
|
||||||
|
env_name = _environment_name()
|
||||||
|
|
||||||
|
if env_name in PRODUCTION_ENV_NAMES:
|
||||||
|
if not configured:
|
||||||
|
raise RuntimeError("CORS_ORIGINS must be configured explicitly in production")
|
||||||
|
origins = configured
|
||||||
|
else:
|
||||||
|
origins = configured or sorted(DEFAULT_DEV_ORIGINS)
|
||||||
|
|
||||||
|
deduped: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for origin in origins:
|
||||||
|
validated = _validate_cors_origin(origin)
|
||||||
|
if validated not in seen:
|
||||||
|
seen.add(validated)
|
||||||
|
deduped.append(validated)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
_validate_db_config()
|
||||||
|
app = FastAPI(title="QuantFortune Backend", version="1.0")
|
||||||
|
cors_origins = _build_cors_origins()
|
||||||
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[], # must be empty when using regex
|
allow_origins=cors_origins,
|
||||||
allow_origin_regex=".*", # allow ANY origin
|
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(strategy_router)
|
app.include_router(strategy_router)
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
app.include_router(broker_router)
|
app.include_router(broker_router)
|
||||||
app.include_router(zerodha_router)
|
app.include_router(zerodha_router)
|
||||||
app.include_router(zerodha_public_router)
|
app.include_router(zerodha_public_router)
|
||||||
app.include_router(paper_router)
|
app.include_router(paper_router)
|
||||||
app.include_router(market_router)
|
app.include_router(market_router)
|
||||||
app.include_router(paper_mtm_router)
|
app.include_router(paper_mtm_router)
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
app.include_router(system_router)
|
app.include_router(system_router)
|
||||||
app.include_router(admin_router)
|
app.include_router(admin_router)
|
||||||
app.include_router(support_ticket_router)
|
app.include_router(support_ticket_router)
|
||||||
app.include_router(password_reset_router)
|
app.include_router(password_reset_router)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def init_app_state():
|
def init_app_state():
|
||||||
|
if os.getenv("DISABLE_STARTUP_TASKS", "0") == "1":
|
||||||
|
return
|
||||||
init_log_state()
|
init_log_state()
|
||||||
bootstrap_super_admin()
|
bootstrap_super_admin()
|
||||||
resume_running_runs()
|
resume_running_runs()
|
||||||
start_live_equity_snapshot_daemon()
|
start_live_equity_snapshot_daemon()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|||||||
@ -15,8 +15,12 @@ from app.services.email_service import send_email_async
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
SESSION_COOKIE_NAME = "session_id"
|
SESSION_COOKIE_NAME = "session_id"
|
||||||
COOKIE_SECURE = os.getenv("COOKIE_SECURE", "0") == "1"
|
APP_ENV = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
|
||||||
|
IS_PRODUCTION = APP_ENV in {"prod", "production"}
|
||||||
|
COOKIE_SECURE = True if IS_PRODUCTION else os.getenv("COOKIE_SECURE", "0") == "1"
|
||||||
COOKIE_SAMESITE = (os.getenv("COOKIE_SAMESITE") or "lax").lower()
|
COOKIE_SAMESITE = (os.getenv("COOKIE_SAMESITE") or "lax").lower()
|
||||||
|
if IS_PRODUCTION and not COOKIE_SECURE:
|
||||||
|
raise RuntimeError("Secure session cookies are mandatory in production")
|
||||||
|
|
||||||
|
|
||||||
def _set_session_cookie(response: Response, session_id: str):
|
def _set_session_cookie(response: Response, session_id: str):
|
||||||
|
|||||||
@ -15,6 +15,10 @@ from app.broker_store import (
|
|||||||
set_pending_broker,
|
set_pending_broker,
|
||||||
)
|
)
|
||||||
from app.services.auth_service import get_user_for_session
|
from app.services.auth_service import get_user_for_session
|
||||||
|
from app.services.broker_callback_state import (
|
||||||
|
consume_broker_callback_state,
|
||||||
|
create_broker_callback_state,
|
||||||
|
)
|
||||||
from app.services.email_service import send_email_async
|
from app.services.email_service import send_email_async
|
||||||
from app.services.groww_service import (
|
from app.services.groww_service import (
|
||||||
GrowwApiError,
|
GrowwApiError,
|
||||||
@ -60,6 +64,13 @@ def _require_user(request: Request):
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _require_session_id(request: Request) -> str:
|
||||||
|
session_id = request.cookies.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
def _first_number(*values, default: float = 0.0) -> float:
|
def _first_number(*values, default: float = 0.0) -> float:
|
||||||
for value in values:
|
for value in values:
|
||||||
try:
|
try:
|
||||||
@ -317,6 +328,7 @@ def _normalize_groww_funds(data: dict | None) -> dict:
|
|||||||
def _build_saved_broker_login_url(
|
def _build_saved_broker_login_url(
|
||||||
request: Request,
|
request: Request,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
redirect_url_override: str | None = None,
|
redirect_url_override: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
entry = get_user_broker(user_id) or {}
|
entry = get_user_broker(user_id) or {}
|
||||||
@ -332,7 +344,13 @@ def _build_saved_broker_login_url(
|
|||||||
if not redirect_url:
|
if not redirect_url:
|
||||||
base = str(request.base_url).rstrip("/")
|
base = str(request.base_url).rstrip("/")
|
||||||
redirect_url = f"{base}/api/broker/callback"
|
redirect_url = f"{base}/api/broker/callback"
|
||||||
return build_login_url(creds["api_key"], redirect_url=redirect_url)
|
state = create_broker_callback_state(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="reconnect",
|
||||||
|
)
|
||||||
|
return build_login_url(creds["api_key"], redirect_url=redirect_url, state=state)
|
||||||
|
|
||||||
|
|
||||||
def _notify_broker_connected(username: str, broker: str, broker_user_id: str | None):
|
def _notify_broker_connected(username: str, broker: str, broker_user_id: str | None):
|
||||||
@ -401,6 +419,7 @@ async def disconnect_broker(request: Request):
|
|||||||
@router.post("/zerodha/login")
|
@router.post("/zerodha/login")
|
||||||
async def zerodha_login(payload: dict, request: Request):
|
async def zerodha_login(payload: dict, request: Request):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
|
session_id = _require_session_id(request)
|
||||||
api_key = (payload.get("apiKey") or "").strip()
|
api_key = (payload.get("apiKey") or "").strip()
|
||||||
api_secret = (payload.get("apiSecret") or "").strip()
|
api_secret = (payload.get("apiSecret") or "").strip()
|
||||||
redirect_url = (payload.get("redirectUrl") or "").strip()
|
redirect_url = (payload.get("redirectUrl") or "").strip()
|
||||||
@ -408,7 +427,13 @@ async def zerodha_login(payload: dict, request: Request):
|
|||||||
raise HTTPException(status_code=400, detail="API key and secret are required")
|
raise HTTPException(status_code=400, detail="API key and secret are required")
|
||||||
|
|
||||||
set_pending_broker(user["id"], "ZERODHA", api_key, api_secret)
|
set_pending_broker(user["id"], "ZERODHA", api_key, api_secret)
|
||||||
return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None)}
|
state = create_broker_callback_state(
|
||||||
|
user_id=user["id"],
|
||||||
|
session_id=session_id,
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="connect",
|
||||||
|
)
|
||||||
|
return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None, state=state)}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/groww/connect")
|
@router.post("/groww/connect")
|
||||||
@ -490,11 +515,23 @@ async def groww_reconnect(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/zerodha/callback")
|
@router.get("/zerodha/callback")
|
||||||
async def zerodha_callback(request: Request, request_token: str = ""):
|
async def zerodha_callback(request: Request, request_token: str = "", state: str = ""):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
|
session_id = _require_session_id(request)
|
||||||
token = request_token.strip()
|
token = request_token.strip()
|
||||||
|
callback_state = state.strip()
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(status_code=400, detail="Missing request_token")
|
raise HTTPException(status_code=400, detail="Missing request_token")
|
||||||
|
if not callback_state:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing state")
|
||||||
|
if not consume_broker_callback_state(
|
||||||
|
state=callback_state,
|
||||||
|
user_id=user["id"],
|
||||||
|
session_id=session_id,
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="connect",
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired broker callback state")
|
||||||
|
|
||||||
pending = get_pending_broker(user["id"]) or {}
|
pending = get_pending_broker(user["id"]) or {}
|
||||||
api_key = (pending.get("api_key") or "").strip()
|
api_key = (pending.get("api_key") or "").strip()
|
||||||
@ -541,32 +578,46 @@ async def zerodha_callback(request: Request, request_token: str = ""):
|
|||||||
@router.get("/login")
|
@router.get("/login")
|
||||||
async def broker_login(request: Request):
|
async def broker_login(request: Request):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
|
session_id = _require_session_id(request)
|
||||||
redirect_url = (
|
redirect_url = (
|
||||||
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
|
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
|
||||||
.strip()
|
.strip()
|
||||||
or None
|
or None
|
||||||
)
|
)
|
||||||
login_url = _build_saved_broker_login_url(request, user["id"], redirect_url)
|
login_url = _build_saved_broker_login_url(request, user["id"], session_id, redirect_url)
|
||||||
return RedirectResponse(login_url)
|
return RedirectResponse(login_url)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/login-url")
|
@router.get("/login-url")
|
||||||
async def broker_login_url(request: Request):
|
async def broker_login_url(request: Request):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
|
session_id = _require_session_id(request)
|
||||||
redirect_url = (
|
redirect_url = (
|
||||||
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
|
(request.query_params.get("redirectUrl") or request.query_params.get("redirect_url") or "")
|
||||||
.strip()
|
.strip()
|
||||||
or None
|
or None
|
||||||
)
|
)
|
||||||
return {"loginUrl": _build_saved_broker_login_url(request, user["id"], redirect_url)}
|
return {"loginUrl": _build_saved_broker_login_url(request, user["id"], session_id, redirect_url)}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/callback")
|
@router.get("/callback")
|
||||||
async def broker_callback(request: Request, request_token: str = ""):
|
async def broker_callback(request: Request, request_token: str = "", state: str = ""):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
|
session_id = _require_session_id(request)
|
||||||
token = request_token.strip()
|
token = request_token.strip()
|
||||||
|
callback_state = state.strip()
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(status_code=400, detail="Missing request_token")
|
raise HTTPException(status_code=400, detail="Missing request_token")
|
||||||
|
if not callback_state:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing state")
|
||||||
|
if not consume_broker_callback_state(
|
||||||
|
state=callback_state,
|
||||||
|
user_id=user["id"],
|
||||||
|
session_id=session_id,
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="reconnect",
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired broker callback state")
|
||||||
creds = get_broker_credentials(user["id"])
|
creds = get_broker_credentials(user["id"])
|
||||||
if not creds:
|
if not creds:
|
||||||
raise HTTPException(status_code=400, detail="Broker credentials not configured")
|
raise HTTPException(status_code=400, detail="Broker credentials not configured")
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, HTTPException, Query, Request, status as http_status
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from app.models import StrategyStartRequest
|
from app.models import StrategyStartRequest
|
||||||
from app.services.strategy_service import (
|
from app.services.strategy_service import (
|
||||||
start_strategy,
|
start_strategy,
|
||||||
@ -15,6 +14,20 @@ from app.services.tenant import get_request_user_id
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_strategy_error(payload: dict, *, default_status: int) -> None:
|
||||||
|
message = payload.get("message") or payload.get("detail") or "Strategy operation failed"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=default_status,
|
||||||
|
detail={
|
||||||
|
"status": payload.get("status", "error"),
|
||||||
|
"message": message,
|
||||||
|
"run_id": payload.get("run_id"),
|
||||||
|
"redirect_url": payload.get("redirect_url"),
|
||||||
|
"broker": payload.get("broker"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/strategy/start")
|
@router.post("/strategy/start")
|
||||||
def start(req: StrategyStartRequest, request: Request):
|
def start(req: StrategyStartRequest, request: Request):
|
||||||
user_id = get_request_user_id(request)
|
user_id = get_request_user_id(request)
|
||||||
@ -24,35 +37,38 @@ def start(req: StrategyStartRequest, request: Request):
|
|||||||
def stop(request: Request):
|
def stop(request: Request):
|
||||||
try:
|
try:
|
||||||
user_id = get_request_user_id(request)
|
user_id = get_request_user_id(request)
|
||||||
return stop_strategy(user_id)
|
result = stop_strategy(user_id)
|
||||||
|
if result.get("status") not in {"stopped", "already_stopped"}:
|
||||||
|
_raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT)
|
||||||
|
return result
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print(f"[STRATEGY] unhandled stop route failure: {exc}", flush=True)
|
print(f"[STRATEGY] unhandled stop route failure: {exc}", flush=True)
|
||||||
return JSONResponse(
|
raise HTTPException(
|
||||||
status_code=200,
|
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
detail={"status": "stop_failed", "message": f"Unable to stop strategy: {exc}"},
|
||||||
"status": "stop_failed",
|
) from exc
|
||||||
"message": f"Unable to stop strategy: {exc}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.post("/strategy/resume")
|
@router.post("/strategy/resume")
|
||||||
def resume(request: Request):
|
def resume(request: Request):
|
||||||
try:
|
try:
|
||||||
user_id = get_request_user_id(request)
|
user_id = get_request_user_id(request)
|
||||||
return resume_strategy(user_id)
|
result = resume_strategy(user_id)
|
||||||
|
success_statuses = {"resumed", "already_running"}
|
||||||
|
if result.get("status") == "broker_auth_required":
|
||||||
|
_raise_strategy_error(result, default_status=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
if result.get("status") not in success_statuses:
|
||||||
|
_raise_strategy_error(result, default_status=http_status.HTTP_409_CONFLICT)
|
||||||
|
return result
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print(f"[STRATEGY] unhandled resume route failure: {exc}", flush=True)
|
print(f"[STRATEGY] unhandled resume route failure: {exc}", flush=True)
|
||||||
return JSONResponse(
|
raise HTTPException(
|
||||||
status_code=200,
|
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
detail={"status": "resume_failed", "message": f"Unable to resume strategy: {exc}"},
|
||||||
"status": "resume_failed",
|
) from exc
|
||||||
"message": f"Unable to resume strategy: {exc}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/strategy/status")
|
@router.get("/strategy/status")
|
||||||
def status(request: Request):
|
def status(request: Request):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Header, HTTPException, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.services.support_abuse import SupportGuardRejected, enforce_support_guard
|
||||||
from app.services.support_ticket import create_ticket, get_ticket_status
|
from app.services.support_ticket import create_ticket, get_ticket_status
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +20,20 @@ class TicketStatusRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/ticket")
|
@router.post("/ticket")
|
||||||
def submit_ticket(payload: TicketCreate):
|
def submit_ticket(
|
||||||
|
payload: TicketCreate,
|
||||||
|
request: Request,
|
||||||
|
support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
enforce_support_guard(
|
||||||
|
request=request,
|
||||||
|
endpoint="ticket_create",
|
||||||
|
email=payload.email.strip(),
|
||||||
|
captcha_token=support_captcha,
|
||||||
|
)
|
||||||
|
except SupportGuardRejected as exc:
|
||||||
|
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
|
||||||
if not payload.subject.strip() or not payload.message.strip():
|
if not payload.subject.strip() or not payload.message.strip():
|
||||||
raise HTTPException(status_code=400, detail="Subject and message are required")
|
raise HTTPException(status_code=400, detail="Subject and message are required")
|
||||||
ticket = create_ticket(
|
ticket = create_ticket(
|
||||||
@ -32,7 +46,22 @@ def submit_ticket(payload: TicketCreate):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/ticket/status/{ticket_id}")
|
@router.post("/ticket/status/{ticket_id}")
|
||||||
def ticket_status(ticket_id: str, payload: TicketStatusRequest):
|
def ticket_status(
|
||||||
|
ticket_id: str,
|
||||||
|
payload: TicketStatusRequest,
|
||||||
|
request: Request,
|
||||||
|
support_captcha: str | None = Header(default=None, alias="X-Support-Captcha"),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
enforce_support_guard(
|
||||||
|
request=request,
|
||||||
|
endpoint="ticket_status",
|
||||||
|
email=payload.email.strip(),
|
||||||
|
ticket_id=ticket_id.strip(),
|
||||||
|
captcha_token=support_captcha,
|
||||||
|
)
|
||||||
|
except SupportGuardRejected as exc:
|
||||||
|
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
|
||||||
status = get_ticket_status(ticket_id.strip(), payload.email.strip())
|
status = get_ticket_status(ticket_id.strip(), payload.email.strip())
|
||||||
if not status:
|
if not status:
|
||||||
raise HTTPException(status_code=404, detail="Ticket not found")
|
raise HTTPException(status_code=404, detail="Ticket not found")
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.exceptions import InvalidHash, VerifyMismatchError
|
||||||
|
|
||||||
from app.services.db import db_connection
|
from app.services.db import db_connection
|
||||||
|
|
||||||
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
|
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
|
||||||
@ -11,7 +15,11 @@ SESSION_REFRESH_WINDOW_SECONDS = int(
|
|||||||
os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60))
|
os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60))
|
||||||
)
|
)
|
||||||
RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10"))
|
RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10"))
|
||||||
RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret")
|
PASSWORD_HASHER = PasswordHasher()
|
||||||
|
LEGACY_SHA256_RE = re.compile(r"^[0-9a-f]{64}$")
|
||||||
|
RESET_OTP_SECRET = (os.getenv("RESET_OTP_SECRET") or "").strip()
|
||||||
|
if not RESET_OTP_SECRET:
|
||||||
|
raise RuntimeError("RESET_OTP_SECRET must be configured")
|
||||||
|
|
||||||
|
|
||||||
def _now_utc() -> datetime:
|
def _now_utc() -> datetime:
|
||||||
@ -23,9 +31,17 @@ def _new_expiry(now: datetime) -> datetime:
|
|||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
def _hash_password(password: str) -> str:
|
||||||
|
return PASSWORD_HASHER.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password_legacy(password: str) -> str:
|
||||||
return hashlib.sha256(password.encode("utf-8")).hexdigest()
|
return hashlib.sha256(password.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_legacy_password_hash(password_hash: str | None) -> bool:
|
||||||
|
return bool(password_hash and LEGACY_SHA256_RE.fullmatch(password_hash))
|
||||||
|
|
||||||
|
|
||||||
def _hash_otp(email: str, otp: str) -> str:
|
def _hash_otp(email: str, otp: str) -> str:
|
||||||
payload = f"{email}:{otp}:{RESET_OTP_SECRET}"
|
payload = f"{email}:{otp}:{RESET_OTP_SECRET}"
|
||||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||||
@ -80,12 +96,47 @@ def create_user(username: str, password: str):
|
|||||||
return _row_to_user(cur.fetchone())
|
return _row_to_user(cur.fetchone())
|
||||||
|
|
||||||
|
|
||||||
|
def _update_password_hash(user_id: str, password_hash: str):
|
||||||
|
with db_connection() as conn:
|
||||||
|
with conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE app_user SET password_hash = %s WHERE id = %s",
|
||||||
|
(password_hash, user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(user_id: str, stored_hash: str | None, password: str) -> tuple[bool, str | None]:
|
||||||
|
if not stored_hash:
|
||||||
|
return False, None
|
||||||
|
if _is_legacy_password_hash(stored_hash):
|
||||||
|
if secrets.compare_digest(stored_hash, _hash_password_legacy(password)):
|
||||||
|
return True, _hash_password(password)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
verified = PASSWORD_HASHER.verify(stored_hash, password)
|
||||||
|
except (VerifyMismatchError, InvalidHash):
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if not verified:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if PASSWORD_HASHER.check_needs_rehash(stored_hash):
|
||||||
|
return True, _hash_password(password)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(username: str, password: str):
|
def authenticate_user(username: str, password: str):
|
||||||
user = get_user_by_username(username)
|
user = get_user_by_username(username)
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
if user.get("password") != _hash_password(password):
|
verified, replacement_hash = _verify_password(user["id"], user.get("password"), password)
|
||||||
|
if not verified:
|
||||||
return None
|
return None
|
||||||
|
if replacement_hash:
|
||||||
|
_update_password_hash(user["id"], replacement_hash)
|
||||||
|
user["password"] = replacement_hash
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@ -130,13 +181,7 @@ def get_last_session_meta(user_id: str):
|
|||||||
|
|
||||||
def update_user_password(user_id: str, new_password: str):
|
def update_user_password(user_id: str, new_password: str):
|
||||||
password_hash = _hash_password(new_password)
|
password_hash = _hash_password(new_password)
|
||||||
with db_connection() as conn:
|
_update_password_hash(user_id, password_hash)
|
||||||
with conn:
|
|
||||||
with conn.cursor() as cur:
|
|
||||||
cur.execute(
|
|
||||||
"UPDATE app_user SET password_hash = %s WHERE id = %s",
|
|
||||||
(password_hash, user_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_password_reset_otp(email: str):
|
def create_password_reset_otp(email: str):
|
||||||
|
|||||||
111
backend/app/services/broker_callback_state.py
Normal file
111
backend/app/services/broker_callback_state.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from app.services.db import db_connection
|
||||||
|
|
||||||
|
CALLBACK_STATE_TTL_SECONDS = 15 * 60
|
||||||
|
|
||||||
|
|
||||||
|
def _now_utc() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _state_hash(state: str) -> str:
|
||||||
|
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def create_broker_callback_state(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
broker: str,
|
||||||
|
flow: str,
|
||||||
|
ttl_seconds: int = CALLBACK_STATE_TTL_SECONDS,
|
||||||
|
) -> str:
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
now = _now_utc()
|
||||||
|
expires_at = now + timedelta(seconds=ttl_seconds)
|
||||||
|
state_hash = _state_hash(state)
|
||||||
|
with db_connection() as conn:
|
||||||
|
with conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM broker_callback_state
|
||||||
|
WHERE expires_at <= %s OR consumed_at IS NOT NULL
|
||||||
|
""",
|
||||||
|
(now,),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO broker_callback_state (
|
||||||
|
id,
|
||||||
|
state_hash,
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
broker,
|
||||||
|
flow,
|
||||||
|
created_at,
|
||||||
|
expires_at,
|
||||||
|
consumed_at
|
||||||
|
)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NULL)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
str(uuid4()),
|
||||||
|
state_hash,
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
broker.strip().upper(),
|
||||||
|
flow.strip().lower(),
|
||||||
|
now,
|
||||||
|
expires_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def consume_broker_callback_state(
|
||||||
|
*,
|
||||||
|
state: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
broker: str,
|
||||||
|
flow: str,
|
||||||
|
):
|
||||||
|
if not state:
|
||||||
|
return None
|
||||||
|
now = _now_utc()
|
||||||
|
state_hash = _state_hash(state)
|
||||||
|
with db_connection() as conn:
|
||||||
|
with conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE broker_callback_state
|
||||||
|
SET consumed_at = %s
|
||||||
|
WHERE state_hash = %s
|
||||||
|
AND user_id = %s
|
||||||
|
AND session_id = %s
|
||||||
|
AND broker = %s
|
||||||
|
AND flow = %s
|
||||||
|
AND consumed_at IS NULL
|
||||||
|
AND expires_at > %s
|
||||||
|
RETURNING id, expires_at
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
now,
|
||||||
|
state_hash,
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
broker.strip().upper(),
|
||||||
|
flow.strip().lower(),
|
||||||
|
now,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return {"id": row[0], "expires_at": row[1].isoformat() if row[1] else None}
|
||||||
@ -16,6 +16,7 @@ Base = declarative_base()
|
|||||||
|
|
||||||
_ENGINE: Engine | None = None
|
_ENGINE: Engine | None = None
|
||||||
_ENGINE_LOCK = threading.Lock()
|
_ENGINE_LOCK = threading.Lock()
|
||||||
|
NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"}
|
||||||
|
|
||||||
|
|
||||||
class _ConnectionProxy:
|
class _ConnectionProxy:
|
||||||
@ -44,16 +45,28 @@ class _ConnectionProxy:
|
|||||||
|
|
||||||
|
|
||||||
def _db_config() -> dict[str, str | int]:
|
def _db_config() -> dict[str, str | int]:
|
||||||
|
env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
|
||||||
|
is_non_prod = env_name in NON_PROD_ENVIRONMENTS
|
||||||
url = os.getenv("DATABASE_URL")
|
url = os.getenv("DATABASE_URL")
|
||||||
if url:
|
if url:
|
||||||
return {"url": url}
|
return {"url": url}
|
||||||
|
|
||||||
|
password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD")
|
||||||
|
if not password and not is_non_prod:
|
||||||
|
raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments")
|
||||||
|
|
||||||
|
host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None)
|
||||||
|
dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None)
|
||||||
|
user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None)
|
||||||
|
if not is_non_prod and (not host or not dbname or not user):
|
||||||
|
raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
|
"host": host,
|
||||||
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
||||||
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
|
"dbname": dbname,
|
||||||
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
|
"user": user,
|
||||||
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
|
"password": password,
|
||||||
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
||||||
"schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app",
|
"schema": os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -172,6 +172,6 @@ def update_run_status(user_id: str, run_id: str, status: str, meta: dict | None
|
|||||||
""",
|
""",
|
||||||
(status, now, Json(meta or {}), run_id, user_id),
|
(status, now, Json(meta or {}), run_id, user_id),
|
||||||
)
|
)
|
||||||
return True
|
return cur.rowcount > 0
|
||||||
|
|
||||||
return run_with_retry(_op)
|
return run_with_retry(_op)
|
||||||
|
|||||||
@ -4,17 +4,28 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
ENGINE_ROOT = Path(__file__).resolve().parents[3]
|
ENGINE_ROOT = Path(__file__).resolve().parents[3]
|
||||||
if str(ENGINE_ROOT) not in sys.path:
|
if str(ENGINE_ROOT) not in sys.path:
|
||||||
sys.path.append(str(ENGINE_ROOT))
|
sys.path.append(str(ENGINE_ROOT))
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now
|
from indian_paper_trading_strategy.engine.market import (
|
||||||
from indian_paper_trading_strategy.engine.runner import start_engine, stop_engine
|
align_to_market_open,
|
||||||
|
market_now,
|
||||||
|
market_session,
|
||||||
|
next_market_open_after,
|
||||||
|
)
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import UnsupportedCalendarYearError
|
||||||
|
from indian_paper_trading_strategy.engine.runner import RunLeaseNotAcquiredError, start_engine, stop_engine
|
||||||
from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state
|
from indian_paper_trading_strategy.engine.state import init_paper_state, load_state, save_state
|
||||||
from indian_paper_trading_strategy.engine.broker import PaperBroker
|
from indian_paper_trading_strategy.engine.broker import PaperBroker
|
||||||
from indian_paper_trading_strategy.engine.time_utils import frequency_to_timedelta
|
from indian_paper_trading_strategy.engine.time_utils import (
|
||||||
|
UTC,
|
||||||
|
frequency_to_timedelta,
|
||||||
|
parse_market_timestamp,
|
||||||
|
parse_persisted_timestamp,
|
||||||
|
serialize_timestamp,
|
||||||
|
)
|
||||||
from indian_paper_trading_strategy.engine.db import engine_context
|
from indian_paper_trading_strategy.engine.db import engine_context
|
||||||
|
|
||||||
from app.broker_store import get_user_broker, set_broker_auth_state
|
from app.broker_store import get_user_broker, set_broker_auth_state
|
||||||
@ -41,7 +52,6 @@ SEQ_LOCK = threading.Lock()
|
|||||||
SEQ = 0
|
SEQ = 0
|
||||||
LAST_WAIT_LOG_TS = {}
|
LAST_WAIT_LOG_TS = {}
|
||||||
WAIT_LOG_INTERVAL = timedelta(seconds=60)
|
WAIT_LOG_INTERVAL = timedelta(seconds=60)
|
||||||
IST = ZoneInfo("Asia/Kolkata")
|
|
||||||
|
|
||||||
def init_log_state():
|
def init_log_state():
|
||||||
global SEQ
|
global SEQ
|
||||||
@ -110,7 +120,7 @@ def emit_event(
|
|||||||
|
|
||||||
evt = {
|
evt = {
|
||||||
"seq": seq,
|
"seq": seq,
|
||||||
"ts": now.isoformat().replace("+00:00", "Z"),
|
"ts": serialize_timestamp(now),
|
||||||
"level": level,
|
"level": level,
|
||||||
"category": category,
|
"category": category,
|
||||||
"event": event,
|
"event": event,
|
||||||
@ -157,14 +167,8 @@ def _maybe_parse_json(value):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _local_tz():
|
def _utc_now():
|
||||||
return IST
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
def _format_local_ts(value: datetime | None):
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
|
|
||||||
|
|
||||||
|
|
||||||
def _load_config(user_id: str, run_id: str):
|
def _load_config(user_id: str, run_id: str):
|
||||||
@ -192,7 +196,7 @@ def _load_config(user_id: str, run_id: str):
|
|||||||
"frequency": _maybe_parse_json(row[7]),
|
"frequency": _maybe_parse_json(row[7]),
|
||||||
"frequency_days": row[8],
|
"frequency_days": row[8],
|
||||||
"unit": row[9],
|
"unit": row[9],
|
||||||
"next_run": _format_local_ts(row[10]),
|
"next_run": serialize_timestamp(row[10]),
|
||||||
}
|
}
|
||||||
if row[2] is not None or row[3] is not None:
|
if row[2] is not None or row[3] is not None:
|
||||||
cfg["sip_frequency"] = {
|
cfg["sip_frequency"] = {
|
||||||
@ -217,13 +221,7 @@ def _save_config(cfg, user_id: str, run_id: str):
|
|||||||
next_run = cfg.get("next_run")
|
next_run = cfg.get("next_run")
|
||||||
next_run_dt = None
|
next_run_dt = None
|
||||||
if isinstance(next_run, str):
|
if isinstance(next_run, str):
|
||||||
try:
|
next_run_dt = parse_persisted_timestamp(next_run)
|
||||||
parsed = datetime.fromisoformat(next_run)
|
|
||||||
if parsed.tzinfo is None:
|
|
||||||
parsed = parsed.replace(tzinfo=_local_tz())
|
|
||||||
next_run_dt = parsed
|
|
||||||
except ValueError:
|
|
||||||
next_run_dt = None
|
|
||||||
|
|
||||||
with db_connection() as conn:
|
with db_connection() as conn:
|
||||||
with conn:
|
with conn:
|
||||||
@ -294,7 +292,7 @@ def reactivate_strategy_config(user_id: str, run_id: str):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _write_status(user_id: str, run_id: str, status):
|
def _write_status(user_id: str, run_id: str, status):
|
||||||
now_local = market_now()
|
now_local = _utc_now()
|
||||||
with db_connection() as conn:
|
with db_connection() as conn:
|
||||||
with conn:
|
with conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
@ -346,6 +344,12 @@ def _effective_running_run_id(user_id: str):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _set_run_status_or_raise(user_id: str, run_id: str, status: str, meta: dict | None = None):
|
||||||
|
updated = update_run_status(user_id, run_id, status, meta=meta)
|
||||||
|
if not updated:
|
||||||
|
raise RuntimeError(f"Run {run_id} for user {user_id} no longer exists")
|
||||||
|
|
||||||
def validate_frequency(freq: dict, mode: str):
|
def validate_frequency(freq: dict, mode: str):
|
||||||
if not isinstance(freq, dict):
|
if not isinstance(freq, dict):
|
||||||
raise ValueError("Frequency payload is required")
|
raise ValueError("Frequency payload is required")
|
||||||
@ -436,9 +440,8 @@ def _validate_live_broker_session(user_id: str):
|
|||||||
def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
|
def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
|
||||||
if not last_run or not sip_frequency:
|
if not last_run or not sip_frequency:
|
||||||
return None
|
return None
|
||||||
try:
|
last_dt = parse_market_timestamp(last_run)
|
||||||
last_dt = datetime.fromisoformat(last_run)
|
if last_dt is None:
|
||||||
except ValueError:
|
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
delta = frequency_to_timedelta(sip_frequency)
|
delta = frequency_to_timedelta(sip_frequency)
|
||||||
@ -446,7 +449,7 @@ def compute_next_eligible(last_run: str | None, sip_frequency: dict | None):
|
|||||||
return None
|
return None
|
||||||
next_dt = last_dt + delta
|
next_dt = last_dt + delta
|
||||||
next_dt = align_to_market_open(next_dt)
|
next_dt = align_to_market_open(next_dt)
|
||||||
return next_dt.isoformat()
|
return serialize_timestamp(next_dt)
|
||||||
|
|
||||||
|
|
||||||
def _last_execution_ts(state: dict, mode: str) -> str | None:
|
def _last_execution_ts(state: dict, mode: str) -> str | None:
|
||||||
@ -473,7 +476,10 @@ def start_strategy(req, user_id: str):
|
|||||||
return {"status": "already_running", "run_id": running_run_id}
|
return {"status": "already_running", "run_id": running_run_id}
|
||||||
engine_config = _build_engine_config(user_id, running_run_id, req)
|
engine_config = _build_engine_config(user_id, running_run_id, req)
|
||||||
if engine_config:
|
if engine_config:
|
||||||
|
try:
|
||||||
started = start_engine(engine_config)
|
started = start_engine(engine_config)
|
||||||
|
except RunLeaseNotAcquiredError:
|
||||||
|
return {"status": "already_running", "run_id": running_run_id}
|
||||||
if started:
|
if started:
|
||||||
_write_status(user_id, running_run_id, "RUNNING")
|
_write_status(user_id, running_run_id, "RUNNING")
|
||||||
return {"status": "restarted", "run_id": running_run_id}
|
return {"status": "restarted", "run_id": running_run_id}
|
||||||
@ -573,7 +579,10 @@ def start_strategy(req, user_id: str):
|
|||||||
engine_config["run_id"] = run_id
|
engine_config["run_id"] = run_id
|
||||||
engine_config["user_id"] = user_id
|
engine_config["user_id"] = user_id
|
||||||
engine_config["emit_event"] = emit_event_cb
|
engine_config["emit_event"] = emit_event_cb
|
||||||
|
try:
|
||||||
start_engine(engine_config)
|
start_engine(engine_config)
|
||||||
|
except RunLeaseNotAcquiredError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = get_user_by_id(user_id)
|
user = get_user_by_id(user_id)
|
||||||
@ -655,14 +664,17 @@ def resume_running_runs():
|
|||||||
engine_config = _build_engine_config(user_id, run_id, None)
|
engine_config = _build_engine_config(user_id, run_id, None)
|
||||||
if not engine_config:
|
if not engine_config:
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
started = start_engine(engine_config)
|
started = start_engine(engine_config)
|
||||||
|
except RunLeaseNotAcquiredError:
|
||||||
|
started = False
|
||||||
if started:
|
if started:
|
||||||
_write_status(user_id, run_id, "RUNNING")
|
_write_status(user_id, run_id, "RUNNING")
|
||||||
|
|
||||||
def stop_strategy(user_id: str):
|
def stop_strategy(user_id: str):
|
||||||
run_id = _effective_running_run_id(user_id)
|
run_id = _effective_running_run_id(user_id)
|
||||||
if not run_id:
|
if not run_id:
|
||||||
latest_run_id = get_active_run_id(user_id)
|
latest_run_id = get_running_run_id(user_id) or get_active_run_id(user_id)
|
||||||
return {"status": "already_stopped", "run_id": latest_run_id}
|
return {"status": "already_stopped", "run_id": latest_run_id}
|
||||||
|
|
||||||
engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"}
|
engine_external = os.getenv("ENGINE_EXTERNAL", "").strip().lower() in {"1", "true", "yes"}
|
||||||
@ -681,7 +693,14 @@ def stop_strategy(user_id: str):
|
|||||||
print(f"[STRATEGY] engine status update failed during stop for {user_id}/{run_id}: {exc}", flush=True)
|
print(f"[STRATEGY] engine status update failed during stop for {user_id}/{run_id}: {exc}", flush=True)
|
||||||
if not stop_warning:
|
if not stop_warning:
|
||||||
stop_warning = str(exc)
|
stop_warning = str(exc)
|
||||||
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "user_request"})
|
try:
|
||||||
|
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "user_request"})
|
||||||
|
except RuntimeError as exc:
|
||||||
|
return {
|
||||||
|
"status": "stop_failed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": str(exc),
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = get_user_by_id(user_id)
|
user = get_user_by_id(user_id)
|
||||||
@ -704,6 +723,8 @@ def resume_strategy(user_id: str):
|
|||||||
return {"status": "already_running", "run_id": running_run_id}
|
return {"status": "already_running", "run_id": running_run_id}
|
||||||
|
|
||||||
run_id = get_active_run_id(user_id)
|
run_id = get_active_run_id(user_id)
|
||||||
|
if not run_id:
|
||||||
|
return {"status": "no_resumable_run"}
|
||||||
cfg = _load_config(user_id, run_id)
|
cfg = _load_config(user_id, run_id)
|
||||||
strategy_name = (cfg.get("strategy") or "").strip()
|
strategy_name = (cfg.get("strategy") or "").strip()
|
||||||
mode = (cfg.get("mode") or "").strip().upper()
|
mode = (cfg.get("mode") or "").strip().upper()
|
||||||
@ -737,16 +758,26 @@ def resume_strategy(user_id: str):
|
|||||||
}
|
}
|
||||||
|
|
||||||
reactivate_strategy_config(user_id, run_id)
|
reactivate_strategy_config(user_id, run_id)
|
||||||
update_run_status(user_id, run_id, "RUNNING", meta={"reason": "user_resume"})
|
try:
|
||||||
|
_set_run_status_or_raise(user_id, run_id, "RUNNING", meta={"reason": "user_resume"})
|
||||||
|
except RuntimeError as exc:
|
||||||
|
deactivate_strategy_config(user_id, run_id)
|
||||||
|
return {
|
||||||
|
"status": "resume_failed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": str(exc),
|
||||||
|
}
|
||||||
_write_status(user_id, run_id, "RUNNING")
|
_write_status(user_id, run_id, "RUNNING")
|
||||||
|
|
||||||
if not engine_external:
|
if not engine_external:
|
||||||
try:
|
try:
|
||||||
started = start_engine(engine_config)
|
started = start_engine(engine_config)
|
||||||
|
except RunLeaseNotAcquiredError:
|
||||||
|
return {"status": "already_running", "run_id": run_id}
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
deactivate_strategy_config(user_id, run_id)
|
deactivate_strategy_config(user_id, run_id)
|
||||||
_write_status(user_id, run_id, "STOPPED")
|
_write_status(user_id, run_id, "STOPPED")
|
||||||
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
|
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
|
||||||
return {
|
return {
|
||||||
"status": "resume_failed",
|
"status": "resume_failed",
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
@ -755,7 +786,7 @@ def resume_strategy(user_id: str):
|
|||||||
if not started:
|
if not started:
|
||||||
deactivate_strategy_config(user_id, run_id)
|
deactivate_strategy_config(user_id, run_id)
|
||||||
_write_status(user_id, run_id, "STOPPED")
|
_write_status(user_id, run_id, "STOPPED")
|
||||||
update_run_status(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
|
_set_run_status_or_raise(user_id, run_id, "STOPPED", meta={"reason": "resume_start_failed"})
|
||||||
return {
|
return {
|
||||||
"status": "resume_failed",
|
"status": "resume_failed",
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
@ -797,7 +828,7 @@ def get_strategy_status(user_id: str):
|
|||||||
else:
|
else:
|
||||||
status = {
|
status = {
|
||||||
"status": default_status,
|
"status": default_status,
|
||||||
"last_updated": _format_local_ts(engine_row[1]),
|
"last_updated": serialize_timestamp(engine_row[1]),
|
||||||
}
|
}
|
||||||
status["run_id"] = run_id
|
status["run_id"] = run_id
|
||||||
engine_state = str((engine_row or [None])[0] or "").strip().upper()
|
engine_state = str((engine_row or [None])[0] or "").strip().upper()
|
||||||
@ -832,17 +863,9 @@ def get_strategy_status(user_id: str):
|
|||||||
status["last_execution_ts"] = last_execution_ts
|
status["last_execution_ts"] = last_execution_ts
|
||||||
status["next_eligible_ts"] = next_eligible
|
status["next_eligible_ts"] = next_eligible
|
||||||
if next_eligible:
|
if next_eligible:
|
||||||
try:
|
parsed_next = parse_persisted_timestamp(next_eligible)
|
||||||
parsed_next = datetime.fromisoformat(next_eligible)
|
if parsed_next and parsed_next > _utc_now():
|
||||||
now_cmp = (
|
|
||||||
datetime.now(parsed_next.tzinfo)
|
|
||||||
if parsed_next.tzinfo
|
|
||||||
else market_now().replace(tzinfo=None)
|
|
||||||
)
|
|
||||||
if parsed_next > now_cmp:
|
|
||||||
status["status"] = "WAITING"
|
status["status"] = "WAITING"
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
status_key = (status.get("status") or "IDLE").upper()
|
status_key = (status.get("status") or "IDLE").upper()
|
||||||
resumable = bool(cfg.get("strategy")) and bool(cfg.get("mode"))
|
resumable = bool(cfg.get("strategy")) and bool(cfg.get("mode"))
|
||||||
status["can_resume"] = resumable and status_key in {"STOPPED", "PAUSED_AUTH_EXPIRED"}
|
status["can_resume"] = resumable and status_key in {"STOPPED", "PAUSED_AUTH_EXPIRED"}
|
||||||
@ -876,11 +899,7 @@ def get_engine_status(user_id: str):
|
|||||||
status["state"] = row[0]
|
status["state"] = row[0]
|
||||||
last_updated = row[1]
|
last_updated = row[1]
|
||||||
if last_updated is not None:
|
if last_updated is not None:
|
||||||
status["last_heartbeat_ts"] = (
|
status["last_heartbeat_ts"] = serialize_timestamp(last_updated)
|
||||||
last_updated.astimezone(timezone.utc)
|
|
||||||
.isoformat()
|
|
||||||
.replace("+00:00", "Z")
|
|
||||||
)
|
|
||||||
cfg = _load_config(user_id, run_id)
|
cfg = _load_config(user_id, run_id)
|
||||||
mode = (cfg.get("mode") or "LIVE").strip().upper()
|
mode = (cfg.get("mode") or "LIVE").strip().upper()
|
||||||
with engine_context(user_id, run_id):
|
with engine_context(user_id, run_id):
|
||||||
@ -926,10 +945,7 @@ def get_strategy_logs(user_id: str, since_seq: int):
|
|||||||
events = []
|
events = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
ts = row[1]
|
ts = row[1]
|
||||||
if ts is not None:
|
ts_str = serialize_timestamp(ts)
|
||||||
ts_str = ts.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
||||||
else:
|
|
||||||
ts_str = None
|
|
||||||
events.append(
|
events.append(
|
||||||
{
|
{
|
||||||
"seq": row[0],
|
"seq": row[0],
|
||||||
@ -980,6 +996,16 @@ def _issue_message(event: str, message: str | None, data: dict | None, meta: dic
|
|||||||
if event == "ENGINE_ERROR":
|
if event == "ENGINE_ERROR":
|
||||||
return message or "Strategy engine hit an error."
|
return message or "Strategy engine hit an error."
|
||||||
if event == "EXECUTION_BLOCKED":
|
if event == "EXECUTION_BLOCKED":
|
||||||
|
if reason_key == "market_holiday":
|
||||||
|
return "Exchange holiday. Execution will resume next session."
|
||||||
|
if reason_key == "market_weekend":
|
||||||
|
return "Weekend closure. Execution will resume next session."
|
||||||
|
if reason_key == "market_pre_open":
|
||||||
|
return "Market has not opened yet. Execution will begin after 9:15 AM IST."
|
||||||
|
if reason_key == "market_post_close":
|
||||||
|
return "Market is closed for the day. Execution will resume next session."
|
||||||
|
if reason_key == "market_calendar_unavailable":
|
||||||
|
return "Market calendar unavailable. Execution paused for safety."
|
||||||
if reason_key == "market_closed":
|
if reason_key == "market_closed":
|
||||||
return "Market is closed. Execution will resume next session."
|
return "Market is closed. Execution will resume next session."
|
||||||
return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}."
|
return f"Execution blocked: {_humanize_reason(reason) or 'Unknown reason'}."
|
||||||
@ -1019,8 +1045,17 @@ def _issue_is_stale_for_current_state(
|
|||||||
}:
|
}:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if event == "EXECUTION_BLOCKED" and reason_key == "market_closed":
|
if event == "EXECUTION_BLOCKED" and reason_key.startswith("market_"):
|
||||||
return is_market_open(market_now())
|
current_session = market_session(market_now())
|
||||||
|
current_reason = str(current_session.get("reason") or "").strip().lower()
|
||||||
|
current_status = str(current_session.get("status") or "").strip().upper()
|
||||||
|
if reason_key == "market_holiday":
|
||||||
|
return current_status != "HOLIDAY"
|
||||||
|
if reason_key == "market_calendar_unavailable":
|
||||||
|
return current_reason != "calendar_unavailable"
|
||||||
|
if reason_key in {"market_weekend", "market_pre_open", "market_post_close", "market_closed"}:
|
||||||
|
return current_status == "OPEN"
|
||||||
|
return False
|
||||||
|
|
||||||
if mode != "LIVE":
|
if mode != "LIVE":
|
||||||
return False
|
return False
|
||||||
@ -1085,7 +1120,7 @@ def get_strategy_summary(user_id: str):
|
|||||||
"tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning",
|
"tone": "error" if event in {"ENGINE_ERROR", "ORDER_REJECTED"} else "warning",
|
||||||
"message": _issue_message(event, message, data, meta),
|
"message": _issue_message(event, message, data, meta),
|
||||||
"event": event,
|
"event": event,
|
||||||
"ts": _format_local_ts(ts),
|
"ts": serialize_timestamp(ts),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
@ -1119,7 +1154,17 @@ def get_strategy_summary(user_id: str):
|
|||||||
|
|
||||||
def get_market_status():
|
def get_market_status():
|
||||||
now = market_now()
|
now = market_now()
|
||||||
|
session = market_session(now)
|
||||||
|
status = str(session.get("status") or "CLOSED")
|
||||||
|
reason = str(session.get("reason") or "")
|
||||||
|
next_open_at = None
|
||||||
|
try:
|
||||||
|
next_open_at = serialize_timestamp(next_market_open_after(now))
|
||||||
|
except UnsupportedCalendarYearError:
|
||||||
|
next_open_at = None
|
||||||
return {
|
return {
|
||||||
"status": "OPEN" if is_market_open(now) else "CLOSED",
|
"status": status,
|
||||||
"checked_at": now.isoformat(),
|
"reason": reason,
|
||||||
|
"checked_at": serialize_timestamp(now),
|
||||||
|
"next_open_at": next_open_at,
|
||||||
}
|
}
|
||||||
|
|||||||
224
backend/app/services/support_abuse.py
Normal file
224
backend/app/services/support_abuse.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Deque
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.services.db import db_connection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MEMORY_LOCK = threading.Lock()
|
||||||
|
_MEMORY_EVENTS: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
class SupportGuardRejected(Exception):
|
||||||
|
def __init__(self, status_code: int, detail: str):
|
||||||
|
super().__init__(detail)
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
|
||||||
|
|
||||||
|
def _now_utc() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256(value: str | None) -> str | None:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return hashlib.sha256(value.strip().lower().encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _backend_mode() -> str:
|
||||||
|
return (os.getenv("SUPPORT_GUARD_BACKEND") or "db").strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _window() -> timedelta:
|
||||||
|
return timedelta(seconds=int(os.getenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_limit() -> int:
|
||||||
|
return int(os.getenv("SUPPORT_CREATE_LIMIT", "5"))
|
||||||
|
|
||||||
|
|
||||||
|
def _status_limit() -> int:
|
||||||
|
return int(os.getenv("SUPPORT_STATUS_LIMIT", "15"))
|
||||||
|
|
||||||
|
|
||||||
|
def _ticket_probe_limit() -> int:
|
||||||
|
return int(os.getenv("SUPPORT_STATUS_TICKET_LIMIT", "10"))
|
||||||
|
|
||||||
|
|
||||||
|
def _captcha_secret() -> str | None:
|
||||||
|
return (os.getenv("SUPPORT_CAPTCHA_SECRET") or "").strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
def _request_ip(request: Request) -> str:
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
first = forwarded.split(",")[0].strip()
|
||||||
|
if first:
|
||||||
|
return first
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_captcha(captcha_token: str | None) -> None:
|
||||||
|
secret = _captcha_secret()
|
||||||
|
if secret and captcha_token != secret:
|
||||||
|
raise SupportGuardRejected(403, "Support verification failed")
|
||||||
|
|
||||||
|
|
||||||
|
def _record_memory_event(record: dict) -> None:
|
||||||
|
cutoff = _now_utc() - _window()
|
||||||
|
with _MEMORY_LOCK:
|
||||||
|
_MEMORY_EVENTS[:] = [entry for entry in _MEMORY_EVENTS if entry["created_at"] >= cutoff]
|
||||||
|
_MEMORY_EVENTS.append(record)
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
||||||
|
with _MEMORY_LOCK:
|
||||||
|
ip_count = sum(
|
||||||
|
1
|
||||||
|
for entry in _MEMORY_EVENTS
|
||||||
|
if entry["endpoint"] == endpoint
|
||||||
|
and entry["created_at"] >= cutoff
|
||||||
|
and entry["ip_hash"] == ip_hash
|
||||||
|
)
|
||||||
|
ticket_count = sum(
|
||||||
|
1
|
||||||
|
for entry in _MEMORY_EVENTS
|
||||||
|
if entry["endpoint"] == endpoint
|
||||||
|
and entry["created_at"] >= cutoff
|
||||||
|
and entry["ticket_hash"] == ticket_hash
|
||||||
|
) if ticket_hash else 0
|
||||||
|
return ip_count, ticket_count
|
||||||
|
|
||||||
|
|
||||||
|
def _record_db_event(record: dict) -> None:
|
||||||
|
with db_connection() as conn:
|
||||||
|
with conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO support_request_audit (
|
||||||
|
endpoint, ip_hash, email_hash, ticket_hash, blocked, reason, created_at
|
||||||
|
)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
record["endpoint"],
|
||||||
|
record["ip_hash"],
|
||||||
|
record["email_hash"],
|
||||||
|
record["ticket_hash"],
|
||||||
|
record["blocked"],
|
||||||
|
record["reason"],
|
||||||
|
record["created_at"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _db_count(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
||||||
|
with db_connection() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM support_request_audit
|
||||||
|
WHERE endpoint = %s
|
||||||
|
AND ip_hash IS NOT DISTINCT FROM %s
|
||||||
|
AND created_at >= %s
|
||||||
|
""",
|
||||||
|
(endpoint, ip_hash, cutoff),
|
||||||
|
)
|
||||||
|
ip_count = cur.fetchone()[0] or 0
|
||||||
|
ticket_count = 0
|
||||||
|
if ticket_hash:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM support_request_audit
|
||||||
|
WHERE endpoint = %s
|
||||||
|
AND ticket_hash = %s
|
||||||
|
AND created_at >= %s
|
||||||
|
""",
|
||||||
|
(endpoint, ticket_hash, cutoff),
|
||||||
|
)
|
||||||
|
ticket_count = cur.fetchone()[0] or 0
|
||||||
|
return ip_count, ticket_count
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_limits(endpoint: str, ip_count: int, ticket_count: int) -> str | None:
|
||||||
|
if endpoint == "ticket_create" and ip_count >= _create_limit():
|
||||||
|
return "create_rate_limited"
|
||||||
|
if endpoint == "ticket_status" and ip_count >= _status_limit():
|
||||||
|
return "status_rate_limited"
|
||||||
|
if endpoint == "ticket_status" and ticket_count >= _ticket_probe_limit():
|
||||||
|
return "ticket_probe_limited"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _audit_attempt(record: dict) -> None:
|
||||||
|
if _backend_mode() == "memory":
|
||||||
|
_record_memory_event(record)
|
||||||
|
return
|
||||||
|
_record_db_event(record)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_recent(endpoint: str, ip_hash: str | None, ticket_hash: str | None, cutoff: datetime) -> tuple[int, int]:
|
||||||
|
if _backend_mode() == "memory":
|
||||||
|
return _memory_count(endpoint, ip_hash, ticket_hash, cutoff)
|
||||||
|
return _db_count(endpoint, ip_hash, ticket_hash, cutoff)
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_support_guard(
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
endpoint: str,
|
||||||
|
email: str | None = None,
|
||||||
|
ticket_id: str | None = None,
|
||||||
|
captcha_token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
_validate_captcha(captcha_token)
|
||||||
|
|
||||||
|
now = _now_utc()
|
||||||
|
cutoff = now - _window()
|
||||||
|
ip_hash = _sha256(_request_ip(request))
|
||||||
|
email_hash = _sha256(email)
|
||||||
|
ticket_hash = _sha256(ticket_id)
|
||||||
|
|
||||||
|
ip_count, ticket_count = _count_recent(endpoint, ip_hash, ticket_hash, cutoff)
|
||||||
|
reason = _determine_limits(endpoint, ip_count, ticket_count)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"ip_hash": ip_hash,
|
||||||
|
"email_hash": email_hash,
|
||||||
|
"ticket_hash": ticket_hash,
|
||||||
|
"blocked": reason is not None,
|
||||||
|
"reason": reason,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
_audit_attempt(record)
|
||||||
|
|
||||||
|
if reason is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Support request blocked",
|
||||||
|
extra={
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"reason": reason,
|
||||||
|
"ip_hash": ip_hash,
|
||||||
|
"ticket_hash": ticket_hash,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise SupportGuardRejected(429, "Too many support requests. Please try again later.")
|
||||||
|
|
||||||
|
|
||||||
|
def reset_memory_support_guard_state() -> None:
|
||||||
|
with _MEMORY_LOCK:
|
||||||
|
_MEMORY_EVENTS.clear()
|
||||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from app.services.db import db_connection
|
from app.services.db import db_connection
|
||||||
from app.services.email_service import send_email
|
from app.services.email_service import send_email
|
||||||
|
from indian_paper_trading_strategy.engine.time_utils import serialize_timestamp
|
||||||
|
|
||||||
|
|
||||||
def _now():
|
def _now():
|
||||||
@ -41,7 +42,7 @@ def create_ticket(name: str, email: str, subject: str, message: str) -> dict:
|
|||||||
return {
|
return {
|
||||||
"ticket_id": ticket_id,
|
"ticket_id": ticket_id,
|
||||||
"status": "NEW",
|
"status": "NEW",
|
||||||
"created_at": now.isoformat(),
|
"created_at": serialize_timestamp(now),
|
||||||
"email_sent": email_sent,
|
"email_sent": email_sent,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +66,6 @@ def get_ticket_status(ticket_id: str, email: str) -> dict | None:
|
|||||||
return {
|
return {
|
||||||
"ticket_id": row[0],
|
"ticket_id": row[0],
|
||||||
"status": row[2],
|
"status": row[2],
|
||||||
"created_at": row[3].isoformat() if row[3] else None,
|
"created_at": serialize_timestamp(row[3]) if row[3] else None,
|
||||||
"updated_at": row[4].isoformat() if row[4] else None,
|
"updated_at": serialize_timestamp(row[4]) if row[4] else None,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from datetime import datetime, timezone
|
|||||||
|
|
||||||
from psycopg2.extras import Json
|
from psycopg2.extras import Json
|
||||||
|
|
||||||
|
from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp
|
||||||
|
|
||||||
from app.broker_store import get_user_broker, set_broker_auth_state
|
from app.broker_store import get_user_broker, set_broker_auth_state
|
||||||
from app.services.db import db_connection
|
from app.services.db import db_connection
|
||||||
from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds
|
from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds
|
||||||
@ -59,12 +61,7 @@ def _resolve_sip_frequency(row: dict):
|
|||||||
|
|
||||||
|
|
||||||
def _parse_ts(value: str | None):
|
def _parse_ts(value: str | None):
|
||||||
if not value:
|
return parse_persisted_timestamp(value)
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return datetime.fromisoformat(value)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_broker_session(user_id: str):
|
def _validate_broker_session(user_id: str):
|
||||||
@ -180,7 +177,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
sip_frequency = _resolve_sip_frequency(run)
|
sip_frequency = _resolve_sip_frequency(run)
|
||||||
last_run = now.isoformat()
|
last_run = serialize_timestamp(now)
|
||||||
next_run = compute_next_eligible(last_run, sip_frequency)
|
next_run = compute_next_eligible(last_run, sip_frequency)
|
||||||
next_run_dt = _parse_ts(next_run)
|
next_run_dt = _parse_ts(next_run)
|
||||||
|
|
||||||
@ -195,7 +192,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
now,
|
now,
|
||||||
Json({"armed_at": now.isoformat()}),
|
Json({"armed_at": serialize_timestamp(now)}),
|
||||||
user_id,
|
user_id,
|
||||||
run["run_id"],
|
run["run_id"],
|
||||||
),
|
),
|
||||||
@ -339,7 +336,7 @@ def arm_system(user_id: str, client_ip: str | None = None):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
broker_state = get_user_broker(user_id) or {}
|
broker_state = get_user_broker(user_id) or {}
|
||||||
next_execution = min(next_runs).isoformat() if next_runs else None
|
next_execution = serialize_timestamp(min(next_runs)) if next_runs else None
|
||||||
return {
|
return {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"armed_runs": armed_runs,
|
"armed_runs": armed_runs,
|
||||||
@ -378,7 +375,7 @@ def system_status(user_id: str):
|
|||||||
"strategy": row[2],
|
"strategy": row[2],
|
||||||
"mode": row[3],
|
"mode": row[3],
|
||||||
"broker": row[4],
|
"broker": row[4],
|
||||||
"next_run": row[5].isoformat() if row[5] else None,
|
"next_run": serialize_timestamp(row[5]),
|
||||||
"active": bool(row[6]) if row[6] is not None else False,
|
"active": bool(row[6]) if row[6] is not None else False,
|
||||||
"lifecycle": row[1],
|
"lifecycle": row[1],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from app.services.auth_service import get_user_for_session
|
from app.services.auth_service import get_user_for_session
|
||||||
from app.services.run_service import get_default_user_id
|
|
||||||
|
|
||||||
SESSION_COOKIE_NAME = "session_id"
|
SESSION_COOKIE_NAME = "session_id"
|
||||||
|
|
||||||
@ -13,7 +12,4 @@ def get_request_user_id(request: Request) -> str:
|
|||||||
if user:
|
if user:
|
||||||
return user["id"]
|
return user["id"]
|
||||||
|
|
||||||
default_user_id = get_default_user_id()
|
|
||||||
if default_user_id:
|
|
||||||
return default_user_id
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|||||||
@ -27,11 +27,13 @@ class KitePermissionError(KiteApiError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def build_login_url(api_key: str, redirect_url: str | None = None) -> str:
|
def build_login_url(api_key: str, redirect_url: str | None = None, state: str | None = None) -> str:
|
||||||
params = {"api_key": api_key, "v": KITE_VERSION}
|
params = {"api_key": api_key, "v": KITE_VERSION}
|
||||||
redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
|
redirect_url = (redirect_url or os.getenv("ZERODHA_REDIRECT_URL") or "").strip()
|
||||||
if redirect_url:
|
if redirect_url:
|
||||||
params["redirect_url"] = redirect_url
|
params["redirect_url"] = redirect_url
|
||||||
|
if state:
|
||||||
|
params["state"] = state
|
||||||
query = urllib.parse.urlencode(params)
|
query = urllib.parse.urlencode(params)
|
||||||
return f"{KITE_LOGIN_URL}?{query}"
|
return f"{KITE_LOGIN_URL}?{query}"
|
||||||
|
|
||||||
|
|||||||
@ -41,3 +41,4 @@ websockets==16.0
|
|||||||
yfinance==1.0
|
yfinance==1.0
|
||||||
alembic==1.13.3
|
alembic==1.13.3
|
||||||
pytest==8.3.5
|
pytest==8.3.5
|
||||||
|
argon2-cffi==25.1.0
|
||||||
|
|||||||
14
backend/tests/conftest.py
Normal file
14
backend/tests/conftest.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(BACKEND_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(BACKEND_ROOT))
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
os.environ.setdefault("RESET_OTP_SECRET", "test-reset-secret")
|
||||||
111
backend/tests/test_api_semantics_and_utc.py
Normal file
111
backend/tests/test_api_semantics_and_utc.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import importlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(monkeypatch):
|
||||||
|
monkeypatch.setenv("APP_ENV", "test")
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
monkeypatch.setenv("DB_NAME", "trading_db")
|
||||||
|
monkeypatch.setenv("DB_USER", "trader")
|
||||||
|
monkeypatch.setenv("DB_PASSWORD", "test-password")
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
|
|
||||||
|
import app.main as app_main
|
||||||
|
|
||||||
|
importlib.reload(app_main)
|
||||||
|
return app_main.create_app()
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_stop_failure_returns_non_200(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.routers.strategy as strategy_router
|
||||||
|
|
||||||
|
monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
strategy_router,
|
||||||
|
"stop_strategy",
|
||||||
|
lambda _user_id: {"status": "stop_failed", "message": "run missing", "run_id": "run-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/api/strategy/stop", cookies={"session_id": "session-1"})
|
||||||
|
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert response.json()["detail"]["status"] == "stop_failed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_resume_failure_returns_non_200(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.routers.strategy as strategy_router
|
||||||
|
|
||||||
|
monkeypatch.setattr(strategy_router, "get_request_user_id", lambda _request: "user-1")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
strategy_router,
|
||||||
|
"resume_strategy",
|
||||||
|
lambda _user_id: {"status": "resume_failed", "message": "engine start failed", "run_id": "run-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/api/strategy/resume", cookies={"session_id": "session-1"})
|
||||||
|
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert response.json()["detail"]["status"] == "resume_failed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_run_status_returns_false_when_no_row_changes(monkeypatch):
|
||||||
|
import app.services.run_service as run_service
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
rowcount = 0
|
||||||
|
|
||||||
|
def execute(self, sql, params):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(run_service, "run_with_retry", lambda op: op(FakeCursor(), None))
|
||||||
|
|
||||||
|
updated = run_service.update_run_status("user-1", "missing-run", "STOPPED")
|
||||||
|
|
||||||
|
assert updated is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_emits_utc_iso_timestamps():
|
||||||
|
from app.services.strategy_service import get_market_status
|
||||||
|
from indian_paper_trading_strategy.engine.time_utils import (
|
||||||
|
parse_persisted_timestamp,
|
||||||
|
serialize_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp = get_market_status()["checked_at"]
|
||||||
|
parsed = parse_persisted_timestamp(timestamp)
|
||||||
|
|
||||||
|
assert timestamp.endswith("+00:00")
|
||||||
|
assert parsed.tzinfo is not None
|
||||||
|
assert serialize_timestamp(parsed) == timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def test_bootstrap_schema_contains_migrated_core_columns_and_tables():
|
||||||
|
schema_path = Path(__file__).resolve().parents[2] / ".." / "SIP_GoldBees_Database" / "schema.sql"
|
||||||
|
schema_sql = schema_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
required_snippets = [
|
||||||
|
"ALTER TABLE app_session",
|
||||||
|
"ADD COLUMN IF NOT EXISTS ip TEXT",
|
||||||
|
"ADD COLUMN IF NOT EXISTS user_agent TEXT",
|
||||||
|
"ALTER TABLE user_broker",
|
||||||
|
"ADD COLUMN IF NOT EXISTS api_secret TEXT",
|
||||||
|
"ADD COLUMN IF NOT EXISTS auth_state TEXT",
|
||||||
|
"CREATE TABLE IF NOT EXISTS broker_callback_state",
|
||||||
|
"CREATE TABLE IF NOT EXISTS execution_claim",
|
||||||
|
"CREATE TABLE IF NOT EXISTS run_leases",
|
||||||
|
"CREATE TABLE IF NOT EXISTS support_request_audit",
|
||||||
|
]
|
||||||
|
|
||||||
|
for snippet in required_snippets:
|
||||||
|
assert snippet in schema_sql
|
||||||
128
backend/tests/test_auth_isolation_and_cors.py
Normal file
128
backend/tests/test_auth_isolation_and_cors.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(monkeypatch, *, app_env="test", cors_origins=None):
|
||||||
|
monkeypatch.setenv("APP_ENV", app_env)
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
monkeypatch.setenv("DB_NAME", "trading_db")
|
||||||
|
monkeypatch.setenv("DB_USER", "trader")
|
||||||
|
monkeypatch.setenv("DB_PASSWORD", "test-password")
|
||||||
|
if cors_origins is None:
|
||||||
|
monkeypatch.delenv("CORS_ORIGINS", raising=False)
|
||||||
|
else:
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", cors_origins)
|
||||||
|
|
||||||
|
import app.main as app_main
|
||||||
|
|
||||||
|
importlib.reload(app_main)
|
||||||
|
return app_main.create_app()
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_status_requires_auth(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/api/strategy/status")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"detail": "Not authenticated"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_stop_requires_auth(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.post("/api/strategy/stop")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"detail": "Not authenticated"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_routes_use_session_identity_only(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.routers.strategy as strategy_router
|
||||||
|
import app.services.tenant as tenant
|
||||||
|
|
||||||
|
monkeypatch.setattr(tenant, "get_user_for_session", lambda _sid: {"id": "user-a"})
|
||||||
|
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
def fake_get_strategy_status(user_id):
|
||||||
|
seen["status_user_id"] = user_id
|
||||||
|
return {"user_id": user_id}
|
||||||
|
|
||||||
|
def fake_stop_strategy(user_id):
|
||||||
|
seen["stop_user_id"] = user_id
|
||||||
|
return {"status": "stopped", "user_id": user_id}
|
||||||
|
|
||||||
|
monkeypatch.setattr(strategy_router, "get_strategy_status", fake_get_strategy_status)
|
||||||
|
monkeypatch.setattr(strategy_router, "stop_strategy", fake_stop_strategy)
|
||||||
|
|
||||||
|
status_response = client.get(
|
||||||
|
"/api/strategy/status?user_id=user-b",
|
||||||
|
cookies={"session_id": "session-a"},
|
||||||
|
)
|
||||||
|
stop_response = client.post(
|
||||||
|
"/api/strategy/stop?user_id=user-b",
|
||||||
|
cookies={"session_id": "session-a"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status_response.status_code == 200
|
||||||
|
assert stop_response.status_code == 200
|
||||||
|
assert status_response.json()["user_id"] == "user-a"
|
||||||
|
assert stop_response.json()["user_id"] == "user-a"
|
||||||
|
assert seen == {"status_user_id": "user-a", "stop_user_id": "user-a"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_origin_preflight_supports_credentials(monkeypatch):
|
||||||
|
app = _build_app(
|
||||||
|
monkeypatch,
|
||||||
|
app_env="production",
|
||||||
|
cors_origins="https://app.quantfortune.com",
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.options(
|
||||||
|
"/api/strategy/status",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://app.quantfortune.com",
|
||||||
|
"Access-Control-Request-Method": "GET",
|
||||||
|
"Access-Control-Request-Headers": "content-type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["access-control-allow-origin"] == "https://app.quantfortune.com"
|
||||||
|
assert response.headers["access-control-allow-credentials"] == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def test_arbitrary_origin_preflight_is_rejected(monkeypatch):
|
||||||
|
app = _build_app(
|
||||||
|
monkeypatch,
|
||||||
|
app_env="production",
|
||||||
|
cors_origins="https://app.quantfortune.com",
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.options(
|
||||||
|
"/api/strategy/status",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://evil.example",
|
||||||
|
"Access-Control-Request-Method": "GET",
|
||||||
|
"Access-Control-Request-Headers": "content-type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "access-control-allow-origin" not in response.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_production_without_cors_origins_fails_closed(monkeypatch):
|
||||||
|
with pytest.raises(RuntimeError, match="CORS_ORIGINS must be configured explicitly in production"):
|
||||||
|
_build_app(monkeypatch, app_env="production", cors_origins=None)
|
||||||
120
backend/tests/test_execution_claims.py
Normal file
120
backend/tests/test_execution_claims.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from indian_paper_trading_strategy.engine import execution, ledger
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_execution_window_allows_only_one_winner(monkeypatch):
|
||||||
|
claims: set[tuple[str, str, datetime]] = set()
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
def __init__(self):
|
||||||
|
self._result = None
|
||||||
|
|
||||||
|
def execute(self, sql, params):
|
||||||
|
assert "INSERT INTO execution_claim" in sql
|
||||||
|
key = (params[1], params[2], params[4])
|
||||||
|
with lock:
|
||||||
|
if key in claims:
|
||||||
|
self._result = None
|
||||||
|
else:
|
||||||
|
claims.add(key)
|
||||||
|
self._result = (params[0],)
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
monkeypatch.setattr(ledger, "get_context", lambda user_id=None, run_id=None: ("user-a", "run-a"))
|
||||||
|
monkeypatch.setattr(ledger, "run_with_retry", lambda op, retries=None, delay=None: op(FakeCursor(), None))
|
||||||
|
|
||||||
|
logical_time = datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc)
|
||||||
|
results: list[bool] = []
|
||||||
|
|
||||||
|
def attempt():
|
||||||
|
results.append(ledger.claim_execution_window(logical_time, mode="LIVE"))
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=attempt) for _ in range(2)]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert sorted(results) == [False, True]
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_execute_sip_live_emits_one_order_batch_when_claim_is_lost(monkeypatch):
|
||||||
|
claim_lock = threading.Lock()
|
||||||
|
claimed = False
|
||||||
|
order_calls: list[tuple[str, str, float]] = []
|
||||||
|
|
||||||
|
class FakeBroker:
|
||||||
|
external_orders = True
|
||||||
|
|
||||||
|
def get_funds(self):
|
||||||
|
return {"cash": 10_000.0}
|
||||||
|
|
||||||
|
def place_order(self, symbol, side, qty, price, logical_time=None):
|
||||||
|
order_calls.append((symbol, side, qty))
|
||||||
|
return {
|
||||||
|
"id": f"{symbol}-{len(order_calls)}",
|
||||||
|
"symbol": symbol,
|
||||||
|
"side": side,
|
||||||
|
"status": "COMPLETE",
|
||||||
|
"filled_qty": qty,
|
||||||
|
"average_price": price,
|
||||||
|
"price": price,
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_claim_execution_window(logical_time, *, mode=None, cur=None, user_id=None, run_id=None):
|
||||||
|
nonlocal claimed
|
||||||
|
with claim_lock:
|
||||||
|
if claimed:
|
||||||
|
return False
|
||||||
|
claimed = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr(execution, "load_state", lambda *args, **kwargs: {"last_sip_ts": None, "last_run": None})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
execution,
|
||||||
|
"_resolve_timing",
|
||||||
|
lambda state, now_ts, sip_interval: (True, None, datetime(2026, 4, 8, 9, 15, tzinfo=timezone.utc)),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(execution, "event_exists", lambda *args, **kwargs: False)
|
||||||
|
monkeypatch.setattr(execution, "claim_execution_window", fake_claim_execution_window)
|
||||||
|
monkeypatch.setattr(execution, "log_event", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(execution, "run_with_retry", lambda op, retries=None, delay=None: op(object(), None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
execution,
|
||||||
|
"_finalize_live_execution",
|
||||||
|
lambda **kwargs: ({"last_run": kwargs["now_ts"].isoformat()}, bool(kwargs["orders"])),
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[bool] = []
|
||||||
|
|
||||||
|
def attempt():
|
||||||
|
_state, executed = execution._try_execute_sip_live(
|
||||||
|
now=datetime(2026, 4, 8, 14, 45, tzinfo=timezone.utc),
|
||||||
|
market_open=True,
|
||||||
|
sip_interval=120,
|
||||||
|
sip_amount=1000.0,
|
||||||
|
sp_price=125.0,
|
||||||
|
gd_price=250.0,
|
||||||
|
eq_w=0.5,
|
||||||
|
gd_w=0.5,
|
||||||
|
broker=FakeBroker(),
|
||||||
|
mode="LIVE",
|
||||||
|
)
|
||||||
|
results.append(executed)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=attempt) for _ in range(2)]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert sorted(results) == [False, True]
|
||||||
|
assert len(order_calls) == 2
|
||||||
|
assert {call[0] for call in order_calls} == {"NIFTYBEES.NS", "GOLDBEES.NS"}
|
||||||
102
backend/tests/test_market_calendar.py
Normal file
102
backend/tests/test_market_calendar.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from datetime import date, datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_market_calendar_open_until_close_boundary_exclusive():
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import (
|
||||||
|
get_market_session,
|
||||||
|
get_market_status,
|
||||||
|
is_market_open,
|
||||||
|
)
|
||||||
|
|
||||||
|
open_now = datetime(2026, 4, 2, 9, 59, 59, tzinfo=timezone.utc) # 15:29:59 IST
|
||||||
|
close_now = datetime(2026, 4, 2, 10, 0, 0, tzinfo=timezone.utc) # 15:30:00 IST
|
||||||
|
|
||||||
|
assert get_market_status(open_now) == "OPEN"
|
||||||
|
assert is_market_open(open_now) is True
|
||||||
|
assert get_market_status(close_now) == "CLOSED"
|
||||||
|
assert is_market_open(close_now) is False
|
||||||
|
|
||||||
|
close_session = get_market_session(close_now)
|
||||||
|
assert close_session["reason"] == "POST_CLOSE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_market_calendar_holiday_during_trading_hours():
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import get_market_session, get_market_status
|
||||||
|
|
||||||
|
now_utc = datetime(2026, 4, 3, 4, 0, tzinfo=timezone.utc) # Good Friday, 09:30 IST
|
||||||
|
|
||||||
|
session = get_market_session(now_utc)
|
||||||
|
assert get_market_status(now_utc) == "HOLIDAY"
|
||||||
|
assert session["status"] == "HOLIDAY"
|
||||||
|
assert session["reason"] == "HOLIDAY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_market_calendar_utc_and_ist_inputs_match():
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import get_market_session
|
||||||
|
|
||||||
|
utc_now = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST
|
||||||
|
ist_now = utc_now.astimezone(ZoneInfo("Asia/Kolkata"))
|
||||||
|
|
||||||
|
assert get_market_session(utc_now) == get_market_session(ist_now)
|
||||||
|
|
||||||
|
|
||||||
|
def test_market_calendar_reason_codes_cover_session_windows():
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import get_market_session
|
||||||
|
|
||||||
|
pre_open = datetime(2026, 4, 2, 3, 0, tzinfo=timezone.utc) # 08:30 IST
|
||||||
|
in_session = datetime(2026, 4, 2, 4, 0, tzinfo=timezone.utc) # 09:30 IST
|
||||||
|
post_close = datetime(2026, 4, 2, 10, 1, tzinfo=timezone.utc) # 15:31 IST
|
||||||
|
weekend = datetime(2026, 4, 4, 4, 0, tzinfo=timezone.utc) # Saturday
|
||||||
|
|
||||||
|
assert get_market_session(pre_open)["reason"] == "PRE_OPEN"
|
||||||
|
assert get_market_session(in_session)["reason"] == "OPEN"
|
||||||
|
assert get_market_session(post_close)["reason"] == "POST_CLOSE"
|
||||||
|
assert get_market_session(weekend)["reason"] == "WEEKEND"
|
||||||
|
|
||||||
|
|
||||||
|
def test_supported_year_returns_expected_holidays():
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import get_nse_holidays
|
||||||
|
|
||||||
|
holidays = get_nse_holidays(2026)
|
||||||
|
|
||||||
|
assert date(2026, 4, 3) in holidays
|
||||||
|
assert date(2026, 1, 26) in holidays
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported_year_fails_closed():
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import (
|
||||||
|
UnsupportedCalendarYearError,
|
||||||
|
get_market_session,
|
||||||
|
next_market_open,
|
||||||
|
)
|
||||||
|
|
||||||
|
now_utc = datetime(2027, 1, 4, 4, 0, tzinfo=timezone.utc)
|
||||||
|
session = get_market_session(now_utc)
|
||||||
|
|
||||||
|
assert session["status"] == "CLOSED"
|
||||||
|
assert session["reason"] == "CALENDAR_UNAVAILABLE"
|
||||||
|
assert session["calendar_supported"] is False
|
||||||
|
|
||||||
|
with pytest.raises(UnsupportedCalendarYearError):
|
||||||
|
next_market_open(now_utc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_service_market_status_exposes_reason_and_holiday(monkeypatch):
|
||||||
|
import app.services.strategy_service as strategy_service
|
||||||
|
from indian_paper_trading_strategy.engine.market_calendar import MARKET_TZ
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
strategy_service,
|
||||||
|
"market_now",
|
||||||
|
lambda: datetime(2026, 4, 3, 9, 30, tzinfo=MARKET_TZ),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = strategy_service.get_market_status()
|
||||||
|
|
||||||
|
assert payload["status"] == "HOLIDAY"
|
||||||
|
assert payload["reason"] == "HOLIDAY"
|
||||||
|
assert payload["next_open_at"] == "2026-04-06T03:45:00+00:00"
|
||||||
205
backend/tests/test_runner_leases.py
Normal file
205
backend/tests/test_runner_leases.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from indian_paper_trading_strategy.engine import db as engine_db
|
||||||
|
from indian_paper_trading_strategy.engine import runner
|
||||||
|
|
||||||
|
|
||||||
|
class _LeaseCursor:
|
||||||
|
def __init__(self, leases: dict[str, dict], lock: threading.Lock):
|
||||||
|
self._leases = leases
|
||||||
|
self._lock = lock
|
||||||
|
self._result = None
|
||||||
|
|
||||||
|
def execute(self, sql, params):
|
||||||
|
sql_text = " ".join(sql.split())
|
||||||
|
with self._lock:
|
||||||
|
if sql_text.startswith("INSERT INTO run_leases"):
|
||||||
|
run_id, owner_id, leased_at, expires_at, heartbeat_at = params
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
if lease is None:
|
||||||
|
self._leases[run_id] = {
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"leased_at": leased_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"heartbeat_at": heartbeat_at,
|
||||||
|
}
|
||||||
|
self._result = (run_id,)
|
||||||
|
else:
|
||||||
|
self._result = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if "SELECT owner_id, expires_at FROM run_leases" in sql_text:
|
||||||
|
run_id = params[0]
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
self._result = None if lease is None else (lease["owner_id"], lease["expires_at"])
|
||||||
|
return
|
||||||
|
|
||||||
|
if sql_text.startswith("UPDATE run_leases SET leased_at"):
|
||||||
|
leased_at, expires_at, heartbeat_at, run_id, owner_id = params
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
if lease and lease["owner_id"] == owner_id:
|
||||||
|
lease.update(
|
||||||
|
{
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"leased_at": leased_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"heartbeat_at": heartbeat_at,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._result = (run_id,)
|
||||||
|
else:
|
||||||
|
self._result = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if sql_text.startswith("UPDATE run_leases SET owner_id"):
|
||||||
|
owner_id, leased_at, expires_at, heartbeat_at, run_id, current_time = params
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
if lease and lease["expires_at"] <= current_time:
|
||||||
|
lease.update(
|
||||||
|
{
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"leased_at": leased_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"heartbeat_at": heartbeat_at,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._result = (run_id,)
|
||||||
|
else:
|
||||||
|
self._result = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if sql_text.startswith("UPDATE run_leases SET heartbeat_at"):
|
||||||
|
heartbeat_at, expires_at, run_id, owner_id, current_time = params
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
if lease and lease["owner_id"] == owner_id and lease["expires_at"] > current_time:
|
||||||
|
lease.update({"heartbeat_at": heartbeat_at, "expires_at": expires_at})
|
||||||
|
self._result = (run_id, expires_at)
|
||||||
|
else:
|
||||||
|
self._result = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if sql_text.startswith("DELETE FROM run_leases"):
|
||||||
|
run_id, owner_id = params
|
||||||
|
lease = self._leases.get(run_id)
|
||||||
|
if lease and lease["owner_id"] == owner_id:
|
||||||
|
del self._leases[run_id]
|
||||||
|
self._result = (run_id,)
|
||||||
|
else:
|
||||||
|
self._result = None
|
||||||
|
return
|
||||||
|
|
||||||
|
raise AssertionError(f"Unexpected SQL: {sql_text}")
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_lease_storage(monkeypatch):
|
||||||
|
leases: dict[str, dict] = {}
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def fake_run_with_retry(operation, retries=None, delay=None):
|
||||||
|
cursor = _LeaseCursor(leases, lock)
|
||||||
|
return operation(cursor, None)
|
||||||
|
|
||||||
|
monkeypatch.setattr(engine_db, "run_with_retry", fake_run_with_retry)
|
||||||
|
return leases
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_lease_allows_only_one_active_owner(monkeypatch):
|
||||||
|
_patch_lease_storage(monkeypatch)
|
||||||
|
now = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def attempt(owner_id: str):
|
||||||
|
results.append(engine_db.acquire_run_lease("run-1", owner_id, lease_seconds=90, now=now))
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=attempt, args=("owner-a",)),
|
||||||
|
threading.Thread(target=attempt, args=("owner-b",)),
|
||||||
|
]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
acquired = [result for result in results if result["acquired"]]
|
||||||
|
denied = [result for result in results if not result["acquired"]]
|
||||||
|
assert len(acquired) == 1
|
||||||
|
assert len(denied) == 1
|
||||||
|
assert denied[0]["status"] == "DENIED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_expired_run_lease_can_be_reacquired(monkeypatch):
|
||||||
|
_patch_lease_storage(monkeypatch)
|
||||||
|
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
||||||
|
later = start + timedelta(seconds=91)
|
||||||
|
|
||||||
|
first = engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
|
||||||
|
second = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=later)
|
||||||
|
|
||||||
|
assert first["status"] == "ACQUIRED"
|
||||||
|
assert second["acquired"] is True
|
||||||
|
assert second["status"] == "REACQUIRED"
|
||||||
|
assert second["previous_owner"] == "owner-a"
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_prevents_takeover(monkeypatch):
|
||||||
|
_patch_lease_storage(monkeypatch)
|
||||||
|
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
||||||
|
heartbeat_time = start + timedelta(seconds=30)
|
||||||
|
challenger_time = start + timedelta(seconds=60)
|
||||||
|
|
||||||
|
engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
|
||||||
|
heartbeat = engine_db.heartbeat_run_lease("run-1", "owner-a", lease_seconds=90, now=heartbeat_time)
|
||||||
|
challenger = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=challenger_time)
|
||||||
|
|
||||||
|
assert heartbeat["active"] is True
|
||||||
|
assert challenger["acquired"] is False
|
||||||
|
assert challenger["status"] == "DENIED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_exits_cleanly_when_lease_is_lost(monkeypatch):
|
||||||
|
statuses: list[str] = []
|
||||||
|
cleared: list[tuple[str, str]] = []
|
||||||
|
released: list[tuple[str, str]] = []
|
||||||
|
refreshed = {"count": 0}
|
||||||
|
|
||||||
|
monkeypatch.setattr(runner, "get_context", lambda user_id=None, run_id=None: ("user-1", "run-1"))
|
||||||
|
monkeypatch.setattr(runner, "set_context", lambda user_id, run_id: None)
|
||||||
|
monkeypatch.setattr(runner, "log_event", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(runner, "PaperBroker", lambda initial_cash=0: object())
|
||||||
|
monkeypatch.setattr(runner, "_set_state", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(runner, "_clear_runner", lambda user_id, run_id: cleared.append((user_id, run_id)))
|
||||||
|
monkeypatch.setattr(runner, "_update_engine_status", lambda user_id, run_id, status: statuses.append(status))
|
||||||
|
monkeypatch.setattr(runner, "release_run_lease", lambda run_id, owner_id: released.append((run_id, owner_id)) or True)
|
||||||
|
monkeypatch.setattr(runner, "_log_runner_lease_event", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
def fake_refresh(user_id, run_id, owner_id):
|
||||||
|
refreshed["count"] += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(runner, "_refresh_run_lease_or_stop", fake_refresh)
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
runner._engine_loop(
|
||||||
|
{
|
||||||
|
"user_id": "user-1",
|
||||||
|
"run_id": "run-1",
|
||||||
|
"strategy": "Golden Nifty",
|
||||||
|
"sip_amount": 1000,
|
||||||
|
"sip_frequency": {"value": 2, "unit": "minutes"},
|
||||||
|
"mode": "PAPER",
|
||||||
|
"broker": "paper",
|
||||||
|
"runner_owner_id": "owner-1",
|
||||||
|
},
|
||||||
|
stop_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert refreshed["count"] >= 1
|
||||||
|
assert statuses == ["RUNNING"]
|
||||||
|
assert cleared == [("user-1", "run-1")]
|
||||||
|
assert released == [("run-1", "owner-1")]
|
||||||
269
backend/tests/test_security_hardening.py
Normal file
269
backend/tests/test_security_hardening.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
import importlib
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Response
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_password_hash_upgrades_on_successful_login(monkeypatch):
|
||||||
|
import app.services.auth_service as auth_service
|
||||||
|
|
||||||
|
legacy_hash = auth_service._hash_password_legacy("correct-horse-battery-staple")
|
||||||
|
user = {
|
||||||
|
"id": "user-1",
|
||||||
|
"username": "user@example.com",
|
||||||
|
"password": legacy_hash,
|
||||||
|
"role": "USER",
|
||||||
|
}
|
||||||
|
updated = {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(auth_service, "get_user_by_username", lambda username: user if username == user["username"] else None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
auth_service,
|
||||||
|
"_update_password_hash",
|
||||||
|
lambda user_id, password_hash: updated.update({"user_id": user_id, "password_hash": password_hash}),
|
||||||
|
)
|
||||||
|
|
||||||
|
authenticated = auth_service.authenticate_user(user["username"], "correct-horse-battery-staple")
|
||||||
|
|
||||||
|
assert authenticated is user
|
||||||
|
assert updated["user_id"] == "user-1"
|
||||||
|
assert updated["password_hash"].startswith("$argon2id$")
|
||||||
|
assert authenticated["password"] == updated["password_hash"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_secure_session_cookie_flags_are_enforced_in_production(monkeypatch):
|
||||||
|
monkeypatch.setenv("APP_ENV", "production")
|
||||||
|
monkeypatch.delenv("COOKIE_SECURE", raising=False)
|
||||||
|
monkeypatch.setenv("COOKIE_SAMESITE", "strict")
|
||||||
|
|
||||||
|
import app.routers.auth as auth_router
|
||||||
|
|
||||||
|
importlib.reload(auth_router)
|
||||||
|
|
||||||
|
response = Response()
|
||||||
|
auth_router._set_session_cookie(response, "session-123")
|
||||||
|
set_cookie = response.headers["set-cookie"].lower()
|
||||||
|
|
||||||
|
assert "secure" in set_cookie
|
||||||
|
assert "httponly" in set_cookie
|
||||||
|
assert "samesite=strict" in set_cookie
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_reset_otp_secret_fails_auth_service_import(monkeypatch):
|
||||||
|
monkeypatch.delenv("RESET_OTP_SECRET", raising=False)
|
||||||
|
|
||||||
|
import app.services.auth_service as auth_service
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="RESET_OTP_SECRET must be configured"):
|
||||||
|
importlib.reload(auth_service)
|
||||||
|
|
||||||
|
monkeypatch.setenv("RESET_OTP_SECRET", "test-reset-secret")
|
||||||
|
importlib.reload(auth_service)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_broker_callback_state_allows_connect_callback(monkeypatch):
|
||||||
|
import app.main as app_main
|
||||||
|
import app.routers.broker as broker_router
|
||||||
|
|
||||||
|
monkeypatch.setenv("APP_ENV", "test")
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
|
importlib.reload(app_main)
|
||||||
|
app = app_main.create_app()
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
broker_router,
|
||||||
|
"consume_broker_callback_state",
|
||||||
|
lambda **kwargs: {"id": "state-1", "expires_at": datetime.now(timezone.utc).isoformat()},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
broker_router,
|
||||||
|
"get_pending_broker",
|
||||||
|
lambda _user_id: {"api_key": "kite-key", "api_secret": "kite-secret"},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
broker_router,
|
||||||
|
"exchange_request_token",
|
||||||
|
lambda api_key, api_secret, token: {
|
||||||
|
"access_token": "access-token",
|
||||||
|
"request_token": token,
|
||||||
|
"user_name": "Trader",
|
||||||
|
"user_id": "Z123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
captured = {}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
broker_router,
|
||||||
|
"set_zerodha_session",
|
||||||
|
lambda user_id, payload: captured.setdefault("zerodha_session", {"user_id": user_id, **payload}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
broker_router,
|
||||||
|
"set_connected_broker",
|
||||||
|
lambda user_id, broker, token, **kwargs: captured.setdefault(
|
||||||
|
"connected",
|
||||||
|
{"user_id": user_id, "broker": broker, "token": token, **kwargs},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/broker/zerodha/callback",
|
||||||
|
params={"request_token": "request-token", "state": "valid-state"},
|
||||||
|
cookies={"session_id": "session-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"connected": True,
|
||||||
|
"userName": "Trader",
|
||||||
|
"brokerUserId": "Z123",
|
||||||
|
}
|
||||||
|
assert captured["connected"]["broker"] == "ZERODHA"
|
||||||
|
assert captured["connected"]["user_id"] == "user-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_broker_callback_state_fails(monkeypatch):
|
||||||
|
import app.main as app_main
|
||||||
|
import app.routers.broker as broker_router
|
||||||
|
|
||||||
|
monkeypatch.setenv("APP_ENV", "test")
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
|
importlib.reload(app_main)
|
||||||
|
app = app_main.create_app()
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/broker/zerodha/callback",
|
||||||
|
params={"request_token": "request-token"},
|
||||||
|
cookies={"session_id": "session-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"detail": "Missing state"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_or_expired_broker_callback_state_fails(monkeypatch):
|
||||||
|
import app.main as app_main
|
||||||
|
import app.routers.broker as broker_router
|
||||||
|
|
||||||
|
monkeypatch.setenv("APP_ENV", "test")
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
|
importlib.reload(app_main)
|
||||||
|
app = app_main.create_app()
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
||||||
|
monkeypatch.setattr(broker_router, "consume_broker_callback_state", lambda **kwargs: None)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/broker/zerodha/callback",
|
||||||
|
params={"request_token": "request-token", "state": "wrong-or-expired"},
|
||||||
|
cookies={"session_id": "session-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"detail": "Invalid or expired broker callback state"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_broker_callback_state_service_rejects_wrong_user_and_expired_state(monkeypatch):
|
||||||
|
import app.services.broker_callback_state as callback_state
|
||||||
|
|
||||||
|
now = datetime(2026, 4, 8, 9, 0, tzinfo=timezone.utc)
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"state_hash": callback_state._state_hash("valid-state"),
|
||||||
|
"user_id": "user-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"broker": "ZERODHA",
|
||||||
|
"flow": "connect",
|
||||||
|
"expires_at": now + timedelta(minutes=5),
|
||||||
|
"consumed_at": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"state_hash": callback_state._state_hash("expired-state"),
|
||||||
|
"user_id": "user-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"broker": "ZERODHA",
|
||||||
|
"flow": "connect",
|
||||||
|
"expires_at": now - timedelta(seconds=1),
|
||||||
|
"consumed_at": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
def __init__(self):
|
||||||
|
self._result = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, sql, params):
|
||||||
|
if "UPDATE broker_callback_state" not in sql:
|
||||||
|
raise AssertionError("Unexpected SQL")
|
||||||
|
consumed_at, state_hash, user_id, session_id, broker, flow, now_param = params
|
||||||
|
self._result = None
|
||||||
|
for row in rows:
|
||||||
|
if (
|
||||||
|
row["state_hash"] == state_hash
|
||||||
|
and row["user_id"] == user_id
|
||||||
|
and row["session_id"] == session_id
|
||||||
|
and row["broker"] == broker
|
||||||
|
and row["flow"] == flow
|
||||||
|
and row["consumed_at"] is None
|
||||||
|
and row["expires_at"] > now_param
|
||||||
|
):
|
||||||
|
row["consumed_at"] = consumed_at
|
||||||
|
self._result = ("row-id", row["expires_at"])
|
||||||
|
break
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
class FakeConnection:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
return FakeCursor()
|
||||||
|
|
||||||
|
monkeypatch.setattr(callback_state, "_now_utc", lambda: now)
|
||||||
|
monkeypatch.setattr(callback_state, "db_connection", lambda: FakeConnection())
|
||||||
|
|
||||||
|
assert callback_state.consume_broker_callback_state(
|
||||||
|
state="valid-state",
|
||||||
|
user_id="user-2",
|
||||||
|
session_id="session-1",
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="connect",
|
||||||
|
) is None
|
||||||
|
assert callback_state.consume_broker_callback_state(
|
||||||
|
state="expired-state",
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="connect",
|
||||||
|
) is None
|
||||||
|
assert callback_state.consume_broker_callback_state(
|
||||||
|
state="valid-state",
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
broker="ZERODHA",
|
||||||
|
flow="connect",
|
||||||
|
) == {
|
||||||
|
"id": "row-id",
|
||||||
|
"expires_at": (now + timedelta(minutes=5)).isoformat(),
|
||||||
|
}
|
||||||
131
backend/tests/test_support_throttling.py
Normal file
131
backend/tests/test_support_throttling.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(monkeypatch):
|
||||||
|
monkeypatch.setenv("APP_ENV", "test")
|
||||||
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
||||||
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
monkeypatch.setenv("DB_NAME", "trading_db")
|
||||||
|
monkeypatch.setenv("DB_USER", "trader")
|
||||||
|
monkeypatch.setenv("DB_PASSWORD", "test-password")
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
|
monkeypatch.setenv("SUPPORT_GUARD_BACKEND", "memory")
|
||||||
|
monkeypatch.setenv("SUPPORT_GUARD_WINDOW_SECONDS", "900")
|
||||||
|
monkeypatch.setenv("SUPPORT_CREATE_LIMIT", "2")
|
||||||
|
monkeypatch.setenv("SUPPORT_STATUS_LIMIT", "3")
|
||||||
|
monkeypatch.setenv("SUPPORT_STATUS_TICKET_LIMIT", "2")
|
||||||
|
|
||||||
|
import app.main as app_main
|
||||||
|
|
||||||
|
importlib.reload(app_main)
|
||||||
|
return app_main.create_app()
|
||||||
|
|
||||||
|
|
||||||
|
def test_support_ticket_creation_is_throttled(monkeypatch, caplog):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.services.support_abuse as support_abuse
|
||||||
|
import app.routers.support_ticket as support_router
|
||||||
|
|
||||||
|
support_abuse.reset_memory_support_guard_state()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
support_router,
|
||||||
|
"create_ticket",
|
||||||
|
lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Trader",
|
||||||
|
"email": "trader@example.com",
|
||||||
|
"subject": "Need help",
|
||||||
|
"message": "Something happened",
|
||||||
|
}
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
first = client.post("/api/support/ticket", json=payload)
|
||||||
|
second = client.post("/api/support/ticket", json=payload)
|
||||||
|
third = client.post("/api/support/ticket", json=payload)
|
||||||
|
|
||||||
|
assert first.status_code == 200
|
||||||
|
assert second.status_code == 200
|
||||||
|
assert third.status_code == 429
|
||||||
|
assert "Support request blocked" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_ticket_probing_is_throttled(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.services.support_abuse as support_abuse
|
||||||
|
import app.routers.support_ticket as support_router
|
||||||
|
|
||||||
|
support_abuse.reset_memory_support_guard_state()
|
||||||
|
monkeypatch.setattr(support_router, "get_ticket_status", lambda ticket_id, email: None)
|
||||||
|
|
||||||
|
payload = {"email": "trader@example.com"}
|
||||||
|
|
||||||
|
first = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
|
||||||
|
second = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
|
||||||
|
third = client.post("/api/support/ticket/status/unknown-ticket", json=payload)
|
||||||
|
|
||||||
|
assert first.status_code == 404
|
||||||
|
assert second.status_code == 404
|
||||||
|
assert third.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
def test_legitimate_status_lookup_still_works(monkeypatch):
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.services.support_abuse as support_abuse
|
||||||
|
import app.routers.support_ticket as support_router
|
||||||
|
|
||||||
|
support_abuse.reset_memory_support_guard_state()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
support_router,
|
||||||
|
"get_ticket_status",
|
||||||
|
lambda ticket_id, email: {
|
||||||
|
"ticket_id": ticket_id,
|
||||||
|
"status": "NEW",
|
||||||
|
"created_at": "2026-04-08T00:00:00+00:00",
|
||||||
|
"updated_at": "2026-04-08T00:00:00+00:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/api/support/ticket/status/ticket-1", json={"email": "trader@example.com"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["ticket_id"] == "ticket-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_support_captcha_hook_blocks_without_matching_header(monkeypatch):
|
||||||
|
monkeypatch.setenv("SUPPORT_CAPTCHA_SECRET", "expected-captcha")
|
||||||
|
app = _build_app(monkeypatch)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
import app.services.support_abuse as support_abuse
|
||||||
|
import app.routers.support_ticket as support_router
|
||||||
|
|
||||||
|
support_abuse.reset_memory_support_guard_state()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
support_router,
|
||||||
|
"create_ticket",
|
||||||
|
lambda **kwargs: {"ticket_id": "ticket-1", "status": "NEW", "created_at": "2026-04-08T00:00:00+00:00"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/support/ticket",
|
||||||
|
json={
|
||||||
|
"name": "Trader",
|
||||||
|
"email": "trader@example.com",
|
||||||
|
"subject": "Need help",
|
||||||
|
"message": "Something happened",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {"detail": "Support verification failed"}
|
||||||
@ -2,7 +2,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
@ -14,24 +14,35 @@ _POOL = None
|
|||||||
_POOL_LOCK = threading.Lock()
|
_POOL_LOCK = threading.Lock()
|
||||||
_DEFAULT_USER_ID = None
|
_DEFAULT_USER_ID = None
|
||||||
_DEFAULT_LOCK = threading.Lock()
|
_DEFAULT_LOCK = threading.Lock()
|
||||||
|
NON_PROD_ENVIRONMENTS = {"development", "dev", "test", "testing", "local"}
|
||||||
|
|
||||||
_USER_ID = ContextVar("engine_user_id", default=None)
|
_USER_ID = ContextVar("engine_user_id", default=None)
|
||||||
_RUN_ID = ContextVar("engine_run_id", default=None)
|
_RUN_ID = ContextVar("engine_run_id", default=None)
|
||||||
|
|
||||||
|
|
||||||
def _db_config():
|
def _db_config():
|
||||||
|
env_name = (os.getenv("APP_ENV") or os.getenv("ENVIRONMENT") or os.getenv("FASTAPI_ENV") or "development").strip().lower()
|
||||||
|
is_non_prod = env_name in NON_PROD_ENVIRONMENTS
|
||||||
url = os.getenv("DATABASE_URL")
|
url = os.getenv("DATABASE_URL")
|
||||||
if url:
|
if url:
|
||||||
return {"dsn": url}
|
return {"dsn": url}
|
||||||
|
|
||||||
schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app"
|
schema = os.getenv("DB_SCHEMA") or os.getenv("PGSCHEMA") or "quant_app"
|
||||||
|
password = os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD")
|
||||||
|
host = os.getenv("DB_HOST") or os.getenv("PGHOST") or ("localhost" if is_non_prod else None)
|
||||||
|
dbname = os.getenv("DB_NAME") or os.getenv("PGDATABASE") or ("trading_db" if is_non_prod else None)
|
||||||
|
user = os.getenv("DB_USER") or os.getenv("PGUSER") or ("trader" if is_non_prod else None)
|
||||||
|
if not is_non_prod and not password:
|
||||||
|
raise RuntimeError("DB_PASSWORD or PGPASSWORD must be configured in non-development environments")
|
||||||
|
if not is_non_prod and (not host or not dbname or not user):
|
||||||
|
raise RuntimeError("DB_HOST, DB_NAME, and DB_USER must be configured in non-development environments")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"host": os.getenv("DB_HOST") or os.getenv("PGHOST") or "localhost",
|
"host": host,
|
||||||
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
"port": int(os.getenv("DB_PORT") or os.getenv("PGPORT") or "5432"),
|
||||||
"dbname": os.getenv("DB_NAME") or os.getenv("PGDATABASE") or "trading_db",
|
"dbname": dbname,
|
||||||
"user": os.getenv("DB_USER") or os.getenv("PGUSER") or "trader",
|
"user": user,
|
||||||
"password": os.getenv("DB_PASSWORD") or os.getenv("PGPASSWORD") or "traderpass",
|
"password": password,
|
||||||
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "5")),
|
||||||
"options": f"-csearch_path={schema},public" if schema else None,
|
"options": f"-csearch_path={schema},public" if schema else None,
|
||||||
}
|
}
|
||||||
@ -322,3 +333,150 @@ def insert_engine_event(
|
|||||||
Json(meta) if meta is not None else None,
|
Json(meta) if meta is not None else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def acquire_run_lease(
|
||||||
|
run_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
*,
|
||||||
|
lease_seconds: int = 90,
|
||||||
|
now: datetime | None = None,
|
||||||
|
):
|
||||||
|
current_time = now or _utc_now()
|
||||||
|
expires_at = current_time + timedelta(seconds=lease_seconds)
|
||||||
|
|
||||||
|
def _op(cur, _conn):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO run_leases (run_id, owner_id, leased_at, expires_at, heartbeat_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (run_id) DO NOTHING
|
||||||
|
RETURNING run_id
|
||||||
|
""",
|
||||||
|
(run_id, owner_id, current_time, expires_at, current_time),
|
||||||
|
)
|
||||||
|
inserted = cur.fetchone()
|
||||||
|
if inserted:
|
||||||
|
return {
|
||||||
|
"acquired": True,
|
||||||
|
"status": "ACQUIRED",
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT owner_id, expires_at
|
||||||
|
FROM run_leases
|
||||||
|
WHERE run_id = %s
|
||||||
|
FOR UPDATE
|
||||||
|
""",
|
||||||
|
(run_id,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return {
|
||||||
|
"acquired": False,
|
||||||
|
"status": "DENIED",
|
||||||
|
"owner_id": None,
|
||||||
|
"expires_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
current_owner, current_expiry = row
|
||||||
|
if current_owner == owner_id:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE run_leases
|
||||||
|
SET leased_at = %s,
|
||||||
|
expires_at = %s,
|
||||||
|
heartbeat_at = %s
|
||||||
|
WHERE run_id = %s AND owner_id = %s
|
||||||
|
RETURNING run_id
|
||||||
|
""",
|
||||||
|
(current_time, expires_at, current_time, run_id, owner_id),
|
||||||
|
)
|
||||||
|
cur.fetchone()
|
||||||
|
return {
|
||||||
|
"acquired": True,
|
||||||
|
"status": "REFRESHED",
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
if current_expiry <= current_time:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE run_leases
|
||||||
|
SET owner_id = %s,
|
||||||
|
leased_at = %s,
|
||||||
|
expires_at = %s,
|
||||||
|
heartbeat_at = %s
|
||||||
|
WHERE run_id = %s AND expires_at <= %s
|
||||||
|
RETURNING run_id
|
||||||
|
""",
|
||||||
|
(owner_id, current_time, expires_at, current_time, run_id, current_time),
|
||||||
|
)
|
||||||
|
replaced = cur.fetchone()
|
||||||
|
if replaced:
|
||||||
|
return {
|
||||||
|
"acquired": True,
|
||||||
|
"status": "REACQUIRED",
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"previous_owner": current_owner,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"acquired": False,
|
||||||
|
"status": "DENIED",
|
||||||
|
"owner_id": current_owner,
|
||||||
|
"expires_at": current_expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
return run_with_retry(_op)
|
||||||
|
|
||||||
|
|
||||||
|
def heartbeat_run_lease(
|
||||||
|
run_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
*,
|
||||||
|
lease_seconds: int = 90,
|
||||||
|
now: datetime | None = None,
|
||||||
|
):
|
||||||
|
current_time = now or _utc_now()
|
||||||
|
expires_at = current_time + timedelta(seconds=lease_seconds)
|
||||||
|
|
||||||
|
def _op(cur, _conn):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE run_leases
|
||||||
|
SET heartbeat_at = %s,
|
||||||
|
expires_at = %s
|
||||||
|
WHERE run_id = %s
|
||||||
|
AND owner_id = %s
|
||||||
|
AND expires_at > %s
|
||||||
|
RETURNING run_id, expires_at
|
||||||
|
""",
|
||||||
|
(current_time, expires_at, run_id, owner_id, current_time),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return {"active": False, "expires_at": None}
|
||||||
|
return {"active": True, "expires_at": row[1]}
|
||||||
|
|
||||||
|
return run_with_retry(_op)
|
||||||
|
|
||||||
|
|
||||||
|
def release_run_lease(run_id: str, owner_id: str):
|
||||||
|
def _op(cur, _conn):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM run_leases
|
||||||
|
WHERE run_id = %s AND owner_id = %s
|
||||||
|
RETURNING run_id
|
||||||
|
""",
|
||||||
|
(run_id, owner_id),
|
||||||
|
)
|
||||||
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
return run_with_retry(_op)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from datetime import datetime, timezone
|
|||||||
from indian_paper_trading_strategy.engine.state import load_state, save_state
|
from indian_paper_trading_strategy.engine.state import load_state, save_state
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired
|
from indian_paper_trading_strategy.engine.broker import Broker, BrokerAuthExpired
|
||||||
from indian_paper_trading_strategy.engine.ledger import log_event, event_exists
|
from indian_paper_trading_strategy.engine.ledger import claim_execution_window, log_event, event_exists
|
||||||
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry
|
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry
|
||||||
from indian_paper_trading_strategy.engine.market import market_now
|
from indian_paper_trading_strategy.engine.market import market_now
|
||||||
from indian_paper_trading_strategy.engine.time_utils import compute_logical_time
|
from indian_paper_trading_strategy.engine.time_utils import compute_logical_time
|
||||||
@ -237,7 +237,7 @@ def _prepare_live_execution(now_ts, sip_interval, sip_amount_val, sp_price_val,
|
|||||||
return {"ready": False, "state": state}
|
return {"ready": False, "state": state}
|
||||||
if event_exists("SIP_EXECUTED", logical_time, cur=cur):
|
if event_exists("SIP_EXECUTED", logical_time, cur=cur):
|
||||||
return {"ready": False, "state": state}
|
return {"ready": False, "state": state}
|
||||||
if event_exists("SIP_ORDER_ATTEMPTED", logical_time, cur=cur):
|
if not claim_execution_window(logical_time, mode=mode, cur=cur):
|
||||||
return {"ready": False, "state": state}
|
return {"ready": False, "state": state}
|
||||||
|
|
||||||
log_event(
|
log_event(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# engine/ledger.py
|
# engine/ledger.py
|
||||||
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context
|
from indian_paper_trading_strategy.engine.db import insert_engine_event, run_with_retry, get_context
|
||||||
@ -30,6 +31,72 @@ def event_exists(event, logical_time, *, cur=None, user_id: str | None = None, r
|
|||||||
return run_with_retry(_op)
|
return run_with_retry(_op)
|
||||||
|
|
||||||
|
|
||||||
|
def _claim_execution_window_in_tx(
|
||||||
|
cur,
|
||||||
|
logical_time,
|
||||||
|
*,
|
||||||
|
mode: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
):
|
||||||
|
scope_user, scope_run = get_context(user_id, run_id)
|
||||||
|
logical_ts = normalize_logical_time(logical_time)
|
||||||
|
claim_id = str(uuid.uuid4())
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO execution_claim (
|
||||||
|
id,
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
mode,
|
||||||
|
logical_time,
|
||||||
|
claimed_at
|
||||||
|
)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (user_id, run_id, logical_time) DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
claim_id,
|
||||||
|
scope_user,
|
||||||
|
scope_run,
|
||||||
|
(mode or "LIVE").strip().upper(),
|
||||||
|
logical_ts,
|
||||||
|
datetime.now(timezone.utc),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def claim_execution_window(
|
||||||
|
logical_time,
|
||||||
|
*,
|
||||||
|
mode: str | None = None,
|
||||||
|
cur=None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
):
|
||||||
|
if cur is not None:
|
||||||
|
return _claim_execution_window_in_tx(
|
||||||
|
cur,
|
||||||
|
logical_time,
|
||||||
|
mode=mode,
|
||||||
|
user_id=user_id,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _op(cur, _conn):
|
||||||
|
return _claim_execution_window_in_tx(
|
||||||
|
cur,
|
||||||
|
logical_time,
|
||||||
|
mode=mode,
|
||||||
|
user_id=user_id,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run_with_retry(_op)
|
||||||
|
|
||||||
|
|
||||||
def _log_event_in_tx(
|
def _log_event_in_tx(
|
||||||
cur,
|
cur,
|
||||||
event,
|
event,
|
||||||
|
|||||||
@ -1,46 +1,51 @@
|
|||||||
# engine/market.py
|
from __future__ import annotations
|
||||||
from datetime import datetime, time as dtime, timedelta
|
|
||||||
import pytz
|
|
||||||
|
|
||||||
_MARKET_TZ = pytz.timezone("Asia/Kolkata")
|
from datetime import datetime
|
||||||
_OPEN_T = dtime(9, 15)
|
|
||||||
_CLOSE_T = dtime(15, 30)
|
from indian_paper_trading_strategy.engine.market_calendar import (
|
||||||
|
MARKET_TZ,
|
||||||
|
get_market_session as _get_market_session,
|
||||||
|
get_market_status as _get_market_status,
|
||||||
|
is_market_open as _is_market_open,
|
||||||
|
market_now_utc,
|
||||||
|
next_market_open as _next_market_open,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def market_now() -> datetime:
|
def market_now() -> datetime:
|
||||||
return datetime.now(_MARKET_TZ)
|
return market_now_utc().astimezone(MARKET_TZ)
|
||||||
|
|
||||||
def _as_market_tz(value: datetime) -> datetime:
|
|
||||||
|
def _to_utc(value: datetime) -> datetime:
|
||||||
if value.tzinfo is None:
|
if value.tzinfo is None:
|
||||||
return _MARKET_TZ.localize(value)
|
return value.replace(tzinfo=MARKET_TZ).astimezone(market_now_utc().tzinfo)
|
||||||
return value.astimezone(_MARKET_TZ)
|
return value.astimezone(market_now_utc().tzinfo)
|
||||||
|
|
||||||
|
|
||||||
def is_market_open(now: datetime) -> bool:
|
def is_market_open(now: datetime) -> bool:
|
||||||
now = _as_market_tz(now)
|
return _is_market_open(_to_utc(now))
|
||||||
return now.weekday() < 5 and _OPEN_T <= now.time() <= _CLOSE_T
|
|
||||||
|
|
||||||
|
def market_session(now: datetime) -> dict[str, object]:
|
||||||
|
return _get_market_session(_to_utc(now))
|
||||||
|
|
||||||
|
|
||||||
|
def market_status(now: datetime) -> str:
|
||||||
|
return _get_market_status(_to_utc(now))
|
||||||
|
|
||||||
|
|
||||||
def india_market_status():
|
def india_market_status():
|
||||||
now = market_now()
|
now = market_now()
|
||||||
|
|
||||||
return is_market_open(now), now
|
return is_market_open(now), now
|
||||||
|
|
||||||
|
|
||||||
def next_market_open_after(value: datetime) -> datetime:
|
def next_market_open_after(value: datetime) -> datetime:
|
||||||
current = _as_market_tz(value)
|
aligned_utc = _next_market_open(_to_utc(value))
|
||||||
while current.weekday() >= 5:
|
aligned_ist = aligned_utc.astimezone(MARKET_TZ)
|
||||||
current = current + timedelta(days=1)
|
if value.tzinfo is None:
|
||||||
current = current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
|
return aligned_ist.replace(tzinfo=None)
|
||||||
if current.time() < _OPEN_T:
|
return aligned_ist
|
||||||
return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
|
|
||||||
if current.time() > _CLOSE_T:
|
|
||||||
current = current + timedelta(days=1)
|
|
||||||
while current.weekday() >= 5:
|
|
||||||
current = current + timedelta(days=1)
|
|
||||||
return current.replace(hour=_OPEN_T.hour, minute=_OPEN_T.minute, second=0, microsecond=0)
|
|
||||||
return current
|
|
||||||
|
|
||||||
def align_to_market_open(value: datetime) -> datetime:
|
def align_to_market_open(value: datetime) -> datetime:
|
||||||
current = _as_market_tz(value)
|
return next_market_open_after(value)
|
||||||
aligned = current if is_market_open(current) else next_market_open_after(current)
|
|
||||||
if value.tzinfo is None:
|
|
||||||
return aligned.replace(tzinfo=None)
|
|
||||||
return aligned
|
|
||||||
|
|||||||
201
indian_paper_trading_strategy/engine/market_calendar.py
Normal file
201
indian_paper_trading_strategy/engine/market_calendar.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import date, datetime, time as dtime, timedelta, timezone
|
||||||
|
from functools import lru_cache
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UTC = timezone.utc
|
||||||
|
MARKET_TZ = ZoneInfo("Asia/Kolkata")
|
||||||
|
MARKET_OPEN = dtime(9, 15)
|
||||||
|
MARKET_CLOSE = dtime(15, 30)
|
||||||
|
|
||||||
|
STATUS_OPEN = "OPEN"
|
||||||
|
STATUS_CLOSED = "CLOSED"
|
||||||
|
STATUS_HOLIDAY = "HOLIDAY"
|
||||||
|
|
||||||
|
REASON_OPEN = "OPEN"
|
||||||
|
REASON_HOLIDAY = "HOLIDAY"
|
||||||
|
REASON_WEEKEND = "WEEKEND"
|
||||||
|
REASON_PRE_OPEN = "PRE_OPEN"
|
||||||
|
REASON_POST_CLOSE = "POST_CLOSE"
|
||||||
|
REASON_CALENDAR_UNAVAILABLE = "CALENDAR_UNAVAILABLE"
|
||||||
|
|
||||||
|
# Capital market trading holidays for NSE calendar year 2026.
|
||||||
|
# Source: NSE circular dated December 12, 2025.
|
||||||
|
DEFAULT_NSE_HOLIDAYS_BY_YEAR = {
|
||||||
|
2026: frozenset(
|
||||||
|
{
|
||||||
|
date(2026, 1, 26), # Republic Day
|
||||||
|
date(2026, 3, 3), # Holi
|
||||||
|
date(2026, 3, 26), # Shri Ram Navami
|
||||||
|
date(2026, 3, 31), # Shri Mahavir Jayanti
|
||||||
|
date(2026, 4, 3), # Good Friday
|
||||||
|
date(2026, 4, 14), # Dr. Baba Saheb Ambedkar Jayanti
|
||||||
|
date(2026, 5, 1), # Maharashtra Day
|
||||||
|
date(2026, 5, 28), # Bakri Id
|
||||||
|
date(2026, 6, 26), # Muharram
|
||||||
|
date(2026, 9, 14), # Ganesh Chaturthi
|
||||||
|
date(2026, 10, 2), # Mahatma Gandhi Jayanti
|
||||||
|
date(2026, 10, 20), # Dussehra
|
||||||
|
date(2026, 11, 10), # Diwali-Balipratipada
|
||||||
|
date(2026, 11, 24), # Prakash Gurpurb Sri Guru Nanak Dev
|
||||||
|
date(2026, 12, 25), # Christmas
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedCalendarYearError(RuntimeError):
|
||||||
|
def __init__(self, year: int):
|
||||||
|
super().__init__(f"NSE holiday calendar is not configured for {year}")
|
||||||
|
self.year = year
|
||||||
|
|
||||||
|
|
||||||
|
def market_now_utc() -> datetime:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_utc(value: datetime | None) -> datetime:
|
||||||
|
if value is None:
|
||||||
|
return market_now_utc()
|
||||||
|
if value.tzinfo is None:
|
||||||
|
return value.replace(tzinfo=UTC)
|
||||||
|
return value.astimezone(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_ist(value: datetime | None) -> datetime:
|
||||||
|
return _as_utc(value).astimezone(MARKET_TZ)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_holiday_token(token: str) -> date:
|
||||||
|
text = token.strip()
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Empty holiday token is not allowed")
|
||||||
|
return date.fromisoformat(text)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=16)
|
||||||
|
def get_nse_holidays(year: int) -> frozenset[date]:
|
||||||
|
configured = (os.getenv(f"NSE_HOLIDAYS_{year}") or "").strip()
|
||||||
|
if configured:
|
||||||
|
return frozenset(_parse_holiday_token(token) for token in configured.split(","))
|
||||||
|
|
||||||
|
if year in DEFAULT_NSE_HOLIDAYS_BY_YEAR:
|
||||||
|
return DEFAULT_NSE_HOLIDAYS_BY_YEAR[year]
|
||||||
|
|
||||||
|
raise UnsupportedCalendarYearError(year)
|
||||||
|
|
||||||
|
|
||||||
|
def is_trading_holiday(now_utc: datetime | None = None) -> bool:
|
||||||
|
current_ist = _as_ist(now_utc)
|
||||||
|
holidays = get_nse_holidays(current_ist.year)
|
||||||
|
return current_ist.date() in holidays
|
||||||
|
|
||||||
|
|
||||||
|
def _session_from_utc(now_utc: datetime | None = None) -> dict[str, object]:
|
||||||
|
current_utc = _as_utc(now_utc)
|
||||||
|
current_ist = current_utc.astimezone(MARKET_TZ)
|
||||||
|
try:
|
||||||
|
holidays = get_nse_holidays(current_ist.year)
|
||||||
|
except UnsupportedCalendarYearError as exc:
|
||||||
|
logger.error("NSE holiday calendar unavailable for year %s", exc.year)
|
||||||
|
return {
|
||||||
|
"status": STATUS_CLOSED,
|
||||||
|
"reason": REASON_CALENDAR_UNAVAILABLE,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
current_time = current_ist.timetz().replace(tzinfo=None)
|
||||||
|
if current_ist.date() in holidays:
|
||||||
|
return {
|
||||||
|
"status": STATUS_HOLIDAY,
|
||||||
|
"reason": REASON_HOLIDAY,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": True,
|
||||||
|
}
|
||||||
|
if current_ist.weekday() >= 5:
|
||||||
|
return {
|
||||||
|
"status": STATUS_CLOSED,
|
||||||
|
"reason": REASON_WEEKEND,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": True,
|
||||||
|
}
|
||||||
|
if current_time < MARKET_OPEN:
|
||||||
|
return {
|
||||||
|
"status": STATUS_CLOSED,
|
||||||
|
"reason": REASON_PRE_OPEN,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": True,
|
||||||
|
}
|
||||||
|
if MARKET_OPEN <= current_time < MARKET_CLOSE:
|
||||||
|
return {
|
||||||
|
"status": STATUS_OPEN,
|
||||||
|
"reason": REASON_OPEN,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": True,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"status": STATUS_CLOSED,
|
||||||
|
"reason": REASON_POST_CLOSE,
|
||||||
|
"checked_at": current_utc,
|
||||||
|
"checked_at_ist": current_ist,
|
||||||
|
"calendar_supported": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_market_session(now_utc: datetime | None = None) -> dict[str, object]:
|
||||||
|
return dict(_session_from_utc(now_utc))
|
||||||
|
|
||||||
|
|
||||||
|
def get_market_status(now_utc: datetime | None = None) -> str:
|
||||||
|
return str(_session_from_utc(now_utc)["status"])
|
||||||
|
|
||||||
|
|
||||||
|
def is_market_open(now_utc: datetime | None = None) -> bool:
|
||||||
|
return get_market_status(now_utc) == STATUS_OPEN
|
||||||
|
|
||||||
|
|
||||||
|
def next_market_open(now_utc: datetime | None = None) -> datetime:
|
||||||
|
current_utc = _as_utc(now_utc)
|
||||||
|
session = _session_from_utc(current_utc)
|
||||||
|
if session["reason"] == REASON_CALENDAR_UNAVAILABLE:
|
||||||
|
checked_at_ist = session["checked_at_ist"]
|
||||||
|
raise UnsupportedCalendarYearError(int(checked_at_ist.year))
|
||||||
|
|
||||||
|
if session["status"] == STATUS_OPEN:
|
||||||
|
return current_utc
|
||||||
|
|
||||||
|
current_ist = current_utc.astimezone(MARKET_TZ)
|
||||||
|
candidate_date = current_ist.date()
|
||||||
|
holidays = get_nse_holidays(candidate_date.year)
|
||||||
|
current_time = current_ist.timetz().replace(tzinfo=None)
|
||||||
|
|
||||||
|
if (
|
||||||
|
session["reason"] == REASON_PRE_OPEN
|
||||||
|
and candidate_date not in holidays
|
||||||
|
and candidate_date.weekday() < 5
|
||||||
|
and current_time < MARKET_OPEN
|
||||||
|
):
|
||||||
|
candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ)
|
||||||
|
return candidate.astimezone(UTC)
|
||||||
|
|
||||||
|
candidate_date = candidate_date + timedelta(days=1)
|
||||||
|
while True:
|
||||||
|
holidays = get_nse_holidays(candidate_date.year)
|
||||||
|
if candidate_date.weekday() < 5 and candidate_date not in holidays:
|
||||||
|
break
|
||||||
|
candidate_date = candidate_date + timedelta(days=1)
|
||||||
|
|
||||||
|
candidate = datetime.combine(candidate_date, MARKET_OPEN, tzinfo=MARKET_TZ)
|
||||||
|
return candidate.astimezone(UTC)
|
||||||
@ -1,11 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from psycopg2.extras import Json
|
from psycopg2.extras import Json
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now
|
from indian_paper_trading_strategy.engine.market import (
|
||||||
|
align_to_market_open,
|
||||||
|
market_now,
|
||||||
|
market_session,
|
||||||
|
)
|
||||||
from indian_paper_trading_strategy.engine.execution import try_execute_sip
|
from indian_paper_trading_strategy.engine.execution import try_execute_sip
|
||||||
from indian_paper_trading_strategy.engine.broker import (
|
from indian_paper_trading_strategy.engine.broker import (
|
||||||
BrokerAuthExpired,
|
BrokerAuthExpired,
|
||||||
@ -21,7 +27,16 @@ from indian_paper_trading_strategy.engine.strategy import allocation
|
|||||||
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
|
from indian_paper_trading_strategy.engine.time_utils import normalize_logical_time
|
||||||
from app.services.zerodha_service import KiteTokenError
|
from app.services.zerodha_service import KiteTokenError
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.db import db_transaction, insert_engine_event, run_with_retry, get_context, set_context
|
from indian_paper_trading_strategy.engine.db import (
|
||||||
|
acquire_run_lease,
|
||||||
|
db_transaction,
|
||||||
|
heartbeat_run_lease,
|
||||||
|
insert_engine_event,
|
||||||
|
release_run_lease,
|
||||||
|
run_with_retry,
|
||||||
|
get_context,
|
||||||
|
set_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _update_engine_status(user_id: str, run_id: str, status: str):
|
def _update_engine_status(user_id: str, run_id: str, status: str):
|
||||||
@ -60,6 +75,17 @@ _RUNNERS_LOCK = threading.Lock()
|
|||||||
|
|
||||||
engine_state = _ENGINE_STATES
|
engine_state = _ENGINE_STATES
|
||||||
|
|
||||||
|
RUNNER_OWNER_ID = os.getenv("RUNNER_OWNER_ID") or f"{socket.gethostname()}:{os.getpid()}:{uuid.uuid4().hex}"
|
||||||
|
RUN_LEASE_SECONDS = int(os.getenv("RUN_LEASE_SECONDS", "90"))
|
||||||
|
|
||||||
|
|
||||||
|
class RunLeaseNotAcquiredError(RuntimeError):
|
||||||
|
def __init__(self, run_id: str, owner_id: str, details: dict | None = None):
|
||||||
|
super().__init__(f"Run lease not acquired for run {run_id}")
|
||||||
|
self.run_id = run_id
|
||||||
|
self.owner_id = owner_id
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
|
||||||
def _state_key(user_id: str, run_id: str):
|
def _state_key(user_id: str, run_id: str):
|
||||||
return (user_id, run_id)
|
return (user_id, run_id)
|
||||||
@ -124,19 +150,79 @@ def log_event(
|
|||||||
|
|
||||||
run_with_retry(_op)
|
run_with_retry(_op)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_runner_lease_event(
|
||||||
|
user_id: str,
|
||||||
|
run_id: str,
|
||||||
|
event: str,
|
||||||
|
message: str,
|
||||||
|
meta: dict | None = None,
|
||||||
|
):
|
||||||
|
details = meta or {}
|
||||||
|
print(f"[ENGINE] {event} {message} {details}", flush=True)
|
||||||
|
|
||||||
|
def _op(cur, _conn):
|
||||||
|
insert_engine_event(
|
||||||
|
cur,
|
||||||
|
event,
|
||||||
|
data=details,
|
||||||
|
message=message,
|
||||||
|
ts=datetime.utcnow().replace(tzinfo=timezone.utc),
|
||||||
|
user_id=user_id,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_with_retry(_op)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_run_lease_or_stop(
|
||||||
|
user_id: str,
|
||||||
|
run_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
):
|
||||||
|
lease = heartbeat_run_lease(
|
||||||
|
run_id,
|
||||||
|
owner_id,
|
||||||
|
lease_seconds=RUN_LEASE_SECONDS,
|
||||||
|
)
|
||||||
|
if lease.get("active"):
|
||||||
|
print(
|
||||||
|
f"[ENGINE] RUNNER_LEASE_HEARTBEAT lease heartbeat refreshed "
|
||||||
|
f"{{'run_id': '{run_id}', 'owner_id': '{owner_id}', 'expires_at': '{lease.get('expires_at')}'}}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
_log_runner_lease_event(
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
"RUNNER_LEASE_LOST",
|
||||||
|
"Runner exiting due to lost lease",
|
||||||
|
{"owner_id": owner_id},
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
def sleep_with_heartbeat(
|
def sleep_with_heartbeat(
|
||||||
total_seconds: int,
|
total_seconds: int,
|
||||||
stop_event: threading.Event,
|
stop_event: threading.Event,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
|
owner_id: str,
|
||||||
step_seconds: int = 5,
|
step_seconds: int = 5,
|
||||||
):
|
):
|
||||||
remaining = total_seconds
|
remaining = total_seconds
|
||||||
while remaining > 0 and not stop_event.is_set():
|
while remaining > 0 and not stop_event.is_set():
|
||||||
time.sleep(min(step_seconds, remaining))
|
chunk = min(step_seconds, remaining)
|
||||||
|
time.sleep(chunk)
|
||||||
_set_state(user_id, run_id, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
_set_state(user_id, run_id, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
||||||
_update_engine_status(user_id, run_id, "RUNNING")
|
_update_engine_status(user_id, run_id, "RUNNING")
|
||||||
remaining -= step_seconds
|
if not _refresh_run_lease_or_stop(user_id, run_id, owner_id):
|
||||||
|
return False
|
||||||
|
remaining -= chunk
|
||||||
|
return True
|
||||||
|
|
||||||
def _clear_runner(user_id: str, run_id: str):
|
def _clear_runner(user_id: str, run_id: str):
|
||||||
key = _state_key(user_id, run_id)
|
key = _state_key(user_id, run_id)
|
||||||
@ -144,7 +230,20 @@ def _clear_runner(user_id: str, run_id: str):
|
|||||||
_RUNNERS.pop(key, None)
|
_RUNNERS.pop(key, None)
|
||||||
|
|
||||||
def can_execute(now: datetime) -> tuple[bool, str]:
|
def can_execute(now: datetime) -> tuple[bool, str]:
|
||||||
if not is_market_open(now):
|
session = market_session(now)
|
||||||
|
status = str(session.get("status") or "CLOSED").upper()
|
||||||
|
reason = str(session.get("reason") or "").upper()
|
||||||
|
if status == "HOLIDAY":
|
||||||
|
return False, "MARKET_HOLIDAY"
|
||||||
|
if status != "OPEN":
|
||||||
|
if reason == "WEEKEND":
|
||||||
|
return False, "MARKET_WEEKEND"
|
||||||
|
if reason == "PRE_OPEN":
|
||||||
|
return False, "MARKET_PRE_OPEN"
|
||||||
|
if reason == "POST_CLOSE":
|
||||||
|
return False, "MARKET_POST_CLOSE"
|
||||||
|
if reason == "CALENDAR_UNAVAILABLE":
|
||||||
|
return False, "MARKET_CALENDAR_UNAVAILABLE"
|
||||||
return False, "MARKET_CLOSED"
|
return False, "MARKET_CLOSED"
|
||||||
return True, "OK"
|
return True, "OK"
|
||||||
|
|
||||||
@ -229,6 +328,7 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
|
|
||||||
user_id = config.get("user_id")
|
user_id = config.get("user_id")
|
||||||
run_id = config.get("run_id")
|
run_id = config.get("run_id")
|
||||||
|
owner_id = config.get("runner_owner_id") or RUNNER_OWNER_ID
|
||||||
scope_user, scope_run = get_context(user_id, run_id)
|
scope_user, scope_run = get_context(user_id, run_id)
|
||||||
set_context(scope_user, scope_run)
|
set_context(scope_user, scope_run)
|
||||||
|
|
||||||
@ -304,9 +404,13 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
|
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
|
||||||
)
|
)
|
||||||
_update_engine_status(scope_user, scope_run, "RUNNING")
|
_update_engine_status(scope_user, scope_run, "RUNNING")
|
||||||
|
exit_reason = "STOPPED"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
|
if not _refresh_run_lease_or_stop(scope_user, scope_run, owner_id):
|
||||||
|
exit_reason = "LEASE_LOST"
|
||||||
|
break
|
||||||
_set_state(scope_user, scope_run, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
_set_state(scope_user, scope_run, last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
||||||
_update_engine_status(scope_user, scope_run, "RUNNING")
|
_update_engine_status(scope_user, scope_run, "RUNNING")
|
||||||
|
|
||||||
@ -368,7 +472,9 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
"frequency": frequency_label,
|
"frequency": frequency_label,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run)
|
if not sleep_with_heartbeat(wait_seconds, stop_event, scope_user, scope_run, owner_id):
|
||||||
|
exit_reason = "LEASE_LOST"
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -395,7 +501,9 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
break
|
break
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)})
|
debug_event("PRICE_FETCH_ERROR", "live price fetch failed", {"error": str(exc)})
|
||||||
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
|
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
|
||||||
|
exit_reason = "LEASE_LOST"
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -416,7 +524,9 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
break
|
break
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)})
|
debug_event("HISTORY_LOAD_ERROR", "history load failed", {"error": str(exc)})
|
||||||
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
|
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
|
||||||
|
exit_reason = "LEASE_LOST"
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1]
|
nifty_sma = nifty_hist.rolling(SMA_MONTHS).mean().iloc[-1]
|
||||||
@ -565,16 +675,32 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
logical_time=logical_time,
|
logical_time=logical_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
sleep_with_heartbeat(30, stop_event, scope_user, scope_run)
|
if not sleep_with_heartbeat(30, stop_event, scope_user, scope_run, owner_id):
|
||||||
|
exit_reason = "LEASE_LOST"
|
||||||
|
break
|
||||||
except BrokerAuthExpired as exc:
|
except BrokerAuthExpired as exc:
|
||||||
|
exit_reason = "AUTH_EXPIRED"
|
||||||
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
|
_pause_for_auth_expiry(scope_user, scope_run, str(exc), emit_event_cb=emit_event_cb)
|
||||||
print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True)
|
print(f"[ENGINE] broker auth expired for run {scope_run}: {exc}", flush=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
exit_reason = "ERROR"
|
||||||
_set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
_set_state(scope_user, scope_run, state="ERROR", last_heartbeat_ts=datetime.utcnow().isoformat() + "Z")
|
||||||
_update_engine_status(scope_user, scope_run, "ERROR")
|
_update_engine_status(scope_user, scope_run, "ERROR")
|
||||||
log_event("ENGINE_ERROR", {"error": str(e)})
|
log_event("ENGINE_ERROR", {"error": str(e)})
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
released = release_run_lease(scope_run, owner_id)
|
||||||
|
if released:
|
||||||
|
print(
|
||||||
|
f"[ENGINE] RUNNER_LEASE_RELEASED released run lease "
|
||||||
|
f"{{'run_id': '{scope_run}', 'owner_id': '{owner_id}'}}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if exit_reason not in {"ERROR", "LEASE_LOST", "AUTH_EXPIRED"}:
|
||||||
log_event("ENGINE_STOP")
|
log_event("ENGINE_STOP")
|
||||||
_set_state(
|
_set_state(
|
||||||
scope_user,
|
scope_user,
|
||||||
@ -584,6 +710,13 @@ def _engine_loop(config, stop_event: threading.Event):
|
|||||||
)
|
)
|
||||||
_update_engine_status(scope_user, scope_run, "STOPPED")
|
_update_engine_status(scope_user, scope_run, "STOPPED")
|
||||||
print("Strategy engine stopped")
|
print("Strategy engine stopped")
|
||||||
|
elif exit_reason == "LEASE_LOST":
|
||||||
|
_set_state(
|
||||||
|
scope_user,
|
||||||
|
scope_run,
|
||||||
|
state="STOPPED",
|
||||||
|
last_heartbeat_ts=datetime.utcnow().isoformat() + "Z",
|
||||||
|
)
|
||||||
_clear_runner(scope_user, scope_run)
|
_clear_runner(scope_user, scope_run)
|
||||||
|
|
||||||
def start_engine(config):
|
def start_engine(config):
|
||||||
@ -600,14 +733,53 @@ def start_engine(config):
|
|||||||
if runner and runner["thread"].is_alive():
|
if runner and runner["thread"].is_alive():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
lease = acquire_run_lease(
|
||||||
|
run_id,
|
||||||
|
RUNNER_OWNER_ID,
|
||||||
|
lease_seconds=RUN_LEASE_SECONDS,
|
||||||
|
)
|
||||||
|
if not lease.get("acquired"):
|
||||||
|
_log_runner_lease_event(
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
"RUNNER_LEASE_DENIED",
|
||||||
|
"Run lease denied",
|
||||||
|
{
|
||||||
|
"owner_id": RUNNER_OWNER_ID,
|
||||||
|
"current_owner": lease.get("owner_id"),
|
||||||
|
"expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise RunLeaseNotAcquiredError(run_id, RUNNER_OWNER_ID, lease)
|
||||||
|
|
||||||
|
lease_status = str(lease.get("status") or "ACQUIRED").upper()
|
||||||
|
event_name = "RUNNER_LEASE_REACQUIRED" if lease_status == "REACQUIRED" else "RUNNER_LEASE_ACQUIRED"
|
||||||
|
_log_runner_lease_event(
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
event_name,
|
||||||
|
"Run lease acquired" if lease_status != "REACQUIRED" else "Expired run lease reacquired",
|
||||||
|
{
|
||||||
|
"owner_id": RUNNER_OWNER_ID,
|
||||||
|
"expires_at": lease.get("expires_at").isoformat() if lease.get("expires_at") else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
thread_config = dict(config)
|
||||||
|
thread_config["runner_owner_id"] = RUNNER_OWNER_ID
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=_engine_loop,
|
target=_engine_loop,
|
||||||
args=(config, stop_event),
|
args=(thread_config, stop_event),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
_RUNNERS[key] = {"thread": thread, "stop_event": stop_event}
|
_RUNNERS[key] = {"thread": thread, "stop_event": stop_event}
|
||||||
|
try:
|
||||||
thread.start()
|
thread.start()
|
||||||
|
except Exception:
|
||||||
|
_RUNNERS.pop(key, None)
|
||||||
|
release_run_lease(run_id, RUNNER_OWNER_ID)
|
||||||
|
raise
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def stop_engine(user_id: str, run_id: str | None = None, timeout: float | None = 10.0):
|
def stop_engine(user_id: str, run_id: str | None = None, timeout: float | None = 10.0):
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
|
from indian_paper_trading_strategy.engine.db import db_connection, insert_engine_event, run_with_retry, get_context
|
||||||
from indian_paper_trading_strategy.engine.market import market_now
|
from indian_paper_trading_strategy.engine.time_utils import parse_persisted_timestamp, serialize_timestamp
|
||||||
|
|
||||||
DEFAULT_STATE = {
|
DEFAULT_STATE = {
|
||||||
"initial_cash": 0.0,
|
"initial_cash": 0.0,
|
||||||
@ -31,33 +31,8 @@ def _default_state(mode: str | None):
|
|||||||
return DEFAULT_PAPER_STATE.copy()
|
return DEFAULT_PAPER_STATE.copy()
|
||||||
return DEFAULT_STATE.copy()
|
return DEFAULT_STATE.copy()
|
||||||
|
|
||||||
def _local_tz():
|
|
||||||
return market_now().tzinfo
|
|
||||||
|
|
||||||
def _format_local_ts(value: datetime | None):
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return value.astimezone(_local_tz()).replace(tzinfo=None).isoformat()
|
|
||||||
|
|
||||||
def _parse_ts(value):
|
def _parse_ts(value):
|
||||||
if value is None:
|
return parse_persisted_timestamp(value)
|
||||||
return None
|
|
||||||
if isinstance(value, datetime):
|
|
||||||
if value.tzinfo is None:
|
|
||||||
return value.replace(tzinfo=_local_tz())
|
|
||||||
return value
|
|
||||||
if isinstance(value, str):
|
|
||||||
text = value.strip()
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
if parsed.tzinfo is None:
|
|
||||||
parsed = parsed.replace(tzinfo=_local_tz())
|
|
||||||
return parsed
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resolve_scope(user_id: str | None, run_id: str | None):
|
def _resolve_scope(user_id: str | None, run_id: str | None):
|
||||||
return get_context(user_id, run_id)
|
return get_context(user_id, run_id)
|
||||||
@ -106,8 +81,8 @@ def load_state(
|
|||||||
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
|
"total_invested": float(row[2]) if row[2] is not None else merged["total_invested"],
|
||||||
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
|
"nifty_units": float(row[3]) if row[3] is not None else merged["nifty_units"],
|
||||||
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
|
"gold_units": float(row[4]) if row[4] is not None else merged["gold_units"],
|
||||||
"last_sip_ts": _format_local_ts(row[5]),
|
"last_sip_ts": serialize_timestamp(row[5]),
|
||||||
"last_run": _format_local_ts(row[6]),
|
"last_run": serialize_timestamp(row[6]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if row[7] is not None or row[8] is not None:
|
if row[7] is not None or row[8] is not None:
|
||||||
@ -143,8 +118,8 @@ def load_state(
|
|||||||
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
|
"total_invested": float(row[0]) if row[0] is not None else merged["total_invested"],
|
||||||
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
|
"nifty_units": float(row[1]) if row[1] is not None else merged["nifty_units"],
|
||||||
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
|
"gold_units": float(row[2]) if row[2] is not None else merged["gold_units"],
|
||||||
"last_sip_ts": _format_local_ts(row[3]),
|
"last_sip_ts": serialize_timestamp(row[3]),
|
||||||
"last_run": _format_local_ts(row[4]),
|
"last_run": serialize_timestamp(row[4]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return merged
|
return merged
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
|
||||||
|
UTC = timezone.utc
|
||||||
|
MARKET_TZ = ZoneInfo("Asia/Kolkata")
|
||||||
|
|
||||||
|
|
||||||
def frequency_to_timedelta(freq: dict) -> timedelta:
|
def frequency_to_timedelta(freq: dict) -> timedelta:
|
||||||
@ -19,6 +24,46 @@ def normalize_logical_time(ts: datetime) -> datetime:
|
|||||||
return ts.replace(microsecond=0)
|
return ts.replace(microsecond=0)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_aware(ts: datetime, *, default_tz=UTC) -> datetime:
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
return ts.replace(tzinfo=default_tz)
|
||||||
|
return ts
|
||||||
|
|
||||||
|
|
||||||
|
def to_utc(ts: datetime, *, default_tz=UTC) -> datetime:
|
||||||
|
return ensure_aware(ts, default_tz=default_tz).astimezone(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_timestamp(ts: datetime | None, *, default_tz=UTC) -> str | None:
|
||||||
|
if ts is None:
|
||||||
|
return None
|
||||||
|
return to_utc(ts, default_tz=default_tz).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_persisted_timestamp(value: datetime | str | None, *, default_tz=UTC) -> datetime | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return to_utc(value, default_tz=default_tz)
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = datetime.fromisoformat(text.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return to_utc(parsed, default_tz=default_tz)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_market_timestamp(value: datetime | str | None) -> datetime | None:
|
||||||
|
parsed = parse_persisted_timestamp(value)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
return parsed.astimezone(MARKET_TZ)
|
||||||
|
|
||||||
|
|
||||||
def compute_logical_time(
|
def compute_logical_time(
|
||||||
now: datetime,
|
now: datetime,
|
||||||
last_run: str | None,
|
last_run: str | None,
|
||||||
@ -26,15 +71,12 @@ def compute_logical_time(
|
|||||||
) -> datetime:
|
) -> datetime:
|
||||||
base = now
|
base = now
|
||||||
if last_run and interval_seconds:
|
if last_run and interval_seconds:
|
||||||
try:
|
parsed = parse_persisted_timestamp(last_run, default_tz=now.tzinfo or UTC)
|
||||||
parsed = datetime.fromisoformat(last_run.replace("Z", "+00:00"))
|
|
||||||
except ValueError:
|
|
||||||
parsed = None
|
|
||||||
if parsed is not None:
|
if parsed is not None:
|
||||||
if now.tzinfo and parsed.tzinfo is None:
|
if now.tzinfo is None:
|
||||||
parsed = parsed.replace(tzinfo=now.tzinfo)
|
|
||||||
elif now.tzinfo is None and parsed.tzinfo:
|
|
||||||
parsed = parsed.replace(tzinfo=None)
|
parsed = parsed.replace(tzinfo=None)
|
||||||
|
else:
|
||||||
|
parsed = parsed.astimezone(now.tzinfo)
|
||||||
candidate = parsed + timedelta(seconds=interval_seconds)
|
candidate = parsed + timedelta(seconds=interval_seconds)
|
||||||
if now >= candidate:
|
if now >= candidate:
|
||||||
base = candidate
|
base = candidate
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user