This commit is contained in:
2025-09-23 23:35:15 +02:00
parent 92c4e0d4fc
commit a6592bd140
3 changed files with 160 additions and 53 deletions

View File

@@ -7,11 +7,12 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request, status from fastapi import FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
import settings import settings
from trading_journal import db, service from trading_journal import db, service
from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
@@ -119,10 +120,10 @@ async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Res
@app.get(f"{settings.settings.api_base}/exchanges") @app.get(f"{settings.settings.api_base}/exchanges")
async def get_exchanges(request: Request) -> list[ExchangesBase]: async def get_exchanges(request: Request) -> list[ExchangesRead]:
db_factory: Database = request.app.state.db_factory db_factory: Database = request.app.state.db_factory
def sync_work() -> list[ExchangesBase]: def sync_work() -> list[ExchangesRead]:
with db_factory.get_session_ctx_manager() as db: with db_factory.get_session_ctx_manager() as db:
return service.get_exchanges_by_user_service(db, request.state.user_id) return service.get_exchanges_by_user_service(db, request.state.user_id)
@@ -133,4 +134,39 @@ async def get_exchanges(request: Request) -> list[ExchangesBase]:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
# Trade @app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}")
async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> ExchangesBase:
with db_factory.get_session_ctx_manager() as db:
return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes)
try:
exchange = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump())
except service.ExchangeNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except service.ExchangeAlreadyExistsError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to update exchange: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
# Cycle
@app.post(f"{settings.settings.api_base}/cycles")
async def create_cycle(request: Request, cycle_data: CycleBase) -> Response:
return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.")
db_factory: Database = request.app.state.db_factory
def sync_work() -> CycleBase:
with db_factory.get_session_ctx_manager() as db:
return service.create_cycle_service(db, request.state.user_id, cycle_data)
try:
cycle = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle))
except Exception as e:
logger.exception("Failed to create cycle: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e

View File

@@ -1,55 +1,11 @@
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from typing import TYPE_CHECKING
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import SQLModel from sqlmodel import SQLModel
if TYPE_CHECKING: from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
class ExchangesBase(SQLModel):
name: str
notes: str | None = None
class ExchangesCreate(ExchangesBase):
user_id: int
class TradeBase(SQLModel):
user_id: int
friendly_name: str | None
symbol: str
exchange: str
underlying_currency: UnderlyingCurrency
trade_type: TradeType
trade_strategy: TradeStrategy
trade_date: date
trade_time_utc: datetime
quantity: int
price_cents: int
gross_cash_flow_cents: int
commission_cents: int
net_cash_flow_cents: int
notes: str | None
cycle_id: int | None = None
class TradeCreate(TradeBase):
expiry_date: date | None = None
strike_price_cents: int | None = None
is_invalidated: bool = False
invalidated_at: datetime | None = None
replaced_by_trade_id: int | None = None
class TradeRead(TradeBase):
id: int
is_invalidated: bool
invalidated_at: datetime | None
class UserBase(SQLModel): class UserBase(SQLModel):
@@ -91,3 +47,68 @@ class SessionsUpdate(SQLModel):
last_seen_at: datetime | None = None last_seen_at: datetime | None = None
last_used_ip: str | None = None last_used_ip: str | None = None
user_agent: str | None = None user_agent: str | None = None
class ExchangesBase(SQLModel):
name: str
notes: str | None = None
class ExchangesCreate(ExchangesBase):
user_id: int
class ExchangesRead(ExchangesBase):
id: int
class CycleBase(SQLModel):
friendly_name: str | None = None
symbol: str
exchange_id: int
underlying_currency: UnderlyingCurrency
status: str
start_date: date
end_date: date | None = None
funding_source: str | None = None
capital_exposure_cents: int | None = None
loan_amount_cents: int | None = None
loan_interest_rate_bps: int | None = None
trades: list[TradeRead] | None = None
class CycleCreate(CycleBase):
user_id: int
class TradeBase(SQLModel):
user_id: int
friendly_name: str | None
symbol: str
exchange: str
underlying_currency: UnderlyingCurrency
trade_type: TradeType
trade_strategy: TradeStrategy
trade_date: date
trade_time_utc: datetime
quantity: int
price_cents: int
gross_cash_flow_cents: int
commission_cents: int
net_cash_flow_cents: int
notes: str | None
cycle_id: int | None = None
class TradeCreate(TradeBase):
expiry_date: date | None = None
strike_price_cents: int | None = None
is_invalidated: bool = False
invalidated_at: datetime | None = None
replaced_by_trade_id: int | None = None
class TradeRead(TradeBase):
id: int
is_invalidated: bool
invalidated_at: datetime | None

View File

@@ -10,9 +10,21 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
import settings import settings
from trading_journal import crud, security from trading_journal import crud, security
from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.dto import (
CycleBase,
CycleCreate,
ExchangesBase,
ExchangesCreate,
ExchangesRead,
SessionsCreate,
SessionsUpdate,
UserCreate,
UserLogin,
UserRead,
)
SessionsCreate.model_rebuild() SessionsCreate.model_rebuild()
CycleBase.model_rebuild()
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlmodel import Session from sqlmodel import Session
@@ -95,6 +107,11 @@ class ExchangeAlreadyExistsError(ServiceError):
pass pass
class ExchangeNotFoundError(ServiceError):
pass
# User service
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username): if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists") raise UserAlreadyExistsError("username already exists")
@@ -156,7 +173,7 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes:
try: try:
exchange_dto = ExchangesCreate.model_validate(exchange) exchange_dto = ExchangesCreate.model_validate(exchange)
except Exception as e: except Exception as e:
logger.exception("Failed to convert exchange to ExchangesCreate: ") logger.exception("Failed to convert exchange to ExchangesCreate:")
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
except Exception as e: except Exception as e:
logger.exception("Failed to create exchange:") logger.exception("Failed to create exchange:")
@@ -164,9 +181,42 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes:
return exchange_dto return exchange_dto
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]: def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]:
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
return [ExchangesBase.model_validate(exchange) for exchange in exchanges] return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase:
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
if not existing_exchange:
raise ExchangeNotFoundError("Exchange not found")
if existing_exchange.user_id != user_id:
raise ExchangeNotFoundError("Exchange not found")
if name:
other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
if other_exchange and other_exchange.id != existing_exchange.id:
raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user")
exchange_data = ExchangesBase(
name=name or existing_exchange.name,
notes=notes or existing_exchange.notes,
)
try:
exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data)
except Exception as e:
logger.exception("Failed to update exchange: \n")
raise ServiceError("Failed to update exchange") from e
return ExchangesBase.model_validate(exchange)
# Cycle Service
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase:
cycle_data_dict = cycle_data.model_dump()
cycle_data_dict["user_id"] = user_id
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id)
return cycle_data
def get_trades_service(db_session: Session, user_id: int) -> list: def get_trades_service(db_session: Session, user_id: int) -> list: