import hashlib import os import secrets from datetime import datetime, timedelta, timezone from uuid import uuid4 from app.services.db import db_connection SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) SESSION_REFRESH_WINDOW_SECONDS = int( os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60)) ) RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10")) RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret") def _now_utc() -> datetime: return datetime.now(timezone.utc) def _new_expiry(now: datetime) -> datetime: return now + timedelta(seconds=SESSION_TTL_SECONDS) def _hash_password(password: str) -> str: return hashlib.sha256(password.encode("utf-8")).hexdigest() def _hash_otp(email: str, otp: str) -> str: payload = f"{email}:{otp}:{RESET_OTP_SECRET}" return hashlib.sha256(payload.encode("utf-8")).hexdigest() def _row_to_user(row): if not row: return None return { "id": row[0], "username": row[1], "password": row[2], "role": row[3] if len(row) > 3 else None, } def get_user_by_username(username: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE username = %s", (username,), ) return _row_to_user(cur.fetchone()) def get_user_by_id(user_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", (user_id,), ) return _row_to_user(cur.fetchone()) def create_user(username: str, password: str): user_id = str(uuid4()) password_hash = _hash_password(password) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO app_user (id, username, password_hash, role) VALUES (%s, %s, %s, 'USER') ON CONFLICT (username) DO NOTHING RETURNING id, username, password_hash, role """, (user_id, username, password_hash), ) return _row_to_user(cur.fetchone()) def authenticate_user(username: str, password: str): user = get_user_by_username(username) if not user: return None if user.get("password") != _hash_password(password): return None return user def verify_user(username: str, password: str): return authenticate_user(username, password) def create_session(user_id: str, ip: str | None = None, user_agent: str | None = None) -> str: session_id = str(uuid4()) now = _now_utc() with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO app_session (id, user_id, created_at, last_seen_at, expires_at, ip, user_agent) VALUES (%s, %s, %s, %s, %s, %s, %s) """, (session_id, user_id, now, now, _new_expiry(now), ip, user_agent), ) return session_id def get_last_session_meta(user_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT ip, user_agent FROM app_session WHERE user_id = %s ORDER BY created_at DESC LIMIT 1 """, (user_id,), ) row = cur.fetchone() if not row: return {"ip": None, "user_agent": None} return {"ip": row[0], "user_agent": row[1]} def update_user_password(user_id: str, new_password: str): password_hash = _hash_password(new_password) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( "UPDATE app_user SET password_hash = %s WHERE id = %s", (password_hash, user_id), ) def create_password_reset_otp(email: str): otp = f"{secrets.randbelow(10000):04d}" now = _now_utc() expires_at = now + timedelta(minutes=RESET_OTP_TTL_MINUTES) otp_hash = _hash_otp(email, otp) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ INSERT INTO password_reset_otp (id, email, otp_hash, created_at, expires_at, used_at) VALUES (%s, %s, %s, %s, %s, NULL) """, (str(uuid4()), email, otp_hash, now, expires_at), ) return otp def consume_password_reset_otp(email: str, otp: str) -> bool: now = _now_utc() otp_hash = _hash_otp(email, otp) with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ SELECT id FROM password_reset_otp WHERE email = %s AND otp_hash = %s AND used_at IS NULL AND expires_at > %s ORDER BY created_at DESC LIMIT 1 """, (email, otp_hash, now), ) row = cur.fetchone() if not row: return False cur.execute( "UPDATE password_reset_otp SET used_at = %s WHERE id = %s", (now, row[0]), ) return True def get_session(session_id: str): with db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT id, user_id, created_at, last_seen_at, expires_at FROM app_session WHERE id = %s """, (session_id,), ) row = cur.fetchone() if not row: return None created_at = row[2].isoformat() if row[2] else None last_seen_at = row[3].isoformat() if row[3] else None expires_at = row[4].isoformat() if row[4] else None return { "id": row[0], "user_id": row[1], "created_at": created_at, "last_seen_at": last_seen_at, "expires_at": expires_at, } def delete_session(session_id: str): with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) def get_user_for_session(session_id: str): if not session_id: return None now = _now_utc() with db_connection() as conn: with conn: with conn.cursor() as cur: cur.execute( """ DELETE FROM app_session WHERE expires_at IS NOT NULL AND expires_at <= %s """, (now,), ) cur.execute( """ SELECT id, user_id, created_at, last_seen_at, expires_at FROM app_session WHERE id = %s """, (session_id,), ) row = cur.fetchone() if not row: return None expires_at = row[4] if expires_at is None: new_expiry = _new_expiry(now) cur.execute( """ UPDATE app_session SET expires_at = %s, last_seen_at = %s WHERE id = %s """, (new_expiry, now, session_id), ) expires_at = new_expiry if expires_at <= now: cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,)) return None if (expires_at - now).total_seconds() <= SESSION_REFRESH_WINDOW_SECONDS: new_expiry = _new_expiry(now) cur.execute( """ UPDATE app_session SET expires_at = %s, last_seen_at = %s WHERE id = %s """, (new_expiry, now, session_id), ) cur.execute( "SELECT id, username, password_hash, role FROM app_user WHERE id = %s", (row[1],), ) return _row_to_user(cur.fetchone())