From 28ec6c9a4d19a0af1396e89c398b9cb523a29bef Mon Sep 17 00:00:00 2001 From: Thigazhezhilan J Date: Sun, 5 Apr 2026 19:42:08 +0530 Subject: [PATCH] Add Groww live broker integration --- backend/app/routers/broker.py | 454 +++++++++++++++++- backend/app/services/groww_service.py | 355 ++++++++++++++ backend/app/services/groww_storage.py | 30 ++ backend/app/services/live_equity_service.py | 117 ++++- backend/app/services/strategy_service.py | 56 ++- backend/app/services/system_service.py | 29 +- .../engine/broker.py | 388 +++++++++++++++ .../engine/runner.py | 9 +- 8 files changed, 1394 insertions(+), 44 deletions(-) create mode 100644 backend/app/services/groww_service.py create mode 100644 backend/app/services/groww_storage.py diff --git a/backend/app/routers/broker.py b/backend/app/routers/broker.py index b41a4f4..e9d757d 100644 --- a/backend/app/routers/broker.py +++ b/backend/app/routers/broker.py @@ -1,10 +1,12 @@ import os +from datetime import datetime, timedelta -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import RedirectResponse from app.broker_store import ( clear_user_broker, + expire_user_broker_session, get_broker_credentials, get_pending_broker, get_user_broker, @@ -13,9 +15,33 @@ from app.broker_store import ( set_pending_broker, ) from app.services.auth_service import get_user_for_session -from app.services.zerodha_service import build_login_url, exchange_request_token from app.services.email_service import send_email_async -from app.services.zerodha_storage import set_session +from app.services.groww_service import ( + GrowwApiError, + GrowwTokenError, + fetch_funds as fetch_groww_funds, + fetch_holdings as fetch_groww_holdings, + fetch_ltp as fetch_groww_ltp, + fetch_profile as fetch_groww_profile, + generate_access_token, + normalize_holding as normalize_groww_holding, +) +from app.services.groww_storage import get_session as get_groww_session +from app.services.live_equity_service import capture_live_equity_snapshot, get_live_equity_curve +from app.services.zerodha_service import ( + KiteApiError, + KiteTokenError, + build_login_url, + exchange_request_token, + fetch_funds as fetch_zerodha_funds, + fetch_holdings as fetch_zerodha_holdings, + normalize_holding as normalize_zerodha_holding, +) +from app.services.zerodha_storage import ( + clear_session as clear_zerodha_session, + get_session as get_zerodha_session, + set_session as set_zerodha_session, +) router = APIRouter(prefix="/api/broker") @@ -30,14 +56,226 @@ def _require_user(request: Request): return user +def _first_number(*values, default: float = 0.0) -> float: + for value in values: + try: + if value is None or value == "": + continue + return float(value) + except (TypeError, ValueError): + continue + return float(default) + + +def _first_text(*values, default: str = "") -> str: + for value in values: + if value is None: + continue + text = str(value).strip() + if text: + return text + return default + + +def _clear_zerodha_broker_session(user_id: str): + expire_user_broker_session(user_id) + clear_zerodha_session(user_id) + + +def _raise_zerodha_error(user_id: str, exc: KiteApiError): + if isinstance(exc, KiteTokenError): + _clear_zerodha_broker_session(user_id) + raise HTTPException( + status_code=401, + detail="Zerodha session expired. Please reconnect.", + ) from exc + raise HTTPException(status_code=502, detail=str(exc)) from exc + + +def _raise_groww_error(user_id: str, exc: GrowwApiError): + if isinstance(exc, GrowwTokenError): + expire_user_broker_session(user_id) + raise HTTPException( + status_code=401, + detail="Groww session expired. Please reconnect.", + ) from exc + raise HTTPException(status_code=502, detail=str(exc)) from exc + + +def _resolve_connected_broker(user_id: str): + entry = get_user_broker(user_id) or {} + broker_name = (entry.get("broker") or "").strip().upper() + if not entry.get("connected") or not broker_name: + raise HTTPException(status_code=400, detail="Broker is not connected") + return entry, broker_name + + +def _groww_access_token(payload: dict | None) -> str: + entry = payload or {} + return _first_text( + entry.get("access_token"), + entry.get("accessToken"), + entry.get("token"), + entry.get("jwt_token"), + entry.get("jwtToken"), + default="", + ) + + +def _groww_user_name(profile: dict | None) -> str | None: + value = _first_text( + (profile or {}).get("user_name"), + (profile or {}).get("full_name"), + (profile or {}).get("name"), + (profile or {}).get("display_name"), + default="", + ) + return value or None + + +def _groww_user_id(profile: dict | None) -> str | None: + value = _first_text( + (profile or {}).get("user_id"), + (profile or {}).get("client_id"), + (profile or {}).get("customer_id"), + (profile or {}).get("account_id"), + default="", + ) + return value or None + + +def _groww_holding_tradingsymbol(item: dict | None) -> str: + return _first_text( + (item or {}).get("tradingsymbol"), + (item or {}).get("trading_symbol"), + (item or {}).get("symbol"), + (item or {}).get("instrument_name"), + default="", + ) + + +def _groww_holding_exchange(item: dict | None) -> str: + exchange = _first_text( + (item or {}).get("exchange"), + (item or {}).get("exchange_segment"), + (item or {}).get("exchange_name"), + default="NSE", + ).upper() + if exchange in {"NSE", "BSE"}: + return exchange + if "BSE" in exchange: + return "BSE" + return "NSE" + + +def _groww_holding_segment(item: dict | None) -> str: + segment = _first_text( + (item or {}).get("segment"), + (item or {}).get("product_segment"), + default="CASH", + ).upper() + return segment or "CASH" + + +def _fetch_normalized_groww_holdings(access_token: str) -> list[dict]: + items = fetch_groww_holdings(access_token) + holdings: list[dict] = [] + for item in items: + normalized = normalize_groww_holding(item) + tradingsymbol = _groww_holding_tradingsymbol(normalized) + exchange = _groww_holding_exchange(normalized) + segment = _groww_holding_segment(normalized) + if tradingsymbol and not normalized.get("last_price"): + try: + ltp_data = fetch_groww_ltp( + access_token, + exchange=exchange, + segment=segment, + trading_symbol=tradingsymbol, + ) + normalized["last_price"] = _first_number( + ltp_data.get("ltp"), + ltp_data.get("last_price"), + ltp_data.get("price"), + normalized.get("last_price"), + default=0.0, + ) + normalized["close_price"] = normalized["last_price"] + normalized["holding_value"] = normalized.get("effective_quantity", 0) * normalized["last_price"] + normalized["display_pnl"] = normalized.get("effective_quantity", 0) * ( + normalized["last_price"] - normalized.get("average_price", 0) + ) + except GrowwApiError: + pass + holdings.append(normalized) + return holdings + + +def _normalize_groww_funds(data: dict | None) -> dict: + payload = data if isinstance(data, dict) else {} + available = payload.get("available") if isinstance(payload.get("available"), dict) else {} + equity = payload.get("equity") if isinstance(payload.get("equity"), dict) else {} + equity_available = equity.get("available") if isinstance(equity.get("available"), dict) else {} + + cash = _first_number( + payload.get("cash"), + payload.get("available_cash"), + payload.get("available_balance"), + available.get("cash"), + available.get("available_cash"), + available.get("balance"), + equity.get("cash"), + equity_available.get("cash"), + equity_available.get("live_balance"), + ) + net = _first_number( + payload.get("net"), + payload.get("total"), + payload.get("margin_available"), + equity.get("net"), + cash, + ) + withdrawable = _first_number( + payload.get("withdrawable"), + payload.get("available_to_withdraw"), + available.get("withdrawable"), + cash, + ) + balance = _first_number( + payload.get("balance"), + payload.get("available_balance"), + available.get("balance"), + cash, + ) + + return { + "net": net, + "cash": cash, + "withdrawable": withdrawable, + "balance": balance, + "available": { + "live_balance": cash, + "cash": cash, + "opening_balance": balance, + }, + "raw": payload, + } + + def _build_saved_broker_login_url( request: Request, user_id: str, redirect_url_override: str | None = None, ) -> str: + entry = get_user_broker(user_id) or {} + broker_name = (entry.get("broker") or "").strip().upper() + if broker_name and broker_name != "ZERODHA": + raise HTTPException(status_code=400, detail="Saved login is only available for Zerodha") + creds = get_broker_credentials(user_id) if not creds: raise HTTPException(status_code=400, detail="Broker credentials not configured") + redirect_url = (redirect_url_override or os.getenv("ZERODHA_REDIRECT_URL") or "").strip() if not redirect_url: base = str(request.base_url).rstrip("/") @@ -45,6 +283,18 @@ def _build_saved_broker_login_url( return build_login_url(creds["api_key"], redirect_url=redirect_url) +def _notify_broker_connected(username: str, broker: str, broker_user_id: str | None): + try: + body = ( + "Your broker has been connected to Quantfortune.\n\n" + f"Broker: {broker}\n" + f"Broker User ID: {broker_user_id or 'N/A'}\n" + ) + send_email_async(username, "Broker connected", body) + except Exception: + pass + + @router.post("/connect") async def connect_broker(payload: dict, request: Request): user = _require_user(request) @@ -62,15 +312,7 @@ async def connect_broker(payload: dict, request: Request): user_name=user_name or None, broker_user_id=broker_user_id or None, ) - try: - body = ( - "Your broker has been connected to Quantfortune.\n\n" - f"Broker: {broker}\n" - f"Broker User ID: {broker_user_id or 'N/A'}\n" - ) - send_email_async(user["username"], "Broker connected", body) - except Exception: - pass + _notify_broker_connected(user["username"], broker, broker_user_id or None) return {"connected": True} @@ -94,6 +336,7 @@ async def broker_status(request: Request): async def disconnect_broker(request: Request): user = _require_user(request) clear_user_broker(user["id"]) + clear_zerodha_session(user["id"]) set_broker_auth_state(user["id"], "DISCONNECTED") try: body = "Your broker connection has been disconnected from Quantfortune." @@ -116,6 +359,84 @@ async def zerodha_login(payload: dict, request: Request): return {"loginUrl": build_login_url(api_key, redirect_url=redirect_url or None)} +@router.post("/groww/connect") +async def groww_connect(payload: dict, request: Request): + user = _require_user(request) + api_key = (payload.get("apiKey") or "").strip() + api_secret = (payload.get("apiSecret") or "").strip() + if not api_key or not api_secret: + raise HTTPException(status_code=400, detail="API key and secret are required") + + try: + token_payload = generate_access_token(api_key, api_secret) + access_token = _groww_access_token(token_payload) + if not access_token: + raise HTTPException(status_code=502, detail="Groww did not return an access token") + profile = fetch_groww_profile(access_token) + except GrowwApiError as exc: + _raise_groww_error(user["id"], exc) + + user_name = _groww_user_name(profile) + broker_user_id = _groww_user_id(profile) + set_connected_broker( + user["id"], + "GROWW", + access_token, + api_key=api_key, + api_secret=api_secret, + user_name=user_name, + broker_user_id=broker_user_id, + auth_state="VALID", + ) + _notify_broker_connected(user["username"], "GROWW", broker_user_id) + return { + "connected": True, + "broker": "GROWW", + "userName": user_name, + "brokerUserId": broker_user_id, + } + + +@router.post("/groww/reconnect") +async def groww_reconnect(request: Request): + user = _require_user(request) + entry = get_user_broker(user["id"]) or {} + if (entry.get("broker") or "").strip().upper() not in {"", "GROWW"}: + raise HTTPException(status_code=400, detail="Current broker is not Groww") + + creds = get_broker_credentials(user["id"]) + if not creds: + raise HTTPException(status_code=400, detail="Broker credentials not configured") + + try: + token_payload = generate_access_token(creds["api_key"], creds["api_secret"]) + access_token = _groww_access_token(token_payload) + if not access_token: + raise HTTPException(status_code=502, detail="Groww did not return an access token") + profile = fetch_groww_profile(access_token) + except GrowwApiError as exc: + _raise_groww_error(user["id"], exc) + + user_name = _groww_user_name(profile) or entry.get("user_name") + broker_user_id = _groww_user_id(profile) or entry.get("broker_user_id") + set_connected_broker( + user["id"], + "GROWW", + access_token, + api_key=creds["api_key"], + api_secret=creds["api_secret"], + user_name=user_name, + broker_user_id=broker_user_id, + auth_state="VALID", + ) + return { + "connected": True, + "broker": "GROWW", + "userName": user_name, + "brokerUserId": broker_user_id, + } + + @router.get("/zerodha/callback") async def zerodha_callback(request: Request, request_token: str = ""): user = _require_user(request) @@ -138,7 +459,7 @@ async def zerodha_callback(request: Request, request_token: str = ""): if not access_token: raise HTTPException(status_code=400, detail="Missing access token from Zerodha") - saved = set_session( + saved = set_zerodha_session( user["id"], { "api_key": api_key, @@ -205,7 +526,7 @@ async def broker_callback(request: Request, request_token: str = ""): if not access_token: raise HTTPException(status_code=400, detail="Missing access token from Zerodha") - set_session( + set_zerodha_session( user["id"], { "api_key": creds["api_key"], @@ -227,3 +548,108 @@ async def broker_callback(request: Request, request_token: str = ""): ) target_url = os.getenv("BROKER_DASHBOARD_URL") or "/dashboard?armed=false" return RedirectResponse(target_url) + + +@router.get("/holdings") +async def broker_holdings(request: Request): + user = _require_user(request) + _entry, broker_name = _resolve_connected_broker(user["id"]) + if broker_name == "ZERODHA": + session = get_zerodha_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + try: + data = fetch_zerodha_holdings(session["api_key"], session["access_token"]) + except KiteApiError as exc: + _raise_zerodha_error(user["id"], exc) + return {"broker": broker_name, "holdings": [normalize_zerodha_holding(item) for item in data]} + + if broker_name == "GROWW": + session = get_groww_session(user["id"]) + if not session or not session.get("access_token"): + raise HTTPException(status_code=400, detail="Groww is not connected") + try: + holdings = _fetch_normalized_groww_holdings(session["access_token"]) + except GrowwApiError as exc: + _raise_groww_error(user["id"], exc) + return {"broker": broker_name, "holdings": holdings} + + raise HTTPException(status_code=400, detail=f"Unsupported broker: {broker_name}") + + +@router.get("/funds") +async def broker_funds(request: Request): + user = _require_user(request) + _entry, broker_name = _resolve_connected_broker(user["id"]) + if broker_name == "ZERODHA": + session = get_zerodha_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + try: + data = fetch_zerodha_funds(session["api_key"], session["access_token"]) + except KiteApiError as exc: + _raise_zerodha_error(user["id"], exc) + equity = data.get("equity", {}) if isinstance(data, dict) else {} + return {"broker": broker_name, "funds": {**equity, "raw": data}} + + if broker_name == "GROWW": + session = get_groww_session(user["id"]) + if not session or not session.get("access_token"): + raise HTTPException(status_code=400, detail="Groww is not connected") + try: + data = fetch_groww_funds(session["access_token"]) + except GrowwApiError as exc: + _raise_groww_error(user["id"], exc) + return {"broker": broker_name, "funds": _normalize_groww_funds(data)} + + raise HTTPException(status_code=400, detail=f"Unsupported broker: {broker_name}") + + +@router.get("/equity-curve") +async def broker_equity_curve(request: Request, from_: str = Query("", alias="from")): + user = _require_user(request) + _entry, broker_name = _resolve_connected_broker(user["id"]) + + if broker_name == "ZERODHA": + session = get_zerodha_session(user["id"]) + if not session: + raise HTTPException(status_code=400, detail="Zerodha is not connected") + try: + holdings = [ + normalize_zerodha_holding(item) + for item in fetch_zerodha_holdings(session["api_key"], session["access_token"]) + ] + raw_funds = fetch_zerodha_funds(session["api_key"], session["access_token"]) + funds_data = {**(raw_funds.get("equity", {}) or {}), "raw": raw_funds} + except KiteApiError as exc: + _raise_zerodha_error(user["id"], exc) + elif broker_name == "GROWW": + session = get_groww_session(user["id"]) + if not session or not session.get("access_token"): + raise HTTPException(status_code=400, detail="Groww is not connected") + try: + holdings = _fetch_normalized_groww_holdings(session["access_token"]) + funds_data = _normalize_groww_funds(fetch_groww_funds(session["access_token"])) + except GrowwApiError as exc: + _raise_groww_error(user["id"], exc) + else: + raise HTTPException(status_code=400, detail=f"Unsupported broker: {broker_name}") + + capture_live_equity_snapshot( + user["id"], + holdings=holdings, + funds_data=funds_data, + ) + + now = datetime.utcnow() + default_start = (now - timedelta(days=90)).date() + if from_: + try: + start_date = datetime.fromisoformat(from_).date() + except ValueError: + start_date = default_start + else: + start_date = default_start + if start_date > now.date(): + start_date = now.date() + return get_live_equity_curve(user["id"], start_date=start_date) diff --git a/backend/app/services/groww_service.py b/backend/app/services/groww_service.py new file mode 100644 index 0000000..efa05b2 --- /dev/null +++ b/backend/app/services/groww_service.py @@ -0,0 +1,355 @@ +import hashlib +import json +import os +import time +import urllib.error +import urllib.parse +import urllib.request + + +GROWW_API_BASE = os.getenv("GROWW_API_BASE", "https://api.groww.in").rstrip("/") +GROWW_API_VERSION = os.getenv("GROWW_API_VERSION", "1.0") + + +class GrowwApiError(Exception): + def __init__(self, status_code: int, error_type: str, message: str): + super().__init__(f"Groww API error {status_code}: {error_type} - {message}") + self.status_code = status_code + self.error_type = error_type + self.message = message + + +class GrowwTokenError(GrowwApiError): + pass + + +class GrowwPermissionError(GrowwApiError): + pass + + +def _json_headers(extra: dict | None = None) -> dict: + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "X-API-VERSION": GROWW_API_VERSION, + } + if extra: + headers.update(extra) + return headers + + +def _request( + method: str, + url: str, + *, + data: dict | None = None, + headers: dict | None = None, +): + payload = None + if data is not None: + payload = json.dumps(data).encode("utf-8") + + req = urllib.request.Request( + url, + data=payload, + headers=headers or {}, + method=method, + ) + try: + with urllib.request.urlopen(req, timeout=20) as resp: + body = resp.read().decode("utf-8") + except urllib.error.HTTPError as err: + error_body = err.read().decode("utf-8") if err.fp else "" + try: + parsed = json.loads(error_body) if error_body else {} + except json.JSONDecodeError: + parsed = {} + + error = parsed.get("error") if isinstance(parsed.get("error"), dict) else {} + error_type = ( + error.get("code") + or parsed.get("error_code") + or parsed.get("error_type") + or parsed.get("status") + or "unknown_error" + ) + message = ( + error.get("message") + or parsed.get("message") + or parsed.get("detail") + or error_body + or err.reason + ) + + normalized_error = str(error_type).strip().lower() + exc_cls = GrowwApiError + if err.code in {401, 403} or "token" in normalized_error or "auth" in normalized_error: + exc_cls = GrowwTokenError + elif "permission" in normalized_error: + exc_cls = GrowwPermissionError + raise exc_cls(err.code, str(error_type), str(message)) from err + + if not body: + return {} + return json.loads(body) + + +def _first_data(payload: dict | None): + if not isinstance(payload, dict): + return payload + data = payload.get("data") + return data if data is not None else payload + + +def _auth_headers(access_token: str) -> dict: + return _json_headers({"Authorization": f"Bearer {access_token}"}) + + +def _api_key_headers(api_key: str) -> dict: + return _json_headers({"Authorization": f"Bearer {api_key}"}) + + +def _single_query_url(path: str, **params) -> str: + query = urllib.parse.urlencode( + [(key, value) for key, value in params.items() if value is not None and value != ""] + ) + if query: + return f"{GROWW_API_BASE}{path}?{query}" + return f"{GROWW_API_BASE}{path}" + + +def generate_access_token(api_key: str, api_secret: str) -> dict: + timestamp = str(int(time.time())) + checksum = hashlib.sha256(f"{api_secret}{timestamp}".encode("utf-8")).hexdigest() + response = _request( + "POST", + f"{GROWW_API_BASE}/v1/token/api/access", + data={ + "key_type": "approval", + "checksum": checksum, + "timestamp": timestamp, + }, + headers=_api_key_headers(api_key), + ) + return _first_data(response) or {} + + +def fetch_profile(access_token: str) -> dict: + response = _request( + "GET", + f"{GROWW_API_BASE}/v1/user/detail", + headers=_auth_headers(access_token), + ) + return _first_data(response) or {} + + +def fetch_holdings(access_token: str) -> list: + response = _request( + "GET", + f"{GROWW_API_BASE}/v1/holdings/user", + headers=_auth_headers(access_token), + ) + data = _first_data(response) + if isinstance(data, list): + return data + if isinstance(data, dict): + for key in ("holdings", "items", "records"): + if isinstance(data.get(key), list): + return data[key] + return [] + + +def fetch_positions(access_token: str) -> list: + response = _request( + "GET", + f"{GROWW_API_BASE}/v1/positions/user", + headers=_auth_headers(access_token), + ) + data = _first_data(response) + if isinstance(data, list): + return data + if isinstance(data, dict): + for key in ("positions", "items", "records"): + if isinstance(data.get(key), list): + return data[key] + return [] + + +def fetch_funds(access_token: str) -> dict: + response = _request( + "GET", + f"{GROWW_API_BASE}/v1/margins/detail/user", + headers=_auth_headers(access_token), + ) + return _first_data(response) or {} + + +def fetch_ltp(access_token: str, *, exchange: str, segment: str, trading_symbol: str) -> dict: + url = _single_query_url( + "/v1/live-data/ltp", + exchange=exchange, + segment=segment, + trading_symbol=trading_symbol, + ) + response = _request("GET", url, headers=_auth_headers(access_token)) + return _first_data(response) or {} + + +def place_order( + access_token: str, + *, + trading_symbol: str, + exchange: str, + segment: str, + transaction_type: str, + order_type: str, + quantity: int, + product: str, + validity: str = "DAY", + price: float | None = None, + trigger_price: float | None = None, + order_reference_id: str | None = None, +) -> dict: + payload = { + "trading_symbol": trading_symbol, + "quantity": int(quantity), + "validity": validity, + "exchange": exchange, + "segment": segment, + "product": product, + "order_type": order_type, + "transaction_type": transaction_type, + } + if price is not None: + payload["price"] = float(price) + if trigger_price is not None: + payload["trigger_price"] = float(trigger_price) + if order_reference_id: + payload["order_reference_id"] = order_reference_id + + response = _request( + "POST", + f"{GROWW_API_BASE}/v1/order/create", + data=payload, + headers=_auth_headers(access_token), + ) + return _first_data(response) or {} + + +def fetch_order_status(access_token: str, groww_order_id: str, *, segment: str = "CASH") -> dict: + url = _single_query_url( + f"/v1/order/status/{urllib.parse.quote(str(groww_order_id).strip())}", + segment=segment, + ) + response = _request("GET", url, headers=_auth_headers(access_token)) + return _first_data(response) or {} + + +def fetch_order_detail(access_token: str, groww_order_id: str, *, segment: str = "CASH") -> dict: + url = _single_query_url( + f"/v1/order/detail/{urllib.parse.quote(str(groww_order_id).strip())}", + segment=segment, + ) + response = _request("GET", url, headers=_auth_headers(access_token)) + return _first_data(response) or {} + + +def fetch_orders(access_token: str, *, segment: str = "CASH") -> list: + url = _single_query_url("/v1/order/list", segment=segment) + response = _request("GET", url, headers=_auth_headers(access_token)) + data = _first_data(response) + if isinstance(data, list): + return data + if isinstance(data, dict): + for key in ("orders", "items", "records"): + if isinstance(data.get(key), list): + return data[key] + return [] + + +def _first_float(*values, default: float = 0.0) -> float: + for value in values: + try: + if value is None or value == "": + continue + return float(value) + except (TypeError, ValueError): + continue + return float(default) + + +def _first_text(*values, default: str = "") -> str: + for value in values: + if value is None: + continue + text = str(value).strip() + if text: + return text + return default + + +def holding_quantity(item: dict | None) -> float: + entry = item or {} + return _first_float( + entry.get("quantity"), + entry.get("available_quantity"), + entry.get("net_quantity"), + default=0.0, + ) + + +def holding_average_price(item: dict | None) -> float: + entry = item or {} + return _first_float(entry.get("average_price"), entry.get("avg_price"), default=0.0) + + +def holding_last_price(item: dict | None) -> float: + entry = item or {} + return _first_float( + entry.get("last_price"), + entry.get("ltp"), + entry.get("close_price"), + entry.get("average_price"), + default=0.0, + ) + + +def normalize_holding(item: dict | None) -> dict: + entry = dict(item or {}) + quantity = holding_quantity(entry) + average_price = holding_average_price(entry) + last_price = holding_last_price(entry) + tradingsymbol = _first_text( + entry.get("trading_symbol"), + entry.get("tradingsymbol"), + entry.get("symbol"), + entry.get("instrument_name"), + default="", + ) + exchange = _first_text( + entry.get("exchange"), + entry.get("exchange_segment"), + entry.get("exchange_name"), + default="NSE", + ).upper() + segment = _first_text(entry.get("segment"), entry.get("product_segment"), default="CASH").upper() + symbol = tradingsymbol + if tradingsymbol and not tradingsymbol.endswith((".NS", ".BO")): + if exchange == "NSE": + symbol = f"{tradingsymbol}.NS" + elif exchange == "BSE": + symbol = f"{tradingsymbol}.BO" + entry["settled_quantity"] = quantity + entry["t1_quantity"] = 0.0 + entry["effective_quantity"] = quantity + entry["quantity"] = quantity + entry["average_price"] = average_price + entry["last_price"] = last_price + entry["close_price"] = last_price + entry["exchange"] = exchange + entry["segment"] = segment + entry["tradingsymbol"] = tradingsymbol + entry["symbol"] = symbol + entry["display_pnl"] = quantity * (last_price - average_price) + entry["holding_value"] = quantity * last_price + return entry diff --git a/backend/app/services/groww_storage.py b/backend/app/services/groww_storage.py new file mode 100644 index 0000000..ec76968 --- /dev/null +++ b/backend/app/services/groww_storage.py @@ -0,0 +1,30 @@ +from app.services.crypto_service import decrypt_value +from app.services.db import db_transaction + + +def get_session(user_id: str): + with db_transaction() as cur: + cur.execute( + """ + SELECT broker, connected, access_token, api_key, user_name, broker_user_id, connected_at + FROM user_broker + WHERE user_id = %s + LIMIT 1 + """, + (user_id,), + ) + row = cur.fetchone() + if not row: + return None + broker, connected, access_token, api_key, user_name, broker_user_id, connected_at = row + if not connected or not access_token: + return None + if (broker or "").strip().upper() != "GROWW": + return None + return { + "api_key": api_key, + "access_token": decrypt_value(access_token), + "user_name": user_name, + "broker_user_id": broker_user_id, + "linked_at": connected_at, + } diff --git a/backend/app/services/live_equity_service.py b/backend/app/services/live_equity_service.py index 7886727..2cff204 100644 --- a/backend/app/services/live_equity_service.py +++ b/backend/app/services/live_equity_service.py @@ -5,15 +5,24 @@ from datetime import date, datetime, timedelta, timezone from decimal import Decimal from zoneinfo import ZoneInfo +from app.broker_store import get_user_broker from app.services.db import db_connection +from app.services.groww_service import ( + GrowwApiError, + fetch_funds as fetch_groww_funds, + fetch_holdings as fetch_groww_holdings, + normalize_holding as normalize_groww_holding, +) +from app.services.groww_storage import get_session as get_groww_session from app.services.zerodha_service import ( KiteApiError, - fetch_funds, - fetch_holdings, + fetch_funds as fetch_zerodha_funds, + fetch_holdings as fetch_zerodha_holdings, holding_effective_quantity, holding_last_price, + normalize_holding as normalize_zerodha_holding, ) -from app.services.zerodha_storage import get_session +from app.services.zerodha_storage import get_session as get_zerodha_session IST = ZoneInfo("Asia/Calcutta") AUTO_SNAPSHOT_AFTER_HOUR = int(os.getenv("LIVE_EQUITY_SNAPSHOT_HOUR", "15")) @@ -72,6 +81,57 @@ def _extract_holdings_value(holdings: list[dict] | None) -> float: return total +def _normalize_groww_funds(data: dict | None) -> dict: + payload = data if isinstance(data, dict) else {} + available = payload.get("available") if isinstance(payload.get("available"), dict) else {} + equity = payload.get("equity") if isinstance(payload.get("equity"), dict) else {} + equity_available = equity.get("available") if isinstance(equity.get("available"), dict) else {} + + cash = _first_numeric( + payload.get("cash"), + payload.get("available_cash"), + payload.get("available_balance"), + available.get("cash"), + available.get("available_cash"), + available.get("balance"), + equity.get("cash"), + equity_available.get("cash"), + equity_available.get("live_balance"), + ) + net = _first_numeric( + payload.get("net"), + payload.get("total"), + payload.get("margin_available"), + equity.get("net"), + cash, + ) + withdrawable = _first_numeric( + payload.get("withdrawable"), + payload.get("available_to_withdraw"), + available.get("withdrawable"), + cash, + ) + balance = _first_numeric( + payload.get("balance"), + payload.get("available_balance"), + available.get("balance"), + cash, + ) + + return { + "net": net, + "cash": cash, + "withdrawable": withdrawable, + "balance": balance, + "available": { + "live_balance": cash, + "cash": cash, + "opening_balance": balance, + }, + "raw": payload, + } + + def _upsert_snapshot( *, user_id: str, @@ -126,15 +186,44 @@ def capture_live_equity_snapshot( funds_data: dict | None = None, captured_at: datetime | None = None, ): - session = get_session(user_id) - if not session: - return None + broker_state = get_user_broker(user_id) or {} + broker_name = (broker_state.get("broker") or "").strip().upper() captured_at = captured_at or _now_utc() if holdings is None: - holdings = fetch_holdings(session["api_key"], session["access_token"]) + if broker_name == "ZERODHA": + session = get_zerodha_session(user_id) + if not session: + return None + holdings = [ + normalize_zerodha_holding(item) + for item in fetch_zerodha_holdings(session["api_key"], session["access_token"]) + ] + elif broker_name == "GROWW": + session = get_groww_session(user_id) + if not session: + return None + holdings = [ + normalize_groww_holding(item) + for item in fetch_groww_holdings(session["access_token"]) + ] + else: + return None if funds_data is None: - funds_data = fetch_funds(session["api_key"], session["access_token"]) + if broker_name == "ZERODHA": + session = get_zerodha_session(user_id) + if not session: + return None + raw_funds = fetch_zerodha_funds(session["api_key"], session["access_token"]) + equity = raw_funds.get("equity", {}) if isinstance(raw_funds, dict) else {} + funds_data = {**equity, "raw": raw_funds} + elif broker_name == "GROWW": + session = get_groww_session(user_id) + if not session: + return None + funds_data = _normalize_groww_funds(fetch_groww_funds(session["access_token"])) + else: + return None cash_value = _extract_cash_value(funds_data) holdings_value = _extract_holdings_value(holdings) @@ -187,18 +276,18 @@ def get_live_equity_curve(user_id: str, *, start_date: date | None = None): } -def _list_connected_zerodha_users() -> list[str]: +def _list_connected_live_brokers() -> list[tuple[str, str]]: with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ - SELECT user_id + SELECT user_id, UPPER(COALESCE(broker, '')) FROM user_broker WHERE connected = TRUE - AND UPPER(COALESCE(broker, '')) = 'ZERODHA' + AND UPPER(COALESCE(broker, '')) IN ('ZERODHA', 'GROWW') """ ) - return [row[0] for row in cur.fetchall()] + return [(row[0], row[1]) for row in cur.fetchall()] def _should_auto_snapshot(now_local: datetime) -> bool: @@ -222,11 +311,13 @@ def _run_auto_snapshot_cycle(): if not _should_auto_snapshot(now_local): return - for user_id in _list_connected_zerodha_users(): + for user_id, _broker_name in _list_connected_live_brokers(): try: capture_live_equity_snapshot(user_id) except KiteApiError: continue + except GrowwApiError: + continue except Exception: continue diff --git a/backend/app/services/strategy_service.py b/backend/app/services/strategy_service.py index 3f4bb6d..528f121 100644 --- a/backend/app/services/strategy_service.py +++ b/backend/app/services/strategy_service.py @@ -27,11 +27,13 @@ from app.services.run_service import ( ) from app.services.auth_service import get_user_by_id from app.services.email_service import send_email_async +from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds +from app.services.groww_storage import get_session as get_groww_session from app.services.zerodha_service import ( KiteTokenError, - fetch_funds, + fetch_funds as fetch_zerodha_funds, ) -from app.services.zerodha_storage import get_session +from app.services.zerodha_storage import get_session as get_zerodha_session from psycopg2.extras import Json from psycopg2 import errors @@ -327,13 +329,44 @@ def validate_frequency(freq: dict, mode: str): def _validate_live_broker_session(user_id: str): broker_state = get_user_broker(user_id) or {} broker_name = (broker_state.get("broker") or "").strip().upper() - if not broker_state.get("connected") or broker_name != "ZERODHA": + if not broker_state.get("connected") or broker_name not in {"ZERODHA", "GROWW"}: return False, broker_state, "broker_not_connected" + if broker_name == "ZERODHA": + try: + session = get_zerodha_session(user_id) + except Exception as exc: + print(f"[STRATEGY] failed to load Zerodha session for {user_id}: {exc}", flush=True) + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + if not session: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + api_key = str(session.get("api_key") or "").strip() + access_token = str(session.get("access_token") or "").strip() + if not api_key or not access_token: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + try: + fetch_zerodha_funds(api_key, access_token) + except KiteTokenError: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + except Exception as exc: + print(f"[STRATEGY] failed to validate Zerodha session for {user_id}: {exc}", flush=True) + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + + set_broker_auth_state(user_id, "VALID") + return True, broker_state, "ok" + try: - session = get_session(user_id) + session = get_groww_session(user_id) except Exception as exc: - print(f"[STRATEGY] failed to load Zerodha session for {user_id}: {exc}", flush=True) + print(f"[STRATEGY] failed to load Groww session for {user_id}: {exc}", flush=True) set_broker_auth_state(user_id, "EXPIRED") return False, broker_state, "broker_auth_required" @@ -341,19 +374,22 @@ def _validate_live_broker_session(user_id: str): set_broker_auth_state(user_id, "EXPIRED") return False, broker_state, "broker_auth_required" - api_key = str(session.get("api_key") or "").strip() access_token = str(session.get("access_token") or "").strip() - if not api_key or not access_token: + if not access_token: set_broker_auth_state(user_id, "EXPIRED") return False, broker_state, "broker_auth_required" try: - fetch_funds(api_key, access_token) - except KiteTokenError: + fetch_groww_funds(access_token) + except GrowwTokenError: + set_broker_auth_state(user_id, "EXPIRED") + return False, broker_state, "broker_auth_required" + except GrowwApiError as exc: + print(f"[STRATEGY] failed to validate Groww session for {user_id}: {exc}", flush=True) set_broker_auth_state(user_id, "EXPIRED") return False, broker_state, "broker_auth_required" except Exception as exc: - print(f"[STRATEGY] failed to validate Zerodha session for {user_id}: {exc}", flush=True) + print(f"[STRATEGY] failed to validate Groww session for {user_id}: {exc}", flush=True) set_broker_auth_state(user_id, "EXPIRED") return False, broker_state, "broker_auth_required" diff --git a/backend/app/services/system_service.py b/backend/app/services/system_service.py index 8681e12..5253042 100644 --- a/backend/app/services/system_service.py +++ b/backend/app/services/system_service.py @@ -7,10 +7,12 @@ from psycopg2.extras import Json from app.broker_store import get_user_broker, set_broker_auth_state from app.services.db import db_connection +from app.services.groww_service import GrowwApiError, GrowwTokenError, fetch_funds as fetch_groww_funds +from app.services.groww_storage import get_session as get_groww_session from app.services.run_lifecycle import RunLifecycleError, RunLifecycleManager from app.services.strategy_service import compute_next_eligible, resume_running_runs -from app.services.zerodha_service import KiteTokenError, fetch_funds -from app.services.zerodha_storage import get_session +from app.services.zerodha_service import KiteTokenError, fetch_funds as fetch_zerodha_funds +from app.services.zerodha_storage import get_session as get_zerodha_session def _hash_value(value: str | None) -> str | None: @@ -66,14 +68,29 @@ def _parse_ts(value: str | None): def _validate_broker_session(user_id: str): - session = get_session(user_id) - if not session: + broker_state = get_user_broker(user_id) or {} + broker_name = (broker_state.get("broker") or "").strip().upper() + if broker_name not in {"ZERODHA", "GROWW"}: return False if os.getenv("BROKER_VALIDATION_MODE", "").strip().lower() == "skip": return True + if broker_name == "ZERODHA": + session = get_zerodha_session(user_id) + if not session: + return False + try: + fetch_zerodha_funds(session["api_key"], session["access_token"]) + except KiteTokenError: + set_broker_auth_state(user_id, "EXPIRED") + return False + return True + + session = get_groww_session(user_id) + if not session: + return False try: - fetch_funds(session["api_key"], session["access_token"]) - except KiteTokenError: + fetch_groww_funds(session["access_token"]) + except (GrowwTokenError, GrowwApiError): set_broker_auth_state(user_id, "EXPIRED") return False return True diff --git a/indian_paper_trading_strategy/engine/broker.py b/indian_paper_trading_strategy/engine/broker.py index 21f61e6..32f0f6c 100644 --- a/indian_paper_trading_strategy/engine/broker.py +++ b/indian_paper_trading_strategy/engine/broker.py @@ -462,6 +462,394 @@ class LiveZerodhaBroker(Broker): ) +class LiveGrowwBroker(Broker): + external_orders = True + + FILLED_STATUSES = {"EXECUTED", "DELIVERY_AWAITED", "COMPLETED"} + REJECTED_STATUSES = {"REJECTED", "FAILED"} + CANCELLED_STATUSES = {"CANCELLED", "CANCELLATION_REQUESTED"} + TERMINAL_STATUSES = FILLED_STATUSES | REJECTED_STATUSES | CANCELLED_STATUSES + POLL_TIMEOUT_SECONDS = float(os.getenv("GROWW_ORDER_POLL_TIMEOUT", "15")) + POLL_INTERVAL_SECONDS = float(os.getenv("GROWW_ORDER_POLL_INTERVAL", "1")) + + def __init__(self, user_id: str | None = None, run_id: str | None = None): + self.user_id = user_id + self.run_id = run_id + + def _scope(self): + return _resolve_scope(self.user_id, self.run_id) + + def _session(self): + from app.services.groww_storage import get_session + + user_id, _run_id = self._scope() + session = get_session(user_id) + if not session or not session.get("access_token"): + raise BrokerAuthExpired("Groww session missing. Please reconnect broker.") + return session + + def _raise_auth_expired(self, exc: Exception): + from app.broker_store import expire_user_broker_session + + user_id, _run_id = self._scope() + expire_user_broker_session(user_id) + raise BrokerAuthExpired(str(exc)) from exc + + def _normalize_symbol(self, symbol: str) -> tuple[str, str, str]: + cleaned = (symbol or "").strip().upper() + if cleaned.endswith(".NS"): + return cleaned[:-3], "NSE", "CASH" + if cleaned.endswith(".BO"): + return cleaned[:-3], "BSE", "CASH" + return cleaned, "NSE", "CASH" + + def _make_reference_id(self, logical_time: datetime | None, symbol: str, side: str) -> str: + user_id, run_id = self._scope() + logical_ts = logical_time or datetime.utcnow().replace(tzinfo=timezone.utc) + digest = hashlib.sha1( + f"{user_id}|{run_id}|{_normalize_ts_for_id(logical_ts)}|{symbol}|{side}".encode("utf-8") + ).hexdigest()[:18] + return f"qfg{digest}" + + def _first_text(self, *values, default: str = "") -> str: + for value in values: + if value is None: + continue + text = str(value).strip() + if text: + return text + return default + + def _first_float(self, *values, default: float = 0.0) -> float: + for value in values: + try: + if value is None or value == "": + continue + return float(value) + except (TypeError, ValueError): + continue + return float(default) + + def _extract_order_id(self, payload: dict | None) -> str: + entry = payload or {} + return self._first_text( + entry.get("groww_order_id"), + entry.get("order_id"), + entry.get("id"), + default="", + ) + + def _normalize_order_payload( + self, + *, + order_id: str, + symbol: str, + side: str, + requested_qty: int, + requested_price: float | None, + order_entry: dict | None, + logical_time: datetime | None, + ) -> dict: + entry = order_entry or {} + raw_status = self._first_text( + entry.get("order_status"), + entry.get("status"), + entry.get("state"), + default="", + ).upper() + if raw_status in self.FILLED_STATUSES: + status = "FILLED" + elif raw_status in self.REJECTED_STATUSES: + status = "REJECTED" + elif raw_status in self.CANCELLED_STATUSES: + status = "CANCELLED" + else: + status = "PENDING" + + quantity = int(self._first_float(entry.get("quantity"), requested_qty, default=0)) + filled_qty = int( + self._first_float( + entry.get("filled_quantity"), + entry.get("executed_quantity"), + entry.get("filled_qty"), + default=0, + ) + ) + average_price = self._first_float( + entry.get("average_price"), + entry.get("avg_price"), + entry.get("average_execution_price"), + requested_price, + default=0.0, + ) + price = self._first_float(entry.get("price"), requested_price, average_price, default=0.0) + timestamp = self._first_text( + entry.get("order_timestamp"), + entry.get("timestamp"), + entry.get("updated_at"), + entry.get("created_at"), + default=_format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)) or "", + ) + if timestamp and " " in timestamp: + timestamp = timestamp.replace(" ", "T") + + return { + "id": order_id, + "symbol": symbol, + "side": side.upper().strip(), + "qty": quantity, + "requested_qty": quantity, + "filled_qty": filled_qty, + "price": price, + "requested_price": float(requested_price or price or 0.0), + "average_price": average_price, + "status": status, + "timestamp": timestamp, + "broker_order_id": order_id, + "exchange": self._first_text(entry.get("exchange"), default=None), + "tradingsymbol": self._first_text( + entry.get("trading_symbol"), + entry.get("tradingsymbol"), + entry.get("symbol"), + default=None, + ), + "status_message": self._first_text( + entry.get("remark"), + entry.get("status_message"), + entry.get("message"), + entry.get("error_message"), + default=None, + ), + } + + def _wait_for_terminal_order( + self, + session: dict, + order_id: str, + *, + symbol: str, + side: str, + requested_qty: int, + requested_price: float | None, + logical_time: datetime | None, + segment: str, + ) -> dict: + from app.services.groww_service import ( + GrowwApiError, + GrowwTokenError, + fetch_order_detail, + fetch_order_status, + ) + + started = time.monotonic() + last_payload = self._normalize_order_payload( + order_id=order_id, + symbol=symbol, + side=side, + requested_qty=requested_qty, + requested_price=requested_price, + order_entry=None, + logical_time=logical_time, + ) + + while True: + try: + detail = fetch_order_detail(session["access_token"], order_id, segment=segment) + status_payload = fetch_order_status(session["access_token"], order_id, segment=segment) + merged = {} + if isinstance(detail, dict): + merged.update(detail) + if isinstance(status_payload, dict): + merged.update(status_payload) + except GrowwTokenError as exc: + self._raise_auth_expired(exc) + except GrowwApiError as exc: + merged = { + "groww_order_id": order_id, + "order_status": "FAILED", + "remark": getattr(exc, "message", str(exc)), + } + + last_payload = self._normalize_order_payload( + order_id=order_id, + symbol=symbol, + side=side, + requested_qty=requested_qty, + requested_price=requested_price, + order_entry=merged, + logical_time=logical_time, + ) + raw_status = self._first_text( + merged.get("order_status"), + merged.get("status"), + merged.get("state"), + default="", + ).upper() + if raw_status in self.TERMINAL_STATUSES: + return last_payload + + if time.monotonic() - started >= self.POLL_TIMEOUT_SECONDS: + return last_payload + + time.sleep(self.POLL_INTERVAL_SECONDS) + + def get_funds(self, cur=None): + from app.services.groww_service import GrowwTokenError, fetch_funds + + session = self._session() + try: + data = fetch_funds(session["access_token"]) + except GrowwTokenError as exc: + self._raise_auth_expired(exc) + + available = data.get("available") if isinstance(data.get("available"), dict) else {} + equity = data.get("equity") if isinstance(data.get("equity"), dict) else {} + equity_available = equity.get("available") if isinstance(equity.get("available"), dict) else {} + cash = self._first_float( + data.get("cash"), + data.get("available_cash"), + data.get("available_balance"), + available.get("cash"), + available.get("available_cash"), + available.get("balance"), + equity.get("cash"), + equity_available.get("cash"), + equity_available.get("live_balance"), + default=0.0, + ) + return {"cash": float(cash), "raw": data} + + def get_positions(self): + from app.services.groww_service import GrowwTokenError, fetch_holdings, normalize_holding + + session = self._session() + try: + holdings = fetch_holdings(session["access_token"]) + except GrowwTokenError as exc: + self._raise_auth_expired(exc) + + normalized = [] + for item in holdings: + entry = normalize_holding(item) + normalized.append( + { + "symbol": entry.get("symbol"), + "qty": float(entry.get("effective_quantity") or 0.0), + "avg_price": float(entry.get("average_price") or 0.0), + "last_price": float(entry.get("last_price") or 0.0), + } + ) + return normalized + + def get_orders(self): + from app.services.groww_service import GrowwTokenError, fetch_orders + + session = self._session() + try: + return fetch_orders(session["access_token"]) + except GrowwTokenError as exc: + self._raise_auth_expired(exc) + + def place_order( + self, + symbol: str, + side: str, + quantity: float, + price: float | None = None, + cur=None, + logical_time: datetime | None = None, + user_id: str | None = None, + run_id: str | None = None, + ): + from app.services.groww_service import GrowwApiError, GrowwTokenError, place_order + + if user_id is not None: + self.user_id = user_id + if run_id is not None: + self.run_id = run_id + + qty = int(math.floor(float(quantity))) + side = side.upper().strip() + requested_price = float(price) if price is not None else None + if qty <= 0: + return { + "id": _deterministic_id("groww_rej", [symbol, side, _stable_num(quantity)]), + "symbol": symbol, + "side": side, + "qty": qty, + "requested_qty": qty, + "filled_qty": 0, + "price": float(price or 0.0), + "requested_price": float(price or 0.0), + "average_price": 0.0, + "status": "REJECTED", + "timestamp": _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)), + "status_message": "Computed quantity is less than 1 share", + } + + session = self._session() + trading_symbol, exchange, segment = self._normalize_symbol(symbol) + order_reference_id = self._make_reference_id(logical_time, symbol, side) + rejected_timestamp = _format_utc_ts(logical_time or datetime.utcnow().replace(tzinfo=timezone.utc)) + + try: + placed = place_order( + session["access_token"], + trading_symbol=trading_symbol, + exchange=exchange, + segment=segment, + transaction_type=side, + order_type="MARKET", + quantity=qty, + product="CNC", + validity="DAY", + price=requested_price, + order_reference_id=order_reference_id, + ) + except GrowwTokenError as exc: + self._raise_auth_expired(exc) + except GrowwApiError as exc: + return { + "id": _deterministic_id( + "groww_rej", + [ + symbol, + side, + _stable_num(quantity), + _stable_num(requested_price or 0.0), + getattr(exc, "error_type", "groww_error"), + ], + ), + "symbol": symbol, + "side": side, + "qty": qty, + "requested_qty": qty, + "filled_qty": 0, + "price": float(requested_price or 0.0), + "requested_price": float(requested_price or 0.0), + "average_price": 0.0, + "status": "REJECTED", + "timestamp": rejected_timestamp, + "status_message": getattr(exc, "message", str(exc)), + "error_type": getattr(exc, "error_type", None), + } + + order_id = self._extract_order_id(placed) + if not order_id: + raise BrokerError("Groww order placement did not return an order id") + + return self._wait_for_terminal_order( + session, + order_id, + symbol=symbol, + side=side, + requested_qty=qty, + requested_price=requested_price, + logical_time=logical_time, + segment=segment, + ) + + @dataclass class PaperBroker(Broker): initial_cash: float diff --git a/indian_paper_trading_strategy/engine/runner.py b/indian_paper_trading_strategy/engine/runner.py index 40c6784..7410309 100644 --- a/indian_paper_trading_strategy/engine/runner.py +++ b/indian_paper_trading_strategy/engine/runner.py @@ -7,7 +7,12 @@ from psycopg2.extras import Json from indian_paper_trading_strategy.engine.market import is_market_open, align_to_market_open, market_now from indian_paper_trading_strategy.engine.execution import try_execute_sip -from indian_paper_trading_strategy.engine.broker import PaperBroker, LiveZerodhaBroker, BrokerAuthExpired +from indian_paper_trading_strategy.engine.broker import ( + BrokerAuthExpired, + LiveGrowwBroker, + LiveZerodhaBroker, + PaperBroker, +) from indian_paper_trading_strategy.engine.mtm import log_mtm, should_log_mtm from indian_paper_trading_strategy.engine.state import load_state from indian_paper_trading_strategy.engine.data import fetch_live_price @@ -266,6 +271,8 @@ def _engine_loop(config, stop_event: threading.Event): ) elif broker_type == "zerodha": broker = LiveZerodhaBroker(user_id=scope_user, run_id=scope_run) + elif broker_type == "groww": + broker = LiveGrowwBroker(user_id=scope_user, run_id=scope_run) else: raise ValueError(f"Unsupported broker: {broker_type}") market_data_provider = "yfinance"