206 lines
7.9 KiB
Python
206 lines
7.9 KiB
Python
from __future__ import annotations
|
|
|
|
import threading
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from indian_paper_trading_strategy.engine import db as engine_db
|
|
from indian_paper_trading_strategy.engine import runner
|
|
|
|
|
|
class _LeaseCursor:
|
|
def __init__(self, leases: dict[str, dict], lock: threading.Lock):
|
|
self._leases = leases
|
|
self._lock = lock
|
|
self._result = None
|
|
|
|
def execute(self, sql, params):
|
|
sql_text = " ".join(sql.split())
|
|
with self._lock:
|
|
if sql_text.startswith("INSERT INTO run_leases"):
|
|
run_id, owner_id, leased_at, expires_at, heartbeat_at = params
|
|
lease = self._leases.get(run_id)
|
|
if lease is None:
|
|
self._leases[run_id] = {
|
|
"owner_id": owner_id,
|
|
"leased_at": leased_at,
|
|
"expires_at": expires_at,
|
|
"heartbeat_at": heartbeat_at,
|
|
}
|
|
self._result = (run_id,)
|
|
else:
|
|
self._result = None
|
|
return
|
|
|
|
if "SELECT owner_id, expires_at FROM run_leases" in sql_text:
|
|
run_id = params[0]
|
|
lease = self._leases.get(run_id)
|
|
self._result = None if lease is None else (lease["owner_id"], lease["expires_at"])
|
|
return
|
|
|
|
if sql_text.startswith("UPDATE run_leases SET leased_at"):
|
|
leased_at, expires_at, heartbeat_at, run_id, owner_id = params
|
|
lease = self._leases.get(run_id)
|
|
if lease and lease["owner_id"] == owner_id:
|
|
lease.update(
|
|
{
|
|
"owner_id": owner_id,
|
|
"leased_at": leased_at,
|
|
"expires_at": expires_at,
|
|
"heartbeat_at": heartbeat_at,
|
|
}
|
|
)
|
|
self._result = (run_id,)
|
|
else:
|
|
self._result = None
|
|
return
|
|
|
|
if sql_text.startswith("UPDATE run_leases SET owner_id"):
|
|
owner_id, leased_at, expires_at, heartbeat_at, run_id, current_time = params
|
|
lease = self._leases.get(run_id)
|
|
if lease and lease["expires_at"] <= current_time:
|
|
lease.update(
|
|
{
|
|
"owner_id": owner_id,
|
|
"leased_at": leased_at,
|
|
"expires_at": expires_at,
|
|
"heartbeat_at": heartbeat_at,
|
|
}
|
|
)
|
|
self._result = (run_id,)
|
|
else:
|
|
self._result = None
|
|
return
|
|
|
|
if sql_text.startswith("UPDATE run_leases SET heartbeat_at"):
|
|
heartbeat_at, expires_at, run_id, owner_id, current_time = params
|
|
lease = self._leases.get(run_id)
|
|
if lease and lease["owner_id"] == owner_id and lease["expires_at"] > current_time:
|
|
lease.update({"heartbeat_at": heartbeat_at, "expires_at": expires_at})
|
|
self._result = (run_id, expires_at)
|
|
else:
|
|
self._result = None
|
|
return
|
|
|
|
if sql_text.startswith("DELETE FROM run_leases"):
|
|
run_id, owner_id = params
|
|
lease = self._leases.get(run_id)
|
|
if lease and lease["owner_id"] == owner_id:
|
|
del self._leases[run_id]
|
|
self._result = (run_id,)
|
|
else:
|
|
self._result = None
|
|
return
|
|
|
|
raise AssertionError(f"Unexpected SQL: {sql_text}")
|
|
|
|
def fetchone(self):
|
|
return self._result
|
|
|
|
|
|
def _patch_lease_storage(monkeypatch):
|
|
leases: dict[str, dict] = {}
|
|
lock = threading.Lock()
|
|
|
|
def fake_run_with_retry(operation, retries=None, delay=None):
|
|
cursor = _LeaseCursor(leases, lock)
|
|
return operation(cursor, None)
|
|
|
|
monkeypatch.setattr(engine_db, "run_with_retry", fake_run_with_retry)
|
|
return leases
|
|
|
|
|
|
def test_run_lease_allows_only_one_active_owner(monkeypatch):
|
|
_patch_lease_storage(monkeypatch)
|
|
now = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
|
results = []
|
|
|
|
def attempt(owner_id: str):
|
|
results.append(engine_db.acquire_run_lease("run-1", owner_id, lease_seconds=90, now=now))
|
|
|
|
threads = [
|
|
threading.Thread(target=attempt, args=("owner-a",)),
|
|
threading.Thread(target=attempt, args=("owner-b",)),
|
|
]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
acquired = [result for result in results if result["acquired"]]
|
|
denied = [result for result in results if not result["acquired"]]
|
|
assert len(acquired) == 1
|
|
assert len(denied) == 1
|
|
assert denied[0]["status"] == "DENIED"
|
|
|
|
|
|
def test_expired_run_lease_can_be_reacquired(monkeypatch):
|
|
_patch_lease_storage(monkeypatch)
|
|
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
|
later = start + timedelta(seconds=91)
|
|
|
|
first = engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
|
|
second = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=later)
|
|
|
|
assert first["status"] == "ACQUIRED"
|
|
assert second["acquired"] is True
|
|
assert second["status"] == "REACQUIRED"
|
|
assert second["previous_owner"] == "owner-a"
|
|
|
|
|
|
def test_heartbeat_prevents_takeover(monkeypatch):
|
|
_patch_lease_storage(monkeypatch)
|
|
start = datetime(2026, 4, 8, 4, 0, tzinfo=timezone.utc)
|
|
heartbeat_time = start + timedelta(seconds=30)
|
|
challenger_time = start + timedelta(seconds=60)
|
|
|
|
engine_db.acquire_run_lease("run-1", "owner-a", lease_seconds=90, now=start)
|
|
heartbeat = engine_db.heartbeat_run_lease("run-1", "owner-a", lease_seconds=90, now=heartbeat_time)
|
|
challenger = engine_db.acquire_run_lease("run-1", "owner-b", lease_seconds=90, now=challenger_time)
|
|
|
|
assert heartbeat["active"] is True
|
|
assert challenger["acquired"] is False
|
|
assert challenger["status"] == "DENIED"
|
|
|
|
|
|
def test_runner_exits_cleanly_when_lease_is_lost(monkeypatch):
|
|
statuses: list[str] = []
|
|
cleared: list[tuple[str, str]] = []
|
|
released: list[tuple[str, str]] = []
|
|
refreshed = {"count": 0}
|
|
|
|
monkeypatch.setattr(runner, "get_context", lambda user_id=None, run_id=None: ("user-1", "run-1"))
|
|
monkeypatch.setattr(runner, "set_context", lambda user_id, run_id: None)
|
|
monkeypatch.setattr(runner, "log_event", lambda *args, **kwargs: None)
|
|
monkeypatch.setattr(runner, "PaperBroker", lambda initial_cash=0: object())
|
|
monkeypatch.setattr(runner, "_set_state", lambda *args, **kwargs: None)
|
|
monkeypatch.setattr(runner, "_clear_runner", lambda user_id, run_id: cleared.append((user_id, run_id)))
|
|
monkeypatch.setattr(runner, "_update_engine_status", lambda user_id, run_id, status: statuses.append(status))
|
|
monkeypatch.setattr(runner, "release_run_lease", lambda run_id, owner_id: released.append((run_id, owner_id)) or True)
|
|
monkeypatch.setattr(runner, "_log_runner_lease_event", lambda *args, **kwargs: None)
|
|
|
|
def fake_refresh(user_id, run_id, owner_id):
|
|
refreshed["count"] += 1
|
|
return False
|
|
|
|
monkeypatch.setattr(runner, "_refresh_run_lease_or_stop", fake_refresh)
|
|
|
|
stop_event = threading.Event()
|
|
runner._engine_loop(
|
|
{
|
|
"user_id": "user-1",
|
|
"run_id": "run-1",
|
|
"strategy": "Golden Nifty",
|
|
"sip_amount": 1000,
|
|
"sip_frequency": {"value": 2, "unit": "minutes"},
|
|
"mode": "PAPER",
|
|
"broker": "paper",
|
|
"runner_owner_id": "owner-1",
|
|
},
|
|
stop_event,
|
|
)
|
|
|
|
assert refreshed["count"] >= 1
|
|
assert statuses == ["RUNNING"]
|
|
assert cleared == [("user-1", "run-1")]
|
|
assert released == [("run-1", "owner-1")]
|