wip user reg
All checks were successful
Backend CI / unit-test (push) Successful in 34s

This commit is contained in:
2025-09-22 17:35:10 +02:00
parent e70a63e4f9
commit 466e6ce653
7 changed files with 163 additions and 8 deletions

View File

@@ -2,13 +2,12 @@ import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, status from fastapi import FastAPI, HTTPException, Request
import settings import settings
from trading_journal import db from trading_journal import db, service
from trading_journal.dto import TradeCreate, TradeRead from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead
API_BASE = "/api/v1"
_db = db.create_database(settings.settings.database_url) _db = db.create_database(settings.settings.database_url)
@@ -23,8 +22,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(service.AuthMiddleWare)
app.state.db_factory = _db
@app.get(f"{API_BASE}/status") @app.get(f"{settings.settings.api_base}/status")
async def get_status() -> dict[str, str]: async def get_status() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}
@app.post(f"{settings.settings.api_base}/register")
async def register_user(request: Request, user_in: UserCreate) -> UserRead:
db_factory: Database = request.app.state.db_factory
def sync_work() -> UserRead:
with db_factory.get_session_ctx_manager() as db:
return service.register_user_service(db, user_in)
try:
return await asyncio.to_thread(sync_work)
except service.UserAlreadyExistsError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e
@app.get(f"{settings.settings.api_base}/trades")
async def get_trades() -> dict[str, str]:
return {"trades": []}

View File

@@ -15,6 +15,7 @@ class Settings(BaseSettings):
workers: int = 1 workers: int = 1
log_level: str = "info" log_level: str = "info"
database_url: str = "sqlite:///:memory:" database_url: str = "sqlite:///:memory:"
api_base: str = "/api/v1"
hmac_key: str | None = None hmac_key: str | None = None
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")

View File

@@ -3,7 +3,8 @@ from collections.abc import Generator
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app import API_BASE, app import settings
from app import app
@pytest.fixture @pytest.fixture
@@ -13,6 +14,6 @@ def client() -> Generator[TestClient, None, None]:
def test_get_status(client: TestClient) -> None: def test_get_status(client: TestClient) -> None:
response = client.get(f"{API_BASE}/status") response = client.get(f"{settings.settings.api_base}/status")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "ok"} assert response.json() == {"status": "ok"}

View File

@@ -651,6 +651,22 @@ def test_create_user(session: Session) -> None:
assert actual_user.password_hash == user_data["password_hash"] assert actual_user.password_hash == user_data["password_hash"]
def test_get_user_by_id(session: Session) -> None:
user_id = make_user(session, username="fetchuser")
user = crud.get_user_by_id(session, user_id)
assert user is not None
assert user.id == user_id
assert user.username == "fetchuser"
def test_get_user_by_username(session: Session) -> None:
username = "uniqueuser"
make_user(session, username=username)
user = crud.get_user_by_username(session, username)
assert user is not None
assert user.username == username
def test_update_user(session: Session) -> None: def test_update_user(session: Session) -> None:
user_id = make_user(session, username="updatableuser") user_id = make_user(session, username="updatableuser")
@@ -716,6 +732,16 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None:
assert fetched_session.session_token_hash == created_session.session_token_hash assert fetched_session.session_token_hash == created_session.session_token_hash
def test_get_login_session_by_token(session: Session) -> None:
now = datetime.now(timezone.utc)
created_session = make_login_session(session, now)
fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash)
assert fetched_session is not None
assert fetched_session.id == created_session.id
assert fetched_session.user_id == created_session.user_id
assert fetched_session.session_token_hash == created_session.session_token_hash
def test_update_login_session(session: Session) -> None: def test_update_login_session(session: Session) -> None:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
created_session = make_login_session(session, now) created_session = make_login_session(session, now)

View File

@@ -349,6 +349,17 @@ def create_user(session: Session, user_data: Mapping) -> models.Users:
return u return u
def get_user_by_id(session: Session, user_id: int) -> models.Users | None:
return session.get(models.Users, user_id)
def get_user_by_username(session: Session, username: str) -> models.Users | None:
statement = select(models.Users).where(
models.Users.username == username,
)
return session.exec(statement).first()
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users:
user: models.Users | None = session.get(models.Users, user_id) user: models.Users | None = session.get(models.Users, user_id)
if user is None: if user is None:
@@ -418,6 +429,15 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_
return session.exec(statement).first() return session.exec(statement).first()
def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None:
statement = select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hash,
models.Sessions.expires_at > datetime.now(timezone.utc),
)
return session.exec(statement).first()
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import event from sqlalchemy import event
@@ -72,6 +73,18 @@ class Database:
finally: finally:
session.close() session.close()
@contextmanager
def get_session_ctx_manager(self) -> Session:
session = Session(self._engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def dispose(self) -> None: def dispose(self) -> None:
self._engine.dispose() self._engine.dispose()

View File

@@ -0,0 +1,72 @@
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from sqlmodel import Session
from starlette.middleware.base import BaseHTTPMiddleware
import settings
from trading_journal import crud, security
from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead
from trading_journal.models import Sessions
EXCEPT_PATHS = [
f"{settings.settings.api_base}/status",
f"{settings.settings.api_base}/register",
]
class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response:
if request.url.path in EXCEPT_PATHS:
return await call_next(request)
token = request.cookies.get("session_token")
if not token:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :]
if not token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
if db_factory is None:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
try:
with db_factory.get_session_ctx_manager() as request_session:
hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token)
except Exception: # noqa: BLE001
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"})
return None
class ServiceError(Exception):
pass
class UserAlreadyExistsError(ServiceError):
pass
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists")
hashed = security.hash_password(user_in.password)
try:
user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed)
try:
# prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user)
except Exception as e:
raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e:
raise ServiceError("Failed to create user") from e
return user