Compare commits
19 Commits
544f5e8c92
...
5e7d801075
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e7d801075 | |||
| 94fb4705ff | |||
| bb87b90285 | |||
| 5eae75b23e | |||
| 6a5f160d83 | |||
| 27b4adaca4 | |||
| e66aab99ea | |||
| 80fc405bf6 | |||
| cf6c826468 | |||
| a6592bd140 | |||
| 92c4e0d4fc | |||
| b68249f9f1 | |||
| 1750401278 | |||
| 466e6ce653 | |||
| e70a63e4f9 | |||
| 76ed38e9af | |||
| 1fbc93353d | |||
| 76cc967c42 | |||
| 442da655c0 |
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -15,3 +15,5 @@ __pycache__/
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
|
||||
devsettings.yaml
|
||||
8
backend/.vscode/launch.json
vendored
8
backend/.vscode/launch.json
vendored
@@ -13,10 +13,14 @@
|
||||
"app:app",
|
||||
"--host=0.0.0.0",
|
||||
"--reload",
|
||||
"--port=5000"
|
||||
"--port=18881"
|
||||
],
|
||||
"jinja": true,
|
||||
"autoStartBrowser": true
|
||||
"autoStartBrowser": false,
|
||||
"env": {
|
||||
"CONFIG_FILE": "devsettings.yaml"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
3
backend/.vscode/settings.json
vendored
3
backend/.vscode/settings.json
vendored
@@ -11,5 +11,6 @@
|
||||
"tests"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.analysis.typeCheckingMode": "standard",
|
||||
}
|
||||
337
backend/app.py
337
backend/app.py
@@ -1,33 +1,324 @@
|
||||
from fastapi import FastAPI
|
||||
from __future__ import annotations
|
||||
|
||||
from models import MsgPayload
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
app = FastAPI()
|
||||
messages_list: dict[int, MsgPayload] = {}
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import settings
|
||||
from trading_journal import db, service
|
||||
from trading_journal.dto import (
|
||||
CycleBase,
|
||||
CycleRead,
|
||||
CycleUpdate,
|
||||
ExchangesBase,
|
||||
ExchangesRead,
|
||||
SessionsBase,
|
||||
SessionsCreate,
|
||||
TradeCreate,
|
||||
TradeFriendlyNameUpdate,
|
||||
TradeNoteUpdate,
|
||||
TradeRead,
|
||||
UserCreate,
|
||||
UserLogin,
|
||||
UserRead,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from trading_journal.db import Database
|
||||
|
||||
_db = db.create_database(settings.settings.database_url)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root() -> dict[str, str]:
|
||||
return {"message": "Hello"}
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
await asyncio.to_thread(_db.init_db)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await asyncio.to_thread(_db.dispose)
|
||||
|
||||
|
||||
# About page route
|
||||
@app.get("/about")
|
||||
def about() -> dict[str, str]:
|
||||
return {"message": "This is the about page."}
|
||||
origins = [
|
||||
"http://127.0.0.1:18881",
|
||||
]
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
service.AuthMiddleWare,
|
||||
)
|
||||
app.state.db_factory = _db
|
||||
|
||||
|
||||
# Route to add a message
|
||||
@app.post("/messages/{msg_name}/")
|
||||
def add_msg(msg_name: str) -> dict[str, MsgPayload]:
|
||||
# Generate an ID for the item based on the highest ID in the messages_list
|
||||
msg_id = max(messages_list.keys()) + 1 if messages_list else 0
|
||||
messages_list[msg_id] = MsgPayload(msg_id=msg_id, msg_name=msg_name)
|
||||
|
||||
return {"message": messages_list[msg_id]}
|
||||
@app.get(f"{settings.settings.api_base}/status")
|
||||
async def get_status() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# Route to list all messages
|
||||
@app.get("/messages")
|
||||
def message_items() -> dict[str, dict[int, MsgPayload]]:
|
||||
return {"messages:": messages_list}
|
||||
@app.post(f"{settings.settings.api_base}/register")
|
||||
async def register_user(request: Request, user_in: UserCreate) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> UserRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.register_user_service(db, user_in)
|
||||
|
||||
try:
|
||||
user = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump())
|
||||
except service.UserAlreadyExistsError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to register user: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||
|
||||
|
||||
@app.post(f"{settings.settings.api_base}/login")
|
||||
async def login(request: Request, user_in: UserLogin) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> tuple[SessionsCreate, str] | None:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.authenticate_user_service(db, user_in)
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(sync_work)
|
||||
if result is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Invalid username or password, or user doesn't exist"},
|
||||
)
|
||||
session, token = result
|
||||
session_return = SessionsBase(user_id=session.user_id)
|
||||
response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump())
|
||||
expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds())
|
||||
response.set_cookie(
|
||||
key="session_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
max_age=expires_sec,
|
||||
path="/",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to login user: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
# Exchange
|
||||
@app.post(f"{settings.settings.api_base}/exchanges")
|
||||
async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> ExchangesBase:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.create_exchange_service(db, request.state.user_id, exchange_data.name, exchange_data.notes)
|
||||
|
||||
try:
|
||||
exchange = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_201_CREATED, content=exchange.model_dump())
|
||||
except service.ExchangeAlreadyExistsError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create exchange: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.get(f"{settings.settings.api_base}/exchanges")
|
||||
async def get_exchanges(request: Request) -> list[ExchangesRead]:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> list[ExchangesRead]:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.get_exchanges_by_user_service(db, request.state.user_id)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(sync_work)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get exchanges: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}")
|
||||
async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> ExchangesBase:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes)
|
||||
|
||||
try:
|
||||
exchange = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump())
|
||||
except service.ExchangeNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except service.ExchangeAlreadyExistsError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update exchange: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
# Cycle
|
||||
@app.post(f"{settings.settings.api_base}/cycles")
|
||||
async def create_cycle(request: Request, cycle_data: CycleBase) -> Response:
|
||||
return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.")
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> CycleBase:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.create_cycle_service(db, request.state.user_id, cycle_data)
|
||||
|
||||
try:
|
||||
cycle = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create cycle: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.get(f"{settings.settings.api_base}/cycles/{{cycle_id}}")
|
||||
async def get_cycle_by_id(request: Request, cycle_id: int) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> CycleBase:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.get_cycle_by_id_service(db, request.state.user_id, cycle_id)
|
||||
|
||||
try:
|
||||
cycle = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
|
||||
except service.CycleNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get cycle by id: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.get(f"{settings.settings.api_base}/cycles/user/{{user_id}}")
|
||||
async def get_cycles_by_user(request: Request, user_id: int) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> list[CycleRead]:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.get_cycles_by_user_service(db, user_id)
|
||||
|
||||
try:
|
||||
cycles = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycles))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get cycles by user: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.patch(f"{settings.settings.api_base}/cycles")
|
||||
async def update_cycle(request: Request, cycle_data: CycleUpdate) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> CycleRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.update_cycle_service(db, request.state.user_id, cycle_data)
|
||||
|
||||
try:
|
||||
cycle = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
|
||||
except service.InvalidCycleDataError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except service.CycleNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update cycle: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.post(f"{settings.settings.api_base}/trades")
|
||||
async def create_trade(request: Request, trade_data: TradeCreate) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> TradeRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.create_trade_service(db, request.state.user_id, trade_data)
|
||||
|
||||
try:
|
||||
trade = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(trade))
|
||||
except service.InvalidTradeDataError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create trade: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.get(f"{settings.settings.api_base}/trades/{{trade_id}}")
|
||||
async def get_trade_by_id(request: Request, trade_id: int) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> TradeRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.get_trade_by_id_service(db, request.state.user_id, trade_id)
|
||||
|
||||
try:
|
||||
trade = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||
except service.TradeNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get trade by id: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.patch(f"{settings.settings.api_base}/trades/friendlyname")
|
||||
async def update_trade_friendly_name(request: Request, friendly_name_update: TradeFriendlyNameUpdate) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> TradeRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.update_trade_friendly_name_service(
|
||||
db,
|
||||
request.state.user_id,
|
||||
friendly_name_update.id,
|
||||
friendly_name_update.friendly_name,
|
||||
)
|
||||
|
||||
try:
|
||||
trade = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||
except service.TradeNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update trade friendly name: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
|
||||
@app.patch(f"{settings.settings.api_base}/trades/notes")
|
||||
async def update_trade_note(request: Request, note_update: TradeNoteUpdate) -> Response:
|
||||
db_factory: Database = request.app.state.db_factory
|
||||
|
||||
def sync_work() -> TradeRead:
|
||||
with db_factory.get_session_ctx_manager() as db:
|
||||
return service.update_trade_note_service(db, request.state.user_id, note_update.id, note_update.notes)
|
||||
|
||||
try:
|
||||
trade = await asyncio.to_thread(sync_work)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||
except service.TradeNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update trade note: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
|
||||
@@ -14,12 +14,130 @@ anyio==4.10.0 \
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
argon2-cffi==25.1.0 \
|
||||
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
|
||||
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
|
||||
# via -r requirements.in
|
||||
argon2-cffi-bindings==25.1.0 \
|
||||
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
|
||||
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
|
||||
--hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \
|
||||
--hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \
|
||||
--hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \
|
||||
--hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \
|
||||
--hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \
|
||||
--hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \
|
||||
--hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \
|
||||
--hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \
|
||||
--hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \
|
||||
--hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \
|
||||
--hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \
|
||||
--hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \
|
||||
--hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \
|
||||
--hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \
|
||||
--hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \
|
||||
--hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \
|
||||
--hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \
|
||||
--hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \
|
||||
--hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \
|
||||
--hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \
|
||||
--hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \
|
||||
--hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \
|
||||
--hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \
|
||||
--hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb
|
||||
# via argon2-cffi
|
||||
certifi==2025.8.3 \
|
||||
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
||||
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
cffi==2.0.0 \
|
||||
--hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \
|
||||
--hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \
|
||||
--hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \
|
||||
--hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \
|
||||
--hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \
|
||||
--hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \
|
||||
--hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \
|
||||
--hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \
|
||||
--hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \
|
||||
--hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \
|
||||
--hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \
|
||||
--hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \
|
||||
--hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \
|
||||
--hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \
|
||||
--hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \
|
||||
--hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \
|
||||
--hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \
|
||||
--hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \
|
||||
--hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \
|
||||
--hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \
|
||||
--hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \
|
||||
--hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \
|
||||
--hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \
|
||||
--hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \
|
||||
--hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \
|
||||
--hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \
|
||||
--hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \
|
||||
--hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \
|
||||
--hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \
|
||||
--hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \
|
||||
--hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \
|
||||
--hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \
|
||||
--hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \
|
||||
--hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \
|
||||
--hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \
|
||||
--hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \
|
||||
--hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \
|
||||
--hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \
|
||||
--hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \
|
||||
--hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \
|
||||
--hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \
|
||||
--hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \
|
||||
--hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \
|
||||
--hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \
|
||||
--hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \
|
||||
--hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \
|
||||
--hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \
|
||||
--hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \
|
||||
--hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \
|
||||
--hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \
|
||||
--hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \
|
||||
--hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \
|
||||
--hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \
|
||||
--hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \
|
||||
--hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \
|
||||
--hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \
|
||||
--hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \
|
||||
--hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \
|
||||
--hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \
|
||||
--hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \
|
||||
--hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \
|
||||
--hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \
|
||||
--hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \
|
||||
--hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \
|
||||
--hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \
|
||||
--hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \
|
||||
--hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \
|
||||
--hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \
|
||||
--hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \
|
||||
--hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \
|
||||
--hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \
|
||||
--hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \
|
||||
--hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \
|
||||
--hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \
|
||||
--hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \
|
||||
--hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \
|
||||
--hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \
|
||||
--hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \
|
||||
--hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \
|
||||
--hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \
|
||||
--hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \
|
||||
--hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \
|
||||
--hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \
|
||||
--hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf
|
||||
# via argon2-cffi-bindings
|
||||
click==8.2.1 \
|
||||
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
||||
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
||||
@@ -116,6 +234,10 @@ pluggy==1.6.0 \
|
||||
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
|
||||
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
|
||||
# via pytest
|
||||
pycparser==2.23 \
|
||||
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
|
||||
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934
|
||||
# via cffi
|
||||
pydantic==2.11.7 \
|
||||
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
||||
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MsgPayload(BaseModel):
|
||||
msg_id: Optional[int]
|
||||
msg_name: str
|
||||
554
backend/openapi.yaml
Normal file
554
backend/openapi.yaml
Normal file
@@ -0,0 +1,554 @@
|
||||
openapi: "3.0.3"
|
||||
info:
|
||||
title: Trading Journal API
|
||||
version: "1.0.0"
|
||||
description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py).
|
||||
servers:
|
||||
- url: "http://127.0.0.1:18881{basePath}"
|
||||
variables:
|
||||
basePath:
|
||||
default: "/api/v1"
|
||||
description: "API base path (matches settings.settings.api_base)"
|
||||
components:
|
||||
securitySchemes:
|
||||
session_cookie:
|
||||
type: apiKey
|
||||
in: cookie
|
||||
name: session_token
|
||||
schemas:
|
||||
UserCreate:
|
||||
$ref: "#/components/schemas/UserCreate_impl"
|
||||
UserCreate_impl:
|
||||
type: object
|
||||
required:
|
||||
- username
|
||||
- password
|
||||
properties:
|
||||
username:
|
||||
type: string
|
||||
is_active:
|
||||
type: boolean
|
||||
default: true
|
||||
password:
|
||||
type: string
|
||||
UserLogin:
|
||||
type: object
|
||||
required:
|
||||
- username
|
||||
- password
|
||||
properties:
|
||||
username:
|
||||
type: string
|
||||
password:
|
||||
type: string
|
||||
UserRead:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- username
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
username:
|
||||
type: string
|
||||
is_active:
|
||||
type: boolean
|
||||
SessionsBase:
|
||||
type: object
|
||||
required:
|
||||
- user_id
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
SessionsCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/SessionsBase"
|
||||
- type: object
|
||||
required:
|
||||
- expires_at
|
||||
properties:
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
ExchangesBase:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
ExchangesRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/ExchangesBase"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
CycleBase:
|
||||
type: object
|
||||
properties:
|
||||
friendly_name:
|
||||
type: string
|
||||
nullable: true
|
||||
status:
|
||||
type: string
|
||||
end_date:
|
||||
type: string
|
||||
format: date
|
||||
nullable: true
|
||||
funding_source:
|
||||
type: string
|
||||
nullable: true
|
||||
capital_exposure_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
loan_amount_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
loan_interest_rate_tenth_bps:
|
||||
type: integer
|
||||
nullable: true
|
||||
trades:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
nullable: true
|
||||
exchange:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
nullable: true
|
||||
CycleCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleBase"
|
||||
- type: object
|
||||
required:
|
||||
- user_id
|
||||
- symbol
|
||||
- exchange_id
|
||||
- underlying_currency
|
||||
- start_date
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
symbol:
|
||||
type: string
|
||||
exchange_id:
|
||||
type: integer
|
||||
underlying_currency:
|
||||
type: string
|
||||
start_date:
|
||||
type: string
|
||||
format: date
|
||||
CycleUpdate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleBase"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
CycleRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleCreate"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
TradeBase:
|
||||
type: object
|
||||
required:
|
||||
- symbol
|
||||
- underlying_currency
|
||||
- trade_type
|
||||
- trade_strategy
|
||||
- trade_date
|
||||
- quantity
|
||||
- price_cents
|
||||
- commission_cents
|
||||
properties:
|
||||
friendly_name:
|
||||
type: string
|
||||
nullable: true
|
||||
symbol:
|
||||
type: string
|
||||
exchange_id:
|
||||
type: integer
|
||||
underlying_currency:
|
||||
type: string
|
||||
trade_type:
|
||||
type: string
|
||||
trade_strategy:
|
||||
type: string
|
||||
trade_date:
|
||||
type: string
|
||||
format: date
|
||||
quantity:
|
||||
type: integer
|
||||
price_cents:
|
||||
type: integer
|
||||
commission_cents:
|
||||
type: integer
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
cycle_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
TradeCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/TradeBase"
|
||||
- type: object
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
trade_time_utc:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
gross_cash_flow_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
net_cash_flow_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
quantity_multiplier:
|
||||
type: integer
|
||||
default: 1
|
||||
expiry_date:
|
||||
type: string
|
||||
format: date
|
||||
nullable: true
|
||||
strike_price_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
is_invalidated:
|
||||
type: boolean
|
||||
default: false
|
||||
invalidated_at:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
replaced_by_trade_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
TradeNoteUpdate:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
TradeFriendlyNameUpdate:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- friendly_name
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
friendly_name:
|
||||
type: string
|
||||
TradeRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/TradeCreate"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
paths:
|
||||
/status:
|
||||
get:
|
||||
summary: "Get API status"
|
||||
security: [] # no auth required
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
/register:
|
||||
post:
|
||||
summary: "Register user"
|
||||
security: [] # no auth required
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserCreate"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserRead"
|
||||
"400":
|
||||
description: Bad Request (user exists)
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/login:
|
||||
post:
|
||||
summary: "Login"
|
||||
security: [] # no auth required
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserLogin"
|
||||
responses:
|
||||
"200":
|
||||
description: OK (sets session cookie)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/SessionsBase"
|
||||
headers:
|
||||
Set-Cookie:
|
||||
description: session cookie
|
||||
schema:
|
||||
type: string
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/exchanges:
|
||||
post:
|
||||
summary: "Create exchange"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesBase"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"400":
|
||||
description: Bad Request
|
||||
"401":
|
||||
description: Unauthorized
|
||||
get:
|
||||
summary: "List user exchanges"
|
||||
security:
|
||||
- session_cookie: []
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"401":
|
||||
description: Unauthorized
|
||||
/exchanges/{exchange_id}:
|
||||
patch:
|
||||
summary: "Update exchange"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: exchange_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesBase"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"404":
|
||||
description: Not found
|
||||
"400":
|
||||
description: Bad request
|
||||
/cycles:
|
||||
post:
|
||||
summary: "Create cycle (currently returns 405 in code)"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleBase"
|
||||
responses:
|
||||
"405":
|
||||
description: Method not allowed (app currently returns 405)
|
||||
patch:
|
||||
summary: "Update cycle"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
"400":
|
||||
description: Invalid data
|
||||
"404":
|
||||
description: Not found
|
||||
/cycles/{cycle_id}:
|
||||
get:
|
||||
summary: "Get cycle by id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: cycle_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/cycles/user/{user_id}:
|
||||
get:
|
||||
summary: "Get cycles by user id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: user_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
/trades:
|
||||
post:
|
||||
summary: "Create trade"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeCreate"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"400":
|
||||
description: Invalid trade data
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/trades/{trade_id}:
|
||||
get:
|
||||
summary: "Get trade by id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: trade_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/trades/friendlyname:
|
||||
patch:
|
||||
summary: "Update trade friendly name"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeFriendlyNameUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/trades/notes:
|
||||
patch:
|
||||
summary: "Update trade notes"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeNoteUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
@@ -4,3 +4,4 @@ httpx
|
||||
pyyaml
|
||||
pydantic-settings
|
||||
sqlmodel
|
||||
argon2-cffi
|
||||
@@ -14,12 +14,130 @@ anyio==4.10.0 \
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
argon2-cffi==25.1.0 \
|
||||
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
|
||||
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
|
||||
# via -r requirements.in
|
||||
argon2-cffi-bindings==25.1.0 \
|
||||
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
|
||||
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
|
||||
--hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \
|
||||
--hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \
|
||||
--hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \
|
||||
--hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \
|
||||
--hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \
|
||||
--hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \
|
||||
--hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \
|
||||
--hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \
|
||||
--hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \
|
||||
--hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \
|
||||
--hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \
|
||||
--hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \
|
||||
--hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \
|
||||
--hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \
|
||||
--hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \
|
||||
--hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \
|
||||
--hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \
|
||||
--hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \
|
||||
--hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \
|
||||
--hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \
|
||||
--hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \
|
||||
--hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \
|
||||
--hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \
|
||||
--hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb
|
||||
# via argon2-cffi
|
||||
certifi==2025.8.3 \
|
||||
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
||||
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
cffi==2.0.0 \
|
||||
--hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \
|
||||
--hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \
|
||||
--hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \
|
||||
--hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \
|
||||
--hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \
|
||||
--hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \
|
||||
--hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \
|
||||
--hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \
|
||||
--hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \
|
||||
--hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \
|
||||
--hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \
|
||||
--hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \
|
||||
--hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \
|
||||
--hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \
|
||||
--hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \
|
||||
--hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \
|
||||
--hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \
|
||||
--hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \
|
||||
--hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \
|
||||
--hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \
|
||||
--hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \
|
||||
--hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \
|
||||
--hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \
|
||||
--hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \
|
||||
--hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \
|
||||
--hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \
|
||||
--hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \
|
||||
--hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \
|
||||
--hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \
|
||||
--hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \
|
||||
--hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \
|
||||
--hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \
|
||||
--hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \
|
||||
--hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \
|
||||
--hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \
|
||||
--hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \
|
||||
--hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \
|
||||
--hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \
|
||||
--hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \
|
||||
--hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \
|
||||
--hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \
|
||||
--hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \
|
||||
--hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \
|
||||
--hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \
|
||||
--hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \
|
||||
--hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \
|
||||
--hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \
|
||||
--hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \
|
||||
--hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \
|
||||
--hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \
|
||||
--hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \
|
||||
--hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \
|
||||
--hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \
|
||||
--hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \
|
||||
--hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \
|
||||
--hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \
|
||||
--hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \
|
||||
--hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \
|
||||
--hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \
|
||||
--hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \
|
||||
--hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \
|
||||
--hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \
|
||||
--hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \
|
||||
--hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \
|
||||
--hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \
|
||||
--hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \
|
||||
--hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \
|
||||
--hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \
|
||||
--hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \
|
||||
--hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \
|
||||
--hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \
|
||||
--hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \
|
||||
--hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \
|
||||
--hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \
|
||||
--hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \
|
||||
--hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \
|
||||
--hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \
|
||||
--hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \
|
||||
--hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \
|
||||
--hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \
|
||||
--hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \
|
||||
--hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \
|
||||
--hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \
|
||||
--hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf
|
||||
# via argon2-cffi-bindings
|
||||
click==8.2.1 \
|
||||
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
||||
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
||||
@@ -104,6 +222,10 @@ idna==3.10 \
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
pycparser==2.23 \
|
||||
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
|
||||
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934
|
||||
# via cffi
|
||||
pydantic==2.11.7 \
|
||||
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
||||
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
||||
|
||||
@@ -13,8 +13,14 @@ ignore = [
|
||||
"TRY003",
|
||||
"EM101",
|
||||
"EM102",
|
||||
"PLC0405",
|
||||
"SIM108",
|
||||
"C901",
|
||||
"PLR0912",
|
||||
"PLR0915",
|
||||
"PLR0913",
|
||||
"PLC0415",
|
||||
]
|
||||
|
||||
[lint.extend-per-file-ignores]
|
||||
"test*.py" = ["S101"]
|
||||
"test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"]
|
||||
"models*.py" = ["FA102"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -12,6 +14,10 @@ class Settings(BaseSettings):
|
||||
port: int = 8000
|
||||
workers: int = 1
|
||||
log_level: str = "info"
|
||||
database_url: str = "sqlite:///:memory:"
|
||||
api_base: str = "/api/v1"
|
||||
session_expiry_seconds: int = 3600 * 24 * 7 # 7 days
|
||||
hmac_key: str | None = None
|
||||
|
||||
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
|
||||
56
backend/testhelpers/tradecycles.sh
Executable file
56
backend/testhelpers/tradecycles.sh
Executable file
@@ -0,0 +1,56 @@
|
||||
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||
--data '{
|
||||
"friendly_name": "20250908-CA-PUT",
|
||||
"symbol": "CA",
|
||||
"exchange_id": 1,
|
||||
"underlying_currency": "EUR",
|
||||
"trade_type": "SELL_PUT",
|
||||
"trade_strategy": "WHEEL",
|
||||
"trade_date": "2025-09-08",
|
||||
"quantity": 1,
|
||||
"quantity_multiplier": 100,
|
||||
"price_cents": 17,
|
||||
"expiry_date": "2025-09-09",
|
||||
"strike_price_cents": 1220,
|
||||
"commission_cents": 114
|
||||
}'
|
||||
|
||||
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||
--data '{
|
||||
"friendly_name": "20250920-CA-ASSIGN",
|
||||
"symbol": "CA",
|
||||
"exchange_id": 1,
|
||||
"cycle_id": 1,
|
||||
"underlying_currency": "EUR",
|
||||
"trade_type": "ASSIGNMENT",
|
||||
"trade_strategy": "WHEEL",
|
||||
"trade_date": "2025-09-20",
|
||||
"quantity": 100,
|
||||
"quantity_multiplier": 1,
|
||||
"price_cents": 1220,
|
||||
"commission_cents": 0
|
||||
}'
|
||||
|
||||
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||
--data '{
|
||||
"friendly_name": "20250923-CA-CALL",
|
||||
"symbol": "CA",
|
||||
"exchange_id": 1,
|
||||
"cycle_id": 1,
|
||||
"underlying_currency": "EUR",
|
||||
"trade_type": "SELL_CALL",
|
||||
"trade_strategy": "WHEEL",
|
||||
"trade_date": "2025-09-23",
|
||||
"quantity": 1,
|
||||
"quantity_multiplier": 100,
|
||||
"price_cents": 31,
|
||||
"expiry_date": "2025-10-10",
|
||||
"strike_price_cents": 1200,
|
||||
"commission_cents": 114
|
||||
}'
|
||||
405
backend/tests/test_app.py
Normal file
405
backend/tests/test_app.py
Normal file
@@ -0,0 +1,405 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import settings
|
||||
import trading_journal.service as svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]:
|
||||
class NoAuth:
|
||||
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||
state = scope.get("state")
|
||||
if state is None:
|
||||
scope["state"] = SimpleNamespace()
|
||||
scope["state"]["user_id"] = 1
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
class DeclineAuth:
|
||||
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||
if scope.get("type") != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
path = scope.get("path", "")
|
||||
# allow public/exempt paths through
|
||||
if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
# immediately respond 401 for protected paths
|
||||
resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
await resp(scope, receive, send)
|
||||
|
||||
def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient:
|
||||
defaults = {
|
||||
"register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})),
|
||||
"authenticate_user_service": MagicMock(
|
||||
return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"),
|
||||
),
|
||||
"create_exchange_service": MagicMock(
|
||||
return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}),
|
||||
),
|
||||
"get_exchanges_by_user_service": MagicMock(return_value=[]),
|
||||
}
|
||||
|
||||
if decline_auth:
|
||||
monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth)
|
||||
else:
|
||||
monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth)
|
||||
merged = {**defaults, **mocks}
|
||||
for name, mock in merged.items():
|
||||
monkeypatch.setattr(svc, name, mock)
|
||||
import sys
|
||||
|
||||
if "app" in sys.modules:
|
||||
del sys.modules["app"]
|
||||
from importlib import import_module
|
||||
|
||||
app = import_module("app").app # re-import app module
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
def test_get_status(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
response = c.get(f"{settings.settings.api_base}/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_register_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory() # use defaults
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == 201
|
||||
|
||||
|
||||
def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "username already exists"}
|
||||
|
||||
|
||||
def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert r.json() == {"detail": "Internal Server Error"}
|
||||
|
||||
|
||||
def test_login_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory() # use defaults
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"user_id": 1}
|
||||
assert r.cookies.get("session_token") == "token"
|
||||
|
||||
|
||||
def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(authenticate_user_service=MagicMock(return_value=None))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"}
|
||||
|
||||
|
||||
def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert r.json() == {"detail": "Internal Server Error"}
|
||||
|
||||
|
||||
def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"}
|
||||
|
||||
|
||||
def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "exchange already exists"}
|
||||
|
||||
|
||||
def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(decline_auth=True)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert r.json() == {"detail": "Unauthorized"}
|
||||
|
||||
|
||||
def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_exchanges_service=MagicMock(
|
||||
return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}
|
||||
|
||||
|
||||
def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "exchange not found"}
|
||||
|
||||
|
||||
def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_cycle_by_id_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
friendly_name="Cycle 1",
|
||||
status="active",
|
||||
id=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"}
|
||||
|
||||
|
||||
def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/999")
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "cycle not found"}
|
||||
|
||||
|
||||
def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_cycles_by_user_service=MagicMock(
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
friendly_name="Cycle 1",
|
||||
status="active",
|
||||
id=1,
|
||||
),
|
||||
SimpleNamespace(
|
||||
friendly_name="Cycle 2",
|
||||
status="completed",
|
||||
id=2,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/user/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == [
|
||||
{"id": 1, "friendly_name": "Cycle 1", "status": "active"},
|
||||
{"id": 2, "friendly_name": "Cycle 2", "status": "completed"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_cycle_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
friendly_name="Updated Cycle",
|
||||
status="completed",
|
||||
id=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"}
|
||||
|
||||
|
||||
def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "invalid cycle data"}
|
||||
|
||||
|
||||
def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "cycle not found"}
|
||||
|
||||
|
||||
def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
create_trade_service=MagicMock(
|
||||
return_value=SimpleNamespace(),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.post(
|
||||
f"{settings.settings.api_base}/trades",
|
||||
json={
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 15,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01",
|
||||
},
|
||||
)
|
||||
assert r.status_code == 201
|
||||
|
||||
|
||||
def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")),
|
||||
)
|
||||
with client as c:
|
||||
r = c.post(
|
||||
f"{settings.settings.api_base}/trades",
|
||||
json={
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 15,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01",
|
||||
},
|
||||
)
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "invalid trade data"}
|
||||
|
||||
|
||||
def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_trade_by_id_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
cycle_id=1,
|
||||
exchange_id=1,
|
||||
symbol="BTCUSD",
|
||||
underlying_currency="USD",
|
||||
trade_type="LONG_SPOT",
|
||||
trade_strategy="FX",
|
||||
quantity=1,
|
||||
price_cents=1500,
|
||||
commission_cents=100,
|
||||
trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc),
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/trades/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {
|
||||
"id": 1,
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 1500,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01T00:00:00+00:00",
|
||||
}
|
||||
|
||||
|
||||
def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/trades/999")
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
|
||||
|
||||
def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_trade_friendly_name_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
friendly_name="Updated Trade Name",
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"}
|
||||
|
||||
|
||||
def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
|
||||
|
||||
def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_trade_note_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
note="Updated trade note",
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "note": "Updated trade note"}
|
||||
|
||||
|
||||
def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
@@ -1,15 +1,19 @@
|
||||
from collections.abc import Generator
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
from trading_journal import crud, models
|
||||
|
||||
# TODO: If needed, add failing flow tests, but now only add happy flow.
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +33,11 @@ def engine() -> Generator[Engine, None, None]:
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine: Engine) -> Generator[Session, None, None]:
|
||||
with Session(engine) as s:
|
||||
yield s
|
||||
session = Session(engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def make_user(session: Session, username: str = "testuser") -> int:
|
||||
@@ -38,38 +45,47 @@ def make_user(session: Session, username: str = "testuser") -> int:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user.id
|
||||
return cast("int", user.id)
|
||||
|
||||
|
||||
def make_cycle(
|
||||
session: Session, user_id: int, friendly_name: str = "Test Cycle"
|
||||
) -> int:
|
||||
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
|
||||
exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange")
|
||||
session.add(exchange)
|
||||
session.commit()
|
||||
session.refresh(exchange)
|
||||
return cast("int", exchange.id)
|
||||
|
||||
|
||||
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||
cycle = models.Cycles(
|
||||
user_id=user_id,
|
||||
friendly_name=friendly_name,
|
||||
symbol="AAPL",
|
||||
exchange_id=exchange_id,
|
||||
underlying_currency=models.UnderlyingCurrency.USD,
|
||||
status=models.CycleStatus.OPEN,
|
||||
start_date=datetime.now().date(),
|
||||
)
|
||||
start_date=datetime.now(timezone.utc).date(),
|
||||
) # type: ignore[arg-type]
|
||||
session.add(cycle)
|
||||
session.commit()
|
||||
session.refresh(cycle)
|
||||
return cycle.id
|
||||
return cast("int", cycle.id)
|
||||
|
||||
|
||||
def make_trade(
|
||||
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
|
||||
) -> int:
|
||||
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
|
||||
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||
assert cycle is not None
|
||||
exchange_id = cycle.exchange_id
|
||||
trade = models.Trades(
|
||||
user_id=user_id,
|
||||
friendly_name=friendly_name,
|
||||
symbol="AAPL",
|
||||
exchange_id=exchange_id,
|
||||
underlying_currency=models.UnderlyingCurrency.USD,
|
||||
trade_type=models.TradeType.LONG_SPOT,
|
||||
trade_strategy=models.TradeStrategy.SPOT,
|
||||
trade_date=datetime.now().date(),
|
||||
trade_time_utc=datetime.now(),
|
||||
trade_date=datetime.now(timezone.utc).date(),
|
||||
trade_time_utc=datetime.now(timezone.utc),
|
||||
quantity=10,
|
||||
price_cents=15000,
|
||||
gross_cash_flow_cents=-150000,
|
||||
@@ -81,7 +97,7 @@ def make_trade(
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
session.refresh(trade)
|
||||
return trade.id
|
||||
return cast("int", trade.id)
|
||||
|
||||
|
||||
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||
@@ -89,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
session.refresh(trade)
|
||||
return trade.id
|
||||
return cast("int", trade.id)
|
||||
|
||||
|
||||
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
|
||||
@@ -113,9 +129,28 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
|
||||
return login_session
|
||||
|
||||
|
||||
def test_create_trade_success_with_cycle(session: Session):
|
||||
def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None:
|
||||
actual_utc = _ensure_utc_aware(actual)
|
||||
expected_utc = _ensure_utc_aware(expected)
|
||||
assert actual_utc is not None
|
||||
assert expected_utc is not None
|
||||
delta = abs(actual_utc - expected_utc)
|
||||
assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
|
||||
|
||||
|
||||
# Trades
|
||||
def test_create_trade_success_with_cycle(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
trade_data = {
|
||||
"user_id": user_id,
|
||||
@@ -124,7 +159,7 @@ def test_create_trade_success_with_cycle(session: Session):
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -150000,
|
||||
@@ -147,6 +182,7 @@ def test_create_trade_success_with_cycle(session: Session):
|
||||
assert actual_trade.trade_type == trade_data["trade_type"]
|
||||
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
|
||||
assert actual_trade.quantity == trade_data["quantity"]
|
||||
assert actual_trade.quantity_multiplier == 1
|
||||
assert actual_trade.price_cents == trade_data["price_cents"]
|
||||
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
|
||||
assert actual_trade.commission_cents == trade_data["commission_cents"]
|
||||
@@ -154,19 +190,68 @@ def test_create_trade_success_with_cycle(session: Session):
|
||||
assert actual_trade.cycle_id == trade_data["cycle_id"]
|
||||
|
||||
|
||||
def test_create_trade_with_auto_created_cycle(session: Session):
|
||||
def test_create_trade_with_custom_multipler(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Test Trade with Multiplier",
|
||||
"symbol": "AAPL",
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"quantity_multiplier": 100,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -1500000,
|
||||
"commission_cents": 50000,
|
||||
"net_cash_flow_cents": -1550000,
|
||||
"cycle_id": cycle_id,
|
||||
}
|
||||
|
||||
trade = crud.create_trade(session, trade_data)
|
||||
assert trade.id is not None
|
||||
assert trade.user_id == user_id
|
||||
assert trade.cycle_id == cycle_id
|
||||
session.refresh(trade)
|
||||
|
||||
actual_trade = session.get(models.Trades, trade.id)
|
||||
assert actual_trade is not None
|
||||
assert actual_trade.friendly_name == trade_data["friendly_name"]
|
||||
assert actual_trade.symbol == trade_data["symbol"]
|
||||
assert actual_trade.underlying_currency == trade_data["underlying_currency"]
|
||||
assert actual_trade.trade_type == trade_data["trade_type"]
|
||||
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
|
||||
assert actual_trade.quantity == trade_data["quantity"]
|
||||
assert actual_trade.quantity_multiplier == trade_data["quantity_multiplier"]
|
||||
assert actual_trade.price_cents == trade_data["price_cents"]
|
||||
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
|
||||
assert actual_trade.commission_cents == trade_data["commission_cents"]
|
||||
assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"]
|
||||
assert actual_trade.cycle_id == trade_data["cycle_id"]
|
||||
|
||||
|
||||
def test_create_trade_with_auto_created_cycle(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
|
||||
trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Test Trade with Auto Cycle",
|
||||
"symbol": "AAPL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 5,
|
||||
"price_cents": 15500,
|
||||
"gross_cash_flow_cents": -77500,
|
||||
"commission_cents": 300,
|
||||
"net_cash_flow_cents": -77800,
|
||||
}
|
||||
|
||||
trade = crud.create_trade(session, trade_data)
|
||||
@@ -193,20 +278,22 @@ def test_create_trade_with_auto_created_cycle(session: Session):
|
||||
assert auto_cycle.symbol == trade_data["symbol"]
|
||||
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
|
||||
assert auto_cycle.status == models.CycleStatus.OPEN
|
||||
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
|
||||
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_create_trade_missing_required_fields(session: Session):
|
||||
def test_create_trade_missing_required_fields(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
|
||||
base_trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Incomplete Trade",
|
||||
"symbol": "AAPL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
}
|
||||
@@ -218,6 +305,13 @@ def test_create_trade_missing_required_fields(session: Session):
|
||||
crud.create_trade(session, trade_data)
|
||||
assert "symbol is required" in str(excinfo.value)
|
||||
|
||||
# Missing exchange and cycle together
|
||||
trade_data = base_trade_data.copy()
|
||||
trade_data.pop("exchange_id", None)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
crud.create_trade(session, trade_data)
|
||||
assert "exchange_id is required when no cycle is attached" in str(excinfo.value)
|
||||
|
||||
# Missing underlying_currency
|
||||
trade_data = base_trade_data.copy()
|
||||
trade_data.pop("underlying_currency", None)
|
||||
@@ -254,18 +348,20 @@ def test_create_trade_missing_required_fields(session: Session):
|
||||
assert "price_cents is required" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_trade_by_id(session: Session):
|
||||
def test_get_trade_by_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Test Trade for Get",
|
||||
"symbol": "AAPL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_date": datetime.now(timezone.utc).date(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -150000,
|
||||
@@ -291,19 +387,21 @@ def test_get_trade_by_id(session: Session):
|
||||
assert trade.trade_date == trade_data["trade_date"]
|
||||
|
||||
|
||||
def test_get_trade_by_user_id_and_friendly_name(session: Session):
|
||||
def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
friendly_name = "Unique Trade Name"
|
||||
trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": friendly_name,
|
||||
"symbol": "AAPL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_date": datetime.now(timezone.utc).date(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -150000,
|
||||
@@ -318,18 +416,20 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session):
|
||||
assert trade.user_id == user_id
|
||||
|
||||
|
||||
def test_get_trades_by_user_id(session: Session):
|
||||
def test_get_trades_by_user_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
trade_data_1 = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Trade One",
|
||||
"symbol": "AAPL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_date": datetime.now(timezone.utc).date(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -150000,
|
||||
@@ -341,11 +441,12 @@ def test_get_trades_by_user_id(session: Session):
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Trade Two",
|
||||
"symbol": "GOOGL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.SHORT_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_date": datetime.now(timezone.utc).date(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 5,
|
||||
"price_cents": 280000,
|
||||
"gross_cash_flow_cents": 1400000,
|
||||
@@ -362,9 +463,28 @@ def test_get_trades_by_user_id(session: Session):
|
||||
assert friendly_names == {"Trade One", "Trade Two"}
|
||||
|
||||
|
||||
def test_update_trade_note(session: Session):
|
||||
def test_update_trade_friendly_name(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
trade_id = make_trade(session, user_id, cycle_id)
|
||||
|
||||
new_friendly_name = "Updated Trade Name"
|
||||
updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name)
|
||||
assert updated_trade is not None
|
||||
assert updated_trade.id == trade_id
|
||||
assert updated_trade.friendly_name == new_friendly_name
|
||||
|
||||
session.refresh(updated_trade)
|
||||
actual_trade = session.get(models.Trades, trade_id)
|
||||
assert actual_trade is not None
|
||||
assert actual_trade.friendly_name == new_friendly_name
|
||||
|
||||
|
||||
def test_update_trade_note(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
trade_id = make_trade(session, user_id, cycle_id)
|
||||
|
||||
new_note = "This is an updated note."
|
||||
@@ -379,9 +499,10 @@ def test_update_trade_note(session: Session):
|
||||
assert actual_trade.notes == new_note
|
||||
|
||||
|
||||
def test_invalidate_trade(session: Session):
|
||||
def test_invalidate_trade(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
trade_id = make_trade(session, user_id, cycle_id)
|
||||
|
||||
invalidated_trade = crud.invalidate_trade(session, trade_id)
|
||||
@@ -395,21 +516,26 @@ def test_invalidate_trade(session: Session):
|
||||
assert actual_trade.is_invalidated is True
|
||||
|
||||
|
||||
def test_replace_trade(session: Session):
|
||||
def test_replace_trade(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
old_trade_id = make_trade(session, user_id, cycle_id)
|
||||
|
||||
new_trade_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Replaced Trade",
|
||||
"symbol": "MSFT",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_time_utc": datetime.now(),
|
||||
"trade_time_utc": datetime.now(timezone.utc),
|
||||
"quantity": 20,
|
||||
"price_cents": 25000,
|
||||
"gross_cash_flow_cents": -500000,
|
||||
"commission_cents": 1000,
|
||||
"net_cash_flow_cents": -501000,
|
||||
}
|
||||
|
||||
new_trade = crud.replace_trade(session, old_trade_id, new_trade_data)
|
||||
@@ -438,15 +564,18 @@ def test_replace_trade(session: Session):
|
||||
assert actual_new_trade.replaced_by_trade_id == old_trade_id
|
||||
|
||||
|
||||
def test_create_cycle(session: Session):
|
||||
# Cycles
|
||||
def test_create_cycle(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_data = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "My First Cycle",
|
||||
"symbol": "GOOGL",
|
||||
"exchange_id": exchange_id,
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"status": models.CycleStatus.OPEN,
|
||||
"start_date": datetime.now().date(),
|
||||
"start_date": datetime.now(timezone.utc).date(),
|
||||
}
|
||||
cycle = crud.create_cycle(session, cycle_data)
|
||||
assert cycle.id is not None
|
||||
@@ -467,9 +596,35 @@ def test_create_cycle(session: Session):
|
||||
assert actual_cycle.start_date == cycle_data["start_date"]
|
||||
|
||||
|
||||
def test_update_cycle(session: Session):
|
||||
def test_get_cycle_by_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name")
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Cycle to Get")
|
||||
cycle = crud.get_cycle_by_id(session, cycle_id)
|
||||
assert cycle is not None
|
||||
assert cycle.id == cycle_id
|
||||
assert cycle.friendly_name == "Cycle to Get"
|
||||
assert cycle.user_id == user_id
|
||||
|
||||
|
||||
def test_get_cycles_by_user_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_names = ["Cycle One", "Cycle Two", "Cycle Three"]
|
||||
for name in cycle_names:
|
||||
make_cycle(session, user_id, exchange_id, friendly_name=name)
|
||||
|
||||
cycles = crud.get_cycles_by_user_id(session, user_id)
|
||||
assert len(cycles) == len(cycle_names)
|
||||
fetched_names = {cycle.friendly_name for cycle in cycles}
|
||||
for name in cycle_names:
|
||||
assert name in fetched_names
|
||||
|
||||
|
||||
def test_update_cycle(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||
|
||||
update_data = {
|
||||
"friendly_name": "Updated Cycle Name",
|
||||
@@ -488,16 +643,17 @@ def test_update_cycle(session: Session):
|
||||
assert actual_cycle.status == update_data["status"]
|
||||
|
||||
|
||||
def test_update_cycle_immutable_fields(session: Session):
|
||||
def test_update_cycle_immutable_fields(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name")
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||
|
||||
# Attempt to update immutable fields
|
||||
update_data = {
|
||||
"id": cycle_id + 1, # Trying to change the ID
|
||||
"user_id": user_id + 1, # Trying to change the user_id
|
||||
"start_date": datetime(2020, 1, 1).date(), # Trying to change start_date
|
||||
"created_at": datetime(2020, 1, 1), # Trying to change created_at
|
||||
"start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date
|
||||
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||
"friendly_name": "Valid Update", # Valid field to update
|
||||
}
|
||||
|
||||
@@ -511,7 +667,314 @@ def test_update_cycle_immutable_fields(session: Session):
|
||||
)
|
||||
|
||||
|
||||
def test_create_user(session: Session):
|
||||
# Cycle loans
|
||||
def test_create_cycle_loan_event(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000, # 5%
|
||||
"notes": "Test loan change for the cycle",
|
||||
}
|
||||
|
||||
loan_event = crud.create_cycle_loan_event(session, loan_data)
|
||||
now = datetime.now(timezone.utc)
|
||||
assert loan_event.id is not None
|
||||
assert loan_event.cycle_id == cycle_id
|
||||
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||
assert loan_event.notes == loan_data["notes"]
|
||||
assert loan_event.effective_date == now.date()
|
||||
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
|
||||
|
||||
session.refresh(loan_event)
|
||||
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
|
||||
assert actual_loan_event is not None
|
||||
assert actual_loan_event.cycle_id == cycle_id
|
||||
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||
assert actual_loan_event.notes == loan_data["notes"]
|
||||
assert actual_loan_event.effective_date == now.date()
|
||||
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
|
||||
|
||||
|
||||
def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data_1 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000,
|
||||
"notes": "First loan event",
|
||||
}
|
||||
yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date()
|
||||
loan_data_2 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 150000,
|
||||
"loan_interest_rate_tenth_bps": 4500,
|
||||
"effective_date": yesterday,
|
||||
"notes": "Second loan event",
|
||||
}
|
||||
|
||||
crud.create_cycle_loan_event(session, loan_data_1)
|
||||
crud.create_cycle_loan_event(session, loan_data_2)
|
||||
|
||||
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
|
||||
assert len(loan_events) == 2
|
||||
notes = [event.notes for event in loan_events]
|
||||
assert loan_events[0].notes == loan_data_2["notes"]
|
||||
assert loan_events[0].effective_date == yesterday
|
||||
assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc
|
||||
|
||||
|
||||
def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data_1 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000,
|
||||
"notes": "First loan event",
|
||||
}
|
||||
loan_data_2 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 150000,
|
||||
"loan_interest_rate_tenth_bps": 4500,
|
||||
"notes": "Second loan event",
|
||||
}
|
||||
|
||||
crud.create_cycle_loan_event(session, loan_data_1)
|
||||
crud.create_cycle_loan_event(session, loan_data_2)
|
||||
|
||||
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
|
||||
assert len(loan_events) == 2
|
||||
notes = [event.notes for event in loan_events]
|
||||
assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same
|
||||
|
||||
|
||||
def test_create_cycle_loan_event_single_field(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 200000,
|
||||
}
|
||||
|
||||
loan_event = crud.create_cycle_loan_event(session, loan_data)
|
||||
now = datetime.now(timezone.utc)
|
||||
assert loan_event.id is not None
|
||||
assert loan_event.cycle_id == cycle_id
|
||||
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert loan_event.loan_interest_rate_tenth_bps is None
|
||||
assert loan_event.notes is None
|
||||
assert loan_event.effective_date == now.date()
|
||||
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
|
||||
|
||||
session.refresh(loan_event)
|
||||
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
|
||||
assert actual_loan_event is not None
|
||||
assert actual_loan_event.cycle_id == cycle_id
|
||||
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert actual_loan_event.loan_interest_rate_tenth_bps is None
|
||||
assert actual_loan_event.notes is None
|
||||
assert actual_loan_event.effective_date == now.date()
|
||||
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
|
||||
|
||||
|
||||
def test_create_cycle_daily_accrual(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
today = datetime.now(timezone.utc).date()
|
||||
accrual_data = {
|
||||
"cycle_id": cycle_id,
|
||||
"accrual_date": today,
|
||||
"accrued_interest_cents": 150,
|
||||
"notes": "Daily interest accrual",
|
||||
}
|
||||
|
||||
accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"])
|
||||
assert accrual.id is not None
|
||||
assert accrual.cycle_id == cycle_id
|
||||
assert accrual.accrual_date == accrual_data["accrual_date"]
|
||||
assert accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
|
||||
|
||||
session.refresh(accrual)
|
||||
actual_accrual = session.get(models.CycleDailyAccrual, accrual.id)
|
||||
assert actual_accrual is not None
|
||||
assert actual_accrual.cycle_id == cycle_id
|
||||
assert actual_accrual.accrual_date == accrual_data["accrual_date"]
|
||||
assert actual_accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
|
||||
|
||||
|
||||
def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
today = datetime.now(timezone.utc).date()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
accrual_data_1 = {
|
||||
"cycle_id": cycle_id,
|
||||
"accrual_date": yesterday,
|
||||
"accrued_interest_cents": 100,
|
||||
}
|
||||
accrual_data_2 = {
|
||||
"cycle_id": cycle_id,
|
||||
"accrual_date": today,
|
||||
"accrued_interest_cents": 150,
|
||||
}
|
||||
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||
|
||||
accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id)
|
||||
assert len(accruals) == 2
|
||||
dates = [accrual.accrual_date for accrual in accruals]
|
||||
assert dates == [yesterday, today] # Ordered by accrual_date asc
|
||||
|
||||
|
||||
def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
today = datetime.now(timezone.utc).date()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
accrual_data_1 = {
|
||||
"cycle_id": cycle_id,
|
||||
"accrual_date": yesterday,
|
||||
"accrued_interest_cents": 100,
|
||||
}
|
||||
accrual_data_2 = {
|
||||
"cycle_id": cycle_id,
|
||||
"accrual_date": today,
|
||||
"accrued_interest_cents": 150,
|
||||
}
|
||||
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||
|
||||
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today)
|
||||
assert accruals_today is not None
|
||||
assert accruals_today.accrual_date == today
|
||||
assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
|
||||
|
||||
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday)
|
||||
assert accruals_yesterday is not None
|
||||
assert accruals_yesterday.accrual_date == yesterday
|
||||
assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"]
|
||||
|
||||
|
||||
# Exchanges
|
||||
def test_create_exchange(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_data = {
|
||||
"name": "NYSE",
|
||||
"notes": "New York Stock Exchange",
|
||||
"user_id": user_id,
|
||||
}
|
||||
exchange = crud.create_exchange(session, exchange_data)
|
||||
assert exchange.id is not None
|
||||
assert exchange.name == exchange_data["name"]
|
||||
assert exchange.notes == exchange_data["notes"]
|
||||
assert exchange.user_id == user_id
|
||||
|
||||
session.refresh(exchange)
|
||||
actual_exchange = session.get(models.Exchanges, exchange.id)
|
||||
assert actual_exchange is not None
|
||||
assert actual_exchange.name == exchange_data["name"]
|
||||
assert actual_exchange.notes == exchange_data["notes"]
|
||||
assert actual_exchange.user_id == user_id
|
||||
|
||||
|
||||
def test_get_exchange_by_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id=user_id, name="LSE")
|
||||
exchange = crud.get_exchange_by_id(session, exchange_id)
|
||||
assert exchange is not None
|
||||
assert exchange.id == exchange_id
|
||||
assert exchange.name == "LSE"
|
||||
assert exchange.user_id == user_id
|
||||
|
||||
|
||||
def test_get_exchange_by_name_and_user_id(session: Session) -> None:
|
||||
exchange_name = "TSX"
|
||||
user_id = make_user(session)
|
||||
make_exchange(session, user_id=user_id, name=exchange_name)
|
||||
exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id)
|
||||
assert exchange is not None
|
||||
assert exchange.name == exchange_name
|
||||
assert exchange.user_id == user_id
|
||||
|
||||
|
||||
def test_get_all_exchanges(session: Session) -> None:
|
||||
exchange_names = ["NYSE", "NASDAQ", "LSE"]
|
||||
user_id = make_user(session)
|
||||
for name in exchange_names:
|
||||
make_exchange(session, user_id=user_id, name=name)
|
||||
|
||||
exchanges = crud.get_all_exchanges(session)
|
||||
assert len(exchanges) >= 3
|
||||
fetched_names = {ex.name for ex in exchanges}
|
||||
for name in exchange_names:
|
||||
assert name in fetched_names
|
||||
|
||||
|
||||
def test_get_all_exchanges_by_user_id(session: Session) -> None:
|
||||
exchange_names = ["NYSE", "NASDAQ"]
|
||||
user_id = make_user(session)
|
||||
for name in exchange_names:
|
||||
make_exchange(session, user_id=user_id, name=name)
|
||||
|
||||
exchanges = crud.get_all_exchanges_by_user_id(session, user_id)
|
||||
assert len(exchanges) == len(exchange_names)
|
||||
fetched_names = {ex.name for ex in exchanges}
|
||||
for name in exchange_names:
|
||||
assert name in fetched_names
|
||||
|
||||
|
||||
def test_update_exchange(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange")
|
||||
update_data = {
|
||||
"name": "Updated Exchange",
|
||||
"notes": "Updated notes for the exchange",
|
||||
}
|
||||
updated_exchange = crud.update_exchange(session, exchange_id, update_data)
|
||||
assert updated_exchange is not None
|
||||
assert updated_exchange.id == exchange_id
|
||||
assert updated_exchange.name == update_data["name"]
|
||||
assert updated_exchange.notes == update_data["notes"]
|
||||
|
||||
session.refresh(updated_exchange)
|
||||
actual_exchange = session.get(models.Exchanges, exchange_id)
|
||||
assert actual_exchange is not None
|
||||
assert actual_exchange.name == update_data["name"]
|
||||
assert actual_exchange.notes == update_data["notes"]
|
||||
|
||||
|
||||
def test_delete_exchange(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange")
|
||||
crud.delete_exchange(session, exchange_id)
|
||||
deleted_exchange = session.get(models.Exchanges, exchange_id)
|
||||
assert deleted_exchange is None
|
||||
|
||||
|
||||
# Users
|
||||
def test_create_user(session: Session) -> None:
|
||||
user_data = {
|
||||
"username": "newuser",
|
||||
"password_hash": "newhashedpassword",
|
||||
@@ -528,7 +991,23 @@ def test_create_user(session: Session):
|
||||
assert actual_user.password_hash == user_data["password_hash"]
|
||||
|
||||
|
||||
def test_update_user(session: Session):
|
||||
def test_get_user_by_id(session: Session) -> None:
|
||||
user_id = make_user(session, username="fetchuser")
|
||||
user = crud.get_user_by_id(session, user_id)
|
||||
assert user is not None
|
||||
assert user.id == user_id
|
||||
assert user.username == "fetchuser"
|
||||
|
||||
|
||||
def test_get_user_by_username(session: Session) -> None:
|
||||
username = "uniqueuser"
|
||||
make_user(session, username=username)
|
||||
user = crud.get_user_by_username(session, username)
|
||||
assert user is not None
|
||||
assert user.username == username
|
||||
|
||||
|
||||
def test_update_user(session: Session) -> None:
|
||||
user_id = make_user(session, username="updatableuser")
|
||||
|
||||
update_data = {
|
||||
@@ -545,14 +1024,14 @@ def test_update_user(session: Session):
|
||||
assert actual_user.password_hash == update_data["password_hash"]
|
||||
|
||||
|
||||
def test_update_user_immutable_fields(session: Session):
|
||||
def test_update_user_immutable_fields(session: Session) -> None:
|
||||
user_id = make_user(session, username="immutableuser")
|
||||
|
||||
# Attempt to update immutable fields
|
||||
update_data = {
|
||||
"id": user_id + 1, # Trying to change the ID
|
||||
"username": "newusername", # Trying to change the username
|
||||
"created_at": datetime(2020, 1, 1), # Trying to change created_at
|
||||
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||
"password_hash": "validupdate", # Valid field to update
|
||||
}
|
||||
|
||||
@@ -566,7 +1045,7 @@ def test_update_user_immutable_fields(session: Session):
|
||||
|
||||
|
||||
# login sessions
|
||||
def test_create_login_session(session: Session):
|
||||
def test_create_login_session(session: Session) -> None:
|
||||
user_id = make_user(session, username="testuser")
|
||||
session_token_hash = "sessiontokenhashed"
|
||||
login_session = crud.create_login_session(session, user_id, session_token_hash)
|
||||
@@ -575,7 +1054,7 @@ def test_create_login_session(session: Session):
|
||||
assert login_session.session_token_hash == session_token_hash
|
||||
|
||||
|
||||
def test_create_login_session_with_invalid_user(session: Session):
|
||||
def test_create_login_session_with_invalid_user(session: Session) -> None:
|
||||
invalid_user_id = 9999 # Assuming this user ID does not exist
|
||||
session_token_hash = "sessiontokenhashed"
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
@@ -583,40 +1062,44 @@ def test_create_login_session_with_invalid_user(session: Session):
|
||||
assert "user_id does not exist" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_login_session_by_token_and_user_id(session: Session):
|
||||
now = datetime.now()
|
||||
def test_get_login_session_by_token_and_user_id(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||
assert fetched_session is not None
|
||||
assert fetched_session.id == created_session.id
|
||||
assert fetched_session.user_id == created_session.user_id
|
||||
assert fetched_session.session_token_hash == created_session.session_token_hash
|
||||
|
||||
|
||||
def test_update_login_session(session: Session):
|
||||
now = datetime.now()
|
||||
def test_get_login_session_by_token(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash)
|
||||
assert fetched_session is not None
|
||||
assert fetched_session.id == created_session.id
|
||||
assert fetched_session.user_id == created_session.user_id
|
||||
assert fetched_session.session_token_hash == created_session.session_token_hash
|
||||
|
||||
|
||||
def test_update_login_session(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
|
||||
update_data = {
|
||||
"last_seen_at": now + timedelta(hours=1),
|
||||
"last_used_ip": "192.168.1.1",
|
||||
}
|
||||
updated_session = crud.update_login_session(
|
||||
session, created_session.session_token_hash, update_data
|
||||
)
|
||||
updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data)
|
||||
assert updated_session is not None
|
||||
assert updated_session.last_seen_at == update_data["last_seen_at"]
|
||||
assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
|
||||
assert updated_session.last_used_ip == update_data["last_used_ip"]
|
||||
|
||||
|
||||
def test_delete_login_session(session: Session):
|
||||
now = datetime.now()
|
||||
def test_delete_login_session(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
|
||||
crud.delete_login_session(session, created_session.session_token_hash)
|
||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||
assert deleted_session is None
|
||||
|
||||
@@ -46,8 +46,7 @@ def database_ctx(db: Database) -> Generator[Database, None, None]:
|
||||
|
||||
def test_select_one_executes() -> None:
|
||||
db = create_database(None) # in-memory by default
|
||||
with database_ctx(db):
|
||||
with session_ctx(db) as session:
|
||||
with database_ctx(db), session_ctx(db) as session:
|
||||
val = session.exec(text("SELECT 1")).scalar_one()
|
||||
assert int(val) == 1
|
||||
|
||||
@@ -56,9 +55,7 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None:
|
||||
db = create_database(None) # in-memory with StaticPool
|
||||
with database_ctx(db):
|
||||
with session_ctx(db) as s1:
|
||||
s1.exec(
|
||||
text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);")
|
||||
)
|
||||
s1.exec(text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);"))
|
||||
s1.exec(text("INSERT INTO t (val) VALUES (:v)").bindparams(v="hello"))
|
||||
with session_ctx(db) as s2:
|
||||
got = s2.exec(text("SELECT val FROM t")).scalar_one()
|
||||
@@ -67,9 +64,8 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None:
|
||||
|
||||
def test_sqlite_pragmas_applied() -> None:
|
||||
db = create_database(None)
|
||||
with database_ctx(db):
|
||||
with database_ctx(db), session_ctx(db) as session:
|
||||
# PRAGMA returns integer 1 when foreign_keys ON
|
||||
with session_ctx(db) as session:
|
||||
fk = session.exec(text("PRAGMA foreign_keys")).scalar_one()
|
||||
assert int(fk) == 1
|
||||
|
||||
@@ -82,16 +78,8 @@ def test_rollback_on_exception() -> None:
|
||||
# Create table then insert and raise inside the same session to force rollback
|
||||
with pytest.raises(RuntimeError): # noqa: PT012, SIM117
|
||||
with session_ctx(db) as s:
|
||||
s.exec(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);"
|
||||
)
|
||||
)
|
||||
s.exec(
|
||||
text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(
|
||||
v="will_rollback"
|
||||
)
|
||||
)
|
||||
s.exec(text("CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);"))
|
||||
s.exec(text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(v="will_rollback"))
|
||||
# simulate handler error -> should trigger rollback in get_session
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
|
||||
@@ -36,33 +36,66 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"user_id": ("INTEGER", 1, 0),
|
||||
"friendly_name": ("TEXT", 0, 0),
|
||||
"symbol": ("TEXT", 1, 0),
|
||||
"exchange_id": ("INTEGER", 1, 0),
|
||||
"underlying_currency": ("TEXT", 1, 0),
|
||||
"status": ("TEXT", 1, 0),
|
||||
"funding_source": ("TEXT", 0, 0),
|
||||
"capital_exposure_cents": ("INTEGER", 0, 0),
|
||||
"loan_amount_cents": ("INTEGER", 0, 0),
|
||||
"loan_interest_rate_bps": ("INTEGER", 0, 0),
|
||||
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
|
||||
"start_date": ("DATE", 1, 0),
|
||||
"end_date": ("DATE", 0, 0),
|
||||
"latest_interest_accrued_date": ("DATE", 0, 0),
|
||||
"total_accrued_amount_cents": ("INTEGER", 1, 0),
|
||||
},
|
||||
"cycle_loan_change_events": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"cycle_id": ("INTEGER", 1, 0),
|
||||
"effective_date": ("DATE", 1, 0),
|
||||
"loan_amount_cents": ("INTEGER", 0, 0),
|
||||
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
|
||||
"related_trade_id": ("INTEGER", 0, 0),
|
||||
"notes": ("TEXT", 0, 0),
|
||||
"created_at": ("DATETIME", 1, 0),
|
||||
},
|
||||
"cycle_daily_accrual": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"cycle_id": ("INTEGER", 1, 0),
|
||||
"accrual_date": ("DATE", 1, 0),
|
||||
"accrual_amount_cents": ("INTEGER", 1, 0),
|
||||
"created_at": ("DATETIME", 1, 0),
|
||||
},
|
||||
"trades": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"user_id": ("INTEGER", 1, 0),
|
||||
"friendly_name": ("TEXT", 0, 0),
|
||||
"symbol": ("TEXT", 1, 0),
|
||||
"exchange_id": ("INTEGER", 1, 0),
|
||||
"underlying_currency": ("TEXT", 1, 0),
|
||||
"trade_type": ("TEXT", 1, 0),
|
||||
"trade_strategy": ("TEXT", 1, 0),
|
||||
"trade_date": ("DATE", 1, 0),
|
||||
"trade_time_utc": ("DATETIME", 1, 0),
|
||||
"expiry_date": ("DATE", 0, 0),
|
||||
"strike_price_cents": ("INTEGER", 0, 0),
|
||||
"quantity": ("INTEGER", 1, 0),
|
||||
"quantity_multiplier": ("INTEGER", 1, 0),
|
||||
"price_cents": ("INTEGER", 1, 0),
|
||||
"gross_cash_flow_cents": ("INTEGER", 1, 0),
|
||||
"commission_cents": ("INTEGER", 1, 0),
|
||||
"net_cash_flow_cents": ("INTEGER", 1, 0),
|
||||
"is_invalidated": ("BOOLEAN", 1, 0),
|
||||
"invalidated_at": ("DATETIME", 0, 0),
|
||||
"replaced_by_trade_id": ("INTEGER", 0, 0),
|
||||
"notes": ("TEXT", 0, 0),
|
||||
"cycle_id": ("INTEGER", 0, 0),
|
||||
},
|
||||
"exchanges": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"user_id": ("INTEGER", 1, 0),
|
||||
"name": ("TEXT", 1, 0),
|
||||
"notes": ("TEXT", 0, 0),
|
||||
},
|
||||
"sessions": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"user_id": ("INTEGER", 1, 0),
|
||||
@@ -80,21 +113,35 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"trades": [
|
||||
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||
{"table": "users", "from": "user_id", "to": "id"},
|
||||
{"table": "exchanges", "from": "exchange_id", "to": "id"},
|
||||
],
|
||||
"cycles": [
|
||||
{"table": "users", "from": "user_id", "to": "id"},
|
||||
{"table": "exchanges", "from": "exchange_id", "to": "id"},
|
||||
],
|
||||
"cycle_loan_change_events": [
|
||||
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||
{"table": "trades", "from": "related_trade_id", "to": "id"},
|
||||
],
|
||||
"cycle_daily_accrual": [
|
||||
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||
],
|
||||
"sessions": [
|
||||
{"table": "users", "from": "user_id", "to": "id"},
|
||||
],
|
||||
"users": [],
|
||||
"exchanges": [
|
||||
{"table": "users", "from": "user_id", "to": "id"},
|
||||
],
|
||||
}
|
||||
|
||||
with engine.connect() as conn:
|
||||
# check tables exist
|
||||
rows = conn.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
text("SELECT name FROM sqlite_master WHERE type='table'"),
|
||||
).fetchall()
|
||||
found_tables = {r[0] for r in rows}
|
||||
assert set(expected_schema.keys()).issubset(found_tables), (
|
||||
f"missing tables: {set(expected_schema.keys()) - found_tables}"
|
||||
)
|
||||
assert set(expected_schema.keys()).issubset(found_tables), f"missing tables: {set(expected_schema.keys()) - found_tables}"
|
||||
|
||||
# check user_version
|
||||
uv = conn.execute(text("PRAGMA user_version")).fetchone()
|
||||
@@ -103,14 +150,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
# validate each table columns
|
||||
for tbl_name, cols in expected_schema.items():
|
||||
info_rows = conn.execute(
|
||||
text(f"PRAGMA table_info({tbl_name})")
|
||||
).fetchall()
|
||||
info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall()
|
||||
# map: name -> (type, notnull, pk)
|
||||
actual = {
|
||||
r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5]))
|
||||
for r in info_rows
|
||||
}
|
||||
actual = {r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) for r in info_rows}
|
||||
for colname, (exp_type, exp_notnull, exp_pk) in cols.items():
|
||||
assert colname in actual, f"{tbl_name}: missing column {colname}"
|
||||
act_type, act_notnull, act_pk = actual[colname]
|
||||
@@ -122,22 +164,47 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert exp_type in act_base or act_base in exp_type, (
|
||||
f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}"
|
||||
)
|
||||
assert act_notnull == exp_notnull, (
|
||||
f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}"
|
||||
)
|
||||
assert act_pk == exp_pk, (
|
||||
f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}"
|
||||
)
|
||||
assert act_notnull == exp_notnull, f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}"
|
||||
assert act_pk == exp_pk, f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}"
|
||||
for tbl_name, fks in expected_fks.items():
|
||||
fk_rows = conn.execute(
|
||||
text(f"PRAGMA foreign_key_list('{tbl_name}')")
|
||||
).fetchall()
|
||||
fk_rows = conn.execute(text(f"PRAGMA foreign_key_list('{tbl_name}')")).fetchall()
|
||||
# fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match)
|
||||
actual_fk_list = [
|
||||
{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows
|
||||
]
|
||||
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
|
||||
for efk in fks:
|
||||
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"
|
||||
|
||||
# check trades.replaced_by_trade_id self-referential FK
|
||||
fk_rows = conn.execute(text("PRAGMA foreign_key_list('trades')")).fetchall()
|
||||
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
|
||||
assert {"table": "trades", "from": "replaced_by_trade_id", "to": "id"} in actual_fk_list, (
|
||||
"missing self FK trades.replaced_by_trade_id -> trades.id"
|
||||
)
|
||||
|
||||
# helper to find unique index on a column
|
||||
def has_unique_index(table: str, column: str) -> bool:
|
||||
idx_rows = conn.execute(text(f"PRAGMA index_list('{table}')")).fetchall()
|
||||
for idx in idx_rows:
|
||||
idx_name = idx[1]
|
||||
is_unique = bool(idx[2])
|
||||
if not is_unique:
|
||||
continue
|
||||
info = conn.execute(text(f"PRAGMA index_info('{idx_name}')")).fetchall()
|
||||
cols = [r[2] for r in info]
|
||||
if column in cols:
|
||||
return True
|
||||
return False
|
||||
|
||||
assert has_unique_index("trades", "friendly_name"), (
|
||||
"expected unique index on trades(friendly_name) per uq_trades_user_friendly_name"
|
||||
)
|
||||
assert has_unique_index("cycles", "friendly_name"), (
|
||||
"expected unique index on cycles(friendly_name) per uq_cycles_user_friendly_name"
|
||||
)
|
||||
assert has_unique_index("exchanges", "name"), "expected unique index on exchanges(name) per uq_exchanges_user_name"
|
||||
assert has_unique_index("sessions", "session_token_hash"), "expected unique index on sessions(session_token_hash)"
|
||||
assert has_unique_index("cycle_loan_change_events", "related_trade_id"), (
|
||||
"expected unique index on cycle_loan_change_events(related_trade_id)"
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
SQLModel.metadata.clear()
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_home_route(client):
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello"}
|
||||
|
||||
|
||||
def test_about_route(client):
|
||||
response = client.get("/about")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "This is the about page."}
|
||||
24
backend/tests/test_security.py
Normal file
24
backend/tests/test_security.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from trading_journal import security
|
||||
|
||||
|
||||
def test_hash_and_verify_password() -> None:
|
||||
plain = "password"
|
||||
hashed = security.hash_password(plain)
|
||||
assert hashed != plain
|
||||
assert security.verify_password(plain, hashed)
|
||||
|
||||
|
||||
def test_generate_session_token() -> None:
|
||||
token1 = security.generate_session_token()
|
||||
token2 = security.generate_session_token()
|
||||
assert token1 != token2
|
||||
assert len(token1) > 0
|
||||
assert len(token2) > 0
|
||||
|
||||
|
||||
def test_hash_and_verify_session_token_sha256() -> None:
|
||||
token = security.generate_session_token()
|
||||
token_hash = security.hash_session_token_sha256(token)
|
||||
assert token_hash != token
|
||||
assert security.verify_token_sha256(token, token_hash)
|
||||
assert not security.verify_token_sha256(token + "x", token_hash)
|
||||
1101
backend/tests/test_service.py
Normal file
1101
backend/tests/test_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ def test_default_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
s = load_settings()
|
||||
assert s.host == "0.0.0.0" # noqa: S104
|
||||
assert s.port == 8000 # noqa: PLR2004
|
||||
assert s.port == 8000
|
||||
assert s.workers == 1
|
||||
assert s.log_level == "info"
|
||||
|
||||
@@ -26,8 +26,8 @@ def test_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
s = load_settings()
|
||||
assert s.host == "127.0.0.1"
|
||||
assert s.port == 9000 # noqa: PLR2004
|
||||
assert s.workers == 3 # noqa: PLR2004
|
||||
assert s.port == 9000
|
||||
assert s.workers == 3
|
||||
assert s.log_level == "debug"
|
||||
|
||||
|
||||
@@ -40,6 +40,6 @@ def test_yaml_config_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
s = load_settings()
|
||||
assert s.host == "10.0.0.5"
|
||||
assert s.port == 8088 # noqa: PLR2004
|
||||
assert s.workers == 5 # noqa: PLR2004
|
||||
assert s.port == 8088
|
||||
assert s.workers == 5
|
||||
assert s.log_level == "debug"
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Mapping
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from trading_journal import models
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
|
||||
def _check_enum(enum_cls, value, field_name: str):
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
|
||||
# Generic enum member type
|
||||
T = TypeVar("T", bound="Enum")
|
||||
|
||||
|
||||
def _check_enum(enum_cls: type[T], value: object, field_name: str) -> T:
|
||||
if value is None:
|
||||
raise ValueError(f"{field_name} is required")
|
||||
# already an enum member
|
||||
@@ -22,38 +35,56 @@ def _check_enum(enum_cls, value, field_name: str):
|
||||
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
|
||||
|
||||
|
||||
def _allowed_columns(model: type[models.SQLModel]) -> set[str]:
|
||||
tbl = cast("models.SQLModel", model).__table__ # type: ignore[attr-defined]
|
||||
return {c.name for c in tbl.columns}
|
||||
|
||||
|
||||
AnyModel = Any
|
||||
|
||||
|
||||
def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]:
|
||||
if isinstance(data, BaseModel):
|
||||
return data.model_dump(exclude_unset=True)
|
||||
if hasattr(data, "dict"):
|
||||
return data.dict(exclude_unset=True)
|
||||
return dict(data)
|
||||
|
||||
|
||||
# Trades
|
||||
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
||||
if hasattr(trade_data, "dict"):
|
||||
data = trade_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(trade_data)
|
||||
allowed = {c.name for c in models.Trades.__table__.columns}
|
||||
def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
|
||||
data = _data_to_dict(trade_data)
|
||||
allowed = _allowed_columns(models.Trades)
|
||||
payload = {k: v for k, v in data.items() if k in allowed}
|
||||
cycle_id = payload.get("cycle_id")
|
||||
if "symbol" not in payload:
|
||||
raise ValueError("symbol is required")
|
||||
if "exchange_id" not in payload and cycle_id is None:
|
||||
raise ValueError("exchange_id is required when no cycle is attached")
|
||||
# If an exchange_id is provided (and no cycle is attached), ensure the exchange exists
|
||||
# and belongs to the same user as the trade (if user_id is provided).
|
||||
if cycle_id is None and "exchange_id" in payload:
|
||||
ex = session.get(models.Exchanges, payload["exchange_id"])
|
||||
if ex is None:
|
||||
raise ValueError("exchange_id does not exist")
|
||||
user_id = payload.get("user_id")
|
||||
if user_id is not None and ex.user_id != user_id:
|
||||
raise ValueError("exchange.user_id does not match trade.user_id")
|
||||
if "underlying_currency" not in payload:
|
||||
raise ValueError("underlying_currency is required")
|
||||
payload["underlying_currency"] = _check_enum(
|
||||
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
|
||||
)
|
||||
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
|
||||
if "trade_type" not in payload:
|
||||
raise ValueError("trade_type is required")
|
||||
payload["trade_type"] = _check_enum(
|
||||
models.TradeType, payload["trade_type"], "trade_type"
|
||||
)
|
||||
payload["trade_type"] = _check_enum(models.TradeType, payload["trade_type"], "trade_type")
|
||||
if "trade_strategy" not in payload:
|
||||
raise ValueError("trade_strategy is required")
|
||||
payload["trade_strategy"] = _check_enum(
|
||||
models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
|
||||
)
|
||||
payload["trade_strategy"] = _check_enum(models.TradeStrategy, payload["trade_strategy"], "trade_strategy")
|
||||
# trade_time_utc is the creation moment: always set to now (caller shouldn't provide)
|
||||
now = datetime.now(timezone.utc)
|
||||
payload.pop("trade_time_utc", None)
|
||||
payload["trade_time_utc"] = now
|
||||
if "trade_date" not in payload or payload.get("trade_date") is None:
|
||||
payload["trade_date"] = payload["trade_time_utc"].date()
|
||||
cycle_id = payload.get("cycle_id")
|
||||
user_id = payload.get("user_id")
|
||||
if "quantity" not in payload:
|
||||
raise ValueError("quantity is required")
|
||||
@@ -61,15 +92,10 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
||||
raise ValueError("price_cents is required")
|
||||
if "commission_cents" not in payload:
|
||||
payload["commission_cents"] = 0
|
||||
quantity: int = payload["quantity"]
|
||||
price_cents: int = payload["price_cents"]
|
||||
commission_cents: int = payload["commission_cents"]
|
||||
if "gross_cash_flow_cents" not in payload:
|
||||
payload["gross_cash_flow_cents"] = -quantity * price_cents
|
||||
raise ValueError("gross_cash_flow_cents is required")
|
||||
if "net_cash_flow_cents" not in payload:
|
||||
payload["net_cash_flow_cents"] = (
|
||||
payload["gross_cash_flow_cents"] - commission_cents
|
||||
)
|
||||
raise ValueError("net_cash_flow_cents is required")
|
||||
|
||||
# If no cycle_id provided, create Cycle instance but don't call create_cycle()
|
||||
created_cycle = None
|
||||
@@ -77,9 +103,9 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
||||
c_payload = {
|
||||
"user_id": user_id,
|
||||
"symbol": payload["symbol"],
|
||||
"exchange_id": payload["exchange_id"],
|
||||
"underlying_currency": payload["underlying_currency"],
|
||||
"friendly_name": "Auto-created Cycle by trade "
|
||||
+ payload.get("friendly_name", ""),
|
||||
"friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""),
|
||||
"status": models.CycleStatus.OPEN,
|
||||
"start_date": payload["trade_date"],
|
||||
}
|
||||
@@ -90,9 +116,11 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
||||
# If cycle_id provided, validate existence and ownership
|
||||
if cycle_id is not None:
|
||||
cycle = session.get(models.Cycles, cycle_id)
|
||||
|
||||
if cycle is None:
|
||||
raise ValueError("cycle_id does not exist")
|
||||
else:
|
||||
payload.pop("exchange_id", None) # ignore exchange_id if provided; use cycle's exchange_id
|
||||
payload["exchange_id"] = cycle.exchange_id
|
||||
if cycle.user_id != user_id:
|
||||
raise ValueError("cycle.user_id does not match trade.user_id")
|
||||
|
||||
@@ -119,9 +147,7 @@ def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None:
|
||||
return session.get(models.Trades, trade_id)
|
||||
|
||||
|
||||
def get_trade_by_user_id_and_friendly_name(
|
||||
session: Session, user_id: int, friendly_name: str
|
||||
) -> models.Trades | None:
|
||||
def get_trade_by_user_id_and_friendly_name(session: Session, user_id: int, friendly_name: str) -> models.Trades | None:
|
||||
statement = select(models.Trades).where(
|
||||
models.Trades.user_id == user_id,
|
||||
models.Trades.friendly_name == friendly_name,
|
||||
@@ -133,7 +159,22 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]
|
||||
statement = select(models.Trades).where(
|
||||
models.Trades.user_id == user_id,
|
||||
)
|
||||
return session.exec(statement).all()
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades:
|
||||
trade: models.Trades | None = session.get(models.Trades, trade_id)
|
||||
if trade is None:
|
||||
raise ValueError("trade_id does not exist")
|
||||
trade.friendly_name = friendly_name
|
||||
session.add(trade)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("update_trade_friendly_name integrity error") from e
|
||||
session.refresh(trade)
|
||||
return trade
|
||||
|
||||
|
||||
def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades:
|
||||
@@ -169,36 +210,33 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades:
|
||||
return trade
|
||||
|
||||
|
||||
def replace_trade(
|
||||
session: Session, old_trade_id: int, new_trade_data: Mapping
|
||||
) -> models.Trades:
|
||||
def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
|
||||
invalidate_trade(session, old_trade_id)
|
||||
if hasattr(new_trade_data, "dict"):
|
||||
data = new_trade_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(new_trade_data)
|
||||
data = _data_to_dict(new_trade_data)
|
||||
data["replaced_by_trade_id"] = old_trade_id
|
||||
new_trade = create_trade(session, data)
|
||||
return new_trade
|
||||
return create_trade(session, data)
|
||||
|
||||
|
||||
# Cycles
|
||||
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
|
||||
if hasattr(cycle_data, "dict"):
|
||||
data = cycle_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(cycle_data)
|
||||
allowed = {c.name for c in models.Cycles.__table__.columns}
|
||||
def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
|
||||
data = _data_to_dict(cycle_data)
|
||||
allowed = _allowed_columns(models.Cycles)
|
||||
payload = {k: v for k, v in data.items() if k in allowed}
|
||||
if "user_id" not in payload:
|
||||
raise ValueError("user_id is required")
|
||||
if "symbol" not in payload:
|
||||
raise ValueError("symbol is required")
|
||||
if "exchange_id" not in payload:
|
||||
raise ValueError("exchange_id is required")
|
||||
# ensure the exchange exists and belongs to the same user
|
||||
ex = session.get(models.Exchanges, payload["exchange_id"])
|
||||
if ex is None:
|
||||
raise ValueError("exchange_id does not exist")
|
||||
if ex.user_id != payload.get("user_id"):
|
||||
raise ValueError("exchange.user_id does not match cycle.user_id")
|
||||
if "underlying_currency" not in payload:
|
||||
raise ValueError("underlying_currency is required")
|
||||
payload["underlying_currency"] = _check_enum(
|
||||
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
|
||||
)
|
||||
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
|
||||
if "status" not in payload:
|
||||
raise ValueError("status is required")
|
||||
payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status")
|
||||
@@ -216,30 +254,44 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
|
||||
return c
|
||||
|
||||
|
||||
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"}
|
||||
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date"}
|
||||
|
||||
|
||||
def update_cycle(
|
||||
session: Session, cycle_id: int, update_data: Mapping
|
||||
) -> models.Cycles:
|
||||
def get_cycle_by_id(session: Session, cycle_id: int) -> models.Cycles | None:
|
||||
return session.get(models.Cycles, cycle_id)
|
||||
|
||||
|
||||
def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles]:
|
||||
statement = select(models.Cycles).where(
|
||||
models.Cycles.user_id == user_id,
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
|
||||
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||
if cycle is None:
|
||||
raise ValueError("cycle_id does not exist")
|
||||
if hasattr(update_data, "dict"):
|
||||
data = update_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(update_data)
|
||||
data = _data_to_dict(update_data)
|
||||
|
||||
allowed = {c.name for c in models.Cycles.__table__.columns}
|
||||
allowed = _allowed_columns(models.Cycles)
|
||||
for k, v in data.items():
|
||||
if k in IMMUTABLE_CYCLE_FIELDS:
|
||||
raise ValueError(f"field {k!r} is immutable")
|
||||
if k not in allowed:
|
||||
continue
|
||||
# If trying to change exchange_id, ensure the new exchange exists and belongs to
|
||||
# the same user as the cycle.
|
||||
if k == "exchange_id":
|
||||
ex = session.get(models.Exchanges, v)
|
||||
if ex is None:
|
||||
raise ValueError("exchange_id does not exist")
|
||||
if ex.user_id != cycle.user_id:
|
||||
raise ValueError("exchange.user_id does not match cycle.user_id")
|
||||
if k == "underlying_currency":
|
||||
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency")
|
||||
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901
|
||||
if k == "status":
|
||||
v = _check_enum(models.CycleStatus, v, "status")
|
||||
v = _check_enum(models.CycleStatus, v, "status") # noqa: PLW2901
|
||||
setattr(cycle, k, v)
|
||||
session.add(cycle)
|
||||
try:
|
||||
@@ -251,16 +303,179 @@ def update_cycle(
|
||||
return cycle
|
||||
|
||||
|
||||
# Cycle loan and interest
|
||||
def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
|
||||
data = _data_to_dict(loan_data)
|
||||
allowed = _allowed_columns(models.CycleLoanChangeEvents)
|
||||
payload = {k: v for k, v in data.items() if k in allowed}
|
||||
if "cycle_id" not in payload:
|
||||
raise ValueError("cycle_id is required")
|
||||
cycle = session.get(models.Cycles, payload["cycle_id"])
|
||||
if cycle is None:
|
||||
raise ValueError("cycle_id does not exist")
|
||||
|
||||
payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date()
|
||||
payload["created_at"] = datetime.now(timezone.utc)
|
||||
cle = models.CycleLoanChangeEvents(**payload)
|
||||
session.add(cle)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("create_cycle_loan_event integrity error") from e
|
||||
session.refresh(cle)
|
||||
return cle
|
||||
|
||||
|
||||
def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]:
|
||||
eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date)
|
||||
id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id)
|
||||
statement = (
|
||||
select(models.CycleLoanChangeEvents)
|
||||
.where(
|
||||
models.CycleLoanChangeEvents.cycle_id == cycle_id,
|
||||
)
|
||||
.order_by(eff_col, id_col.asc())
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual:
|
||||
cycle = session.get(models.Cycles, cycle_id)
|
||||
if cycle is None:
|
||||
raise ValueError("cycle_id does not exist")
|
||||
existing = session.exec(
|
||||
select(models.CycleDailyAccrual).where(
|
||||
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||
models.CycleDailyAccrual.accrual_date == accrual_date,
|
||||
),
|
||||
).first()
|
||||
if existing:
|
||||
return existing
|
||||
if accrual_amount_cents < 0:
|
||||
raise ValueError("accrual_amount_cents must be non-negative")
|
||||
row = models.CycleDailyAccrual(
|
||||
cycle_id=cycle_id,
|
||||
accrual_date=accrual_date,
|
||||
accrual_amount_cents=accrual_amount_cents,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("create_cycle_daily_accrual integrity error") from e
|
||||
session.refresh(row)
|
||||
return row
|
||||
|
||||
|
||||
def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]:
|
||||
date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date)
|
||||
statement = (
|
||||
select(models.CycleDailyAccrual)
|
||||
.where(
|
||||
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||
)
|
||||
.order_by(date_col.asc())
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None:
|
||||
statement = select(models.CycleDailyAccrual).where(
|
||||
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||
models.CycleDailyAccrual.accrual_date == accrual_date,
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
# Exchanges
|
||||
IMMUTABLE_EXCHANGE_FIELDS = {"id"}
|
||||
|
||||
|
||||
def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
|
||||
data = _data_to_dict(exchange_data)
|
||||
allowed = _allowed_columns(models.Exchanges)
|
||||
payload = {k: v for k, v in data.items() if k in allowed}
|
||||
if "name" not in payload:
|
||||
raise ValueError("name is required")
|
||||
|
||||
e = models.Exchanges(**payload)
|
||||
session.add(e)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("create_exchange integrity error") from e
|
||||
session.refresh(e)
|
||||
return e
|
||||
|
||||
|
||||
def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None:
|
||||
return session.get(models.Exchanges, exchange_id)
|
||||
|
||||
|
||||
def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None:
|
||||
statement = select(models.Exchanges).where(
|
||||
models.Exchanges.name == name,
|
||||
models.Exchanges.user_id == user_id,
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def get_all_exchanges(session: Session) -> list[models.Exchanges]:
|
||||
statement = select(models.Exchanges)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]:
|
||||
statement = select(models.Exchanges).where(
|
||||
models.Exchanges.user_id == user_id,
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
|
||||
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||
if exchange is None:
|
||||
raise ValueError("exchange_id does not exist")
|
||||
data = _data_to_dict(update_data)
|
||||
allowed = _allowed_columns(models.Exchanges)
|
||||
for k, v in data.items():
|
||||
if k in IMMUTABLE_EXCHANGE_FIELDS:
|
||||
raise ValueError(f"field {k!r} is immutable")
|
||||
if k in allowed:
|
||||
setattr(exchange, k, v)
|
||||
session.add(exchange)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("update_exchange integrity error") from e
|
||||
session.refresh(exchange)
|
||||
return exchange
|
||||
|
||||
|
||||
def delete_exchange(session: Session, exchange_id: int) -> None:
|
||||
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||
if exchange is None:
|
||||
return
|
||||
session.delete(exchange)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("delete_exchange integrity error") from e
|
||||
|
||||
|
||||
# Users
|
||||
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
|
||||
|
||||
|
||||
def create_user(session: Session, user_data: Mapping) -> models.Users:
|
||||
if hasattr(user_data, "dict"):
|
||||
data = user_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(user_data)
|
||||
allowed = {c.name for c in models.Users.__table__.columns}
|
||||
def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users:
|
||||
data = _data_to_dict(user_data)
|
||||
allowed = _allowed_columns(models.Users)
|
||||
payload = {k: v for k, v in data.items() if k in allowed}
|
||||
if "username" not in payload:
|
||||
raise ValueError("username is required")
|
||||
@@ -278,15 +493,23 @@ def create_user(session: Session, user_data: Mapping) -> models.Users:
|
||||
return u
|
||||
|
||||
|
||||
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users:
|
||||
def get_user_by_id(session: Session, user_id: int) -> models.Users | None:
|
||||
return session.get(models.Users, user_id)
|
||||
|
||||
|
||||
def get_user_by_username(session: Session, username: str) -> models.Users | None:
|
||||
statement = select(models.Users).where(
|
||||
models.Users.username == username,
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users:
|
||||
user: models.Users | None = session.get(models.Users, user_id)
|
||||
if user is None:
|
||||
raise ValueError("user_id does not exist")
|
||||
if hasattr(update_data, "dict"):
|
||||
data = update_data.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(update_data)
|
||||
allowed = {c.name for c in models.Users.__table__.columns}
|
||||
data = _data_to_dict(update_data)
|
||||
allowed = _allowed_columns(models.Users)
|
||||
for k, v in data.items():
|
||||
if k in IMMUTABLE_USER_FIELDS:
|
||||
raise ValueError(f"field {k!r} is immutable")
|
||||
@@ -315,10 +538,11 @@ def create_login_session(
|
||||
user: models.Users | None = session.get(models.Users, user_id)
|
||||
if user is None:
|
||||
raise ValueError("user_id does not exist")
|
||||
user_id_val = cast("int", user.id)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=session_length_seconds)
|
||||
s = models.Sessions(
|
||||
user_id=user.id,
|
||||
user_id=user_id_val,
|
||||
session_token_hash=session_token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
@@ -337,9 +561,7 @@ def create_login_session(
|
||||
return s
|
||||
|
||||
|
||||
def get_login_session_by_token_hash_and_user_id(
|
||||
session: Session, session_token_hash: str, user_id: int
|
||||
) -> models.Sessions | None:
|
||||
def get_login_session_by_token_hash_and_user_id(session: Session, session_token_hash: str, user_id: int) -> models.Sessions | None:
|
||||
statement = select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
models.Sessions.user_id == user_id,
|
||||
@@ -349,25 +571,29 @@ def get_login_session_by_token_hash_and_user_id(
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None:
|
||||
statement = select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
||||
|
||||
|
||||
def update_login_session(
|
||||
session: Session, session_token_hashed: str, update_session: Mapping
|
||||
) -> models.Sessions | None:
|
||||
def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hashed,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
),
|
||||
).first()
|
||||
if login_session is None:
|
||||
return None
|
||||
if hasattr(update_session, "dict"):
|
||||
data = update_session.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(update_session)
|
||||
allowed = {c.name for c in models.Sessions.__table__.columns}
|
||||
data = _data_to_dict(update_session)
|
||||
allowed = _allowed_columns(models.Sessions)
|
||||
for k, v in data.items():
|
||||
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
|
||||
setattr(login_session, k, v)
|
||||
@@ -385,7 +611,7 @@ def delete_login_session(session: Session, session_token_hash: str) -> None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
)
|
||||
),
|
||||
).first()
|
||||
if login_session is None:
|
||||
return
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, create_engine
|
||||
|
||||
from trading_journal import db_migration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from sqlite3 import Connection as DBAPIConnection
|
||||
@@ -24,17 +23,13 @@ class Database:
|
||||
) -> None:
|
||||
self._database_url = database_url or "sqlite:///:memory:"
|
||||
|
||||
default_connect = (
|
||||
{"check_same_thread": False, "timeout": 30}
|
||||
if self._database_url.startswith("sqlite")
|
||||
else {}
|
||||
)
|
||||
default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {}
|
||||
merged_connect = {**default_connect, **(connect_args or {})}
|
||||
|
||||
if self._database_url == "sqlite:///:memory:":
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Using in-memory SQLite database; all data will be lost when the application stops."
|
||||
"Using in-memory SQLite database; all data will be lost when the application stops.",
|
||||
)
|
||||
self._engine = create_engine(
|
||||
self._database_url,
|
||||
@@ -43,15 +38,11 @@ class Database:
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
else:
|
||||
self._engine = create_engine(
|
||||
self._database_url, echo=echo, connect_args=merged_connect
|
||||
)
|
||||
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect)
|
||||
|
||||
if self._database_url.startswith("sqlite"):
|
||||
|
||||
def _enable_sqlite_pragmas(
|
||||
dbapi_conn: DBAPIConnection, _connection_record: object
|
||||
) -> None:
|
||||
def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None:
|
||||
try:
|
||||
cur = dbapi_conn.cursor()
|
||||
cur.execute("PRAGMA journal_mode=WAL;")
|
||||
@@ -66,7 +57,7 @@ class Database:
|
||||
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
|
||||
|
||||
def init_db(self) -> None:
|
||||
db_migration.run_migrations(self._engine)
|
||||
pass
|
||||
|
||||
def get_session(self) -> Generator[Session, None, None]:
|
||||
session = Session(self._engine)
|
||||
@@ -79,6 +70,18 @@ class Database:
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@contextmanager
|
||||
def get_session_ctx_manager(self) -> Generator[Session, None, None]:
|
||||
session = Session(self._engine)
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def dispose(self) -> None:
|
||||
self._engine.dispose()
|
||||
|
||||
|
||||
@@ -23,10 +23,13 @@ def _mig_0_1(engine: Engine) -> None:
|
||||
SQLModel.metadata.create_all(
|
||||
bind=engine,
|
||||
tables=[
|
||||
models_v1.Trades.__table__,
|
||||
models_v1.Cycles.__table__,
|
||||
models_v1.Users.__table__,
|
||||
models_v1.Sessions.__table__,
|
||||
models_v1.Trades.__table__, # type: ignore[attr-defined]
|
||||
models_v1.Cycles.__table__, # type: ignore[attr-defined]
|
||||
models_v1.Users.__table__, # type: ignore[attr-defined]
|
||||
models_v1.Sessions.__table__, # type: ignore[attr-defined]
|
||||
models_v1.Exchanges.__table__, # type: ignore[attr-defined]
|
||||
models_v1.CycleLoanChangeEvents.__table__, # type: ignore[attr-defined]
|
||||
models_v1.CycleDailyAccrual.__table__, # type: ignore[attr-defined]
|
||||
],
|
||||
)
|
||||
|
||||
@@ -60,7 +63,7 @@ def run_migrations(engine: Engine, target_version: int | None = None) -> int:
|
||||
fn = MIGRATIONS.get(cur_version)
|
||||
if fn is None:
|
||||
raise RuntimeError(
|
||||
f"No migration from {cur_version} -> {cur_version + 1}"
|
||||
f"No migration from {cur_version} -> {cur_version + 1}",
|
||||
)
|
||||
# call migration with Engine (fn should use transactions)
|
||||
fn(engine)
|
||||
|
||||
136
backend/trading_journal/dto.py
Normal file
136
backend/trading_journal/dto.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime # noqa: TC003
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
username: str
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserRead(UserBase):
|
||||
id: int
|
||||
|
||||
|
||||
class SessionsBase(SQLModel):
|
||||
user_id: int
|
||||
|
||||
|
||||
class SessionRead(SessionsBase):
|
||||
id: int
|
||||
expires_at: datetime
|
||||
last_seen_at: datetime | None
|
||||
last_used_ip: str | None
|
||||
user_agent: str | None
|
||||
|
||||
|
||||
class SessionsCreate(SessionsBase):
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class SessionsUpdate(SQLModel):
|
||||
expires_at: datetime | None = None
|
||||
last_seen_at: datetime | None = None
|
||||
last_used_ip: str | None = None
|
||||
user_agent: str | None = None
|
||||
|
||||
|
||||
class ExchangesBase(SQLModel):
|
||||
name: str
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class ExchangesCreate(ExchangesBase):
|
||||
user_id: int
|
||||
|
||||
|
||||
class ExchangesRead(ExchangesBase):
|
||||
id: int
|
||||
|
||||
|
||||
class CycleBase(SQLModel):
|
||||
friendly_name: str | None = None
|
||||
status: str
|
||||
end_date: date | None = None
|
||||
funding_source: str | None = None
|
||||
capital_exposure_cents: int | None = None
|
||||
loan_amount_cents: int | None = None
|
||||
loan_interest_rate_tenth_bps: int | None = None
|
||||
trades: list[TradeRead] | None = None
|
||||
exchange: ExchangesRead | None = None
|
||||
|
||||
|
||||
class CycleCreate(CycleBase):
|
||||
user_id: int
|
||||
symbol: str
|
||||
exchange_id: int
|
||||
underlying_currency: UnderlyingCurrency
|
||||
start_date: date
|
||||
|
||||
|
||||
class CycleUpdate(CycleBase):
|
||||
id: int
|
||||
|
||||
|
||||
class CycleRead(CycleCreate):
|
||||
id: int
|
||||
|
||||
|
||||
class TradeBase(SQLModel):
|
||||
friendly_name: str | None = None
|
||||
symbol: str
|
||||
exchange_id: int
|
||||
underlying_currency: UnderlyingCurrency
|
||||
trade_type: TradeType
|
||||
trade_strategy: TradeStrategy
|
||||
trade_date: date
|
||||
quantity: int
|
||||
price_cents: int
|
||||
commission_cents: int
|
||||
notes: str | None = None
|
||||
cycle_id: int | None = None
|
||||
|
||||
|
||||
class TradeCreate(TradeBase):
|
||||
user_id: int | None = None
|
||||
trade_time_utc: datetime | None = None
|
||||
gross_cash_flow_cents: int | None = None
|
||||
net_cash_flow_cents: int | None = None
|
||||
quantity_multiplier: int = 1
|
||||
expiry_date: date | None = None
|
||||
strike_price_cents: int | None = None
|
||||
is_invalidated: bool = False
|
||||
invalidated_at: datetime | None = None
|
||||
replaced_by_trade_id: int | None = None
|
||||
|
||||
|
||||
class TradeNoteUpdate(BaseModel):
|
||||
id: int
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class TradeFriendlyNameUpdate(BaseModel):
|
||||
id: int
|
||||
friendly_name: str
|
||||
|
||||
|
||||
class TradeRead(TradeCreate):
|
||||
id: int
|
||||
|
||||
|
||||
SessionsCreate.model_rebuild()
|
||||
CycleBase.model_rebuild()
|
||||
@@ -1,11 +1,13 @@
|
||||
from datetime import date, datetime # noqa: TC003
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
@@ -16,8 +18,10 @@ from sqlmodel import (
|
||||
|
||||
class TradeType(str, Enum):
|
||||
SELL_PUT = "SELL_PUT"
|
||||
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
|
||||
ASSIGNMENT = "ASSIGNMENT"
|
||||
SELL_CALL = "SELL_CALL"
|
||||
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
|
||||
EXERCISE_CALL = "EXERCISE_CALL"
|
||||
LONG_SPOT = "LONG_SPOT"
|
||||
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
||||
@@ -64,102 +68,132 @@ class FundingSource(str, Enum):
|
||||
|
||||
|
||||
class Trades(SQLModel, table=True):
|
||||
__tablename__ = "trades"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
||||
),
|
||||
)
|
||||
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
trade_time_utc: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expiry_date: date | None = Field(default=None, nullable=True)
|
||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
is_invalidated: bool = Field(default=False, nullable=False)
|
||||
invalidated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
replaced_by_trade_id: int | None = Field(
|
||||
default=None, foreign_key="trades.id", nullable=True
|
||||
)
|
||||
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
cycle_id: int | None = Field(
|
||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||
)
|
||||
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||
|
||||
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
|
||||
back_populates="trade",
|
||||
sa_relationship_kwargs={"uselist": False},
|
||||
)
|
||||
|
||||
|
||||
class Cycles(SQLModel, table=True):
|
||||
__tablename__ = "cycles"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
||||
),
|
||||
)
|
||||
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
|
||||
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||
|
||||
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||
|
||||
|
||||
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||
|
||||
|
||||
class CycleDailyAccrual(SQLModel, table=True):
|
||||
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||
|
||||
|
||||
class Exchanges(SQLModel, table=True):
|
||||
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
trades: list["Trades"] = Relationship(back_populates="exchange")
|
||||
cycles: list["Cycles"] = Relationship(back_populates="exchange")
|
||||
user: "Users" = Relationship(back_populates="exchanges")
|
||||
|
||||
|
||||
class Users(SQLModel, table=True):
|
||||
__tablename__ = "users"
|
||||
__tablename__ = "users" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
# unique=True already creates an index; no need to also set index=True
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
sessions: list["Sessions"] = Relationship(back_populates="user")
|
||||
exchanges: list["Exchanges"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
__tablename__ = "sessions" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user: "Users" = Relationship(back_populates="sessions")
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from datetime import date, datetime # noqa: TC003
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
@@ -16,8 +18,10 @@ from sqlmodel import (
|
||||
|
||||
class TradeType(str, Enum):
|
||||
SELL_PUT = "SELL_PUT"
|
||||
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
|
||||
ASSIGNMENT = "ASSIGNMENT"
|
||||
SELL_CALL = "SELL_CALL"
|
||||
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
|
||||
EXERCISE_CALL = "EXERCISE_CALL"
|
||||
LONG_SPOT = "LONG_SPOT"
|
||||
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
||||
@@ -64,102 +68,132 @@ class FundingSource(str, Enum):
|
||||
|
||||
|
||||
class Trades(SQLModel, table=True):
|
||||
__tablename__ = "trades"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
||||
),
|
||||
)
|
||||
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
trade_time_utc: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expiry_date: date | None = Field(default=None, nullable=True)
|
||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
is_invalidated: bool = Field(default=False, nullable=False)
|
||||
invalidated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
replaced_by_trade_id: int | None = Field(
|
||||
default=None, foreign_key="trades.id", nullable=True
|
||||
)
|
||||
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
cycle_id: int | None = Field(
|
||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||
)
|
||||
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||
|
||||
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
|
||||
back_populates="trade",
|
||||
sa_relationship_kwargs={"uselist": False},
|
||||
)
|
||||
|
||||
|
||||
class Cycles(SQLModel, table=True):
|
||||
__tablename__ = "cycles"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
||||
),
|
||||
)
|
||||
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
|
||||
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||
|
||||
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||
|
||||
|
||||
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||
|
||||
|
||||
class CycleDailyAccrual(SQLModel, table=True):
|
||||
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||
|
||||
|
||||
class Exchanges(SQLModel, table=True):
|
||||
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
trades: list["Trades"] = Relationship(back_populates="exchange")
|
||||
cycles: list["Cycles"] = Relationship(back_populates="exchange")
|
||||
user: "Users" = Relationship(back_populates="exchanges")
|
||||
|
||||
|
||||
class Users(SQLModel, table=True):
|
||||
__tablename__ = "users"
|
||||
__tablename__ = "users" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
# unique=True already creates an index; no need to also set index=True
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
sessions: list["Sessions"] = Relationship(back_populates="user")
|
||||
exchanges: list["Exchanges"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
__tablename__ = "sessions" # type: ignore[attr-defined]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user: "Users" = Relationship(back_populates="sessions")
|
||||
|
||||
51
backend/trading_journal/security.py
Normal file
51
backend/trading_journal/security.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
import settings
|
||||
|
||||
ph = PasswordHasher()
|
||||
|
||||
# Utility functions for password hashing and verification
|
||||
|
||||
|
||||
def hash_password(plain: str) -> str:
|
||||
return ph.hash(plain)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
try:
|
||||
return ph.verify(hashed, plain)
|
||||
except VerifyMismatchError:
|
||||
return False
|
||||
|
||||
|
||||
# Session token hash
|
||||
|
||||
|
||||
def generate_session_token(nbytes: int = 32) -> str:
|
||||
return secrets.token_urlsafe(nbytes)
|
||||
|
||||
|
||||
def hash_session_token_sha256(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def sign_token_hmac(token: str) -> str:
|
||||
if not settings.settings.hmac_key:
|
||||
return token
|
||||
return hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def verify_token_sha256(token: str, expected_hash: str) -> bool:
|
||||
return hmac.compare_digest(hash_session_token_sha256(token), expected_hash)
|
||||
|
||||
|
||||
def verify_token_hmac(token: str, expected_hmac: str) -> bool:
|
||||
if not settings.settings.hmac_key:
|
||||
return verify_token_sha256(token, expected_hmac)
|
||||
sig = hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
return hmac.compare_digest(sig, expected_hmac)
|
||||
364
backend/trading_journal/service.py
Normal file
364
backend/trading_journal/service.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
|
||||
import settings
|
||||
from trading_journal import crud, security
|
||||
from trading_journal.dto import (
|
||||
CycleBase,
|
||||
CycleCreate,
|
||||
CycleRead,
|
||||
CycleUpdate,
|
||||
ExchangesBase,
|
||||
ExchangesCreate,
|
||||
ExchangesRead,
|
||||
SessionsCreate,
|
||||
SessionsUpdate,
|
||||
TradeCreate,
|
||||
TradeRead,
|
||||
UserCreate,
|
||||
UserLogin,
|
||||
UserRead,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel import Session
|
||||
|
||||
from trading_journal.db import Database
|
||||
from trading_journal.models import Sessions
|
||||
|
||||
|
||||
EXCEPT_PATHS = [
|
||||
f"{settings.settings.api_base}/status",
|
||||
f"{settings.settings.api_base}/register",
|
||||
f"{settings.settings.api_base}/login",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleWare(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911
|
||||
if request.url.path in EXCEPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
token = request.cookies.get("session_token")
|
||||
if not token:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :]
|
||||
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Unauthorized"},
|
||||
)
|
||||
|
||||
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
|
||||
if db_factory is None:
|
||||
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
|
||||
try:
|
||||
with db_factory.get_session_ctx_manager() as request_session:
|
||||
hashed_token = security.hash_session_token_sha256(token)
|
||||
request.state.db_session = request_session
|
||||
login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token)
|
||||
if not login_session:
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
|
||||
if session_expires_utc < datetime.now(timezone.utc):
|
||||
crud.delete_login_session(request_session, login_session.session_token_hash)
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
if login_session.user.is_active is False:
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
|
||||
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
|
||||
else:
|
||||
updated_expiry = session_expires_utc
|
||||
updated_session: SessionsUpdate = SessionsUpdate(
|
||||
last_seen_at=datetime.now(timezone.utc),
|
||||
last_used_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
expires_at=updated_expiry,
|
||||
)
|
||||
user_id = login_session.user_id
|
||||
request.state.user_id = user_id
|
||||
crud.update_login_session(request_session, hashed_token, update_session=updated_session)
|
||||
except Exception:
|
||||
logger.exception("Failed to authenticate user: \n")
|
||||
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class ServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CycleNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class TradeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTradeDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCycleDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
# User service
|
||||
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
||||
if crud.get_user_by_username(db_session, user_in.username):
|
||||
raise UserAlreadyExistsError("username already exists")
|
||||
hashed = security.hash_password(user_in.password)
|
||||
user_data: dict = {
|
||||
"username": user_in.username,
|
||||
"password_hash": hashed,
|
||||
}
|
||||
try:
|
||||
user = crud.create_user(db_session, user_data=user_data)
|
||||
try:
|
||||
# prefer pydantic's from_orm if DTO supports orm_mode
|
||||
user = UserRead.model_validate(user)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to convert user to UserRead: ")
|
||||
raise ServiceError("Failed to convert user to UserRead") from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create user:")
|
||||
raise ServiceError("Failed to create user") from e
|
||||
return user
|
||||
|
||||
|
||||
def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None:
|
||||
user = crud.get_user_by_username(db_session, user_in.username)
|
||||
if not user:
|
||||
return None
|
||||
user_id_val = cast("int", user.id)
|
||||
|
||||
if not security.verify_password(user_in.password, user.password_hash):
|
||||
return None
|
||||
|
||||
token = security.generate_session_token()
|
||||
token_hashed = security.hash_session_token_sha256(token)
|
||||
try:
|
||||
session = crud.create_login_session(
|
||||
session=db_session,
|
||||
user_id=user_id_val,
|
||||
session_token_hash=token_hashed,
|
||||
session_length_seconds=settings.settings.session_expiry_seconds,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create login session: \n")
|
||||
raise ServiceError("Failed to create login session") from e
|
||||
return SessionsCreate.model_validate(session), token
|
||||
|
||||
|
||||
# Exchanges service
|
||||
def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate:
|
||||
existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||
if existing_exchange:
|
||||
raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user")
|
||||
exchange_data = ExchangesCreate(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
notes=notes,
|
||||
)
|
||||
try:
|
||||
exchange = crud.create_exchange(db_session, exchange_data=exchange_data)
|
||||
try:
|
||||
exchange_dto = ExchangesCreate.model_validate(exchange)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to convert exchange to ExchangesCreate:")
|
||||
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create exchange:")
|
||||
raise ServiceError("Failed to create exchange") from e
|
||||
return exchange_dto
|
||||
|
||||
|
||||
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]:
|
||||
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
|
||||
return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
|
||||
|
||||
|
||||
def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase:
|
||||
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
|
||||
if not existing_exchange:
|
||||
raise ExchangeNotFoundError("Exchange not found")
|
||||
if existing_exchange.user_id != user_id:
|
||||
raise ExchangeNotFoundError("Exchange not found")
|
||||
|
||||
if name:
|
||||
other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||
if other_exchange and other_exchange.id != existing_exchange.id:
|
||||
raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user")
|
||||
|
||||
exchange_data = ExchangesBase(
|
||||
name=name or existing_exchange.name,
|
||||
notes=notes or existing_exchange.notes,
|
||||
)
|
||||
try:
|
||||
exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update exchange: \n")
|
||||
raise ServiceError("Failed to update exchange") from e
|
||||
return ExchangesBase.model_validate(exchange)
|
||||
|
||||
|
||||
# Cycle Service
|
||||
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead:
|
||||
raise NotImplementedError("Cycle creation not implemented")
|
||||
cycle_data_dict = cycle_data.model_dump()
|
||||
cycle_data_dict["user_id"] = user_id
|
||||
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
|
||||
created_cycle = crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id)
|
||||
return CycleRead.model_validate(created_cycle)
|
||||
|
||||
|
||||
def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead:
|
||||
cycle = crud.get_cycle_by_id(db_session, cycle_id)
|
||||
if not cycle:
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
if cycle.user_id != user_id:
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
return CycleRead.model_validate(cycle)
|
||||
|
||||
|
||||
def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleRead]:
|
||||
cycles = crud.get_cycles_by_user_id(db_session, user_id)
|
||||
return [CycleRead.model_validate(cycle) for cycle in cycles]
|
||||
|
||||
|
||||
def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # noqa: PLR0911
|
||||
if cycle_data.status == "CLOSED" and cycle_data.end_date is None:
|
||||
return False, "end_date is required when status is CLOSED"
|
||||
if cycle_data.status == "OPEN" and cycle_data.end_date is not None:
|
||||
return False, "end_date must be empty when status is OPEN"
|
||||
if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0:
|
||||
return False, "capital_exposure_cents must be non-negative"
|
||||
if (
|
||||
cycle_data.funding_source is not None
|
||||
and cycle_data.funding_source != "CASH"
|
||||
and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None)
|
||||
):
|
||||
return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH"
|
||||
if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0:
|
||||
return False, "loan_amount_cents must be non-negative"
|
||||
if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0:
|
||||
return False, "loan_interest_rate_tenth_bps must be non-negative"
|
||||
return True, ""
|
||||
|
||||
|
||||
def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead:
|
||||
is_valid, err_msg = _validate_cycle_update_data(cycle_data)
|
||||
if not is_valid:
|
||||
raise InvalidCycleDataError(err_msg)
|
||||
cycle_id = cast("int", cycle_data.id)
|
||||
existing_cycle = crud.get_cycle_by_id(db_session, cycle_id)
|
||||
if not existing_cycle:
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
if existing_cycle.user_id != user_id:
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
|
||||
provided_data_dict = cycle_data.model_dump(exclude_unset=True)
|
||||
cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict)
|
||||
|
||||
try:
|
||||
updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update cycle: \n")
|
||||
raise ServiceError("Failed to update cycle") from e
|
||||
return CycleRead.model_validate(updated_cycle)
|
||||
|
||||
|
||||
# Trades service
|
||||
def _append_cashflows(trade_data: TradeCreate) -> TradeCreate:
|
||||
sign_multipler: int
|
||||
if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"):
|
||||
sign_multipler = 1
|
||||
else:
|
||||
sign_multipler = -1
|
||||
quantity = trade_data.quantity * trade_data.quantity_multiplier
|
||||
gross_cash_flow_cents = quantity * trade_data.price_cents * sign_multipler
|
||||
net_cash_flow_cents = gross_cash_flow_cents - trade_data.commission_cents
|
||||
trade_data.gross_cash_flow_cents = gross_cash_flow_cents
|
||||
trade_data.net_cash_flow_cents = net_cash_flow_cents
|
||||
return trade_data
|
||||
|
||||
|
||||
def _validate_trade_data(trade_data: TradeCreate) -> bool:
|
||||
return not (
|
||||
trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None)
|
||||
)
|
||||
|
||||
|
||||
def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead:
|
||||
if not _validate_trade_data(trade_data):
|
||||
raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades")
|
||||
trade_data_dict = trade_data.model_dump()
|
||||
trade_data_dict["user_id"] = user_id
|
||||
trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict)
|
||||
trade_data_with_user_id = _append_cashflows(trade_data_with_user_id)
|
||||
created_trade = crud.create_trade(db_session, trade_data=trade_data_with_user_id)
|
||||
return TradeRead.model_validate(created_trade)
|
||||
|
||||
|
||||
def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead:
|
||||
trade = crud.get_trade_by_id(db_session, trade_id)
|
||||
if not trade:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
if trade.user_id != user_id:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
return TradeRead.model_validate(trade)
|
||||
|
||||
|
||||
def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead:
|
||||
existing_trade = crud.get_trade_by_id(db_session, trade_id)
|
||||
if not existing_trade:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
if existing_trade.user_id != user_id:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
try:
|
||||
updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update trade friendly name: \n")
|
||||
raise ServiceError("Failed to update trade friendly name") from e
|
||||
return TradeRead.model_validate(updated_trade)
|
||||
|
||||
|
||||
def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead:
|
||||
existing_trade = crud.get_trade_by_id(db_session, trade_id)
|
||||
if not existing_trade:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
if existing_trade.user_id != user_id:
|
||||
raise TradeNotFoundError("Trade not found")
|
||||
if note is None:
|
||||
note = ""
|
||||
try:
|
||||
updated_trade = crud.update_trade_note(db_session, trade_id, note)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update trade notes: \n")
|
||||
raise ServiceError("Failed to update trade notes") from e
|
||||
return TradeRead.model_validate(updated_trade)
|
||||
0
backend/utils/__init__.py
Normal file
0
backend/utils/__init__.py
Normal file
7
backend/utils/db_migration.py
Normal file
7
backend/utils/db_migration.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sqlmodel import create_engine
|
||||
|
||||
import settings
|
||||
from trading_journal import db_migration
|
||||
|
||||
db_engine = create_engine(settings.settings.database_url, echo=True)
|
||||
db_migration.run_migrations(db_engine)
|
||||
Reference in New Issue
Block a user