127 lines
4.0 KiB
Python

import os
from urllib.parse import urlparse
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.admin_role_service import bootstrap_super_admin
from app.admin_router import router as admin_router
from app.routers.auth import router as auth_router
from app.routers.broker import router as broker_router
from app.routers.health import router as health_router
from app.routers.paper import router as paper_router
from app.routers.password_reset import router as password_reset_router
from app.routers.strategy import router as strategy_router
from app.routers.support_ticket import router as support_ticket_router
from app.routers.system import router as system_router
from app.routers.zerodha import router as zerodha_router, public_router as zerodha_public_router
from app.services.db import _db_config as _validate_db_config
from app.services.live_equity_service import start_live_equity_snapshot_daemon
from app.services.strategy_service import init_log_state, resume_running_runs
from market import router as market_router
from paper_mtm import router as paper_mtm_router
DEFAULT_PRODUCTION_ORIGINS = {"https://app.quantfortune.com"}
DEFAULT_DEV_ORIGINS = {
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:5173",
"http://127.0.0.1:5173",
}
PRODUCTION_ENV_NAMES = {"prod", "production"}
def _environment_name() -> str:
return (
os.getenv("APP_ENV")
or os.getenv("ENVIRONMENT")
or os.getenv("FASTAPI_ENV")
or "development"
).strip().lower()
def _normalize_origin(origin: str) -> str:
return origin.strip().rstrip("/")
def _is_dev_origin(origin: str) -> bool:
parsed = urlparse(origin)
return parsed.scheme == "http" and parsed.hostname in {"localhost", "127.0.0.1"}
def _validate_cors_origin(origin: str) -> str:
normalized = _normalize_origin(origin)
if not normalized:
raise RuntimeError("Empty CORS origin is not allowed")
if normalized in DEFAULT_PRODUCTION_ORIGINS or _is_dev_origin(normalized):
return normalized
raise RuntimeError(
f"Unsupported CORS origin '{normalized}'. Only app.quantfortune.com and localhost dev origins are allowed."
)
def _build_cors_origins() -> list[str]:
configured = [
_normalize_origin(origin)
for origin in os.getenv("CORS_ORIGINS", "").split(",")
if origin.strip()
]
env_name = _environment_name()
if env_name in PRODUCTION_ENV_NAMES:
if not configured:
raise RuntimeError("CORS_ORIGINS must be configured explicitly in production")
origins = configured
else:
origins = configured or sorted(DEFAULT_DEV_ORIGINS)
deduped: list[str] = []
seen: set[str] = set()
for origin in origins:
validated = _validate_cors_origin(origin)
if validated not in seen:
seen.add(validated)
deduped.append(validated)
return deduped
def create_app() -> FastAPI:
_validate_db_config()
app = FastAPI(title="QuantFortune Backend", version="1.0")
cors_origins = _build_cors_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(strategy_router)
app.include_router(auth_router)
app.include_router(broker_router)
app.include_router(zerodha_router)
app.include_router(zerodha_public_router)
app.include_router(paper_router)
app.include_router(market_router)
app.include_router(paper_mtm_router)
app.include_router(health_router)
app.include_router(system_router)
app.include_router(admin_router)
app.include_router(support_ticket_router)
app.include_router(password_reset_router)
@app.on_event("startup")
def init_app_state():
if os.getenv("DISABLE_STARTUP_TASKS", "0") == "1":
return
init_log_state()
bootstrap_super_admin()
resume_running_runs()
start_live_equity_snapshot_daemon()
return app
app = create_app()