From 80fc405bf6831d94be95d7c0817371f20b46124e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 24 Sep 2025 17:33:27 +0200 Subject: [PATCH] Almost finish basic functionalities --- backend/.vscode/launch.json | 10 +- backend/app.py | 148 ++++++++++++++++++++++++++- backend/tests/test_crud.py | 94 +++++++++++++++++ backend/tests/test_db_migration.py | 3 +- backend/trading_journal/crud.py | 35 +++++-- backend/trading_journal/dto.py | 50 ++++++--- backend/trading_journal/models.py | 3 +- backend/trading_journal/models_v1.py | 3 +- backend/trading_journal/service.py | 144 ++++++++++++++++++++++++-- 9 files changed, 455 insertions(+), 35 deletions(-) diff --git a/backend/.vscode/launch.json b/backend/.vscode/launch.json index 93b32ee..929a971 100644 --- a/backend/.vscode/launch.json +++ b/backend/.vscode/launch.json @@ -13,10 +13,14 @@ "app:app", "--host=0.0.0.0", "--reload", - "--port=5000" + "--port=18881" ], "jinja": true, - "autoStartBrowser": true + "autoStartBrowser": false, + "env": { + "CONFIG_FILE": "devsettings.yaml" + }, + "console": "integratedTerminal" } ] -} +} \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index aae91b7..e0e2799 100644 --- a/backend/app.py +++ b/backend/app.py @@ -12,7 +12,22 @@ from fastapi.responses import JSONResponse, Response import settings from trading_journal import db, service -from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ( + CycleBase, + CycleRead, + CycleUpdate, + ExchangesBase, + ExchangesRead, + SessionsBase, + SessionsCreate, + TradeCreate, + TradeFriendlyNameUpdate, + TradeNoteUpdate, + TradeRead, + UserCreate, + UserLogin, + UserRead, +) if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -170,3 +185,134 @@ async def create_cycle(request: Request, cycle_data: CycleBase) -> Response: except Exception as e: logger.exception("Failed to create cycle: \n") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/cycles/{{cycle_id}}") +async def get_cycle_by_id(request: Request, cycle_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleBase: + with db_factory.get_session_ctx_manager() as db: + return service.get_cycle_by_id_service(db, request.state.user_id, cycle_id) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle)) + except service.CycleNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to get cycle by id: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/cycles/user/{{user_id}}") +async def get_cycles_by_user(request: Request, user_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> list[CycleRead]: + with db_factory.get_session_ctx_manager() as db: + return service.get_cycles_by_user_service(db, user_id) + + try: + cycles = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycles)) + except Exception as e: + logger.exception("Failed to get cycles by user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/cycles") +async def update_cycle(request: Request, cycle_data: CycleUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_cycle_service(db, request.state.user_id, cycle_data) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle)) + except service.InvalidCycleDataError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except service.CycleNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update cycle: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.post(f"{settings.settings.api_base}/trades") +async def create_trade(request: Request, trade_data: TradeCreate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.create_trade_service(db, request.state.user_id, trade_data) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(trade)) + except service.InvalidTradeDataError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to create trade: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/trades/{{trade_id}}") +async def get_trade_by_id(request: Request, trade_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.get_trade_by_id_service(db, request.state.user_id, trade_id) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to get trade by id: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/trades/friendlyname") +async def update_trade_friendly_name(request: Request, friendly_name_update: TradeFriendlyNameUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_trade_friendly_name_service( + db, + request.state.user_id, + friendly_name_update.id, + friendly_name_update.friendly_name, + ) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update trade friendly name: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/trades/notes") +async def update_trade_note(request: Request, note_update: TradeNoteUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_trade_note_service(db, request.state.user_id, note_update.id, note_update.notes) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update trade note: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 0620fed..d7ae484 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -172,6 +172,51 @@ def test_create_trade_success_with_cycle(session: Session) -> None: assert actual_trade.trade_type == trade_data["trade_type"] assert actual_trade.trade_strategy == trade_data["trade_strategy"] assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.quantity_multiplier == 1 + assert actual_trade.price_cents == trade_data["price_cents"] + assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] + assert actual_trade.commission_cents == trade_data["commission_cents"] + assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"] + assert actual_trade.cycle_id == trade_data["cycle_id"] + + +def test_create_trade_with_custom_multipler(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + trade_data = { + "user_id": user_id, + "friendly_name": "Test Trade with Multiplier", + "symbol": "AAPL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_time_utc": datetime.now(timezone.utc), + "quantity": 10, + "quantity_multiplier": 100, + "price_cents": 15000, + "gross_cash_flow_cents": -1500000, + "commission_cents": 50000, + "net_cash_flow_cents": -1550000, + "cycle_id": cycle_id, + } + + trade = crud.create_trade(session, trade_data) + assert trade.id is not None + assert trade.user_id == user_id + assert trade.cycle_id == cycle_id + session.refresh(trade) + + actual_trade = session.get(models.Trades, trade.id) + assert actual_trade is not None + assert actual_trade.friendly_name == trade_data["friendly_name"] + assert actual_trade.symbol == trade_data["symbol"] + assert actual_trade.underlying_currency == trade_data["underlying_currency"] + assert actual_trade.trade_type == trade_data["trade_type"] + assert actual_trade.trade_strategy == trade_data["trade_strategy"] + assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.quantity_multiplier == trade_data["quantity_multiplier"] assert actual_trade.price_cents == trade_data["price_cents"] assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] assert actual_trade.commission_cents == trade_data["commission_cents"] @@ -194,6 +239,9 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: "trade_time_utc": datetime.now(timezone.utc), "quantity": 5, "price_cents": 15500, + "gross_cash_flow_cents": -77500, + "commission_cents": 300, + "net_cash_flow_cents": -77800, } trade = crud.create_trade(session, trade_data) @@ -405,6 +453,24 @@ def test_get_trades_by_user_id(session: Session) -> None: assert friendly_names == {"Trade One", "Trade Two"} +def test_update_trade_friendly_name(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + trade_id = make_trade(session, user_id, cycle_id) + + new_friendly_name = "Updated Trade Name" + updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name) + assert updated_trade is not None + assert updated_trade.id == trade_id + assert updated_trade.friendly_name == new_friendly_name + + session.refresh(updated_trade) + actual_trade = session.get(models.Trades, trade_id) + assert actual_trade is not None + assert actual_trade.friendly_name == new_friendly_name + + def test_update_trade_note(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) @@ -457,6 +523,9 @@ def test_replace_trade(session: Session) -> None: "trade_time_utc": datetime.now(timezone.utc), "quantity": 20, "price_cents": 25000, + "gross_cash_flow_cents": -500000, + "commission_cents": 1000, + "net_cash_flow_cents": -501000, } new_trade = crud.replace_trade(session, old_trade_id, new_trade_data) @@ -516,6 +585,31 @@ def test_create_cycle(session: Session) -> None: assert actual_cycle.start_date == cycle_data["start_date"] +def test_get_cycle_by_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Cycle to Get") + cycle = crud.get_cycle_by_id(session, cycle_id) + assert cycle is not None + assert cycle.id == cycle_id + assert cycle.friendly_name == "Cycle to Get" + assert cycle.user_id == user_id + + +def test_get_cycles_by_user_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_names = ["Cycle One", "Cycle Two", "Cycle Three"] + for name in cycle_names: + make_cycle(session, user_id, exchange_id, friendly_name=name) + + cycles = crud.get_cycles_by_user_id(session, user_id) + assert len(cycles) == len(cycle_names) + fetched_names = {cycle.friendly_name for cycle in cycles} + for name in cycle_names: + assert name in fetched_names + + def test_update_cycle(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 343214b..042bb54 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -42,7 +42,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "funding_source": ("TEXT", 0, 0), "capital_exposure_cents": ("INTEGER", 0, 0), "loan_amount_cents": ("INTEGER", 0, 0), - "loan_interest_rate_bps": ("INTEGER", 0, 0), + "loan_interest_rate_tenth_bps": ("INTEGER", 0, 0), "start_date": ("DATE", 1, 0), "end_date": ("DATE", 0, 0), }, @@ -60,6 +60,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "expiry_date": ("DATE", 0, 0), "strike_price_cents": ("INTEGER", 0, 0), "quantity": ("INTEGER", 1, 0), + "quantity_multiplier": ("INTEGER", 1, 0), "price_cents": ("INTEGER", 1, 0), "gross_cash_flow_cents": ("INTEGER", 1, 0), "commission_cents": ("INTEGER", 1, 0), diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 1918b5d..da8e237 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -90,13 +90,10 @@ def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> raise ValueError("price_cents is required") if "commission_cents" not in payload: payload["commission_cents"] = 0 - quantity: int = payload["quantity"] - price_cents: int = payload["price_cents"] - commission_cents: int = payload["commission_cents"] if "gross_cash_flow_cents" not in payload: - payload["gross_cash_flow_cents"] = -quantity * price_cents + raise ValueError("gross_cash_flow_cents is required") if "net_cash_flow_cents" not in payload: - payload["net_cash_flow_cents"] = payload["gross_cash_flow_cents"] - commission_cents + raise ValueError("net_cash_flow_cents is required") # If no cycle_id provided, create Cycle instance but don't call create_cycle() created_cycle = None @@ -163,6 +160,21 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] return list(session.exec(statement).all()) +def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades: + trade: models.Trades | None = session.get(models.Trades, trade_id) + if trade is None: + raise ValueError("trade_id does not exist") + trade.friendly_name = friendly_name + session.add(trade) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_trade_friendly_name integrity error") from e + session.refresh(trade) + return trade + + def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades: trade: models.Trades | None = session.get(models.Trades, trade_id) if trade is None: @@ -240,7 +252,18 @@ def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> return c -IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} +IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date"} + + +def get_cycle_by_id(session: Session, cycle_id: int) -> models.Cycles | None: + return session.get(models.Cycles, cycle_id) + + +def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles]: + statement = select(models.Cycles).where( + models.Cycles.user_id == user_id, + ) + return list(session.exec(statement).all()) def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles: diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 7b377f9..1851b0d 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -64,43 +64,53 @@ class ExchangesRead(ExchangesBase): class CycleBase(SQLModel): friendly_name: str | None = None - symbol: str - exchange_id: int - underlying_currency: UnderlyingCurrency status: str - start_date: date end_date: date | None = None funding_source: str | None = None capital_exposure_cents: int | None = None loan_amount_cents: int | None = None loan_interest_rate_bps: int | None = None trades: list[TradeRead] | None = None + exchange: ExchangesRead | None = None class CycleCreate(CycleBase): user_id: int + symbol: str + exchange_id: int + underlying_currency: UnderlyingCurrency + start_date: date + + +class CycleUpdate(CycleBase): + id: int + + +class CycleRead(CycleCreate): + id: int class TradeBase(SQLModel): - user_id: int - friendly_name: str | None + friendly_name: str | None = None symbol: str - exchange: str + exchange_id: int underlying_currency: UnderlyingCurrency trade_type: TradeType trade_strategy: TradeStrategy trade_date: date - trade_time_utc: datetime quantity: int price_cents: int - gross_cash_flow_cents: int commission_cents: int - net_cash_flow_cents: int - notes: str | None + notes: str | None = None cycle_id: int | None = None class TradeCreate(TradeBase): + user_id: int | None = None + trade_time_utc: datetime | None = None + gross_cash_flow_cents: int | None = None + net_cash_flow_cents: int | None = None + quantity_multiplier: int = 1 expiry_date: date | None = None strike_price_cents: int | None = None is_invalidated: bool = False @@ -108,7 +118,19 @@ class TradeCreate(TradeBase): replaced_by_trade_id: int | None = None -class TradeRead(TradeBase): +class TradeNoteUpdate(BaseModel): id: int - is_invalidated: bool - invalidated_at: datetime | None + notes: str | None = None + + +class TradeFriendlyNameUpdate(BaseModel): + id: int + friendly_name: str + + +class TradeRead(TradeCreate): + id: int + + +SessionsCreate.model_rebuild() +CycleBase.model_rebuild() diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index a7d364d..e8dc281 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -82,6 +82,7 @@ class Trades(SQLModel, table=True): expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) + quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) @@ -109,7 +110,7 @@ class Cycles(SQLModel, table=True): funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index a7d364d..e8dc281 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -82,6 +82,7 @@ class Trades(SQLModel, table=True): expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) + quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) @@ -109,7 +110,7 @@ class Cycles(SQLModel, table=True): funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 5650515..07a3b19 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -13,19 +13,20 @@ from trading_journal import crud, security from trading_journal.dto import ( CycleBase, CycleCreate, + CycleRead, + CycleUpdate, ExchangesBase, ExchangesCreate, ExchangesRead, SessionsCreate, SessionsUpdate, + TradeCreate, + TradeRead, UserCreate, UserLogin, UserRead, ) -SessionsCreate.model_rebuild() -CycleBase.model_rebuild() - if TYPE_CHECKING: from sqlmodel import Session @@ -111,6 +112,22 @@ class ExchangeNotFoundError(ServiceError): pass +class CycleNotFoundError(ServiceError): + pass + + +class TradeNotFoundError(ServiceError): + pass + + +class InvalidTradeDataError(ServiceError): + pass + + +class InvalidCycleDataError(ServiceError): + pass + + # User service def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): @@ -211,13 +228,124 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int # Cycle Service -def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase: +def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) - crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) - return cycle_data + created_cycle = crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) + return CycleRead.model_validate(created_cycle) -def get_trades_service(db_session: Session, user_id: int) -> list: - return crud.get_trades_by_user_id(db_session, user_id) +def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead: + cycle = crud.get_cycle_by_id(db_session, cycle_id) + if not cycle: + raise CycleNotFoundError("Cycle not found") + if cycle.user_id != user_id: + raise CycleNotFoundError("Cycle not found") + return CycleRead.model_validate(cycle) + + +def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleRead]: + cycles = crud.get_cycles_by_user_id(db_session, user_id) + return [CycleRead.model_validate(cycle) for cycle in cycles] + + +def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: + if cycle_data.status == "CLOSED" and cycle_data.end_date is None: + return False, "end_date is required when status is CLOSED" + if cycle_data.status == "OPEN" and cycle_data.end_date is not None: + return False, "end_date must be empty when status is OPEN" + return True, "" + + +def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead: + is_valid, err_msg = _validate_cycle_update_data(cycle_data) + if not is_valid: + raise InvalidCycleDataError(err_msg) + cycle_id = cast("int", cycle_data.id) + existing_cycle = crud.get_cycle_by_id(db_session, cycle_id) + if not existing_cycle: + raise CycleNotFoundError("Cycle not found") + if existing_cycle.user_id != user_id: + raise CycleNotFoundError("Cycle not found") + + provided_data_dict = cycle_data.model_dump(exclude_unset=True) + cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict) + + try: + updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id) + except Exception as e: + logger.exception("Failed to update cycle: \n") + raise ServiceError("Failed to update cycle") from e + return CycleRead.model_validate(updated_cycle) + + +# Trades service +def _append_cashflows(trade_data: TradeCreate) -> TradeCreate: + sign_multipler: int + if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"): + sign_multipler = 1 + else: + sign_multipler = -1 + quantity = trade_data.quantity * trade_data.quantity_multiplier + gross_cash_flow_cents = quantity * trade_data.price_cents * sign_multipler + net_cash_flow_cents = gross_cash_flow_cents - trade_data.commission_cents + trade_data.gross_cash_flow_cents = gross_cash_flow_cents + trade_data.net_cash_flow_cents = net_cash_flow_cents + return trade_data + + +def _validate_trade_data(trade_data: TradeCreate) -> bool: + return not ( + trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) + ) + + +def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead: + if not _validate_trade_data(trade_data): + raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades") + trade_data_dict = trade_data.model_dump() + trade_data_dict["user_id"] = user_id + trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict) + trade_data_with_user_id = _append_cashflows(trade_data_with_user_id) + created_trade = crud.create_trade(db_session, trade_data=trade_data_with_user_id) + return TradeRead.model_validate(created_trade) + + +def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead: + trade = crud.get_trade_by_id(db_session, trade_id) + if not trade: + raise TradeNotFoundError("Trade not found") + if trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + return TradeRead.model_validate(trade) + + +def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead: + existing_trade = crud.get_trade_by_id(db_session, trade_id) + if not existing_trade: + raise TradeNotFoundError("Trade not found") + if existing_trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + try: + updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name) + except Exception as e: + logger.exception("Failed to update trade friendly name: \n") + raise ServiceError("Failed to update trade friendly name") from e + return TradeRead.model_validate(updated_trade) + + +def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead: + existing_trade = crud.get_trade_by_id(db_session, trade_id) + if not existing_trade: + raise TradeNotFoundError("Trade not found") + if existing_trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + if note is None: + note = "" + try: + updated_trade = crud.update_trade_note(db_session, trade_id, note) + except Exception as e: + logger.exception("Failed to update trade notes: \n") + raise ServiceError("Failed to update trade notes") from e + return TradeRead.model_validate(updated_trade)