import importlib from datetime import datetime, timedelta, timezone import pytest from fastapi import Response from fastapi.testclient import TestClient def test_legacy_password_hash_upgrades_on_successful_login(monkeypatch): import app.services.auth_service as auth_service legacy_hash = auth_service._hash_password_legacy("correct-horse-battery-staple") user = { "id": "user-1", "username": "user@example.com", "password": legacy_hash, "role": "USER", } updated = {} monkeypatch.setattr(auth_service, "get_user_by_username", lambda username: user if username == user["username"] else None) monkeypatch.setattr( auth_service, "_update_password_hash", lambda user_id, password_hash: updated.update({"user_id": user_id, "password_hash": password_hash}), ) authenticated = auth_service.authenticate_user(user["username"], "correct-horse-battery-staple") assert authenticated is user assert updated["user_id"] == "user-1" assert updated["password_hash"].startswith("$argon2id$") assert authenticated["password"] == updated["password_hash"] def test_secure_session_cookie_flags_are_enforced_in_production(monkeypatch): monkeypatch.setenv("APP_ENV", "production") monkeypatch.delenv("COOKIE_SECURE", raising=False) monkeypatch.setenv("COOKIE_SAMESITE", "strict") import app.routers.auth as auth_router importlib.reload(auth_router) response = Response() auth_router._set_session_cookie(response, "session-123") set_cookie = response.headers["set-cookie"].lower() assert "secure" in set_cookie assert "httponly" in set_cookie assert "samesite=strict" in set_cookie def test_missing_reset_otp_secret_fails_auth_service_import(monkeypatch): monkeypatch.delenv("RESET_OTP_SECRET", raising=False) import app.services.auth_service as auth_service with pytest.raises(RuntimeError, match="RESET_OTP_SECRET must be configured"): importlib.reload(auth_service) monkeypatch.setenv("RESET_OTP_SECRET", "test-reset-secret") importlib.reload(auth_service) def test_valid_broker_callback_state_allows_connect_callback(monkeypatch): import app.main as app_main import app.routers.broker as broker_router monkeypatch.setenv("APP_ENV", "test") monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") importlib.reload(app_main) app = app_main.create_app() client = TestClient(app) monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) monkeypatch.setattr( broker_router, "consume_broker_callback_state", lambda **kwargs: {"id": "state-1", "expires_at": datetime.now(timezone.utc).isoformat()}, ) monkeypatch.setattr( broker_router, "get_pending_broker", lambda _user_id: {"api_key": "kite-key", "api_secret": "kite-secret"}, ) monkeypatch.setattr( broker_router, "exchange_request_token", lambda api_key, api_secret, token: { "access_token": "access-token", "request_token": token, "user_name": "Trader", "user_id": "Z123", }, ) captured = {} monkeypatch.setattr( broker_router, "set_zerodha_session", lambda user_id, payload: captured.setdefault("zerodha_session", {"user_id": user_id, **payload}), ) monkeypatch.setattr( broker_router, "set_connected_broker", lambda user_id, broker, token, **kwargs: captured.setdefault( "connected", {"user_id": user_id, "broker": broker, "token": token, **kwargs}, ), ) response = client.get( "/api/broker/zerodha/callback", params={"request_token": "request-token", "state": "valid-state"}, cookies={"session_id": "session-1"}, ) assert response.status_code == 200 assert response.json() == { "connected": True, "userName": "Trader", "brokerUserId": "Z123", } assert captured["connected"]["broker"] == "ZERODHA" assert captured["connected"]["user_id"] == "user-1" def test_missing_broker_callback_state_fails(monkeypatch): import app.main as app_main import app.routers.broker as broker_router monkeypatch.setenv("APP_ENV", "test") monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") importlib.reload(app_main) app = app_main.create_app() client = TestClient(app) monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) response = client.get( "/api/broker/zerodha/callback", params={"request_token": "request-token"}, cookies={"session_id": "session-1"}, ) assert response.status_code == 400 assert response.json() == {"detail": "Missing state"} def test_wrong_or_expired_broker_callback_state_fails(monkeypatch): import app.main as app_main import app.routers.broker as broker_router monkeypatch.setenv("APP_ENV", "test") monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1") monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000") importlib.reload(app_main) app = app_main.create_app() client = TestClient(app) monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"}) monkeypatch.setattr(broker_router, "consume_broker_callback_state", lambda **kwargs: None) response = client.get( "/api/broker/zerodha/callback", params={"request_token": "request-token", "state": "wrong-or-expired"}, cookies={"session_id": "session-1"}, ) assert response.status_code == 401 assert response.json() == {"detail": "Invalid or expired broker callback state"} def test_broker_callback_state_service_rejects_wrong_user_and_expired_state(monkeypatch): import app.services.broker_callback_state as callback_state now = datetime(2026, 4, 8, 9, 0, tzinfo=timezone.utc) rows = [ { "state_hash": callback_state._state_hash("valid-state"), "user_id": "user-1", "session_id": "session-1", "broker": "ZERODHA", "flow": "connect", "expires_at": now + timedelta(minutes=5), "consumed_at": None, }, { "state_hash": callback_state._state_hash("expired-state"), "user_id": "user-1", "session_id": "session-1", "broker": "ZERODHA", "flow": "connect", "expires_at": now - timedelta(seconds=1), "consumed_at": None, }, ] class FakeCursor: def __init__(self): self._result = None def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def execute(self, sql, params): if "UPDATE broker_callback_state" not in sql: raise AssertionError("Unexpected SQL") consumed_at, state_hash, user_id, session_id, broker, flow, now_param = params self._result = None for row in rows: if ( row["state_hash"] == state_hash and row["user_id"] == user_id and row["session_id"] == session_id and row["broker"] == broker and row["flow"] == flow and row["consumed_at"] is None and row["expires_at"] > now_param ): row["consumed_at"] = consumed_at self._result = ("row-id", row["expires_at"]) break def fetchone(self): return self._result class FakeConnection: def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def cursor(self): return FakeCursor() monkeypatch.setattr(callback_state, "_now_utc", lambda: now) monkeypatch.setattr(callback_state, "db_connection", lambda: FakeConnection()) assert callback_state.consume_broker_callback_state( state="valid-state", user_id="user-2", session_id="session-1", broker="ZERODHA", flow="connect", ) is None assert callback_state.consume_broker_callback_state( state="expired-state", user_id="user-1", session_id="session-1", broker="ZERODHA", flow="connect", ) is None assert callback_state.consume_broker_callback_state( state="valid-state", user_id="user-1", session_id="session-1", broker="ZERODHA", flow="connect", ) == { "id": "row-id", "expires_at": (now + timedelta(minutes=5)).isoformat(), }