feature/api_endpoint #5
@@ -7,11 +7,12 @@ from datetime import datetime, timezone
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, status
|
from fastapi import FastAPI, HTTPException, Request, status
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from trading_journal import db, service
|
from trading_journal import db, service
|
||||||
from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
|
from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -119,10 +120,10 @@ async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Res
|
|||||||
|
|
||||||
|
|
||||||
@app.get(f"{settings.settings.api_base}/exchanges")
|
@app.get(f"{settings.settings.api_base}/exchanges")
|
||||||
async def get_exchanges(request: Request) -> list[ExchangesBase]:
|
async def get_exchanges(request: Request) -> list[ExchangesRead]:
|
||||||
db_factory: Database = request.app.state.db_factory
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
def sync_work() -> list[ExchangesBase]:
|
def sync_work() -> list[ExchangesRead]:
|
||||||
with db_factory.get_session_ctx_manager() as db:
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
return service.get_exchanges_by_user_service(db, request.state.user_id)
|
return service.get_exchanges_by_user_service(db, request.state.user_id)
|
||||||
|
|
||||||
@@ -133,4 +134,39 @@ async def get_exchanges(request: Request) -> list[ExchangesBase]:
|
|||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
# Trade
|
@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}")
|
||||||
|
async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> ExchangesBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
exchange = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump())
|
||||||
|
except service.ExchangeNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except service.ExchangeAlreadyExistsError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update exchange: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
# Cycle
|
||||||
|
@app.post(f"{settings.settings.api_base}/cycles")
|
||||||
|
async def create_cycle(request: Request, cycle_data: CycleBase) -> Response:
|
||||||
|
return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.")
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> CycleBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.create_cycle_service(db, request.state.user_id, cycle_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cycle = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create cycle: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|||||||
@@ -1,55 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date, datetime # noqa: TC003
|
from datetime import date, datetime # noqa: TC003
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001
|
||||||
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
|
|
||||||
|
|
||||||
|
|
||||||
class ExchangesBase(SQLModel):
|
|
||||||
name: str
|
|
||||||
notes: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExchangesCreate(ExchangesBase):
|
|
||||||
user_id: int
|
|
||||||
|
|
||||||
|
|
||||||
class TradeBase(SQLModel):
|
|
||||||
user_id: int
|
|
||||||
friendly_name: str | None
|
|
||||||
symbol: str
|
|
||||||
exchange: str
|
|
||||||
underlying_currency: UnderlyingCurrency
|
|
||||||
trade_type: TradeType
|
|
||||||
trade_strategy: TradeStrategy
|
|
||||||
trade_date: date
|
|
||||||
trade_time_utc: datetime
|
|
||||||
quantity: int
|
|
||||||
price_cents: int
|
|
||||||
gross_cash_flow_cents: int
|
|
||||||
commission_cents: int
|
|
||||||
net_cash_flow_cents: int
|
|
||||||
notes: str | None
|
|
||||||
cycle_id: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class TradeCreate(TradeBase):
|
|
||||||
expiry_date: date | None = None
|
|
||||||
strike_price_cents: int | None = None
|
|
||||||
is_invalidated: bool = False
|
|
||||||
invalidated_at: datetime | None = None
|
|
||||||
replaced_by_trade_id: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class TradeRead(TradeBase):
|
|
||||||
id: int
|
|
||||||
is_invalidated: bool
|
|
||||||
invalidated_at: datetime | None
|
|
||||||
|
|
||||||
|
|
||||||
class UserBase(SQLModel):
|
class UserBase(SQLModel):
|
||||||
@@ -91,3 +47,68 @@ class SessionsUpdate(SQLModel):
|
|||||||
last_seen_at: datetime | None = None
|
last_seen_at: datetime | None = None
|
||||||
last_used_ip: str | None = None
|
last_used_ip: str | None = None
|
||||||
user_agent: str | None = None
|
user_agent: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesBase(SQLModel):
|
||||||
|
name: str
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesCreate(ExchangesBase):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesRead(ExchangesBase):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class CycleBase(SQLModel):
|
||||||
|
friendly_name: str | None = None
|
||||||
|
symbol: str
|
||||||
|
exchange_id: int
|
||||||
|
underlying_currency: UnderlyingCurrency
|
||||||
|
status: str
|
||||||
|
start_date: date
|
||||||
|
end_date: date | None = None
|
||||||
|
funding_source: str | None = None
|
||||||
|
capital_exposure_cents: int | None = None
|
||||||
|
loan_amount_cents: int | None = None
|
||||||
|
loan_interest_rate_bps: int | None = None
|
||||||
|
trades: list[TradeRead] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CycleCreate(CycleBase):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class TradeBase(SQLModel):
|
||||||
|
user_id: int
|
||||||
|
friendly_name: str | None
|
||||||
|
symbol: str
|
||||||
|
exchange: str
|
||||||
|
underlying_currency: UnderlyingCurrency
|
||||||
|
trade_type: TradeType
|
||||||
|
trade_strategy: TradeStrategy
|
||||||
|
trade_date: date
|
||||||
|
trade_time_utc: datetime
|
||||||
|
quantity: int
|
||||||
|
price_cents: int
|
||||||
|
gross_cash_flow_cents: int
|
||||||
|
commission_cents: int
|
||||||
|
net_cash_flow_cents: int
|
||||||
|
notes: str | None
|
||||||
|
cycle_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TradeCreate(TradeBase):
|
||||||
|
expiry_date: date | None = None
|
||||||
|
strike_price_cents: int | None = None
|
||||||
|
is_invalidated: bool = False
|
||||||
|
invalidated_at: datetime | None = None
|
||||||
|
replaced_by_trade_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TradeRead(TradeBase):
|
||||||
|
id: int
|
||||||
|
is_invalidated: bool
|
||||||
|
invalidated_at: datetime | None
|
||||||
|
|||||||
@@ -10,9 +10,21 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
|
|||||||
|
|
||||||
import settings
|
import settings
|
||||||
from trading_journal import crud, security
|
from trading_journal import crud, security
|
||||||
from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
|
from trading_journal.dto import (
|
||||||
|
CycleBase,
|
||||||
|
CycleCreate,
|
||||||
|
ExchangesBase,
|
||||||
|
ExchangesCreate,
|
||||||
|
ExchangesRead,
|
||||||
|
SessionsCreate,
|
||||||
|
SessionsUpdate,
|
||||||
|
UserCreate,
|
||||||
|
UserLogin,
|
||||||
|
UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
SessionsCreate.model_rebuild()
|
SessionsCreate.model_rebuild()
|
||||||
|
CycleBase.model_rebuild()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
@@ -95,6 +107,11 @@ class ExchangeAlreadyExistsError(ServiceError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangeNotFoundError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# User service
|
||||||
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
||||||
if crud.get_user_by_username(db_session, user_in.username):
|
if crud.get_user_by_username(db_session, user_in.username):
|
||||||
raise UserAlreadyExistsError("username already exists")
|
raise UserAlreadyExistsError("username already exists")
|
||||||
@@ -156,7 +173,7 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes:
|
|||||||
try:
|
try:
|
||||||
exchange_dto = ExchangesCreate.model_validate(exchange)
|
exchange_dto = ExchangesCreate.model_validate(exchange)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to convert exchange to ExchangesCreate: ")
|
logger.exception("Failed to convert exchange to ExchangesCreate:")
|
||||||
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
|
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to create exchange:")
|
logger.exception("Failed to create exchange:")
|
||||||
@@ -164,9 +181,42 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes:
|
|||||||
return exchange_dto
|
return exchange_dto
|
||||||
|
|
||||||
|
|
||||||
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]:
|
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]:
|
||||||
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
|
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
|
||||||
return [ExchangesBase.model_validate(exchange) for exchange in exchanges]
|
return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
|
||||||
|
|
||||||
|
|
||||||
|
def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase:
|
||||||
|
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
|
||||||
|
if not existing_exchange:
|
||||||
|
raise ExchangeNotFoundError("Exchange not found")
|
||||||
|
if existing_exchange.user_id != user_id:
|
||||||
|
raise ExchangeNotFoundError("Exchange not found")
|
||||||
|
|
||||||
|
if name:
|
||||||
|
other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||||
|
if other_exchange and other_exchange.id != existing_exchange.id:
|
||||||
|
raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user")
|
||||||
|
|
||||||
|
exchange_data = ExchangesBase(
|
||||||
|
name=name or existing_exchange.name,
|
||||||
|
notes=notes or existing_exchange.notes,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update exchange: \n")
|
||||||
|
raise ServiceError("Failed to update exchange") from e
|
||||||
|
return ExchangesBase.model_validate(exchange)
|
||||||
|
|
||||||
|
|
||||||
|
# Cycle Service
|
||||||
|
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase:
|
||||||
|
cycle_data_dict = cycle_data.model_dump()
|
||||||
|
cycle_data_dict["user_id"] = user_id
|
||||||
|
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
|
||||||
|
crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id)
|
||||||
|
return cycle_data
|
||||||
|
|
||||||
|
|
||||||
def get_trades_service(db_session: Session, user_id: int) -> list:
|
def get_trades_service(db_session: Session, user_id: int) -> list:
|
||||||
|
|||||||
Reference in New Issue
Block a user