add create get exchange endpoint
All checks were successful
Backend CI / unit-test (push) Successful in 34s
All checks were successful
Backend CI / unit-test (push) Successful in 34s
This commit is contained in:
@@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
import settings
|
import settings
|
||||||
from trading_journal import db, service
|
from trading_journal import db, service
|
||||||
from trading_journal.db import Database
|
from trading_journal.db import Database
|
||||||
from trading_journal.dto import SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
|
from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
|
||||||
|
|
||||||
_db = db.create_database(settings.settings.database_url)
|
_db = db.create_database(settings.settings.database_url)
|
||||||
|
|
||||||
@@ -96,12 +96,37 @@ async def login(request: Request, user_in: UserLogin) -> SessionsBase:
|
|||||||
|
|
||||||
|
|
||||||
# Exchange
|
# Exchange
|
||||||
# @app.post(f"{settings.settings.api_base}/exchanges")
|
@app.post(f"{settings.settings.api_base}/exchanges")
|
||||||
# async def create_exchange(request: Request, name: str, notes: str | None) -> dict:
|
async def create_exchange(request: Request, exchange_data: ExchangesBase) -> dict:
|
||||||
|
|
||||||
|
|
||||||
@app.get(f"{settings.settings.api_base}/trades")
|
|
||||||
async def get_trades(request: Request) -> list:
|
|
||||||
db_factory: Database = request.app.state.db_factory
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> ExchangesBase:
|
||||||
with db_factory.get_session_ctx_manager() as db:
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
return service.get_trades_service(db, request.state.user_id)
|
return service.create_exchange_service(db, request.state.user_id, exchange_data.name, exchange_data.notes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
exchange = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=exchange.model_dump())
|
||||||
|
except service.ExchangeAlreadyExistsError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create exchange: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(f"{settings.settings.api_base}/exchanges")
|
||||||
|
async def get_exchanges(request: Request) -> list[ExchangesBase]:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> list[ExchangesBase]:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.get_exchanges_by_user_service(db, request.state.user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(sync_work)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to get exchanges: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
# Trade
|
||||||
|
|||||||
@@ -10,6 +10,15 @@ if TYPE_CHECKING:
|
|||||||
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
|
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesBase(SQLModel):
|
||||||
|
name: str
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesCreate(ExchangesBase):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
class TradeBase(SQLModel):
|
class TradeBase(SQLModel):
|
||||||
user_id: int
|
user_id: int
|
||||||
friendly_name: str | None
|
friendly_name: str | None
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
import settings
|
import settings
|
||||||
from trading_journal import crud, security
|
from trading_journal import crud, security
|
||||||
from trading_journal.db import Database
|
from trading_journal.db import Database
|
||||||
from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
|
from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
|
||||||
from trading_journal.models import Sessions
|
from trading_journal.models import Sessions
|
||||||
|
|
||||||
SessionsCreate.model_rebuild()
|
SessionsCreate.model_rebuild()
|
||||||
@@ -88,6 +88,10 @@ class UserAlreadyExistsError(ServiceError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangeAlreadyExistsError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
||||||
if crud.get_user_by_username(db_session, user_in.username):
|
if crud.get_user_by_username(db_session, user_in.username):
|
||||||
raise UserAlreadyExistsError("username already exists")
|
raise UserAlreadyExistsError("username already exists")
|
||||||
@@ -133,5 +137,33 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[
|
|||||||
return SessionsCreate.model_validate(session), token
|
return SessionsCreate.model_validate(session), token
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges service
|
||||||
|
def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate:
|
||||||
|
existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||||
|
if existing_exchange:
|
||||||
|
raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user")
|
||||||
|
exchange_data = ExchangesCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
name=name,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
exchange = crud.create_exchange(db_session, exchange_data=exchange_data)
|
||||||
|
try:
|
||||||
|
exchange_dto = ExchangesCreate.model_validate(exchange)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to convert exchange to ExchangesCreate: ")
|
||||||
|
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create exchange:")
|
||||||
|
raise ServiceError("Failed to create exchange") from e
|
||||||
|
return exchange_dto
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]:
|
||||||
|
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
|
||||||
|
return [ExchangesBase.model_validate(exchange) for exchange in exchanges]
|
||||||
|
|
||||||
|
|
||||||
def get_trades_service(db_session: Session, user_id: int) -> list:
|
def get_trades_service(db_session: Session, user_id: int) -> list:
|
||||||
return crud.get_trades_by_user_id(db_session, user_id)
|
return crud.get_trades_by_user_id(db_session, user_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user