from __future__ import annotations import logging from datetime import datetime, timedelta, timezone from typing import Callable from fastapi import Request, Response, status from fastapi.responses import JSONResponse from sqlmodel import Session from starlette.middleware.base import BaseHTTPMiddleware import settings from trading_journal import crud, security from trading_journal.db import Database from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.models import Sessions SessionsCreate.model_rebuild() EXCEPT_PATHS = [ f"{settings.settings.api_base}/status", f"{settings.settings.api_base}/register", f"{settings.settings.api_base}/login", ] logger = logging.getLogger(__name__) class AuthMiddleWare(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) token = request.cookies.get("session_token") if not token: auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): token = auth_header[len("Bearer ") :] if not token: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}, ) db_factory: Database | None = getattr(request.app.state, "db_factory", None) if db_factory is None: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"}) try: with db_factory.get_session_ctx_manager() as request_session: hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) if not login_session: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) if session_expires_utc < datetime.now(timezone.utc): crud.delete_login_session(request.state.db_session, login_session) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) if login_session.user.is_active is False: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600): updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds) else: updated_expiry = session_expires_utc updated_session: SessionsUpdate = SessionsUpdate( last_seen_at=datetime.now(timezone.utc), last_used_ip=request.client.host if request.client else None, user_agent=request.headers.get("User-Agent"), expires_at=updated_expiry, ) user_id = login_session.user_id request.state.user_id = user_id crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session) except Exception: logger.exception("Failed to authenticate user: \n") return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) return await call_next(request) class ServiceError(Exception): pass class UserAlreadyExistsError(ServiceError): pass def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") hashed = security.hash_password(user_in.password) user_data: dict = { "username": user_in.username, "password_hash": hashed, } try: user = crud.create_user(db_session, user_data=user_data) try: # prefer pydantic's from_orm if DTO supports orm_mode user = UserRead.model_validate(user) except Exception as e: logger.exception("Failed to convert user to UserRead: %s", e) raise ServiceError("Failed to convert user to UserRead") from e except Exception as e: logger.exception("Failed to create user:") raise ServiceError("Failed to create user") from e return user def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None: user = crud.get_user_by_username(db_session, user_in.username) if not user: return None if not security.verify_password(user_in.password, user.password_hash): return None token = security.generate_session_token() token_hashed = security.hash_session_token_sha256(token) try: session = crud.create_login_session( session=db_session, user_id=user.id, session_token_hash=token_hashed, session_length_seconds=settings.settings.session_expiry_seconds, ) except Exception as e: logger.exception("Failed to create login session: \n") raise ServiceError("Failed to create login session") from e return SessionsCreate.model_validate(session), token def get_trades_service(db_session: Session, user_id: int) -> list: return crud.get_trades_by_user_id(db_session, user_id)