from __future__ import annotations import hashlib import logging import secrets from dataclasses import dataclass from datetime import UTC, datetime, timedelta from argon2 import PasswordHasher from argon2.exceptions import InvalidHashError, VerificationError, VerifyMismatchError from sqlalchemy import Select, select from sqlalchemy.orm import Session from app.config import Settings from app.models.auth import AuthSession, AuthUser logger = logging.getLogger(__name__) password_hasher = PasswordHasher() class AuthBootstrapError(RuntimeError): """Raised when the auth system cannot be safely initialized.""" class AuthPasswordChangeError(ValueError): """Raised when a password change request is invalid.""" @dataclass(slots=True) class AuthenticatedSession: user: AuthUser session: AuthSession def initialize_auth_schema(session: Session, settings: Settings) -> None: has_any_user = session.scalar(select(AuthUser.id).limit(1)) is not None if has_any_user: return if not settings.auth_bootstrap_username or not settings.auth_bootstrap_password: raise AuthBootstrapError( "Auth DB has no users. Set AUTH_BOOTSTRAP_USERNAME and " "AUTH_BOOTSTRAP_PASSWORD before starting the app." ) bootstrap_user = AuthUser( username=settings.auth_bootstrap_username, password_hash=hash_password(settings.auth_bootstrap_password), is_active=True, force_password_change=True, created_at=_utc_now(), ) session.add(bootstrap_user) session.commit() logger.warning( "Bootstrapped initial auth user '%s'. Rotate AUTH_BOOTSTRAP_PASSWORD after first setup.", bootstrap_user.username, ) def hash_password(password: str) -> str: return password_hasher.hash(password) def verify_password(password: str, stored_hash: str) -> bool: try: return password_hasher.verify(stored_hash, password) except VerifyMismatchError: return False except (InvalidHashError, VerificationError): return False def authenticate_user(session: Session, *, username: str, password: str) -> AuthUser | None: user = session.scalar(select(AuthUser).where(AuthUser.username == username).limit(1)) if user is None or not user.is_active: logger.info("Failed login for unknown or inactive user '%s'", username) return None if not verify_password(password, user.password_hash): logger.info("Failed login due to invalid password for user '%s'", username) return None return user def create_session(session: Session, *, user: AuthUser, settings: Settings) -> tuple[AuthSession, str]: raw_token = secrets.token_urlsafe(32) auth_session = AuthSession( user_id=user.id, token_hash=_hash_token(raw_token), csrf_token=secrets.token_urlsafe(24), created_at=_utc_now(), expires_at=_utc_now() + timedelta(hours=settings.auth_session_ttl_hours), revoked_at=None, ) session.add(auth_session) session.commit() session.refresh(auth_session) return auth_session, raw_token def get_authenticated_session(session: Session, *, raw_token: str | None) -> AuthenticatedSession | None: if not raw_token: return None stmt: Select[tuple[AuthSession, AuthUser]] = ( select(AuthSession, AuthUser) .join(AuthUser, AuthSession.user_id == AuthUser.id) .where(AuthSession.token_hash == _hash_token(raw_token)) .limit(1) ) result = session.execute(stmt).first() if result is None: return None auth_session, user = result now = _utc_now() expires_at = _as_utc(auth_session.expires_at) revoked_at = _as_utc(auth_session.revoked_at) if expires_at is None: logger.warning("Auth session %s has no expires_at; treating it as invalid", auth_session.id) return None if revoked_at is not None or expires_at <= now or not user.is_active: if revoked_at is None and expires_at <= now: auth_session.revoked_at = now session.commit() return None return AuthenticatedSession(user=user, session=auth_session) def revoke_session(session: Session, *, auth_session: AuthSession) -> None: if auth_session.revoked_at is not None: return auth_session.revoked_at = _utc_now() session.commit() def change_password( session: Session, *, user: AuthUser, current_password: str, new_password: str, confirm_password: str, ) -> None: if not verify_password(current_password, user.password_hash): raise AuthPasswordChangeError("current password is invalid") if not new_password: raise AuthPasswordChangeError("new password must not be empty") if new_password != confirm_password: raise AuthPasswordChangeError("new password confirmation does not match") if len(new_password) < 8: raise AuthPasswordChangeError("new password must be at least 8 characters long") if verify_password(new_password, user.password_hash): raise AuthPasswordChangeError("new password must be different from the current password") user.password_hash = hash_password(new_password) user.force_password_change = False session.commit() def issue_login_csrf_token() -> str: return secrets.token_urlsafe(24) def validate_csrf_token(*, expected: str | None, actual: str | None) -> bool: if not expected or not actual: return False return secrets.compare_digest(expected, actual) def _hash_token(raw_token: str) -> str: return hashlib.sha256(raw_token.encode("utf-8")).hexdigest() def _utc_now() -> datetime: return datetime.now(UTC) def _as_utc(value: datetime | None) -> datetime | None: if value is None: return None if value.tzinfo is None: return value.replace(tzinfo=UTC) return value.astimezone(UTC)