270 lines
8.8 KiB
Python
270 lines
8.8 KiB
Python
import importlib
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
from fastapi import Response
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
def test_legacy_password_hash_upgrades_on_successful_login(monkeypatch):
|
|
import app.services.auth_service as auth_service
|
|
|
|
legacy_hash = auth_service._hash_password_legacy("correct-horse-battery-staple")
|
|
user = {
|
|
"id": "user-1",
|
|
"username": "user@example.com",
|
|
"password": legacy_hash,
|
|
"role": "USER",
|
|
}
|
|
updated = {}
|
|
|
|
monkeypatch.setattr(auth_service, "get_user_by_username", lambda username: user if username == user["username"] else None)
|
|
monkeypatch.setattr(
|
|
auth_service,
|
|
"_update_password_hash",
|
|
lambda user_id, password_hash: updated.update({"user_id": user_id, "password_hash": password_hash}),
|
|
)
|
|
|
|
authenticated = auth_service.authenticate_user(user["username"], "correct-horse-battery-staple")
|
|
|
|
assert authenticated is user
|
|
assert updated["user_id"] == "user-1"
|
|
assert updated["password_hash"].startswith("$argon2id$")
|
|
assert authenticated["password"] == updated["password_hash"]
|
|
|
|
|
|
def test_secure_session_cookie_flags_are_enforced_in_production(monkeypatch):
|
|
monkeypatch.setenv("APP_ENV", "production")
|
|
monkeypatch.delenv("COOKIE_SECURE", raising=False)
|
|
monkeypatch.setenv("COOKIE_SAMESITE", "strict")
|
|
|
|
import app.routers.auth as auth_router
|
|
|
|
importlib.reload(auth_router)
|
|
|
|
response = Response()
|
|
auth_router._set_session_cookie(response, "session-123")
|
|
set_cookie = response.headers["set-cookie"].lower()
|
|
|
|
assert "secure" in set_cookie
|
|
assert "httponly" in set_cookie
|
|
assert "samesite=strict" in set_cookie
|
|
|
|
|
|
def test_missing_reset_otp_secret_fails_auth_service_import(monkeypatch):
|
|
monkeypatch.delenv("RESET_OTP_SECRET", raising=False)
|
|
|
|
import app.services.auth_service as auth_service
|
|
|
|
with pytest.raises(RuntimeError, match="RESET_OTP_SECRET must be configured"):
|
|
importlib.reload(auth_service)
|
|
|
|
monkeypatch.setenv("RESET_OTP_SECRET", "test-reset-secret")
|
|
importlib.reload(auth_service)
|
|
|
|
|
|
def test_valid_broker_callback_state_allows_connect_callback(monkeypatch):
|
|
import app.main as app_main
|
|
import app.routers.broker as broker_router
|
|
|
|
monkeypatch.setenv("APP_ENV", "test")
|
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
|
importlib.reload(app_main)
|
|
app = app_main.create_app()
|
|
client = TestClient(app)
|
|
|
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
|
monkeypatch.setattr(
|
|
broker_router,
|
|
"consume_broker_callback_state",
|
|
lambda **kwargs: {"id": "state-1", "expires_at": datetime.now(timezone.utc).isoformat()},
|
|
)
|
|
monkeypatch.setattr(
|
|
broker_router,
|
|
"get_pending_broker",
|
|
lambda _user_id: {"api_key": "kite-key", "api_secret": "kite-secret"},
|
|
)
|
|
monkeypatch.setattr(
|
|
broker_router,
|
|
"exchange_request_token",
|
|
lambda api_key, api_secret, token: {
|
|
"access_token": "access-token",
|
|
"request_token": token,
|
|
"user_name": "Trader",
|
|
"user_id": "Z123",
|
|
},
|
|
)
|
|
captured = {}
|
|
monkeypatch.setattr(
|
|
broker_router,
|
|
"set_zerodha_session",
|
|
lambda user_id, payload: captured.setdefault("zerodha_session", {"user_id": user_id, **payload}),
|
|
)
|
|
monkeypatch.setattr(
|
|
broker_router,
|
|
"set_connected_broker",
|
|
lambda user_id, broker, token, **kwargs: captured.setdefault(
|
|
"connected",
|
|
{"user_id": user_id, "broker": broker, "token": token, **kwargs},
|
|
),
|
|
)
|
|
|
|
response = client.get(
|
|
"/api/broker/zerodha/callback",
|
|
params={"request_token": "request-token", "state": "valid-state"},
|
|
cookies={"session_id": "session-1"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"connected": True,
|
|
"userName": "Trader",
|
|
"brokerUserId": "Z123",
|
|
}
|
|
assert captured["connected"]["broker"] == "ZERODHA"
|
|
assert captured["connected"]["user_id"] == "user-1"
|
|
|
|
|
|
def test_missing_broker_callback_state_fails(monkeypatch):
|
|
import app.main as app_main
|
|
import app.routers.broker as broker_router
|
|
|
|
monkeypatch.setenv("APP_ENV", "test")
|
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
|
importlib.reload(app_main)
|
|
app = app_main.create_app()
|
|
client = TestClient(app)
|
|
|
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
|
|
|
response = client.get(
|
|
"/api/broker/zerodha/callback",
|
|
params={"request_token": "request-token"},
|
|
cookies={"session_id": "session-1"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert response.json() == {"detail": "Missing state"}
|
|
|
|
|
|
def test_wrong_or_expired_broker_callback_state_fails(monkeypatch):
|
|
import app.main as app_main
|
|
import app.routers.broker as broker_router
|
|
|
|
monkeypatch.setenv("APP_ENV", "test")
|
|
monkeypatch.setenv("DISABLE_STARTUP_TASKS", "1")
|
|
monkeypatch.setenv("CORS_ORIGINS", "http://localhost:3000")
|
|
importlib.reload(app_main)
|
|
app = app_main.create_app()
|
|
client = TestClient(app)
|
|
|
|
monkeypatch.setattr(broker_router, "get_user_for_session", lambda _sid: {"id": "user-1", "username": "user@example.com"})
|
|
monkeypatch.setattr(broker_router, "consume_broker_callback_state", lambda **kwargs: None)
|
|
|
|
response = client.get(
|
|
"/api/broker/zerodha/callback",
|
|
params={"request_token": "request-token", "state": "wrong-or-expired"},
|
|
cookies={"session_id": "session-1"},
|
|
)
|
|
|
|
assert response.status_code == 401
|
|
assert response.json() == {"detail": "Invalid or expired broker callback state"}
|
|
|
|
|
|
def test_broker_callback_state_service_rejects_wrong_user_and_expired_state(monkeypatch):
|
|
import app.services.broker_callback_state as callback_state
|
|
|
|
now = datetime(2026, 4, 8, 9, 0, tzinfo=timezone.utc)
|
|
rows = [
|
|
{
|
|
"state_hash": callback_state._state_hash("valid-state"),
|
|
"user_id": "user-1",
|
|
"session_id": "session-1",
|
|
"broker": "ZERODHA",
|
|
"flow": "connect",
|
|
"expires_at": now + timedelta(minutes=5),
|
|
"consumed_at": None,
|
|
},
|
|
{
|
|
"state_hash": callback_state._state_hash("expired-state"),
|
|
"user_id": "user-1",
|
|
"session_id": "session-1",
|
|
"broker": "ZERODHA",
|
|
"flow": "connect",
|
|
"expires_at": now - timedelta(seconds=1),
|
|
"consumed_at": None,
|
|
},
|
|
]
|
|
|
|
class FakeCursor:
|
|
def __init__(self):
|
|
self._result = None
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def execute(self, sql, params):
|
|
if "UPDATE broker_callback_state" not in sql:
|
|
raise AssertionError("Unexpected SQL")
|
|
consumed_at, state_hash, user_id, session_id, broker, flow, now_param = params
|
|
self._result = None
|
|
for row in rows:
|
|
if (
|
|
row["state_hash"] == state_hash
|
|
and row["user_id"] == user_id
|
|
and row["session_id"] == session_id
|
|
and row["broker"] == broker
|
|
and row["flow"] == flow
|
|
and row["consumed_at"] is None
|
|
and row["expires_at"] > now_param
|
|
):
|
|
row["consumed_at"] = consumed_at
|
|
self._result = ("row-id", row["expires_at"])
|
|
break
|
|
|
|
def fetchone(self):
|
|
return self._result
|
|
|
|
class FakeConnection:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def cursor(self):
|
|
return FakeCursor()
|
|
|
|
monkeypatch.setattr(callback_state, "_now_utc", lambda: now)
|
|
monkeypatch.setattr(callback_state, "db_connection", lambda: FakeConnection())
|
|
|
|
assert callback_state.consume_broker_callback_state(
|
|
state="valid-state",
|
|
user_id="user-2",
|
|
session_id="session-1",
|
|
broker="ZERODHA",
|
|
flow="connect",
|
|
) is None
|
|
assert callback_state.consume_broker_callback_state(
|
|
state="expired-state",
|
|
user_id="user-1",
|
|
session_id="session-1",
|
|
broker="ZERODHA",
|
|
flow="connect",
|
|
) is None
|
|
assert callback_state.consume_broker_callback_state(
|
|
state="valid-state",
|
|
user_id="user-1",
|
|
session_id="session-1",
|
|
broker="ZERODHA",
|
|
flow="connect",
|
|
) == {
|
|
"id": "row-id",
|
|
"expires_at": (now + timedelta(minutes=5)).isoformat(),
|
|
}
|