From 466e6ce653c1912d564b45a216f9253957a4f231 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 17:35:10 +0200 Subject: [PATCH] wip user reg --- backend/app.py | 34 +++++++++++--- backend/settings.py | 1 + backend/tests/test_app.py | 5 ++- backend/tests/test_crud.py | 26 +++++++++++ backend/trading_journal/crud.py | 20 +++++++++ backend/trading_journal/db.py | 13 ++++++ backend/trading_journal/service.py | 72 ++++++++++++++++++++++++++++++ 7 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 backend/trading_journal/service.py diff --git a/backend/app.py b/backend/app.py index 812f896..f485bb2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -2,13 +2,12 @@ import asyncio from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from fastapi import FastAPI, status +from fastapi import FastAPI, HTTPException, Request import settings -from trading_journal import db -from trading_journal.dto import TradeCreate, TradeRead - -API_BASE = "/api/v1" +from trading_journal import db, service +from trading_journal.db import Database +from trading_journal.dto import UserCreate, UserRead _db = db.create_database(settings.settings.database_url) @@ -23,8 +22,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 app = FastAPI(lifespan=lifespan) +app.add_middleware(service.AuthMiddleWare) +app.state.db_factory = _db -@app.get(f"{API_BASE}/status") +@app.get(f"{settings.settings.api_base}/status") async def get_status() -> dict[str, str]: return {"status": "ok"} + + +@app.post(f"{settings.settings.api_base}/register") +async def register_user(request: Request, user_in: UserCreate) -> UserRead: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> UserRead: + with db_factory.get_session_ctx_manager() as db: + return service.register_user_service(db, user_in) + + try: + return await asyncio.to_thread(sync_work) + except service.UserAlreadyExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e + + +@app.get(f"{settings.settings.api_base}/trades") +async def get_trades() -> dict[str, str]: + return {"trades": []} diff --git a/backend/settings.py b/backend/settings.py index 62305be..1e1e29f 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -15,6 +15,7 @@ class Settings(BaseSettings): workers: int = 1 log_level: str = "info" database_url: str = "sqlite:///:memory:" + api_base: str = "/api/v1" hmac_key: str | None = None model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py index d6123a6..78cf8ad 100644 --- a/backend/tests/test_app.py +++ b/backend/tests/test_app.py @@ -3,7 +3,8 @@ from collections.abc import Generator import pytest from fastapi.testclient import TestClient -from app import API_BASE, app +import settings +from app import app @pytest.fixture @@ -13,6 +14,6 @@ def client() -> Generator[TestClient, None, None]: def test_get_status(client: TestClient) -> None: - response = client.get(f"{API_BASE}/status") + response = client.get(f"{settings.settings.api_base}/status") assert response.status_code == 200 assert response.json() == {"status": "ok"} diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 1ae5a55..9e0fade 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -651,6 +651,22 @@ def test_create_user(session: Session) -> None: assert actual_user.password_hash == user_data["password_hash"] +def test_get_user_by_id(session: Session) -> None: + user_id = make_user(session, username="fetchuser") + user = crud.get_user_by_id(session, user_id) + assert user is not None + assert user.id == user_id + assert user.username == "fetchuser" + + +def test_get_user_by_username(session: Session) -> None: + username = "uniqueuser" + make_user(session, username=username) + user = crud.get_user_by_username(session, username) + assert user is not None + assert user.username == username + + def test_update_user(session: Session) -> None: user_id = make_user(session, username="updatableuser") @@ -716,6 +732,16 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None: assert fetched_session.session_token_hash == created_session.session_token_hash +def test_get_login_session_by_token(session: Session) -> None: + now = datetime.now(timezone.utc) + created_session = make_login_session(session, now) + fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash) + assert fetched_session is not None + assert fetched_session.id == created_session.id + assert fetched_session.user_id == created_session.user_id + assert fetched_session.session_token_hash == created_session.session_token_hash + + def test_update_login_session(session: Session) -> None: now = datetime.now(timezone.utc) created_session = make_login_session(session, now) diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 9e998bd..d21e157 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -349,6 +349,17 @@ def create_user(session: Session, user_data: Mapping) -> models.Users: return u +def get_user_by_id(session: Session, user_id: int) -> models.Users | None: + return session.get(models.Users, user_id) + + +def get_user_by_username(session: Session, username: str) -> models.Users | None: + statement = select(models.Users).where( + models.Users.username == username, + ) + return session.exec(statement).first() + + def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: user: models.Users | None = session.get(models.Users, user_id) if user is None: @@ -418,6 +429,15 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_ return session.exec(statement).first() +def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None: + statement = select(models.Sessions).where( + models.Sessions.session_token_hash == session_token_hash, + models.Sessions.expires_at > datetime.now(timezone.utc), + ) + + return session.exec(statement).first() + + IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index 039727c..93503f3 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from contextlib import contextmanager from typing import TYPE_CHECKING from sqlalchemy import event @@ -72,6 +73,18 @@ class Database: finally: session.close() + @contextmanager + def get_session_ctx_manager(self) -> Session: + session = Session(self._engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + def dispose(self) -> None: self._engine.dispose() diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py new file mode 100644 index 0000000..404ffb8 --- /dev/null +++ b/backend/trading_journal/service.py @@ -0,0 +1,72 @@ +from typing import Callable + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from sqlmodel import Session +from starlette.middleware.base import BaseHTTPMiddleware + +import settings +from trading_journal import crud, security +from trading_journal.db import Database +from trading_journal.dto import UserCreate, UserRead +from trading_journal.models import Sessions + +EXCEPT_PATHS = [ + f"{settings.settings.api_base}/status", + f"{settings.settings.api_base}/register", +] + + +class AuthMiddleWare(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + if request.url.path in EXCEPT_PATHS: + return await call_next(request) + + token = request.cookies.get("session_token") + if not token: + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[len("Bearer ") :] + + if not token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Unauthorized"}, + ) + + db_factory: Database | None = getattr(request.app.state, "db_factory", None) + if db_factory is None: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"}) + try: + with db_factory.get_session_ctx_manager() as request_session: + hashed_token = security.hash_session_token_sha256(token) + request.state.db_session = request_session + login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) + except Exception: # noqa: BLE001 + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"}) + + return None + + +class ServiceError(Exception): + pass + + +class UserAlreadyExistsError(ServiceError): + pass + + +def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: + if crud.get_user_by_username(db_session, user_in.username): + raise UserAlreadyExistsError("username already exists") + hashed = security.hash_password(user_in.password) + try: + user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed) + try: + # prefer pydantic's from_orm if DTO supports orm_mode + user = UserRead.model_validate(user) + except Exception as e: + raise ServiceError("Failed to convert user to UserRead") from e + except Exception as e: + raise ServiceError("Failed to create user") from e + return user