import hashlib import os import re import secrets from datetime import datetime, timedelta, timezone from uuid import uuid4 from argon2 import PasswordHasher from argon2.exceptions import InvalidHash, VerifyMismatchError from app.services.db import db_connection SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) SESSION_REFRESH_WINDOW_SECONDS = int( os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60)) ) RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10")) PASSWORD_HASHER = PasswordHasher() LEGACY_SHA256_RE = re.compile(r"^[0-9a-f]{64}$") def _get_reset_otp_secret() -> str: secret = (os.getenv("RESET_OTP_SECRET") or "").strip() if not secret: raise RuntimeError("RESET_OTP_SECRET is not configured on this server") return secret def _now_utc() -> datetime: return datetime.now(timezone.utc) def _new_expiry(now: datetime) -> datetime: return now + timedelta(seconds=SESSION_TTL_SECONDS) def _hash_password(password: str) -> str: return PASSWORD_HASHER.hash(password) def _hash_password_legacy(password: str) -> str: return hashlib.sha256(password.encode("utf-8")).hexdigest() def _is_legacy_password_hash(password_hash: str | None) -> bool: return bool(password_hash and LEGACY_SHA256_RE.fullmatch(password_hash)) def _hash_otp(email: str, otp: str) -> str: payload = f"{email}:{otp}:{_get_reset_otp_secret()}" return hashlib.sha256(payload.encode("utf-8")).hexdigest() def _row_to_user(row): if not row: return None return { "id": row[0], "username": row[1], "password": row[2], "role": row[3] if len(row) > 3 else None, } def get_user_by_username(username: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE username = %s", (username,), ) return _row_to_user(cur.fetchone()) def get_user_by_id(user_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", (user_id,), ) return _row_to_user(cur.fetchone()) def create_user(username: str, password: str): user_id = str(uuid4()) password_hash = _hash_password(password) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO app_user (id, username, password_hash, role) VALUES (%s, %s, %s, 'USER') ON CONFLICT (username) DO NOTHING RETURNING id, username, password_hash, role """, (user_id, username, password_hash), ) return _row_to_user(cur.fetchone()) def _update_password_hash(user_id: str, password_hash: str): with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( "UPDATE app_user SET password_hash = %s WHERE id = %s", (password_hash, user_id), ) def _verify_password(user_id: str, stored_hash: str | None, password: str) -> tuple[bool, str | None]: if not stored_hash: return False, None if _is_legacy_password_hash(stored_hash): if secrets.compare_digest(stored_hash, _hash_password_legacy(password)): return True, _hash_password(password) return False, None try: verified = PASSWORD_HASHER.verify(stored_hash, password) except (VerifyMismatchError, InvalidHash): return False, None if not verified: return False, None if PASSWORD_HASHER.check_needs_rehash(stored_hash): return True, _hash_password(password) return True, None def authenticate_user(username: str, password: str): user = get_user_by_username(username) if not user: return None verified, replacement_hash = _verify_password(user["id"], user.get("password"), password) if not verified: return None if replacement_hash: _update_password_hash(user["id"], replacement_hash) user["password"] = replacement_hash return user def verify_user(username: str, password: str): return authenticate_user(username, password) def create_session(user_id: str, ip: str | None = None, user_agent: str | None = None) -> str: session_id = str(uuid4()) now = _now_utc() with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO app_session (id, user_id, created_at, last_seen_at, expires_at, ip, user_agent) VALUES (%s, %s, %s, %s, %s, %s, %s) """, (session_id, user_id, now, now, _new_expiry(now), ip, user_agent), ) return session_id def get_last_session_meta(user_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT ip, user_agent FROM app_session WHERE user_id = %s ORDER BY created_at DESC LIMIT 1 """, (user_id,), ) row = cur.fetchone() if not row: return {"ip": None, "user_agent": None} return {"ip": row[0], "user_agent": row[1]} def update_user_password(user_id: str, new_password: str): password_hash = _hash_password(new_password) _update_password_hash(user_id, password_hash) def create_password_reset_otp(email: str): otp = f"{secrets.randbelow(10000):04d}" now = _now_utc() expires_at = now + timedelta(minutes=RESET_OTP_TTL_MINUTES) otp_hash = _hash_otp(email, otp) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO password_reset_otp (id, email, otp_hash, created_at, expires_at, used_at) VALUES (%s, %s, %s, %s, %s, NULL) """, (str(uuid4()), email, otp_hash, now, expires_at), ) return otp def consume_password_reset_otp(email: str, otp: str) -> bool: now = _now_utc() otp_hash = _hash_otp(email, otp) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ SELECT id FROM password_reset_otp WHERE email = %s AND otp_hash = %s AND used_at IS NULL AND expires_at > %s ORDER BY created_at DESC LIMIT 1 """, (email, otp_hash, now), ) row = cur.fetchone() if not row: return False cur.execute( "UPDATE password_reset_otp SET used_at = %s WHERE id = %s", (now, row[0]), ) return True def get_session(session_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT id, user_id, created_at, last_seen_at, expires_at FROM app_session WHERE id = %s """, (session_id,), ) row = cur.fetchone() if not row: return None created_at = row[2].isoformat() if row[2] else None last_seen_at = row[3].isoformat() if row[3] else None expires_at = row[4].isoformat() if row[4] else None return { "id": row[0], "user_id": row[1], "created_at": created_at, "last_seen_at": last_seen_at, "expires_at": expires_at, } def delete_session(session_id: str): with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) def get_user_for_session(session_id: str): if not session_id: return None now = _now_utc() with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ DELETE FROM app_session WHERE expires_at IS NOT NULL AND expires_at <= %s """, (now,), ) cur.execute( """ SELECT id, user_id, created_at, last_seen_at, expires_at FROM app_session WHERE id = %s """, (session_id,), ) row = cur.fetchone() if not row: return None expires_at = row[4] if expires_at is None: new_expiry = _new_expiry(now) cur.execute( """ UPDATE app_session SET expires_at = %s, last_seen_at = %s WHERE id = %s """, (new_expiry, now, session_id), ) expires_at = new_expiry if expires_at <= now: cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) return None if (expires_at - now).total_seconds() <= SESSION_REFRESH_WINDOW_SECONDS: new_expiry = _new_expiry(now) cur.execute( """ UPDATE app_session SET expires_at = %s, last_seen_at = %s WHERE id = %s """, (new_expiry, now, session_id), ) cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", (row[1],), ) return _row_to_user(cur.fetchone())