227 lines
6.5 KiB
Python
227 lines
6.5 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import logging
|
|
import secrets
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from sqlalchemy import Select, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import Settings
|
|
from app.models.auth import AuthSession, AuthUser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SCRYPT_N = 2**14
|
|
SCRYPT_R = 8
|
|
SCRYPT_P = 1
|
|
SCRYPT_DKLEN = 64
|
|
|
|
|
|
class AuthBootstrapError(RuntimeError):
|
|
"""Raised when the auth system cannot be safely initialized."""
|
|
|
|
|
|
class AuthPasswordChangeError(ValueError):
|
|
"""Raised when a password change request is invalid."""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AuthenticatedSession:
|
|
user: AuthUser
|
|
session: AuthSession
|
|
|
|
|
|
def initialize_auth_schema(session: Session, settings: Settings) -> None:
|
|
has_any_user = session.scalar(select(AuthUser.id).limit(1)) is not None
|
|
if has_any_user:
|
|
return
|
|
|
|
if not settings.auth_bootstrap_username or not settings.auth_bootstrap_password:
|
|
raise AuthBootstrapError(
|
|
"Auth DB has no users. Set AUTH_BOOTSTRAP_USERNAME and "
|
|
"AUTH_BOOTSTRAP_PASSWORD before starting the app."
|
|
)
|
|
|
|
bootstrap_user = AuthUser(
|
|
username=settings.auth_bootstrap_username,
|
|
password_hash=hash_password(settings.auth_bootstrap_password),
|
|
is_active=True,
|
|
force_password_change=True,
|
|
created_at=_utc_now(),
|
|
)
|
|
session.add(bootstrap_user)
|
|
session.commit()
|
|
logger.warning(
|
|
"Bootstrapped initial auth user '%s'. Rotate AUTH_BOOTSTRAP_PASSWORD after first setup.",
|
|
bootstrap_user.username,
|
|
)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
salt = secrets.token_bytes(16)
|
|
derived_key = hashlib.scrypt(
|
|
password.encode("utf-8"),
|
|
salt=salt,
|
|
n=SCRYPT_N,
|
|
r=SCRYPT_R,
|
|
p=SCRYPT_P,
|
|
dklen=SCRYPT_DKLEN,
|
|
)
|
|
return "$".join(
|
|
[
|
|
"scrypt",
|
|
str(SCRYPT_N),
|
|
str(SCRYPT_R),
|
|
str(SCRYPT_P),
|
|
base64.b64encode(salt).decode("ascii"),
|
|
base64.b64encode(derived_key).decode("ascii"),
|
|
]
|
|
)
|
|
|
|
|
|
def verify_password(password: str, stored_hash: str) -> bool:
|
|
try:
|
|
algorithm, n, r, p, encoded_salt, encoded_key = stored_hash.split("$")
|
|
except ValueError:
|
|
return False
|
|
|
|
if algorithm != "scrypt":
|
|
return False
|
|
|
|
try:
|
|
salt = base64.b64decode(encoded_salt.encode("ascii"))
|
|
expected_key = base64.b64decode(encoded_key.encode("ascii"))
|
|
derived_key = hashlib.scrypt(
|
|
password.encode("utf-8"),
|
|
salt=salt,
|
|
n=int(n),
|
|
r=int(r),
|
|
p=int(p),
|
|
dklen=len(expected_key),
|
|
)
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
return secrets.compare_digest(derived_key, expected_key)
|
|
|
|
|
|
def authenticate_user(session: Session, *, username: str, password: str) -> AuthUser | None:
|
|
user = session.scalar(select(AuthUser).where(AuthUser.username == username).limit(1))
|
|
if user is None or not user.is_active:
|
|
logger.info("Failed login for unknown or inactive user '%s'", username)
|
|
return None
|
|
|
|
if not verify_password(password, user.password_hash):
|
|
logger.info("Failed login due to invalid password for user '%s'", username)
|
|
return None
|
|
|
|
return user
|
|
|
|
|
|
def create_session(session: Session, *, user: AuthUser, settings: Settings) -> tuple[AuthSession, str]:
|
|
raw_token = secrets.token_urlsafe(32)
|
|
auth_session = AuthSession(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(raw_token),
|
|
csrf_token=secrets.token_urlsafe(24),
|
|
created_at=_utc_now(),
|
|
expires_at=_utc_now() + timedelta(hours=settings.auth_session_ttl_hours),
|
|
revoked_at=None,
|
|
)
|
|
session.add(auth_session)
|
|
session.commit()
|
|
session.refresh(auth_session)
|
|
return auth_session, raw_token
|
|
|
|
|
|
def get_authenticated_session(session: Session, *, raw_token: str | None) -> AuthenticatedSession | None:
|
|
if not raw_token:
|
|
return None
|
|
|
|
stmt: Select[tuple[AuthSession, AuthUser]] = (
|
|
select(AuthSession, AuthUser)
|
|
.join(AuthUser, AuthSession.user_id == AuthUser.id)
|
|
.where(AuthSession.token_hash == _hash_token(raw_token))
|
|
.limit(1)
|
|
)
|
|
result = session.execute(stmt).first()
|
|
if result is None:
|
|
return None
|
|
|
|
auth_session, user = result
|
|
now = _utc_now()
|
|
expires_at = _as_utc(auth_session.expires_at)
|
|
revoked_at = _as_utc(auth_session.revoked_at)
|
|
if revoked_at is not None or expires_at <= now or not user.is_active:
|
|
if revoked_at is None and expires_at <= now:
|
|
auth_session.revoked_at = now
|
|
session.commit()
|
|
return None
|
|
|
|
return AuthenticatedSession(user=user, session=auth_session)
|
|
|
|
|
|
def revoke_session(session: Session, *, auth_session: AuthSession) -> None:
|
|
if auth_session.revoked_at is not None:
|
|
return
|
|
auth_session.revoked_at = _utc_now()
|
|
session.commit()
|
|
|
|
|
|
def change_password(
|
|
session: Session,
|
|
*,
|
|
user: AuthUser,
|
|
current_password: str,
|
|
new_password: str,
|
|
confirm_password: str,
|
|
) -> None:
|
|
if not verify_password(current_password, user.password_hash):
|
|
raise AuthPasswordChangeError("current password is invalid")
|
|
|
|
if not new_password:
|
|
raise AuthPasswordChangeError("new password must not be empty")
|
|
|
|
if new_password != confirm_password:
|
|
raise AuthPasswordChangeError("new password confirmation does not match")
|
|
|
|
if len(new_password) < 8:
|
|
raise AuthPasswordChangeError("new password must be at least 8 characters long")
|
|
|
|
if verify_password(new_password, user.password_hash):
|
|
raise AuthPasswordChangeError("new password must be different from the current password")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
user.force_password_change = False
|
|
session.commit()
|
|
|
|
|
|
def issue_login_csrf_token() -> str:
|
|
return secrets.token_urlsafe(24)
|
|
|
|
|
|
def validate_csrf_token(*, expected: str | None, actual: str | None) -> bool:
|
|
if not expected or not actual:
|
|
return False
|
|
return secrets.compare_digest(expected, actual)
|
|
|
|
|
|
def _hash_token(raw_token: str) -> str:
|
|
return hashlib.sha256(raw_token.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _utc_now() -> datetime:
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def _as_utc(value: datetime | None) -> datetime | None:
|
|
if value is None:
|
|
return None
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=UTC)
|
|
return value.astimezone(UTC)
|