From a6592bd1408a10839c966b9442ea2d8347b8f005 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 23 Sep 2025 23:35:15 +0200 Subject: [PATCH] wip --- backend/app.py | 44 ++++++++++-- backend/trading_journal/dto.py | 111 +++++++++++++++++------------ backend/trading_journal/service.py | 58 +++++++++++++-- 3 files changed, 160 insertions(+), 53 deletions(-) diff --git a/backend/app.py b/backend/app.py index 4c6f61d..aae91b7 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,11 +7,12 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, Response import settings from trading_journal import db, service -from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -119,10 +120,10 @@ async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Res @app.get(f"{settings.settings.api_base}/exchanges") -async def get_exchanges(request: Request) -> list[ExchangesBase]: +async def get_exchanges(request: Request) -> list[ExchangesRead]: db_factory: Database = request.app.state.db_factory - def sync_work() -> list[ExchangesBase]: + def sync_work() -> list[ExchangesRead]: with db_factory.get_session_ctx_manager() as db: return service.get_exchanges_by_user_service(db, request.state.user_id) @@ -133,4 +134,39 @@ async def get_exchanges(request: Request) -> list[ExchangesBase]: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e -# Trade +@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}") +async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> ExchangesBase: + with db_factory.get_session_ctx_manager() as db: + return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes) + + try: + exchange = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump()) + except service.ExchangeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except service.ExchangeAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update exchange: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +# Cycle +@app.post(f"{settings.settings.api_base}/cycles") +async def create_cycle(request: Request, cycle_data: CycleBase) -> Response: + return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.") + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleBase: + with db_factory.get_session_ctx_manager() as db: + return service.create_cycle_service(db, request.state.user_id, cycle_data) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle)) + except Exception as e: + logger.exception("Failed to create cycle: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 5ea709c..7b377f9 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -1,55 +1,11 @@ from __future__ import annotations from datetime import date, datetime # noqa: TC003 -from typing import TYPE_CHECKING from pydantic import BaseModel from sqlmodel import SQLModel -if TYPE_CHECKING: - from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency - - -class ExchangesBase(SQLModel): - name: str - notes: str | None = None - - -class ExchangesCreate(ExchangesBase): - user_id: int - - -class TradeBase(SQLModel): - user_id: int - friendly_name: str | None - symbol: str - exchange: str - underlying_currency: UnderlyingCurrency - trade_type: TradeType - trade_strategy: TradeStrategy - trade_date: date - trade_time_utc: datetime - quantity: int - price_cents: int - gross_cash_flow_cents: int - commission_cents: int - net_cash_flow_cents: int - notes: str | None - cycle_id: int | None = None - - -class TradeCreate(TradeBase): - expiry_date: date | None = None - strike_price_cents: int | None = None - is_invalidated: bool = False - invalidated_at: datetime | None = None - replaced_by_trade_id: int | None = None - - -class TradeRead(TradeBase): - id: int - is_invalidated: bool - invalidated_at: datetime | None +from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001 class UserBase(SQLModel): @@ -91,3 +47,68 @@ class SessionsUpdate(SQLModel): last_seen_at: datetime | None = None last_used_ip: str | None = None user_agent: str | None = None + + +class ExchangesBase(SQLModel): + name: str + notes: str | None = None + + +class ExchangesCreate(ExchangesBase): + user_id: int + + +class ExchangesRead(ExchangesBase): + id: int + + +class CycleBase(SQLModel): + friendly_name: str | None = None + symbol: str + exchange_id: int + underlying_currency: UnderlyingCurrency + status: str + start_date: date + end_date: date | None = None + funding_source: str | None = None + capital_exposure_cents: int | None = None + loan_amount_cents: int | None = None + loan_interest_rate_bps: int | None = None + trades: list[TradeRead] | None = None + + +class CycleCreate(CycleBase): + user_id: int + + +class TradeBase(SQLModel): + user_id: int + friendly_name: str | None + symbol: str + exchange: str + underlying_currency: UnderlyingCurrency + trade_type: TradeType + trade_strategy: TradeStrategy + trade_date: date + trade_time_utc: datetime + quantity: int + price_cents: int + gross_cash_flow_cents: int + commission_cents: int + net_cash_flow_cents: int + notes: str | None + cycle_id: int | None = None + + +class TradeCreate(TradeBase): + expiry_date: date | None = None + strike_price_cents: int | None = None + is_invalidated: bool = False + invalidated_at: datetime | None = None + replaced_by_trade_id: int | None = None + + +class TradeRead(TradeBase): + id: int + is_invalidated: bool + invalidated_at: datetime | None diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index a9ba8c5..5650515 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -10,9 +10,21 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin import settings from trading_journal import crud, security -from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ( + CycleBase, + CycleCreate, + ExchangesBase, + ExchangesCreate, + ExchangesRead, + SessionsCreate, + SessionsUpdate, + UserCreate, + UserLogin, + UserRead, +) SessionsCreate.model_rebuild() +CycleBase.model_rebuild() if TYPE_CHECKING: from sqlmodel import Session @@ -95,6 +107,11 @@ class ExchangeAlreadyExistsError(ServiceError): pass +class ExchangeNotFoundError(ServiceError): + pass + + +# User service def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") @@ -156,7 +173,7 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes: try: exchange_dto = ExchangesCreate.model_validate(exchange) except Exception as e: - logger.exception("Failed to convert exchange to ExchangesCreate: ") + logger.exception("Failed to convert exchange to ExchangesCreate:") raise ServiceError("Failed to convert exchange to ExchangesCreate") from e except Exception as e: logger.exception("Failed to create exchange:") @@ -164,9 +181,42 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes: return exchange_dto -def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]: +def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]: exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) - return [ExchangesBase.model_validate(exchange) for exchange in exchanges] + return [ExchangesRead.model_validate(exchange) for exchange in exchanges] + + +def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase: + existing_exchange = crud.get_exchange_by_id(db_session, exchange_id) + if not existing_exchange: + raise ExchangeNotFoundError("Exchange not found") + if existing_exchange.user_id != user_id: + raise ExchangeNotFoundError("Exchange not found") + + if name: + other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) + if other_exchange and other_exchange.id != existing_exchange.id: + raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user") + + exchange_data = ExchangesBase( + name=name or existing_exchange.name, + notes=notes or existing_exchange.notes, + ) + try: + exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data) + except Exception as e: + logger.exception("Failed to update exchange: \n") + raise ServiceError("Failed to update exchange") from e + return ExchangesBase.model_validate(exchange) + + +# Cycle Service +def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase: + cycle_data_dict = cycle_data.model_dump() + cycle_data_dict["user_id"] = user_id + cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) + crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) + return cycle_data def get_trades_service(db_session: Session, user_id: int) -> list: