from __future__ import annotations import base64 import hashlib import logging import secrets from dataclasses import dataclass from datetime import UTC, datetime, timedelta from sqlalchemy import Select, select from sqlalchemy.orm import Session from app.config import Settings from app.models.auth import AuthSession, AuthUser logger = logging.getLogger(__name__) SCRYPT_N = 2**14 SCRYPT_R = 8 SCRYPT_P = 1 SCRYPT_DKLEN = 64 class AuthBootstrapError(RuntimeError): """Raised when the auth system cannot be safely initialized.""" class AuthPasswordChangeError(ValueError): """Raised when a password change request is invalid.""" @dataclass(slots=True) class AuthenticatedSession: user: AuthUser session: AuthSession def initialize_auth_schema(session: Session, settings: Settings) -> None: has_any_user = session.scalar(select(AuthUser.id).limit(1)) is not None if has_any_user: return if not settings.auth_bootstrap_username or not settings.auth_bootstrap_password: raise AuthBootstrapError( "Auth DB has no users. Set AUTH_BOOTSTRAP_USERNAME and " "AUTH_BOOTSTRAP_PASSWORD before starting the app." ) bootstrap_user = AuthUser( username=settings.auth_bootstrap_username, password_hash=hash_password(settings.auth_bootstrap_password), is_active=True, force_password_change=True, created_at=_utc_now(), ) session.add(bootstrap_user) session.commit() logger.warning( "Bootstrapped initial auth user '%s'. Rotate AUTH_BOOTSTRAP_PASSWORD after first setup.", bootstrap_user.username, ) def hash_password(password: str) -> str: salt = secrets.token_bytes(16) derived_key = hashlib.scrypt( password.encode("utf-8"), salt=salt, n=SCRYPT_N, r=SCRYPT_R, p=SCRYPT_P, dklen=SCRYPT_DKLEN, ) return "$".join( [ "scrypt", str(SCRYPT_N), str(SCRYPT_R), str(SCRYPT_P), base64.b64encode(salt).decode("ascii"), base64.b64encode(derived_key).decode("ascii"), ] ) def verify_password(password: str, stored_hash: str) -> bool: try: algorithm, n, r, p, encoded_salt, encoded_key = stored_hash.split("$") except ValueError: return False if algorithm != "scrypt": return False try: salt = base64.b64decode(encoded_salt.encode("ascii")) expected_key = base64.b64decode(encoded_key.encode("ascii")) derived_key = hashlib.scrypt( password.encode("utf-8"), salt=salt, n=int(n), r=int(r), p=int(p), dklen=len(expected_key), ) except (ValueError, TypeError): return False return secrets.compare_digest(derived_key, expected_key) def authenticate_user(session: Session, *, username: str, password: str) -> AuthUser | None: user = session.scalar(select(AuthUser).where(AuthUser.username == username).limit(1)) if user is None or not user.is_active: logger.info("Failed login for unknown or inactive user '%s'", username) return None if not verify_password(password, user.password_hash): logger.info("Failed login due to invalid password for user '%s'", username) return None return user def create_session(session: Session, *, user: AuthUser, settings: Settings) -> tuple[AuthSession, str]: raw_token = secrets.token_urlsafe(32) auth_session = AuthSession( user_id=user.id, token_hash=_hash_token(raw_token), csrf_token=secrets.token_urlsafe(24), created_at=_utc_now(), expires_at=_utc_now() + timedelta(hours=settings.auth_session_ttl_hours), revoked_at=None, ) session.add(auth_session) session.commit() session.refresh(auth_session) return auth_session, raw_token def get_authenticated_session(session: Session, *, raw_token: str | None) -> AuthenticatedSession | None: if not raw_token: return None stmt: Select[tuple[AuthSession, AuthUser]] = ( select(AuthSession, AuthUser) .join(AuthUser, AuthSession.user_id == AuthUser.id) .where(AuthSession.token_hash == _hash_token(raw_token)) .limit(1) ) result = session.execute(stmt).first() if result is None: return None auth_session, user = result now = _utc_now() expires_at = _as_utc(auth_session.expires_at) revoked_at = _as_utc(auth_session.revoked_at) if revoked_at is not None or expires_at <= now or not user.is_active: if revoked_at is None and expires_at <= now: auth_session.revoked_at = now session.commit() return None return AuthenticatedSession(user=user, session=auth_session) def revoke_session(session: Session, *, auth_session: AuthSession) -> None: if auth_session.revoked_at is not None: return auth_session.revoked_at = _utc_now() session.commit() def change_password( session: Session, *, user: AuthUser, current_password: str, new_password: str, confirm_password: str, ) -> None: if not verify_password(current_password, user.password_hash): raise AuthPasswordChangeError("current password is invalid") if not new_password: raise AuthPasswordChangeError("new password must not be empty") if new_password != confirm_password: raise AuthPasswordChangeError("new password confirmation does not match") if len(new_password) < 8: raise AuthPasswordChangeError("new password must be at least 8 characters long") if verify_password(new_password, user.password_hash): raise AuthPasswordChangeError("new password must be different from the current password") user.password_hash = hash_password(new_password) user.force_password_change = False session.commit() def issue_login_csrf_token() -> str: return secrets.token_urlsafe(24) def validate_csrf_token(*, expected: str | None, actual: str | None) -> bool: if not expected or not actual: return False return secrets.compare_digest(expected, actual) def _hash_token(raw_token: str) -> str: return hashlib.sha256(raw_token.encode("utf-8")).hexdigest() def _utc_now() -> datetime: return datetime.now(UTC) def _as_utc(value: datetime | None) -> datetime | None: if value is None: return None if value.tzinfo is None: return value.replace(tzinfo=UTC) return value.astimezone(UTC)