326 lines
10 KiB
Python

import hashlib
import os
import re
import secrets
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from argon2 import PasswordHasher
from argon2.exceptions import InvalidHash, VerifyMismatchError
from app.services.db import db_connection
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
SESSION_REFRESH_WINDOW_SECONDS = int(
os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60))
)
RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10"))
PASSWORD_HASHER = PasswordHasher()
LEGACY_SHA256_RE = re.compile(r"^[0-9a-f]{64}$")
RESET_OTP_SECRET = (os.getenv("RESET_OTP_SECRET") or "").strip()
if not RESET_OTP_SECRET:
raise RuntimeError("RESET_OTP_SECRET must be configured")
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _new_expiry(now: datetime) -> datetime:
return now + timedelta(seconds=SESSION_TTL_SECONDS)
def _hash_password(password: str) -> str:
return PASSWORD_HASHER.hash(password)
def _hash_password_legacy(password: str) -> str:
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def _is_legacy_password_hash(password_hash: str | None) -> bool:
return bool(password_hash and LEGACY_SHA256_RE.fullmatch(password_hash))
def _hash_otp(email: str, otp: str) -> str:
payload = f"{email}:{otp}:{RESET_OTP_SECRET}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _row_to_user(row):
if not row:
return None
return {
"id": row[0],
"username": row[1],
"password": row[2],
"role": row[3] if len(row) > 3 else None,
}
def get_user_by_username(username: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE username = %s",
(username,),
)
return _row_to_user(cur.fetchone())
def get_user_by_id(user_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE id = %s",
(user_id,),
)
return _row_to_user(cur.fetchone())
def create_user(username: str, password: str):
user_id = str(uuid4())
password_hash = _hash_password(password)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO app_user (id, username, password_hash, role)
VALUES (%s, %s, %s, 'USER')
ON CONFLICT (username) DO NOTHING
RETURNING id, username, password_hash, role
""",
(user_id, username, password_hash),
)
return _row_to_user(cur.fetchone())
def _update_password_hash(user_id: str, password_hash: str):
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE app_user SET password_hash = %s WHERE id = %s",
(password_hash, user_id),
)
def _verify_password(user_id: str, stored_hash: str | None, password: str) -> tuple[bool, str | None]:
if not stored_hash:
return False, None
if _is_legacy_password_hash(stored_hash):
if secrets.compare_digest(stored_hash, _hash_password_legacy(password)):
return True, _hash_password(password)
return False, None
try:
verified = PASSWORD_HASHER.verify(stored_hash, password)
except (VerifyMismatchError, InvalidHash):
return False, None
if not verified:
return False, None
if PASSWORD_HASHER.check_needs_rehash(stored_hash):
return True, _hash_password(password)
return True, None
def authenticate_user(username: str, password: str):
user = get_user_by_username(username)
if not user:
return None
verified, replacement_hash = _verify_password(user["id"], user.get("password"), password)
if not verified:
return None
if replacement_hash:
_update_password_hash(user["id"], replacement_hash)
user["password"] = replacement_hash
return user
def verify_user(username: str, password: str):
return authenticate_user(username, password)
def create_session(user_id: str, ip: str | None = None, user_agent: str | None = None) -> str:
session_id = str(uuid4())
now = _now_utc()
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO app_session (id, user_id, created_at, last_seen_at, expires_at, ip, user_agent)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(session_id, user_id, now, now, _new_expiry(now), ip, user_agent),
)
return session_id
def get_last_session_meta(user_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT ip, user_agent
FROM app_session
WHERE user_id = %s
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
)
row = cur.fetchone()
if not row:
return {"ip": None, "user_agent": None}
return {"ip": row[0], "user_agent": row[1]}
def update_user_password(user_id: str, new_password: str):
password_hash = _hash_password(new_password)
_update_password_hash(user_id, password_hash)
def create_password_reset_otp(email: str):
otp = f"{secrets.randbelow(10000):04d}"
now = _now_utc()
expires_at = now + timedelta(minutes=RESET_OTP_TTL_MINUTES)
otp_hash = _hash_otp(email, otp)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO password_reset_otp (id, email, otp_hash, created_at, expires_at, used_at)
VALUES (%s, %s, %s, %s, %s, NULL)
""",
(str(uuid4()), email, otp_hash, now, expires_at),
)
return otp
def consume_password_reset_otp(email: str, otp: str) -> bool:
now = _now_utc()
otp_hash = _hash_otp(email, otp)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id
FROM password_reset_otp
WHERE email = %s
AND otp_hash = %s
AND used_at IS NULL
AND expires_at > %s
ORDER BY created_at DESC
LIMIT 1
""",
(email, otp_hash, now),
)
row = cur.fetchone()
if not row:
return False
cur.execute(
"UPDATE password_reset_otp SET used_at = %s WHERE id = %s",
(now, row[0]),
)
return True
def get_session(session_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, user_id, created_at, last_seen_at, expires_at
FROM app_session
WHERE id = %s
""",
(session_id,),
)
row = cur.fetchone()
if not row:
return None
created_at = row[2].isoformat() if row[2] else None
last_seen_at = row[3].isoformat() if row[3] else None
expires_at = row[4].isoformat() if row[4] else None
return {
"id": row[0],
"user_id": row[1],
"created_at": created_at,
"last_seen_at": last_seen_at,
"expires_at": expires_at,
}
def delete_session(session_id: str):
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,))
def get_user_for_session(session_id: str):
if not session_id:
return None
now = _now_utc()
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM app_session
WHERE expires_at IS NOT NULL AND expires_at <= %s
""",
(now,),
)
cur.execute(
"""
SELECT id, user_id, created_at, last_seen_at, expires_at
FROM app_session
WHERE id = %s
""",
(session_id,),
)
row = cur.fetchone()
if not row:
return None
expires_at = row[4]
if expires_at is None:
new_expiry = _new_expiry(now)
cur.execute(
"""
UPDATE app_session
SET expires_at = %s, last_seen_at = %s
WHERE id = %s
""",
(new_expiry, now, session_id),
)
expires_at = new_expiry
if expires_at <= now:
cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,))
return None
if (expires_at - now).total_seconds() <= SESSION_REFRESH_WINDOW_SECONDS:
new_expiry = _new_expiry(now)
cur.execute(
"""
UPDATE app_session
SET expires_at = %s, last_seen_at = %s
WHERE id = %s
""",
(new_expiry, now, session_id),
)
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE id = %s",
(row[1],),
)
return _row_to_user(cur.fetchone())