193 lines
5.8 KiB
Python
193 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import logging
|
|
import secrets
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from argon2 import PasswordHasher
|
|
from argon2.exceptions import InvalidHashError, VerificationError, VerifyMismatchError
|
|
from sqlalchemy import Select, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import Settings
|
|
from app.models.auth import AuthSession, AuthUser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
password_hasher = PasswordHasher()
|
|
|
|
|
|
class AuthBootstrapError(RuntimeError):
|
|
"""Raised when the auth system cannot be safely initialized."""
|
|
|
|
|
|
class AuthPasswordChangeError(ValueError):
|
|
"""Raised when a password change request is invalid."""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AuthenticatedSession:
|
|
user: AuthUser
|
|
session: AuthSession
|
|
|
|
|
|
def initialize_auth_schema(session: Session, settings: Settings) -> None:
|
|
has_any_user = session.scalar(select(AuthUser.id).limit(1)) is not None
|
|
if has_any_user:
|
|
return
|
|
|
|
if not settings.auth_bootstrap_username or not settings.auth_bootstrap_password:
|
|
raise AuthBootstrapError(
|
|
"Auth DB has no users. Set AUTH_BOOTSTRAP_USERNAME and "
|
|
"AUTH_BOOTSTRAP_PASSWORD before starting the app."
|
|
)
|
|
|
|
bootstrap_user = AuthUser(
|
|
username=settings.auth_bootstrap_username,
|
|
password_hash=hash_password(settings.auth_bootstrap_password),
|
|
is_active=True,
|
|
force_password_change=True,
|
|
created_at=_utc_now(),
|
|
)
|
|
session.add(bootstrap_user)
|
|
session.commit()
|
|
logger.warning(
|
|
"Bootstrapped initial auth user '%s'. Rotate AUTH_BOOTSTRAP_PASSWORD after first setup.",
|
|
bootstrap_user.username,
|
|
)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
return password_hasher.hash(password)
|
|
|
|
|
|
def verify_password(password: str, stored_hash: str) -> bool:
|
|
try:
|
|
return password_hasher.verify(stored_hash, password)
|
|
except VerifyMismatchError:
|
|
return False
|
|
except (InvalidHashError, VerificationError):
|
|
return False
|
|
|
|
|
|
def authenticate_user(session: Session, *, username: str, password: str) -> AuthUser | None:
|
|
user = session.scalar(select(AuthUser).where(AuthUser.username == username).limit(1))
|
|
if user is None or not user.is_active:
|
|
logger.info("Failed login for unknown or inactive user '%s'", username)
|
|
return None
|
|
|
|
if not verify_password(password, user.password_hash):
|
|
logger.info("Failed login due to invalid password for user '%s'", username)
|
|
return None
|
|
|
|
return user
|
|
|
|
|
|
def create_session(session: Session, *, user: AuthUser, settings: Settings) -> tuple[AuthSession, str]:
|
|
raw_token = secrets.token_urlsafe(32)
|
|
auth_session = AuthSession(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(raw_token),
|
|
csrf_token=secrets.token_urlsafe(24),
|
|
created_at=_utc_now(),
|
|
expires_at=_utc_now() + timedelta(hours=settings.auth_session_ttl_hours),
|
|
revoked_at=None,
|
|
)
|
|
session.add(auth_session)
|
|
session.commit()
|
|
session.refresh(auth_session)
|
|
return auth_session, raw_token
|
|
|
|
|
|
def get_authenticated_session(session: Session, *, raw_token: str | None) -> AuthenticatedSession | None:
|
|
if not raw_token:
|
|
return None
|
|
|
|
stmt: Select[tuple[AuthSession, AuthUser]] = (
|
|
select(AuthSession, AuthUser)
|
|
.join(AuthUser, AuthSession.user_id == AuthUser.id)
|
|
.where(AuthSession.token_hash == _hash_token(raw_token))
|
|
.limit(1)
|
|
)
|
|
result = session.execute(stmt).first()
|
|
if result is None:
|
|
return None
|
|
|
|
auth_session, user = result
|
|
now = _utc_now()
|
|
expires_at = _as_utc(auth_session.expires_at)
|
|
revoked_at = _as_utc(auth_session.revoked_at)
|
|
if expires_at is None:
|
|
logger.warning("Auth session %s has no expires_at; treating it as invalid", auth_session.id)
|
|
return None
|
|
|
|
if revoked_at is not None or expires_at <= now or not user.is_active:
|
|
if revoked_at is None and expires_at <= now:
|
|
auth_session.revoked_at = now
|
|
session.commit()
|
|
return None
|
|
|
|
return AuthenticatedSession(user=user, session=auth_session)
|
|
|
|
|
|
def revoke_session(session: Session, *, auth_session: AuthSession) -> None:
|
|
if auth_session.revoked_at is not None:
|
|
return
|
|
auth_session.revoked_at = _utc_now()
|
|
session.commit()
|
|
|
|
|
|
def change_password(
|
|
session: Session,
|
|
*,
|
|
user: AuthUser,
|
|
current_password: str,
|
|
new_password: str,
|
|
confirm_password: str,
|
|
) -> None:
|
|
if not verify_password(current_password, user.password_hash):
|
|
raise AuthPasswordChangeError("current password is invalid")
|
|
|
|
if not new_password:
|
|
raise AuthPasswordChangeError("new password must not be empty")
|
|
|
|
if new_password != confirm_password:
|
|
raise AuthPasswordChangeError("new password confirmation does not match")
|
|
|
|
if len(new_password) < 8:
|
|
raise AuthPasswordChangeError("new password must be at least 8 characters long")
|
|
|
|
if verify_password(new_password, user.password_hash):
|
|
raise AuthPasswordChangeError("new password must be different from the current password")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
user.force_password_change = False
|
|
session.commit()
|
|
|
|
|
|
def issue_login_csrf_token() -> str:
|
|
return secrets.token_urlsafe(24)
|
|
|
|
|
|
def validate_csrf_token(*, expected: str | None, actual: str | None) -> bool:
|
|
if not expected or not actual:
|
|
return False
|
|
return secrets.compare_digest(expected, actual)
|
|
|
|
|
|
def _hash_token(raw_token: str) -> str:
|
|
return hashlib.sha256(raw_token.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _utc_now() -> datetime:
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def _as_utc(value: datetime | None) -> datetime | None:
|
|
if value is None:
|
|
return None
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=UTC)
|
|
return value.astimezone(UTC)
|