2026-02-01 13:57:30 +00:00

281 lines
8.9 KiB
Python

import hashlib
import os
import secrets
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from app.services.db import db_connection
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
SESSION_REFRESH_WINDOW_SECONDS = int(
os.getenv("SESSION_REFRESH_WINDOW_SECONDS", str(60 * 60))
)
RESET_OTP_TTL_MINUTES = int(os.getenv("RESET_OTP_TTL_MINUTES", "10"))
RESET_OTP_SECRET = os.getenv("RESET_OTP_SECRET", "otp_secret")
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _new_expiry(now: datetime) -> datetime:
return now + timedelta(seconds=SESSION_TTL_SECONDS)
def _hash_password(password: str) -> str:
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def _hash_otp(email: str, otp: str) -> str:
payload = f"{email}:{otp}:{RESET_OTP_SECRET}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _row_to_user(row):
if not row:
return None
return {
"id": row[0],
"username": row[1],
"password": row[2],
"role": row[3] if len(row) > 3 else None,
}
def get_user_by_username(username: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE username = %s",
(username,),
)
return _row_to_user(cur.fetchone())
def get_user_by_id(user_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE id = %s",
(user_id,),
)
return _row_to_user(cur.fetchone())
def create_user(username: str, password: str):
user_id = str(uuid4())
password_hash = _hash_password(password)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO app_user (id, username, password_hash, role)
VALUES (%s, %s, %s, 'USER')
ON CONFLICT (username) DO NOTHING
RETURNING id, username, password_hash, role
""",
(user_id, username, password_hash),
)
return _row_to_user(cur.fetchone())
def authenticate_user(username: str, password: str):
user = get_user_by_username(username)
if not user:
return None
if user.get("password") != _hash_password(password):
return None
return user
def verify_user(username: str, password: str):
return authenticate_user(username, password)
def create_session(user_id: str, ip: str | None = None, user_agent: str | None = None) -> str:
session_id = str(uuid4())
now = _now_utc()
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO app_session (id, user_id, created_at, last_seen_at, expires_at, ip, user_agent)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(session_id, user_id, now, now, _new_expiry(now), ip, user_agent),
)
return session_id
def get_last_session_meta(user_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT ip, user_agent
FROM app_session
WHERE user_id = %s
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
)
row = cur.fetchone()
if not row:
return {"ip": None, "user_agent": None}
return {"ip": row[0], "user_agent": row[1]}
def update_user_password(user_id: str, new_password: str):
password_hash = _hash_password(new_password)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE app_user SET password_hash = %s WHERE id = %s",
(password_hash, user_id),
)
def create_password_reset_otp(email: str):
otp = f"{secrets.randbelow(10000):04d}"
now = _now_utc()
expires_at = now + timedelta(minutes=RESET_OTP_TTL_MINUTES)
otp_hash = _hash_otp(email, otp)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO password_reset_otp (id, email, otp_hash, created_at, expires_at, used_at)
VALUES (%s, %s, %s, %s, %s, NULL)
""",
(str(uuid4()), email, otp_hash, now, expires_at),
)
return otp
def consume_password_reset_otp(email: str, otp: str) -> bool:
now = _now_utc()
otp_hash = _hash_otp(email, otp)
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id
FROM password_reset_otp
WHERE email = %s
AND otp_hash = %s
AND used_at IS NULL
AND expires_at > %s
ORDER BY created_at DESC
LIMIT 1
""",
(email, otp_hash, now),
)
row = cur.fetchone()
if not row:
return False
cur.execute(
"UPDATE password_reset_otp SET used_at = %s WHERE id = %s",
(now, row[0]),
)
return True
def get_session(session_id: str):
with db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, user_id, created_at, last_seen_at, expires_at
FROM app_session
WHERE id = %s
""",
(session_id,),
)
row = cur.fetchone()
if not row:
return None
created_at = row[2].isoformat() if row[2] else None
last_seen_at = row[3].isoformat() if row[3] else None
expires_at = row[4].isoformat() if row[4] else None
return {
"id": row[0],
"user_id": row[1],
"created_at": created_at,
"last_seen_at": last_seen_at,
"expires_at": expires_at,
}
def delete_session(session_id: str):
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,))
def get_user_for_session(session_id: str):
if not session_id:
return None
now = _now_utc()
with db_connection() as conn:
with conn:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM app_session
WHERE expires_at IS NOT NULL AND expires_at <= %s
""",
(now,),
)
cur.execute(
"""
SELECT id, user_id, created_at, last_seen_at, expires_at
FROM app_session
WHERE id = %s
""",
(session_id,),
)
row = cur.fetchone()
if not row:
return None
expires_at = row[4]
if expires_at is None:
new_expiry = _new_expiry(now)
cur.execute(
"""
UPDATE app_session
SET expires_at = %s, last_seen_at = %s
WHERE id = %s
""",
(new_expiry, now, session_id),
)
expires_at = new_expiry
if expires_at <= now:
cur.execute("DELETE FROM app_session WHERE id = %s", (session_id,))
return None
if (expires_at - now).total_seconds() <= SESSION_REFRESH_WINDOW_SECONDS:
new_expiry = _new_expiry(now)
cur.execute(
"""
UPDATE app_session
SET expires_at = %s, last_seen_at = %s
WHERE id = %s
""",
(new_expiry, now, session_id),
)
cur.execute(
"SELECT id, username, password_hash, role FROM app_user WHERE id = %s",
(row[1],),
)
return _row_to_user(cur.fetchone())