From 1750401278d8df80f2931e5e16bbc5fe31c176ce Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 22:51:59 +0200 Subject: [PATCH] several changes: * api calls for auth * exchange now bind to user --- backend/.gitignore | 4 +- backend/app.py | 69 +++++++++++++++++++--- backend/settings.py | 1 + backend/tests/test_crud.py | 66 ++++++++++++++------- backend/tests/test_db_migration.py | 10 +++- backend/trading_journal/crud.py | 10 +++- backend/trading_journal/db.py | 5 +- backend/trading_journal/db_migration.py | 1 + backend/trading_journal/dto.py | 32 +++++++++- backend/trading_journal/models.py | 8 ++- backend/trading_journal/models_v1.py | 8 ++- backend/trading_journal/service.py | 77 +++++++++++++++++++++++-- backend/utils/__init__.py | 0 backend/utils/db_mirgration.py | 13 +++++ 14 files changed, 259 insertions(+), 45 deletions(-) create mode 100644 backend/utils/__init__.py create mode 100644 backend/utils/db_mirgration.py diff --git a/backend/.gitignore b/backend/.gitignore index 6321b92..837cf41 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -14,4 +14,6 @@ __pycache__/ *.db *.db-shm -*.db-wal \ No newline at end of file +*.db-wal + +devsettings.yaml \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index f485bb2..93a9186 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,16 +1,27 @@ +from __future__ import annotations + import asyncio +import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from datetime import datetime, timezone -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.responses import JSONResponse import settings from trading_journal import db, service from trading_journal.db import Database -from trading_journal.dto import UserCreate, UserRead +from trading_journal.dto import SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead _db = db.create_database(settings.settings.database_url) +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 @@ -40,13 +51,57 @@ async def register_user(request: Request, user_in: UserCreate) -> UserRead: return service.register_user_service(db, user_in) try: - return await asyncio.to_thread(sync_work) + user = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump()) except service.UserAlreadyExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) from e + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e + logger.exception("Failed to register user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.post(f"{settings.settings.api_base}/login") +async def login(request: Request, user_in: UserLogin) -> SessionsBase: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> tuple[SessionsCreate, str] | None: + with db_factory.get_session_ctx_manager() as db: + return service.authenticate_user_service(db, user_in) + + try: + result = await asyncio.to_thread(sync_work) + if result is None: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Invalid username or password, or user doesn't exist"}, + ) + session, token = result + session_return = SessionsBase(user_id=session.user_id) + response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump()) + expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds()) + response.set_cookie( + key="session_token", + value=token, + httponly=True, + secure=True, + samesite="lax", + max_age=expires_sec, + path="/", + ) + except Exception as e: + logger.exception("Failed to login user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + else: + return response + + +# Exchange +# @app.post(f"{settings.settings.api_base}/exchanges") +# async def create_exchange(request: Request, name: str, notes: str | None) -> dict: @app.get(f"{settings.settings.api_base}/trades") -async def get_trades() -> dict[str, str]: - return {"trades": []} +async def get_trades(request: Request) -> list: + db_factory: Database = request.app.state.db_factory + with db_factory.get_session_ctx_manager() as db: + return service.get_trades_service(db, request.state.user_id) diff --git a/backend/settings.py b/backend/settings.py index 1e1e29f..eff1071 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -16,6 +16,7 @@ class Settings(BaseSettings): log_level: str = "info" database_url: str = "sqlite:///:memory:" api_base: str = "/api/v1" + session_expiry_seconds: int = 3600 * 24 * 7 # 7 days hmac_key: str | None = None model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 9e0fade..3e02227 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -48,8 +48,8 @@ def make_user(session: Session, username: str = "testuser") -> int: return user.id -def make_exchange(session: Session, name: str = "NASDAQ") -> int: - exchange = models.Exchanges(name=name, notes="Test exchange") +def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: + exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange") session.add(exchange) session.commit() session.refresh(exchange) @@ -138,7 +138,7 @@ def _ensure_utc_aware(dt: datetime) -> datetime | None: def test_create_trade_success_with_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { @@ -180,7 +180,7 @@ def test_create_trade_success_with_cycle(session: Session) -> None: def test_create_trade_with_auto_created_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) trade_data = { "user_id": user_id, @@ -224,7 +224,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: def test_create_trade_missing_required_fields(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) base_trade_data = { "user_id": user_id, @@ -291,7 +291,7 @@ def test_create_trade_missing_required_fields(session: Session) -> None: def test_get_trade_by_id(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { "user_id": user_id, @@ -330,7 +330,7 @@ def test_get_trade_by_id(session: Session) -> None: def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) friendly_name = "Unique Trade Name" trade_data = { @@ -359,7 +359,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: def test_get_trades_by_user_id(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data_1 = { "user_id": user_id, @@ -406,7 +406,7 @@ def test_get_trades_by_user_id(session: Session) -> None: def test_update_trade_note(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) @@ -424,7 +424,7 @@ def test_update_trade_note(session: Session) -> None: def test_invalidate_trade(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) @@ -441,7 +441,7 @@ def test_invalidate_trade(session: Session) -> None: def test_replace_trade(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) old_trade_id = make_trade(session, user_id, cycle_id) @@ -486,7 +486,7 @@ def test_replace_trade(session: Session) -> None: def test_create_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_data = { "user_id": user_id, "friendly_name": "My First Cycle", @@ -517,7 +517,7 @@ def test_create_cycle(session: Session) -> None: def test_update_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") update_data = { @@ -539,7 +539,7 @@ def test_update_cycle(session: Session) -> None: def test_update_cycle_immutable_fields(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") # Attempt to update immutable fields @@ -563,42 +563,51 @@ def test_update_cycle_immutable_fields(session: Session) -> None: # Exchanges def test_create_exchange(session: Session) -> None: + user_id = make_user(session) exchange_data = { "name": "NYSE", "notes": "New York Stock Exchange", + "user_id": user_id, } exchange = crud.create_exchange(session, exchange_data) assert exchange.id is not None assert exchange.name == exchange_data["name"] assert exchange.notes == exchange_data["notes"] + assert exchange.user_id == user_id session.refresh(exchange) actual_exchange = session.get(models.Exchanges, exchange.id) assert actual_exchange is not None assert actual_exchange.name == exchange_data["name"] assert actual_exchange.notes == exchange_data["notes"] + assert actual_exchange.user_id == user_id def test_get_exchange_by_id(session: Session) -> None: - exchange_id = make_exchange(session, name="LSE") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="LSE") exchange = crud.get_exchange_by_id(session, exchange_id) assert exchange is not None assert exchange.id == exchange_id assert exchange.name == "LSE" + assert exchange.user_id == user_id -def test_get_exchange_by_name(session: Session) -> None: +def test_get_exchange_by_name_and_user_id(session: Session) -> None: exchange_name = "TSX" - make_exchange(session, name=exchange_name) - exchange = crud.get_exchange_by_name(session, exchange_name) + user_id = make_user(session) + make_exchange(session, user_id=user_id, name=exchange_name) + exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id) assert exchange is not None assert exchange.name == exchange_name + assert exchange.user_id == user_id def test_get_all_exchanges(session: Session) -> None: exchange_names = ["NYSE", "NASDAQ", "LSE"] + user_id = make_user(session) for name in exchange_names: - make_exchange(session, name=name) + make_exchange(session, user_id=user_id, name=name) exchanges = crud.get_all_exchanges(session) assert len(exchanges) >= 3 @@ -607,8 +616,22 @@ def test_get_all_exchanges(session: Session) -> None: assert name in fetched_names +def test_get_all_exchanges_by_user_id(session: Session) -> None: + exchange_names = ["NYSE", "NASDAQ"] + user_id = make_user(session) + for name in exchange_names: + make_exchange(session, user_id=user_id, name=name) + + exchanges = crud.get_all_exchanges_by_user_id(session, user_id) + assert len(exchanges) == len(exchange_names) + fetched_names = {ex.name for ex in exchanges} + for name in exchange_names: + assert name in fetched_names + + def test_update_exchange(session: Session) -> None: - exchange_id = make_exchange(session, name="Initial Exchange") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange") update_data = { "name": "Updated Exchange", "notes": "Updated notes for the exchange", @@ -627,7 +650,8 @@ def test_update_exchange(session: Session) -> None: def test_delete_exchange(session: Session) -> None: - exchange_id = make_exchange(session, name="Deletable Exchange") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange") crud.delete_exchange(session, exchange_id) deleted_exchange = session.get(models.Exchanges, exchange_id) assert deleted_exchange is None diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 15c7fba..343214b 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -70,6 +70,12 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "notes": ("TEXT", 0, 0), "cycle_id": ("INTEGER", 0, 0), }, + "exchanges": { + "id": ("INTEGER", 1, 1), + "user_id": ("INTEGER", 1, 0), + "name": ("TEXT", 1, 0), + "notes": ("TEXT", 0, 0), + }, "sessions": { "id": ("INTEGER", 1, 1), "user_id": ("INTEGER", 1, 0), @@ -97,7 +103,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: {"table": "users", "from": "user_id", "to": "id"}, ], "users": [], - "exchanges": [], + "exchanges": [ + {"table": "users", "from": "user_id", "to": "id"}, + ], } with engine.connect() as conn: diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index d21e157..75bb8ad 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -274,9 +274,10 @@ def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | return session.get(models.Exchanges, exchange_id) -def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None: +def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None: statement = select(models.Exchanges).where( models.Exchanges.name == name, + models.Exchanges.user_id == user_id, ) return session.exec(statement).first() @@ -286,6 +287,13 @@ def get_all_exchanges(session: Session) -> list[models.Exchanges]: return session.exec(statement).all() +def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: + statement = select(models.Exchanges).where( + models.Exchanges.user_id == user_id, + ) + return session.exec(statement).all() + + def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) if exchange is None: diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index 93503f3..ddd2dc6 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -8,8 +8,6 @@ from sqlalchemy import event from sqlalchemy.pool import StaticPool from sqlmodel import Session, create_engine -from trading_journal import db_migration - if TYPE_CHECKING: from collections.abc import Generator from sqlite3 import Connection as DBAPIConnection @@ -59,7 +57,6 @@ class Database: event.listen(self._engine, "connect", _enable_sqlite_pragmas) def init_db(self) -> None: - # db_migration.run_migrations(self._engine) pass def get_session(self) -> Generator[Session, None, None]: @@ -74,7 +71,7 @@ class Database: session.close() @contextmanager - def get_session_ctx_manager(self) -> Session: + def get_session_ctx_manager(self) -> Generator[Session, None, None]: session = Session(self._engine) try: yield session diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index 2a57464..e3766a2 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -27,6 +27,7 @@ def _mig_0_1(engine: Engine) -> None: models_v1.Cycles.__table__, models_v1.Users.__table__, models_v1.Sessions.__table__, + models_v1.Exchanges.__table__, ], ) diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 1a9d478..9c4ce25 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -1,12 +1,12 @@ from __future__ import annotations +from datetime import date, datetime # noqa: TC003 from typing import TYPE_CHECKING +from pydantic import BaseModel from sqlmodel import SQLModel if TYPE_CHECKING: - from datetime import date, datetime - from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency @@ -52,5 +52,33 @@ class UserCreate(UserBase): password: str +class UserLogin(BaseModel): + username: str + password: str + + class UserRead(UserBase): id: int + + +class SessionsBase(SQLModel): + user_id: int + + +class SessionRead(SessionsBase): + id: int + expires_at: datetime + last_seen_at: datetime | None + last_used_ip: str | None + user_agent: str | None + + +class SessionsCreate(SessionsBase): + expires_at: datetime + + +class SessionsUpdate(SQLModel): + expires_at: datetime | None = None + last_seen_at: datetime | None = None + last_used_ip: str | None = None + user_agent: str | None = None diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 0238a81..9bdc57a 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" + __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) - name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + name: str = Field(sa_column=Column(Text, nullable=False)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) trades: list["Trades"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange") + user: "Users" = Relationship(back_populates="exchanges") class Users(SQLModel, table=True): @@ -131,6 +134,8 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + sessions: list["Sessions"] = Relationship(back_populates="user") + exchanges: list["Exchanges"] = Relationship(back_populates="user") class Sessions(SQLModel, table=True): @@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True): last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 0238a81..9bdc57a 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" + __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) - name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + name: str = Field(sa_column=Column(Text, nullable=False)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) trades: list["Trades"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange") + user: "Users" = Relationship(back_populates="exchanges") class Users(SQLModel, table=True): @@ -131,6 +134,8 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + sessions: list["Sessions"] = Relationship(back_populates="user") + exchanges: list["Exchanges"] = Relationship(back_populates="user") class Sessions(SQLModel, table=True): @@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True): last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 404ffb8..8fd8cd4 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone from typing import Callable from fastapi import Request, Response, status @@ -8,17 +12,23 @@ from starlette.middleware.base import BaseHTTPMiddleware import settings from trading_journal import crud, security from trading_journal.db import Database -from trading_journal.dto import UserCreate, UserRead +from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.models import Sessions +SessionsCreate.model_rebuild() + + EXCEPT_PATHS = [ f"{settings.settings.api_base}/status", f"{settings.settings.api_base}/register", + f"{settings.settings.api_base}/login", ] +logger = logging.getLogger(__name__) + class AuthMiddleWare(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) @@ -42,10 +52,32 @@ class AuthMiddleWare(BaseHTTPMiddleware): hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) - except Exception: # noqa: BLE001 - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"}) + if not login_session: + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) + if session_expires_utc < datetime.now(timezone.utc): + crud.delete_login_session(request.state.db_session, login_session) + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + if login_session.user.is_active is False: + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600): + updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds) + else: + updated_expiry = session_expires_utc + updated_session: SessionsUpdate = SessionsUpdate( + last_seen_at=datetime.now(timezone.utc), + last_used_ip=request.client.host if request.client else None, + user_agent=request.headers.get("User-Agent"), + expires_at=updated_expiry, + ) + user_id = login_session.user_id + request.state.user_id = user_id + crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session) + except Exception: + logger.exception("Failed to authenticate user: \n") + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) - return None + return await call_next(request) class ServiceError(Exception): @@ -60,13 +92,46 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") hashed = security.hash_password(user_in.password) + user_data: dict = { + "username": user_in.username, + "password_hash": hashed, + } try: - user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed) + user = crud.create_user(db_session, user_data=user_data) try: # prefer pydantic's from_orm if DTO supports orm_mode user = UserRead.model_validate(user) except Exception as e: + logger.exception("Failed to convert user to UserRead: %s", e) raise ServiceError("Failed to convert user to UserRead") from e except Exception as e: + logger.exception("Failed to create user:") raise ServiceError("Failed to create user") from e return user + + +def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None: + user = crud.get_user_by_username(db_session, user_in.username) + if not user: + return None + + if not security.verify_password(user_in.password, user.password_hash): + return None + + token = security.generate_session_token() + token_hashed = security.hash_session_token_sha256(token) + try: + session = crud.create_login_session( + session=db_session, + user_id=user.id, + session_token_hash=token_hashed, + session_length_seconds=settings.settings.session_expiry_seconds, + ) + except Exception as e: + logger.exception("Failed to create login session: \n") + raise ServiceError("Failed to create login session") from e + return SessionsCreate.model_validate(session), token + + +def get_trades_service(db_session: Session, user_id: int) -> list: + return crud.get_trades_by_user_id(db_session, user_id) diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/utils/db_mirgration.py b/backend/utils/db_mirgration.py new file mode 100644 index 0000000..1103349 --- /dev/null +++ b/backend/utils/db_mirgration.py @@ -0,0 +1,13 @@ +import sys +from pathlib import Path + +from sqlmodel import create_engine + +project_parent = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(project_parent)) + +import settings # noqa: E402 +from trading_journal import db_migration # noqa: E402 + +db_engine = create_engine(settings.settings.database_url, echo=True) +db_migration.run_migrations(db_engine)