Add auth foundation and app DB management
This commit is contained in:
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import Settings
|
||||
from app.models.auth import AuthSession, AuthUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCRYPT_N = 2**14
|
||||
SCRYPT_R = 8
|
||||
SCRYPT_P = 1
|
||||
SCRYPT_DKLEN = 64
|
||||
|
||||
|
||||
class AuthBootstrapError(RuntimeError):
|
||||
"""Raised when the auth system cannot be safely initialized."""
|
||||
|
||||
|
||||
class AuthPasswordChangeError(ValueError):
|
||||
"""Raised when a password change request is invalid."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AuthenticatedSession:
|
||||
user: AuthUser
|
||||
session: AuthSession
|
||||
|
||||
|
||||
def initialize_auth_schema(session: Session, settings: Settings) -> None:
|
||||
has_any_user = session.scalar(select(AuthUser.id).limit(1)) is not None
|
||||
if has_any_user:
|
||||
return
|
||||
|
||||
if not settings.auth_bootstrap_username or not settings.auth_bootstrap_password:
|
||||
raise AuthBootstrapError(
|
||||
"Auth DB has no users. Set AUTH_BOOTSTRAP_USERNAME and "
|
||||
"AUTH_BOOTSTRAP_PASSWORD before starting the app."
|
||||
)
|
||||
|
||||
bootstrap_user = AuthUser(
|
||||
username=settings.auth_bootstrap_username,
|
||||
password_hash=hash_password(settings.auth_bootstrap_password),
|
||||
is_active=True,
|
||||
force_password_change=True,
|
||||
created_at=_utc_now(),
|
||||
)
|
||||
session.add(bootstrap_user)
|
||||
session.commit()
|
||||
logger.warning(
|
||||
"Bootstrapped initial auth user '%s'. Rotate AUTH_BOOTSTRAP_PASSWORD after first setup.",
|
||||
bootstrap_user.username,
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
salt = secrets.token_bytes(16)
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=SCRYPT_N,
|
||||
r=SCRYPT_R,
|
||||
p=SCRYPT_P,
|
||||
dklen=SCRYPT_DKLEN,
|
||||
)
|
||||
return "$".join(
|
||||
[
|
||||
"scrypt",
|
||||
str(SCRYPT_N),
|
||||
str(SCRYPT_R),
|
||||
str(SCRYPT_P),
|
||||
base64.b64encode(salt).decode("ascii"),
|
||||
base64.b64encode(derived_key).decode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def verify_password(password: str, stored_hash: str) -> bool:
|
||||
try:
|
||||
algorithm, n, r, p, encoded_salt, encoded_key = stored_hash.split("$")
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if algorithm != "scrypt":
|
||||
return False
|
||||
|
||||
try:
|
||||
salt = base64.b64decode(encoded_salt.encode("ascii"))
|
||||
expected_key = base64.b64decode(encoded_key.encode("ascii"))
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=int(n),
|
||||
r=int(r),
|
||||
p=int(p),
|
||||
dklen=len(expected_key),
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
return secrets.compare_digest(derived_key, expected_key)
|
||||
|
||||
|
||||
def authenticate_user(session: Session, *, username: str, password: str) -> AuthUser | None:
|
||||
user = session.scalar(select(AuthUser).where(AuthUser.username == username).limit(1))
|
||||
if user is None or not user.is_active:
|
||||
logger.info("Failed login for unknown or inactive user '%s'", username)
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
logger.info("Failed login due to invalid password for user '%s'", username)
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def create_session(session: Session, *, user: AuthUser, settings: Settings) -> tuple[AuthSession, str]:
|
||||
raw_token = secrets.token_urlsafe(32)
|
||||
auth_session = AuthSession(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(raw_token),
|
||||
csrf_token=secrets.token_urlsafe(24),
|
||||
created_at=_utc_now(),
|
||||
expires_at=_utc_now() + timedelta(hours=settings.auth_session_ttl_hours),
|
||||
revoked_at=None,
|
||||
)
|
||||
session.add(auth_session)
|
||||
session.commit()
|
||||
session.refresh(auth_session)
|
||||
return auth_session, raw_token
|
||||
|
||||
|
||||
def get_authenticated_session(session: Session, *, raw_token: str | None) -> AuthenticatedSession | None:
|
||||
if not raw_token:
|
||||
return None
|
||||
|
||||
stmt: Select[tuple[AuthSession, AuthUser]] = (
|
||||
select(AuthSession, AuthUser)
|
||||
.join(AuthUser, AuthSession.user_id == AuthUser.id)
|
||||
.where(AuthSession.token_hash == _hash_token(raw_token))
|
||||
.limit(1)
|
||||
)
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
auth_session, user = result
|
||||
now = _utc_now()
|
||||
expires_at = _as_utc(auth_session.expires_at)
|
||||
revoked_at = _as_utc(auth_session.revoked_at)
|
||||
if revoked_at is not None or expires_at <= now or not user.is_active:
|
||||
if revoked_at is None and expires_at <= now:
|
||||
auth_session.revoked_at = now
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
return AuthenticatedSession(user=user, session=auth_session)
|
||||
|
||||
|
||||
def revoke_session(session: Session, *, auth_session: AuthSession) -> None:
|
||||
if auth_session.revoked_at is not None:
|
||||
return
|
||||
auth_session.revoked_at = _utc_now()
|
||||
session.commit()
|
||||
|
||||
|
||||
def change_password(
|
||||
session: Session,
|
||||
*,
|
||||
user: AuthUser,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
confirm_password: str,
|
||||
) -> None:
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise AuthPasswordChangeError("current password is invalid")
|
||||
|
||||
if not new_password:
|
||||
raise AuthPasswordChangeError("new password must not be empty")
|
||||
|
||||
if new_password != confirm_password:
|
||||
raise AuthPasswordChangeError("new password confirmation does not match")
|
||||
|
||||
if len(new_password) < 8:
|
||||
raise AuthPasswordChangeError("new password must be at least 8 characters long")
|
||||
|
||||
if verify_password(new_password, user.password_hash):
|
||||
raise AuthPasswordChangeError("new password must be different from the current password")
|
||||
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.force_password_change = False
|
||||
session.commit()
|
||||
|
||||
|
||||
def issue_login_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(24)
|
||||
|
||||
|
||||
def validate_csrf_token(*, expected: str | None, actual: str | None) -> bool:
|
||||
if not expected or not actual:
|
||||
return False
|
||||
return secrets.compare_digest(expected, actual)
|
||||
|
||||
|
||||
def _hash_token(raw_token: str) -> str:
|
||||
return hashlib.sha256(raw_token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _as_utc(value: datetime | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=UTC)
|
||||
return value.astimezone(UTC)
|
||||
Reference in New Issue
Block a user