Compare commits
25 Commits
39fc10572e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e7d801075 | |||
| 94fb4705ff | |||
| bb87b90285 | |||
| 5eae75b23e | |||
| 6a5f160d83 | |||
| 27b4adaca4 | |||
| e66aab99ea | |||
| 80fc405bf6 | |||
| cf6c826468 | |||
| a6592bd140 | |||
| 92c4e0d4fc | |||
| 544f5e8c92 | |||
| b6ba108156 | |||
| b68249f9f1 | |||
| 1750401278 | |||
| 466e6ce653 | |||
| e70a63e4f9 | |||
| 76ed38e9af | |||
| 1fbc93353d | |||
| 76cc967c42 | |||
| 442da655c0 | |||
| 07d33c4568 | |||
| 9f3010d300 | |||
| bc264c8014 | |||
| afd342b31f |
3
.github/workflows/backend-ci.yml
vendored
3
.github/workflows/backend-ci.yml
vendored
@@ -25,8 +25,9 @@ jobs:
|
|||||||
run: pip install -r dev-requirements.txt
|
run: pip install -r dev-requirements.txt
|
||||||
|
|
||||||
- name: Run models vs snapshot check
|
- name: Run models vs snapshot check
|
||||||
|
working-directory: ${{ github.workspace }}
|
||||||
run: |
|
run: |
|
||||||
python .github/scripts/compare_models.py trading_journal/models.py
|
python .github/script/compare_models.py backend/trading_journal/models.py
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
19
LICENSE
Normal file
19
LICENSE
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
Copyright (c) 2025 Tianyu Liu, Studio TJ
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||||
|
OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
28
README.md
Normal file
28
README.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
# Trading Journal (Work In Progress)
|
||||||
|
|
||||||
|
A simple trading journal application (work in progress).
|
||||||
|
|
||||||
|
This repository contains the backend of a trading journal designed to help you record and analyse trades. The system is specially designed to support journaling trades for the "wheel" options strategy, but it also supports other trade types such as long/short spot positions, forex, and more.
|
||||||
|
|
||||||
|
Important: the project is still under active development. There is a backend in this repo, but the frontend UI has not been implemented yet.
|
||||||
|
|
||||||
|
## Key features
|
||||||
|
|
||||||
|
- Journal trades with rich metadata (strategy, entry/exit, P/L, notes).
|
||||||
|
- Built-in support and data model conveniences for the Wheel strategy (puts/calls lifecycle tracking).
|
||||||
|
- Flexible support for other trade types: long/short spots, forex, futures, etc.
|
||||||
|
- Backend-first design with tests and migration helpers.
|
||||||
|
|
||||||
|
## Repository layout
|
||||||
|
|
||||||
|
- `backend/` — Python backend code (API, models, services, migrations, tests).
|
||||||
|
- `backend/trading_journal/` — core application modules: CRUD, models, DTOs, services, and security.
|
||||||
|
- `backend/tests/` — unit tests targeting the backend logic and DB layer.
|
||||||
|
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
See the `LICENSE` file in the project root for license details.
|
||||||
|
|
||||||
|
|
||||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -15,3 +15,5 @@ __pycache__/
|
|||||||
*.db
|
*.db
|
||||||
*.db-shm
|
*.db-shm
|
||||||
*.db-wal
|
*.db-wal
|
||||||
|
|
||||||
|
devsettings.yaml
|
||||||
8
backend/.vscode/launch.json
vendored
8
backend/.vscode/launch.json
vendored
@@ -13,10 +13,14 @@
|
|||||||
"app:app",
|
"app:app",
|
||||||
"--host=0.0.0.0",
|
"--host=0.0.0.0",
|
||||||
"--reload",
|
"--reload",
|
||||||
"--port=5000"
|
"--port=18881"
|
||||||
],
|
],
|
||||||
"jinja": true,
|
"jinja": true,
|
||||||
"autoStartBrowser": true
|
"autoStartBrowser": false,
|
||||||
|
"env": {
|
||||||
|
"CONFIG_FILE": "devsettings.yaml"
|
||||||
|
},
|
||||||
|
"console": "integratedTerminal"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
3
backend/.vscode/settings.json
vendored
3
backend/.vscode/settings.json
vendored
@@ -11,5 +11,6 @@
|
|||||||
"tests"
|
"tests"
|
||||||
],
|
],
|
||||||
"python.testing.unittestEnabled": false,
|
"python.testing.unittestEnabled": false,
|
||||||
"python.testing.pytestEnabled": true
|
"python.testing.pytestEnabled": true,
|
||||||
|
"python.analysis.typeCheckingMode": "standard",
|
||||||
}
|
}
|
||||||
337
backend/app.py
337
backend/app.py
@@ -1,33 +1,324 @@
|
|||||||
from fastapi import FastAPI
|
from __future__ import annotations
|
||||||
|
|
||||||
from models import MsgPayload
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
app = FastAPI()
|
from fastapi import FastAPI, HTTPException, Request, status
|
||||||
messages_list: dict[int, MsgPayload] = {}
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
import settings
|
||||||
|
from trading_journal import db, service
|
||||||
|
from trading_journal.dto import (
|
||||||
|
CycleBase,
|
||||||
|
CycleRead,
|
||||||
|
CycleUpdate,
|
||||||
|
ExchangesBase,
|
||||||
|
ExchangesRead,
|
||||||
|
SessionsBase,
|
||||||
|
SessionsCreate,
|
||||||
|
TradeCreate,
|
||||||
|
TradeFriendlyNameUpdate,
|
||||||
|
TradeNoteUpdate,
|
||||||
|
TradeRead,
|
||||||
|
UserCreate,
|
||||||
|
UserLogin,
|
||||||
|
UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from trading_journal.db import Database
|
||||||
|
|
||||||
|
_db = db.create_database(settings.settings.database_url)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARNING,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@asynccontextmanager
|
||||||
def root() -> dict[str, str]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||||
return {"message": "Hello"}
|
await asyncio.to_thread(_db.init_db)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await asyncio.to_thread(_db.dispose)
|
||||||
|
|
||||||
|
|
||||||
# About page route
|
origins = [
|
||||||
@app.get("/about")
|
"http://127.0.0.1:18881",
|
||||||
def about() -> dict[str, str]:
|
]
|
||||||
return {"message": "This is the about page."}
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.add_middleware(
|
||||||
|
service.AuthMiddleWare,
|
||||||
|
)
|
||||||
|
app.state.db_factory = _db
|
||||||
|
|
||||||
|
|
||||||
# Route to add a message
|
@app.get(f"{settings.settings.api_base}/status")
|
||||||
@app.post("/messages/{msg_name}/")
|
async def get_status() -> dict[str, str]:
|
||||||
def add_msg(msg_name: str) -> dict[str, MsgPayload]:
|
return {"status": "ok"}
|
||||||
# Generate an ID for the item based on the highest ID in the messages_list
|
|
||||||
msg_id = max(messages_list.keys()) + 1 if messages_list else 0
|
|
||||||
messages_list[msg_id] = MsgPayload(msg_id=msg_id, msg_name=msg_name)
|
|
||||||
|
|
||||||
return {"message": messages_list[msg_id]}
|
|
||||||
|
|
||||||
|
|
||||||
# Route to list all messages
|
@app.post(f"{settings.settings.api_base}/register")
|
||||||
@app.get("/messages")
|
async def register_user(request: Request, user_in: UserCreate) -> Response:
|
||||||
def message_items() -> dict[str, dict[int, MsgPayload]]:
|
db_factory: Database = request.app.state.db_factory
|
||||||
return {"messages:": messages_list}
|
|
||||||
|
def sync_work() -> UserRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.register_user_service(db, user_in)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump())
|
||||||
|
except service.UserAlreadyExistsError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to register user: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(f"{settings.settings.api_base}/login")
|
||||||
|
async def login(request: Request, user_in: UserLogin) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> tuple[SessionsCreate, str] | None:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.authenticate_user_service(db, user_in)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.to_thread(sync_work)
|
||||||
|
if result is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "Invalid username or password, or user doesn't exist"},
|
||||||
|
)
|
||||||
|
session, token = result
|
||||||
|
session_return = SessionsBase(user_id=session.user_id)
|
||||||
|
response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump())
|
||||||
|
expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds())
|
||||||
|
response.set_cookie(
|
||||||
|
key="session_token",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=True,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=expires_sec,
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to login user: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Exchange
|
||||||
|
@app.post(f"{settings.settings.api_base}/exchanges")
|
||||||
|
async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> ExchangesBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.create_exchange_service(db, request.state.user_id, exchange_data.name, exchange_data.notes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
exchange = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=exchange.model_dump())
|
||||||
|
except service.ExchangeAlreadyExistsError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create exchange: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(f"{settings.settings.api_base}/exchanges")
|
||||||
|
async def get_exchanges(request: Request) -> list[ExchangesRead]:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> list[ExchangesRead]:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.get_exchanges_by_user_service(db, request.state.user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(sync_work)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to get exchanges: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}")
|
||||||
|
async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> ExchangesBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
exchange = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump())
|
||||||
|
except service.ExchangeNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except service.ExchangeAlreadyExistsError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update exchange: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
# Cycle
|
||||||
|
@app.post(f"{settings.settings.api_base}/cycles")
|
||||||
|
async def create_cycle(request: Request, cycle_data: CycleBase) -> Response:
|
||||||
|
return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.")
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> CycleBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.create_cycle_service(db, request.state.user_id, cycle_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cycle = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create cycle: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(f"{settings.settings.api_base}/cycles/{{cycle_id}}")
|
||||||
|
async def get_cycle_by_id(request: Request, cycle_id: int) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> CycleBase:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.get_cycle_by_id_service(db, request.state.user_id, cycle_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cycle = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
|
||||||
|
except service.CycleNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to get cycle by id: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(f"{settings.settings.api_base}/cycles/user/{{user_id}}")
|
||||||
|
async def get_cycles_by_user(request: Request, user_id: int) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> list[CycleRead]:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.get_cycles_by_user_service(db, user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cycles = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycles))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to get cycles by user: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch(f"{settings.settings.api_base}/cycles")
|
||||||
|
async def update_cycle(request: Request, cycle_data: CycleUpdate) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> CycleRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.update_cycle_service(db, request.state.user_id, cycle_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cycle = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
|
||||||
|
except service.InvalidCycleDataError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except service.CycleNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update cycle: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(f"{settings.settings.api_base}/trades")
|
||||||
|
async def create_trade(request: Request, trade_data: TradeCreate) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> TradeRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.create_trade_service(db, request.state.user_id, trade_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trade = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(trade))
|
||||||
|
except service.InvalidTradeDataError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create trade: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(f"{settings.settings.api_base}/trades/{{trade_id}}")
|
||||||
|
async def get_trade_by_id(request: Request, trade_id: int) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> TradeRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.get_trade_by_id_service(db, request.state.user_id, trade_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trade = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||||
|
except service.TradeNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to get trade by id: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch(f"{settings.settings.api_base}/trades/friendlyname")
|
||||||
|
async def update_trade_friendly_name(request: Request, friendly_name_update: TradeFriendlyNameUpdate) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> TradeRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.update_trade_friendly_name_service(
|
||||||
|
db,
|
||||||
|
request.state.user_id,
|
||||||
|
friendly_name_update.id,
|
||||||
|
friendly_name_update.friendly_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trade = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||||
|
except service.TradeNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update trade friendly name: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch(f"{settings.settings.api_base}/trades/notes")
|
||||||
|
async def update_trade_note(request: Request, note_update: TradeNoteUpdate) -> Response:
|
||||||
|
db_factory: Database = request.app.state.db_factory
|
||||||
|
|
||||||
|
def sync_work() -> TradeRead:
|
||||||
|
with db_factory.get_session_ctx_manager() as db:
|
||||||
|
return service.update_trade_note_service(db, request.state.user_id, note_update.id, note_update.notes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trade = await asyncio.to_thread(sync_work)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
|
||||||
|
except service.TradeNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update trade note: \n")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||||
|
|||||||
@@ -14,12 +14,130 @@ anyio==4.10.0 \
|
|||||||
# via
|
# via
|
||||||
# httpx
|
# httpx
|
||||||
# starlette
|
# starlette
|
||||||
|
argon2-cffi==25.1.0 \
|
||||||
|
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
|
||||||
|
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
|
||||||
|
# via -r requirements.in
|
||||||
|
argon2-cffi-bindings==25.1.0 \
|
||||||
|
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
|
||||||
|
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
|
||||||
|
--hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \
|
||||||
|
--hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \
|
||||||
|
--hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \
|
||||||
|
--hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \
|
||||||
|
--hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \
|
||||||
|
--hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \
|
||||||
|
--hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \
|
||||||
|
--hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \
|
||||||
|
--hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \
|
||||||
|
--hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \
|
||||||
|
--hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \
|
||||||
|
--hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \
|
||||||
|
--hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \
|
||||||
|
--hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \
|
||||||
|
--hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \
|
||||||
|
--hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \
|
||||||
|
--hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \
|
||||||
|
--hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \
|
||||||
|
--hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \
|
||||||
|
--hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \
|
||||||
|
--hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \
|
||||||
|
--hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \
|
||||||
|
--hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \
|
||||||
|
--hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb
|
||||||
|
# via argon2-cffi
|
||||||
certifi==2025.8.3 \
|
certifi==2025.8.3 \
|
||||||
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
||||||
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
||||||
# via
|
# via
|
||||||
# httpcore
|
# httpcore
|
||||||
# httpx
|
# httpx
|
||||||
|
cffi==2.0.0 \
|
||||||
|
--hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \
|
||||||
|
--hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \
|
||||||
|
--hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \
|
||||||
|
--hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \
|
||||||
|
--hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \
|
||||||
|
--hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \
|
||||||
|
--hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \
|
||||||
|
--hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \
|
||||||
|
--hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \
|
||||||
|
--hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \
|
||||||
|
--hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \
|
||||||
|
--hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \
|
||||||
|
--hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \
|
||||||
|
--hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \
|
||||||
|
--hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \
|
||||||
|
--hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \
|
||||||
|
--hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \
|
||||||
|
--hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \
|
||||||
|
--hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \
|
||||||
|
--hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \
|
||||||
|
--hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \
|
||||||
|
--hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \
|
||||||
|
--hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \
|
||||||
|
--hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \
|
||||||
|
--hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \
|
||||||
|
--hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \
|
||||||
|
--hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \
|
||||||
|
--hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \
|
||||||
|
--hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \
|
||||||
|
--hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \
|
||||||
|
--hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \
|
||||||
|
--hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \
|
||||||
|
--hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \
|
||||||
|
--hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \
|
||||||
|
--hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \
|
||||||
|
--hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \
|
||||||
|
--hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \
|
||||||
|
--hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \
|
||||||
|
--hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \
|
||||||
|
--hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \
|
||||||
|
--hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \
|
||||||
|
--hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \
|
||||||
|
--hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \
|
||||||
|
--hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \
|
||||||
|
--hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \
|
||||||
|
--hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \
|
||||||
|
--hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \
|
||||||
|
--hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \
|
||||||
|
--hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \
|
||||||
|
--hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \
|
||||||
|
--hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \
|
||||||
|
--hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \
|
||||||
|
--hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \
|
||||||
|
--hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \
|
||||||
|
--hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \
|
||||||
|
--hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \
|
||||||
|
--hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \
|
||||||
|
--hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \
|
||||||
|
--hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \
|
||||||
|
--hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \
|
||||||
|
--hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \
|
||||||
|
--hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \
|
||||||
|
--hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \
|
||||||
|
--hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \
|
||||||
|
--hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \
|
||||||
|
--hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \
|
||||||
|
--hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \
|
||||||
|
--hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \
|
||||||
|
--hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \
|
||||||
|
--hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \
|
||||||
|
--hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \
|
||||||
|
--hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \
|
||||||
|
--hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \
|
||||||
|
--hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \
|
||||||
|
--hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \
|
||||||
|
--hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \
|
||||||
|
--hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \
|
||||||
|
--hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \
|
||||||
|
--hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \
|
||||||
|
--hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \
|
||||||
|
--hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \
|
||||||
|
--hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \
|
||||||
|
--hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \
|
||||||
|
--hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf
|
||||||
|
# via argon2-cffi-bindings
|
||||||
click==8.2.1 \
|
click==8.2.1 \
|
||||||
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
||||||
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
||||||
@@ -116,6 +234,10 @@ pluggy==1.6.0 \
|
|||||||
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
|
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
|
||||||
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
|
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
|
||||||
# via pytest
|
# via pytest
|
||||||
|
pycparser==2.23 \
|
||||||
|
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
|
||||||
|
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934
|
||||||
|
# via cffi
|
||||||
pydantic==2.11.7 \
|
pydantic==2.11.7 \
|
||||||
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
||||||
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MsgPayload(BaseModel):
|
|
||||||
msg_id: Optional[int]
|
|
||||||
msg_name: str
|
|
||||||
554
backend/openapi.yaml
Normal file
554
backend/openapi.yaml
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
openapi: "3.0.3"
|
||||||
|
info:
|
||||||
|
title: Trading Journal API
|
||||||
|
version: "1.0.0"
|
||||||
|
description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py).
|
||||||
|
servers:
|
||||||
|
- url: "http://127.0.0.1:18881{basePath}"
|
||||||
|
variables:
|
||||||
|
basePath:
|
||||||
|
default: "/api/v1"
|
||||||
|
description: "API base path (matches settings.settings.api_base)"
|
||||||
|
components:
|
||||||
|
securitySchemes:
|
||||||
|
session_cookie:
|
||||||
|
type: apiKey
|
||||||
|
in: cookie
|
||||||
|
name: session_token
|
||||||
|
schemas:
|
||||||
|
UserCreate:
|
||||||
|
$ref: "#/components/schemas/UserCreate_impl"
|
||||||
|
UserCreate_impl:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- username
|
||||||
|
- password
|
||||||
|
properties:
|
||||||
|
username:
|
||||||
|
type: string
|
||||||
|
is_active:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
password:
|
||||||
|
type: string
|
||||||
|
UserLogin:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- username
|
||||||
|
- password
|
||||||
|
properties:
|
||||||
|
username:
|
||||||
|
type: string
|
||||||
|
password:
|
||||||
|
type: string
|
||||||
|
UserRead:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- username
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
username:
|
||||||
|
type: string
|
||||||
|
is_active:
|
||||||
|
type: boolean
|
||||||
|
SessionsBase:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- user_id
|
||||||
|
properties:
|
||||||
|
user_id:
|
||||||
|
type: integer
|
||||||
|
SessionsCreate:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/SessionsBase"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- expires_at
|
||||||
|
properties:
|
||||||
|
expires_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
ExchangesBase:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
notes:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
ExchangesRead:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/ExchangesBase"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
CycleBase:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
friendly_name:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
end_date:
|
||||||
|
type: string
|
||||||
|
format: date
|
||||||
|
nullable: true
|
||||||
|
funding_source:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
capital_exposure_cents:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
loan_amount_cents:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
loan_interest_rate_tenth_bps:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
trades:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/TradeRead"
|
||||||
|
nullable: true
|
||||||
|
exchange:
|
||||||
|
$ref: "#/components/schemas/ExchangesRead"
|
||||||
|
nullable: true
|
||||||
|
CycleCreate:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/CycleBase"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- user_id
|
||||||
|
- symbol
|
||||||
|
- exchange_id
|
||||||
|
- underlying_currency
|
||||||
|
- start_date
|
||||||
|
properties:
|
||||||
|
user_id:
|
||||||
|
type: integer
|
||||||
|
symbol:
|
||||||
|
type: string
|
||||||
|
exchange_id:
|
||||||
|
type: integer
|
||||||
|
underlying_currency:
|
||||||
|
type: string
|
||||||
|
start_date:
|
||||||
|
type: string
|
||||||
|
format: date
|
||||||
|
CycleUpdate:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/CycleBase"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
CycleRead:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/CycleCreate"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
TradeBase:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- symbol
|
||||||
|
- underlying_currency
|
||||||
|
- trade_type
|
||||||
|
- trade_strategy
|
||||||
|
- trade_date
|
||||||
|
- quantity
|
||||||
|
- price_cents
|
||||||
|
- commission_cents
|
||||||
|
properties:
|
||||||
|
friendly_name:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
symbol:
|
||||||
|
type: string
|
||||||
|
exchange_id:
|
||||||
|
type: integer
|
||||||
|
underlying_currency:
|
||||||
|
type: string
|
||||||
|
trade_type:
|
||||||
|
type: string
|
||||||
|
trade_strategy:
|
||||||
|
type: string
|
||||||
|
trade_date:
|
||||||
|
type: string
|
||||||
|
format: date
|
||||||
|
quantity:
|
||||||
|
type: integer
|
||||||
|
price_cents:
|
||||||
|
type: integer
|
||||||
|
commission_cents:
|
||||||
|
type: integer
|
||||||
|
notes:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
cycle_id:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
TradeCreate:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/TradeBase"
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
user_id:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
trade_time_utc:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
nullable: true
|
||||||
|
gross_cash_flow_cents:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
net_cash_flow_cents:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
quantity_multiplier:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
expiry_date:
|
||||||
|
type: string
|
||||||
|
format: date
|
||||||
|
nullable: true
|
||||||
|
strike_price_cents:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
is_invalidated:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
invalidated_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
nullable: true
|
||||||
|
replaced_by_trade_id:
|
||||||
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
TradeNoteUpdate:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
notes:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
TradeFriendlyNameUpdate:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- friendly_name
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
friendly_name:
|
||||||
|
type: string
|
||||||
|
TradeRead:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/components/schemas/TradeCreate"
|
||||||
|
- type: object
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
paths:
|
||||||
|
/status:
|
||||||
|
get:
|
||||||
|
summary: "Get API status"
|
||||||
|
security: [] # no auth required
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
/register:
|
||||||
|
post:
|
||||||
|
summary: "Register user"
|
||||||
|
security: [] # no auth required
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/UserCreate"
|
||||||
|
responses:
|
||||||
|
"201":
|
||||||
|
description: Created
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/UserRead"
|
||||||
|
"400":
|
||||||
|
description: Bad Request (user exists)
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
/login:
|
||||||
|
post:
|
||||||
|
summary: "Login"
|
||||||
|
security: [] # no auth required
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/UserLogin"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK (sets session cookie)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/SessionsBase"
|
||||||
|
headers:
|
||||||
|
Set-Cookie:
|
||||||
|
description: session cookie
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
/exchanges:
|
||||||
|
post:
|
||||||
|
summary: "Create exchange"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ExchangesBase"
|
||||||
|
responses:
|
||||||
|
"201":
|
||||||
|
description: Created
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ExchangesRead"
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
get:
|
||||||
|
summary: "List user exchanges"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/ExchangesRead"
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
/exchanges/{exchange_id}:
|
||||||
|
patch:
|
||||||
|
summary: "Update exchange"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
parameters:
|
||||||
|
- name: exchange_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ExchangesBase"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Updated
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ExchangesRead"
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
|
"400":
|
||||||
|
description: Bad request
|
||||||
|
/cycles:
|
||||||
|
post:
|
||||||
|
summary: "Create cycle (currently returns 405 in code)"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/CycleBase"
|
||||||
|
responses:
|
||||||
|
"405":
|
||||||
|
description: Method not allowed (app currently returns 405)
|
||||||
|
patch:
|
||||||
|
summary: "Update cycle"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/CycleUpdate"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Updated
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/CycleRead"
|
||||||
|
"400":
|
||||||
|
description: Invalid data
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
|
/cycles/{cycle_id}:
|
||||||
|
get:
|
||||||
|
summary: "Get cycle by id"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
parameters:
|
||||||
|
- name: cycle_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/CycleRead"
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
|
/cycles/user/{user_id}:
|
||||||
|
get:
|
||||||
|
summary: "Get cycles by user id"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/CycleRead"
|
||||||
|
/trades:
|
||||||
|
post:
|
||||||
|
summary: "Create trade"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeCreate"
|
||||||
|
responses:
|
||||||
|
"201":
|
||||||
|
description: Created
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeRead"
|
||||||
|
"400":
|
||||||
|
description: Invalid trade data
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
/trades/{trade_id}:
|
||||||
|
get:
|
||||||
|
summary: "Get trade by id"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
parameters:
|
||||||
|
- name: trade_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeRead"
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
|
/trades/friendlyname:
|
||||||
|
patch:
|
||||||
|
summary: "Update trade friendly name"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeFriendlyNameUpdate"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Updated
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeRead"
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
|
/trades/notes:
|
||||||
|
patch:
|
||||||
|
summary: "Update trade notes"
|
||||||
|
security:
|
||||||
|
- session_cookie: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeNoteUpdate"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Updated
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/TradeRead"
|
||||||
|
"404":
|
||||||
|
description: Not found
|
||||||
@@ -4,3 +4,4 @@ httpx
|
|||||||
pyyaml
|
pyyaml
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
sqlmodel
|
sqlmodel
|
||||||
|
argon2-cffi
|
||||||
@@ -14,12 +14,130 @@ anyio==4.10.0 \
|
|||||||
# via
|
# via
|
||||||
# httpx
|
# httpx
|
||||||
# starlette
|
# starlette
|
||||||
|
argon2-cffi==25.1.0 \
|
||||||
|
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
|
||||||
|
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
|
||||||
|
# via -r requirements.in
|
||||||
|
argon2-cffi-bindings==25.1.0 \
|
||||||
|
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
|
||||||
|
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
|
||||||
|
--hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \
|
||||||
|
--hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \
|
||||||
|
--hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \
|
||||||
|
--hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \
|
||||||
|
--hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \
|
||||||
|
--hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \
|
||||||
|
--hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \
|
||||||
|
--hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \
|
||||||
|
--hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \
|
||||||
|
--hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \
|
||||||
|
--hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \
|
||||||
|
--hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \
|
||||||
|
--hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \
|
||||||
|
--hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \
|
||||||
|
--hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \
|
||||||
|
--hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \
|
||||||
|
--hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \
|
||||||
|
--hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \
|
||||||
|
--hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \
|
||||||
|
--hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \
|
||||||
|
--hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \
|
||||||
|
--hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \
|
||||||
|
--hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \
|
||||||
|
--hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb
|
||||||
|
# via argon2-cffi
|
||||||
certifi==2025.8.3 \
|
certifi==2025.8.3 \
|
||||||
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
--hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \
|
||||||
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
--hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5
|
||||||
# via
|
# via
|
||||||
# httpcore
|
# httpcore
|
||||||
# httpx
|
# httpx
|
||||||
|
cffi==2.0.0 \
|
||||||
|
--hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \
|
||||||
|
--hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \
|
||||||
|
--hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \
|
||||||
|
--hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \
|
||||||
|
--hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \
|
||||||
|
--hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \
|
||||||
|
--hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \
|
||||||
|
--hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \
|
||||||
|
--hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \
|
||||||
|
--hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \
|
||||||
|
--hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \
|
||||||
|
--hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \
|
||||||
|
--hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \
|
||||||
|
--hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \
|
||||||
|
--hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \
|
||||||
|
--hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \
|
||||||
|
--hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \
|
||||||
|
--hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \
|
||||||
|
--hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \
|
||||||
|
--hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \
|
||||||
|
--hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \
|
||||||
|
--hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \
|
||||||
|
--hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \
|
||||||
|
--hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \
|
||||||
|
--hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \
|
||||||
|
--hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \
|
||||||
|
--hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \
|
||||||
|
--hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \
|
||||||
|
--hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \
|
||||||
|
--hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \
|
||||||
|
--hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \
|
||||||
|
--hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \
|
||||||
|
--hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \
|
||||||
|
--hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \
|
||||||
|
--hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \
|
||||||
|
--hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \
|
||||||
|
--hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \
|
||||||
|
--hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \
|
||||||
|
--hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \
|
||||||
|
--hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \
|
||||||
|
--hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \
|
||||||
|
--hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \
|
||||||
|
--hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \
|
||||||
|
--hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \
|
||||||
|
--hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \
|
||||||
|
--hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \
|
||||||
|
--hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \
|
||||||
|
--hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \
|
||||||
|
--hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \
|
||||||
|
--hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \
|
||||||
|
--hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \
|
||||||
|
--hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \
|
||||||
|
--hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \
|
||||||
|
--hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \
|
||||||
|
--hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \
|
||||||
|
--hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \
|
||||||
|
--hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \
|
||||||
|
--hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \
|
||||||
|
--hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \
|
||||||
|
--hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \
|
||||||
|
--hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \
|
||||||
|
--hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \
|
||||||
|
--hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \
|
||||||
|
--hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \
|
||||||
|
--hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \
|
||||||
|
--hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \
|
||||||
|
--hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \
|
||||||
|
--hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \
|
||||||
|
--hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \
|
||||||
|
--hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \
|
||||||
|
--hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \
|
||||||
|
--hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \
|
||||||
|
--hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \
|
||||||
|
--hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \
|
||||||
|
--hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \
|
||||||
|
--hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \
|
||||||
|
--hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \
|
||||||
|
--hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \
|
||||||
|
--hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \
|
||||||
|
--hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \
|
||||||
|
--hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \
|
||||||
|
--hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \
|
||||||
|
--hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \
|
||||||
|
--hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf
|
||||||
|
# via argon2-cffi-bindings
|
||||||
click==8.2.1 \
|
click==8.2.1 \
|
||||||
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
--hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \
|
||||||
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
--hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b
|
||||||
@@ -104,6 +222,10 @@ idna==3.10 \
|
|||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
# httpx
|
# httpx
|
||||||
|
pycparser==2.23 \
|
||||||
|
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
|
||||||
|
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934
|
||||||
|
# via cffi
|
||||||
pydantic==2.11.7 \
|
pydantic==2.11.7 \
|
||||||
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
--hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \
|
||||||
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
--hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b
|
||||||
|
|||||||
@@ -13,8 +13,14 @@ ignore = [
|
|||||||
"TRY003",
|
"TRY003",
|
||||||
"EM101",
|
"EM101",
|
||||||
"EM102",
|
"EM102",
|
||||||
"PLC0405",
|
"SIM108",
|
||||||
|
"C901",
|
||||||
|
"PLR0912",
|
||||||
|
"PLR0915",
|
||||||
|
"PLR0913",
|
||||||
|
"PLC0415",
|
||||||
]
|
]
|
||||||
|
|
||||||
[lint.extend-per-file-ignores]
|
[lint.extend-per-file-ignores]
|
||||||
"test*.py" = ["S101"]
|
"test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"]
|
||||||
|
"models*.py" = ["FA102"]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -12,6 +14,10 @@ class Settings(BaseSettings):
|
|||||||
port: int = 8000
|
port: int = 8000
|
||||||
workers: int = 1
|
workers: int = 1
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
|
database_url: str = "sqlite:///:memory:"
|
||||||
|
api_base: str = "/api/v1"
|
||||||
|
session_expiry_seconds: int = 3600 * 24 * 7 # 7 days
|
||||||
|
hmac_key: str | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|||||||
56
backend/testhelpers/tradecycles.sh
Executable file
56
backend/testhelpers/tradecycles.sh
Executable file
@@ -0,0 +1,56 @@
|
|||||||
|
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||||
|
--data '{
|
||||||
|
"friendly_name": "20250908-CA-PUT",
|
||||||
|
"symbol": "CA",
|
||||||
|
"exchange_id": 1,
|
||||||
|
"underlying_currency": "EUR",
|
||||||
|
"trade_type": "SELL_PUT",
|
||||||
|
"trade_strategy": "WHEEL",
|
||||||
|
"trade_date": "2025-09-08",
|
||||||
|
"quantity": 1,
|
||||||
|
"quantity_multiplier": 100,
|
||||||
|
"price_cents": 17,
|
||||||
|
"expiry_date": "2025-09-09",
|
||||||
|
"strike_price_cents": 1220,
|
||||||
|
"commission_cents": 114
|
||||||
|
}'
|
||||||
|
|
||||||
|
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||||
|
--data '{
|
||||||
|
"friendly_name": "20250920-CA-ASSIGN",
|
||||||
|
"symbol": "CA",
|
||||||
|
"exchange_id": 1,
|
||||||
|
"cycle_id": 1,
|
||||||
|
"underlying_currency": "EUR",
|
||||||
|
"trade_type": "ASSIGNMENT",
|
||||||
|
"trade_strategy": "WHEEL",
|
||||||
|
"trade_date": "2025-09-20",
|
||||||
|
"quantity": 100,
|
||||||
|
"quantity_multiplier": 1,
|
||||||
|
"price_cents": 1220,
|
||||||
|
"commission_cents": 0
|
||||||
|
}'
|
||||||
|
|
||||||
|
curl --location '127.0.0.1:18881/api/v1/trades' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
|
||||||
|
--data '{
|
||||||
|
"friendly_name": "20250923-CA-CALL",
|
||||||
|
"symbol": "CA",
|
||||||
|
"exchange_id": 1,
|
||||||
|
"cycle_id": 1,
|
||||||
|
"underlying_currency": "EUR",
|
||||||
|
"trade_type": "SELL_CALL",
|
||||||
|
"trade_strategy": "WHEEL",
|
||||||
|
"trade_date": "2025-09-23",
|
||||||
|
"quantity": 1,
|
||||||
|
"quantity_multiplier": 100,
|
||||||
|
"price_cents": 31,
|
||||||
|
"expiry_date": "2025-10-10",
|
||||||
|
"strike_price_cents": 1200,
|
||||||
|
"commission_cents": 114
|
||||||
|
}'
|
||||||
405
backend/tests/test_app.py
Normal file
405
backend/tests/test_app.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import settings
|
||||||
|
import trading_journal.service as svc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]:
|
||||||
|
class NoAuth:
|
||||||
|
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||||
|
state = scope.get("state")
|
||||||
|
if state is None:
|
||||||
|
scope["state"] = SimpleNamespace()
|
||||||
|
scope["state"]["user_id"] = 1
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
class DeclineAuth:
|
||||||
|
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||||
|
if scope.get("type") != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
path = scope.get("path", "")
|
||||||
|
# allow public/exempt paths through
|
||||||
|
if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
# immediately respond 401 for protected paths
|
||||||
|
resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
|
await resp(scope, receive, send)
|
||||||
|
|
||||||
|
def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient:
|
||||||
|
defaults = {
|
||||||
|
"register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})),
|
||||||
|
"authenticate_user_service": MagicMock(
|
||||||
|
return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"),
|
||||||
|
),
|
||||||
|
"create_exchange_service": MagicMock(
|
||||||
|
return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}),
|
||||||
|
),
|
||||||
|
"get_exchanges_by_user_service": MagicMock(return_value=[]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if decline_auth:
|
||||||
|
monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth)
|
||||||
|
else:
|
||||||
|
monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth)
|
||||||
|
merged = {**defaults, **mocks}
|
||||||
|
for name, mock in merged.items():
|
||||||
|
monkeypatch.setattr(svc, name, mock)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if "app" in sys.modules:
|
||||||
|
del sys.modules["app"]
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
app = import_module("app").app # re-import app module
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
return _factory
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_status(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory()
|
||||||
|
with client as c:
|
||||||
|
response = c.get(f"{settings.settings.api_base}/status")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory() # use defaults
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists")))
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert r.json() == {"detail": "username already exists"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert r.json() == {"detail": "Internal Server Error"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory() # use defaults
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"user_id": 1}
|
||||||
|
assert r.cookies.get("session_token") == "token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(authenticate_user_service=MagicMock(return_value=None))
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||||
|
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert r.json() == {"detail": "Internal Server Error"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory()
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||||
|
assert r.status_code == 201
|
||||||
|
assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists")))
|
||||||
|
with client as c:
|
||||||
|
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||||
|
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert r.json() == {"detail": "exchange already exists"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(decline_auth=True)
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||||
|
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert r.json() == {"detail": "Unauthorized"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory()
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
update_exchanges_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"})
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "exchange not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
get_cycle_by_id_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
friendly_name="Cycle 1",
|
||||||
|
status="active",
|
||||||
|
id=1,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/cycles/1")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/cycles/999")
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "cycle not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
get_cycles_by_user_service=MagicMock(
|
||||||
|
return_value=[
|
||||||
|
SimpleNamespace(
|
||||||
|
friendly_name="Cycle 1",
|
||||||
|
status="active",
|
||||||
|
id=1,
|
||||||
|
),
|
||||||
|
SimpleNamespace(
|
||||||
|
friendly_name="Cycle 2",
|
||||||
|
status="completed",
|
||||||
|
id=2,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/cycles/user/1")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == [
|
||||||
|
{"id": 1, "friendly_name": "Cycle 1", "status": "active"},
|
||||||
|
{"id": 2, "friendly_name": "Cycle 2", "status": "completed"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
update_cycle_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
friendly_name="Updated Cycle",
|
||||||
|
status="completed",
|
||||||
|
id=1,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1})
|
||||||
|
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert r.json() == {"detail": "invalid cycle data"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999})
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "cycle not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
create_trade_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.post(
|
||||||
|
f"{settings.settings.api_base}/trades",
|
||||||
|
json={
|
||||||
|
"cycle_id": 1,
|
||||||
|
"exchange_id": 1,
|
||||||
|
"symbol": "BTCUSD",
|
||||||
|
"underlying_currency": "USD",
|
||||||
|
"trade_type": "LONG_SPOT",
|
||||||
|
"trade_strategy": "FX",
|
||||||
|
"quantity": 1,
|
||||||
|
"price_cents": 15,
|
||||||
|
"commission_cents": 100,
|
||||||
|
"trade_date": "2025-10-01",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert r.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.post(
|
||||||
|
f"{settings.settings.api_base}/trades",
|
||||||
|
json={
|
||||||
|
"cycle_id": 1,
|
||||||
|
"exchange_id": 1,
|
||||||
|
"symbol": "BTCUSD",
|
||||||
|
"underlying_currency": "USD",
|
||||||
|
"trade_type": "LONG_SPOT",
|
||||||
|
"trade_strategy": "FX",
|
||||||
|
"quantity": 1,
|
||||||
|
"price_cents": 15,
|
||||||
|
"commission_cents": 100,
|
||||||
|
"trade_date": "2025-10-01",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert r.json() == {"detail": "invalid trade data"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
get_trade_by_id_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
id=1,
|
||||||
|
cycle_id=1,
|
||||||
|
exchange_id=1,
|
||||||
|
symbol="BTCUSD",
|
||||||
|
underlying_currency="USD",
|
||||||
|
trade_type="LONG_SPOT",
|
||||||
|
trade_strategy="FX",
|
||||||
|
quantity=1,
|
||||||
|
price_cents=1500,
|
||||||
|
commission_cents=100,
|
||||||
|
trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/trades/1")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"cycle_id": 1,
|
||||||
|
"exchange_id": 1,
|
||||||
|
"symbol": "BTCUSD",
|
||||||
|
"underlying_currency": "USD",
|
||||||
|
"trade_type": "LONG_SPOT",
|
||||||
|
"trade_strategy": "FX",
|
||||||
|
"quantity": 1,
|
||||||
|
"price_cents": 1500,
|
||||||
|
"commission_cents": 100,
|
||||||
|
"trade_date": "2025-10-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.get(f"{settings.settings.api_base}/trades/999")
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "trade not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
update_trade_friendly_name_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
id=1,
|
||||||
|
friendly_name="Updated Trade Name",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"})
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "trade not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(
|
||||||
|
update_trade_note_service=MagicMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
id=1,
|
||||||
|
note="Updated trade note",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"id": 1, "note": "Updated trade note"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||||
|
client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||||
|
with client as c:
|
||||||
|
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"})
|
||||||
|
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert r.json() == {"detail": "trade not found"}
|
||||||
@@ -1,15 +1,19 @@
|
|||||||
from collections.abc import Generator
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
from sqlmodel import Session, SQLModel
|
from sqlmodel import Session, SQLModel
|
||||||
|
|
||||||
from trading_journal import crud, models
|
from trading_journal import crud, models
|
||||||
|
|
||||||
# TODO: If needed, add failing flow tests, but now only add happy flow.
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -29,8 +33,11 @@ def engine() -> Generator[Engine, None, None]:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(engine: Engine) -> Generator[Session, None, None]:
|
def session(engine: Engine) -> Generator[Session, None, None]:
|
||||||
with Session(engine) as s:
|
session = Session(engine)
|
||||||
yield s
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
def make_user(session: Session, username: str = "testuser") -> int:
|
def make_user(session: Session, username: str = "testuser") -> int:
|
||||||
@@ -38,38 +45,47 @@ def make_user(session: Session, username: str = "testuser") -> int:
|
|||||||
session.add(user)
|
session.add(user)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(user)
|
session.refresh(user)
|
||||||
return user.id
|
return cast("int", user.id)
|
||||||
|
|
||||||
|
|
||||||
def make_cycle(
|
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
|
||||||
session: Session, user_id: int, friendly_name: str = "Test Cycle"
|
exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange")
|
||||||
) -> int:
|
session.add(exchange)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(exchange)
|
||||||
|
return cast("int", exchange.id)
|
||||||
|
|
||||||
|
|
||||||
|
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||||
cycle = models.Cycles(
|
cycle = models.Cycles(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
friendly_name=friendly_name,
|
friendly_name=friendly_name,
|
||||||
symbol="AAPL",
|
symbol="AAPL",
|
||||||
|
exchange_id=exchange_id,
|
||||||
underlying_currency=models.UnderlyingCurrency.USD,
|
underlying_currency=models.UnderlyingCurrency.USD,
|
||||||
status=models.CycleStatus.OPEN,
|
status=models.CycleStatus.OPEN,
|
||||||
start_date=datetime.now().date(),
|
start_date=datetime.now(timezone.utc).date(),
|
||||||
)
|
) # type: ignore[arg-type]
|
||||||
session.add(cycle)
|
session.add(cycle)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(cycle)
|
session.refresh(cycle)
|
||||||
return cycle.id
|
return cast("int", cycle.id)
|
||||||
|
|
||||||
|
|
||||||
def make_trade(
|
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
|
||||||
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
|
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||||
) -> int:
|
assert cycle is not None
|
||||||
|
exchange_id = cycle.exchange_id
|
||||||
trade = models.Trades(
|
trade = models.Trades(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
friendly_name=friendly_name,
|
friendly_name=friendly_name,
|
||||||
symbol="AAPL",
|
symbol="AAPL",
|
||||||
|
exchange_id=exchange_id,
|
||||||
underlying_currency=models.UnderlyingCurrency.USD,
|
underlying_currency=models.UnderlyingCurrency.USD,
|
||||||
trade_type=models.TradeType.LONG_SPOT,
|
trade_type=models.TradeType.LONG_SPOT,
|
||||||
trade_strategy=models.TradeStrategy.SPOT,
|
trade_strategy=models.TradeStrategy.SPOT,
|
||||||
trade_date=datetime.now().date(),
|
trade_date=datetime.now(timezone.utc).date(),
|
||||||
trade_time_utc=datetime.now(),
|
trade_time_utc=datetime.now(timezone.utc),
|
||||||
quantity=10,
|
quantity=10,
|
||||||
price_cents=15000,
|
price_cents=15000,
|
||||||
gross_cash_flow_cents=-150000,
|
gross_cash_flow_cents=-150000,
|
||||||
@@ -81,7 +97,7 @@ def make_trade(
|
|||||||
session.add(trade)
|
session.add(trade)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(trade)
|
session.refresh(trade)
|
||||||
return trade.id
|
return cast("int", trade.id)
|
||||||
|
|
||||||
|
|
||||||
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||||
@@ -89,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
|||||||
session.add(trade)
|
session.add(trade)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(trade)
|
session.refresh(trade)
|
||||||
return trade.id
|
return cast("int", trade.id)
|
||||||
|
|
||||||
|
|
||||||
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
|
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
|
||||||
@@ -113,9 +129,28 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
|
|||||||
return login_session
|
return login_session
|
||||||
|
|
||||||
|
|
||||||
def test_create_trade_success_with_cycle(session: Session):
|
def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
return dt.replace(tzinfo=timezone.utc)
|
||||||
|
return dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None:
|
||||||
|
actual_utc = _ensure_utc_aware(actual)
|
||||||
|
expected_utc = _ensure_utc_aware(expected)
|
||||||
|
assert actual_utc is not None
|
||||||
|
assert expected_utc is not None
|
||||||
|
delta = abs(actual_utc - expected_utc)
|
||||||
|
assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
|
||||||
|
|
||||||
|
|
||||||
|
# Trades
|
||||||
|
def test_create_trade_success_with_cycle(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
trade_data = {
|
trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -124,7 +159,7 @@ def test_create_trade_success_with_cycle(session: Session):
|
|||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 10,
|
"quantity": 10,
|
||||||
"price_cents": 15000,
|
"price_cents": 15000,
|
||||||
"gross_cash_flow_cents": -150000,
|
"gross_cash_flow_cents": -150000,
|
||||||
@@ -147,6 +182,7 @@ def test_create_trade_success_with_cycle(session: Session):
|
|||||||
assert actual_trade.trade_type == trade_data["trade_type"]
|
assert actual_trade.trade_type == trade_data["trade_type"]
|
||||||
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
|
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
|
||||||
assert actual_trade.quantity == trade_data["quantity"]
|
assert actual_trade.quantity == trade_data["quantity"]
|
||||||
|
assert actual_trade.quantity_multiplier == 1
|
||||||
assert actual_trade.price_cents == trade_data["price_cents"]
|
assert actual_trade.price_cents == trade_data["price_cents"]
|
||||||
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
|
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
|
||||||
assert actual_trade.commission_cents == trade_data["commission_cents"]
|
assert actual_trade.commission_cents == trade_data["commission_cents"]
|
||||||
@@ -154,19 +190,68 @@ def test_create_trade_success_with_cycle(session: Session):
|
|||||||
assert actual_trade.cycle_id == trade_data["cycle_id"]
|
assert actual_trade.cycle_id == trade_data["cycle_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_create_trade_with_auto_created_cycle(session: Session):
|
def test_create_trade_with_custom_multipler(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
trade_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"friendly_name": "Test Trade with Multiplier",
|
||||||
|
"symbol": "AAPL",
|
||||||
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
|
"quantity": 10,
|
||||||
|
"quantity_multiplier": 100,
|
||||||
|
"price_cents": 15000,
|
||||||
|
"gross_cash_flow_cents": -1500000,
|
||||||
|
"commission_cents": 50000,
|
||||||
|
"net_cash_flow_cents": -1550000,
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
trade = crud.create_trade(session, trade_data)
|
||||||
|
assert trade.id is not None
|
||||||
|
assert trade.user_id == user_id
|
||||||
|
assert trade.cycle_id == cycle_id
|
||||||
|
session.refresh(trade)
|
||||||
|
|
||||||
|
actual_trade = session.get(models.Trades, trade.id)
|
||||||
|
assert actual_trade is not None
|
||||||
|
assert actual_trade.friendly_name == trade_data["friendly_name"]
|
||||||
|
assert actual_trade.symbol == trade_data["symbol"]
|
||||||
|
assert actual_trade.underlying_currency == trade_data["underlying_currency"]
|
||||||
|
assert actual_trade.trade_type == trade_data["trade_type"]
|
||||||
|
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
|
||||||
|
assert actual_trade.quantity == trade_data["quantity"]
|
||||||
|
assert actual_trade.quantity_multiplier == trade_data["quantity_multiplier"]
|
||||||
|
assert actual_trade.price_cents == trade_data["price_cents"]
|
||||||
|
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
|
||||||
|
assert actual_trade.commission_cents == trade_data["commission_cents"]
|
||||||
|
assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"]
|
||||||
|
assert actual_trade.cycle_id == trade_data["cycle_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_trade_with_auto_created_cycle(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
|
||||||
trade_data = {
|
trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Test Trade with Auto Cycle",
|
"friendly_name": "Test Trade with Auto Cycle",
|
||||||
"symbol": "AAPL",
|
"symbol": "AAPL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 5,
|
"quantity": 5,
|
||||||
"price_cents": 15500,
|
"price_cents": 15500,
|
||||||
|
"gross_cash_flow_cents": -77500,
|
||||||
|
"commission_cents": 300,
|
||||||
|
"net_cash_flow_cents": -77800,
|
||||||
}
|
}
|
||||||
|
|
||||||
trade = crud.create_trade(session, trade_data)
|
trade = crud.create_trade(session, trade_data)
|
||||||
@@ -193,20 +278,22 @@ def test_create_trade_with_auto_created_cycle(session: Session):
|
|||||||
assert auto_cycle.symbol == trade_data["symbol"]
|
assert auto_cycle.symbol == trade_data["symbol"]
|
||||||
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
|
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
|
||||||
assert auto_cycle.status == models.CycleStatus.OPEN
|
assert auto_cycle.status == models.CycleStatus.OPEN
|
||||||
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
|
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
|
||||||
|
|
||||||
|
|
||||||
def test_create_trade_missing_required_fields(session: Session):
|
def test_create_trade_missing_required_fields(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
|
||||||
base_trade_data = {
|
base_trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Incomplete Trade",
|
"friendly_name": "Incomplete Trade",
|
||||||
"symbol": "AAPL",
|
"symbol": "AAPL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 10,
|
"quantity": 10,
|
||||||
"price_cents": 15000,
|
"price_cents": 15000,
|
||||||
}
|
}
|
||||||
@@ -218,6 +305,13 @@ def test_create_trade_missing_required_fields(session: Session):
|
|||||||
crud.create_trade(session, trade_data)
|
crud.create_trade(session, trade_data)
|
||||||
assert "symbol is required" in str(excinfo.value)
|
assert "symbol is required" in str(excinfo.value)
|
||||||
|
|
||||||
|
# Missing exchange and cycle together
|
||||||
|
trade_data = base_trade_data.copy()
|
||||||
|
trade_data.pop("exchange_id", None)
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
crud.create_trade(session, trade_data)
|
||||||
|
assert "exchange_id is required when no cycle is attached" in str(excinfo.value)
|
||||||
|
|
||||||
# Missing underlying_currency
|
# Missing underlying_currency
|
||||||
trade_data = base_trade_data.copy()
|
trade_data = base_trade_data.copy()
|
||||||
trade_data.pop("underlying_currency", None)
|
trade_data.pop("underlying_currency", None)
|
||||||
@@ -254,18 +348,20 @@ def test_create_trade_missing_required_fields(session: Session):
|
|||||||
assert "price_cents is required" in str(excinfo.value)
|
assert "price_cents is required" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_get_trade_by_id(session: Session):
|
def test_get_trade_by_id(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
trade_data = {
|
trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Test Trade for Get",
|
"friendly_name": "Test Trade for Get",
|
||||||
"symbol": "AAPL",
|
"symbol": "AAPL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_date": datetime.now().date(),
|
"trade_date": datetime.now(timezone.utc).date(),
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 10,
|
"quantity": 10,
|
||||||
"price_cents": 15000,
|
"price_cents": 15000,
|
||||||
"gross_cash_flow_cents": -150000,
|
"gross_cash_flow_cents": -150000,
|
||||||
@@ -291,19 +387,21 @@ def test_get_trade_by_id(session: Session):
|
|||||||
assert trade.trade_date == trade_data["trade_date"]
|
assert trade.trade_date == trade_data["trade_date"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_trade_by_user_id_and_friendly_name(session: Session):
|
def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
friendly_name = "Unique Trade Name"
|
friendly_name = "Unique Trade Name"
|
||||||
trade_data = {
|
trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": friendly_name,
|
"friendly_name": friendly_name,
|
||||||
"symbol": "AAPL",
|
"symbol": "AAPL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_date": datetime.now().date(),
|
"trade_date": datetime.now(timezone.utc).date(),
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 10,
|
"quantity": 10,
|
||||||
"price_cents": 15000,
|
"price_cents": 15000,
|
||||||
"gross_cash_flow_cents": -150000,
|
"gross_cash_flow_cents": -150000,
|
||||||
@@ -318,18 +416,20 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session):
|
|||||||
assert trade.user_id == user_id
|
assert trade.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
def test_get_trades_by_user_id(session: Session):
|
def test_get_trades_by_user_id(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
trade_data_1 = {
|
trade_data_1 = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Trade One",
|
"friendly_name": "Trade One",
|
||||||
"symbol": "AAPL",
|
"symbol": "AAPL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_date": datetime.now().date(),
|
"trade_date": datetime.now(timezone.utc).date(),
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 10,
|
"quantity": 10,
|
||||||
"price_cents": 15000,
|
"price_cents": 15000,
|
||||||
"gross_cash_flow_cents": -150000,
|
"gross_cash_flow_cents": -150000,
|
||||||
@@ -341,11 +441,12 @@ def test_get_trades_by_user_id(session: Session):
|
|||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Trade Two",
|
"friendly_name": "Trade Two",
|
||||||
"symbol": "GOOGL",
|
"symbol": "GOOGL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.SHORT_SPOT,
|
"trade_type": models.TradeType.SHORT_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_date": datetime.now().date(),
|
"trade_date": datetime.now(timezone.utc).date(),
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 5,
|
"quantity": 5,
|
||||||
"price_cents": 280000,
|
"price_cents": 280000,
|
||||||
"gross_cash_flow_cents": 1400000,
|
"gross_cash_flow_cents": 1400000,
|
||||||
@@ -362,9 +463,28 @@ def test_get_trades_by_user_id(session: Session):
|
|||||||
assert friendly_names == {"Trade One", "Trade Two"}
|
assert friendly_names == {"Trade One", "Trade Two"}
|
||||||
|
|
||||||
|
|
||||||
def test_update_trade_note(session: Session):
|
def test_update_trade_friendly_name(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
trade_id = make_trade(session, user_id, cycle_id)
|
||||||
|
|
||||||
|
new_friendly_name = "Updated Trade Name"
|
||||||
|
updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name)
|
||||||
|
assert updated_trade is not None
|
||||||
|
assert updated_trade.id == trade_id
|
||||||
|
assert updated_trade.friendly_name == new_friendly_name
|
||||||
|
|
||||||
|
session.refresh(updated_trade)
|
||||||
|
actual_trade = session.get(models.Trades, trade_id)
|
||||||
|
assert actual_trade is not None
|
||||||
|
assert actual_trade.friendly_name == new_friendly_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_trade_note(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
trade_id = make_trade(session, user_id, cycle_id)
|
trade_id = make_trade(session, user_id, cycle_id)
|
||||||
|
|
||||||
new_note = "This is an updated note."
|
new_note = "This is an updated note."
|
||||||
@@ -379,9 +499,10 @@ def test_update_trade_note(session: Session):
|
|||||||
assert actual_trade.notes == new_note
|
assert actual_trade.notes == new_note
|
||||||
|
|
||||||
|
|
||||||
def test_invalidate_trade(session: Session):
|
def test_invalidate_trade(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
trade_id = make_trade(session, user_id, cycle_id)
|
trade_id = make_trade(session, user_id, cycle_id)
|
||||||
|
|
||||||
invalidated_trade = crud.invalidate_trade(session, trade_id)
|
invalidated_trade = crud.invalidate_trade(session, trade_id)
|
||||||
@@ -395,21 +516,26 @@ def test_invalidate_trade(session: Session):
|
|||||||
assert actual_trade.is_invalidated is True
|
assert actual_trade.is_invalidated is True
|
||||||
|
|
||||||
|
|
||||||
def test_replace_trade(session: Session):
|
def test_replace_trade(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id)
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
old_trade_id = make_trade(session, user_id, cycle_id)
|
old_trade_id = make_trade(session, user_id, cycle_id)
|
||||||
|
|
||||||
new_trade_data = {
|
new_trade_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "Replaced Trade",
|
"friendly_name": "Replaced Trade",
|
||||||
"symbol": "MSFT",
|
"symbol": "MSFT",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"trade_type": models.TradeType.LONG_SPOT,
|
"trade_type": models.TradeType.LONG_SPOT,
|
||||||
"trade_strategy": models.TradeStrategy.SPOT,
|
"trade_strategy": models.TradeStrategy.SPOT,
|
||||||
"trade_time_utc": datetime.now(),
|
"trade_time_utc": datetime.now(timezone.utc),
|
||||||
"quantity": 20,
|
"quantity": 20,
|
||||||
"price_cents": 25000,
|
"price_cents": 25000,
|
||||||
|
"gross_cash_flow_cents": -500000,
|
||||||
|
"commission_cents": 1000,
|
||||||
|
"net_cash_flow_cents": -501000,
|
||||||
}
|
}
|
||||||
|
|
||||||
new_trade = crud.replace_trade(session, old_trade_id, new_trade_data)
|
new_trade = crud.replace_trade(session, old_trade_id, new_trade_data)
|
||||||
@@ -438,15 +564,18 @@ def test_replace_trade(session: Session):
|
|||||||
assert actual_new_trade.replaced_by_trade_id == old_trade_id
|
assert actual_new_trade.replaced_by_trade_id == old_trade_id
|
||||||
|
|
||||||
|
|
||||||
def test_create_cycle(session: Session):
|
# Cycles
|
||||||
|
def test_create_cycle(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
cycle_data = {
|
cycle_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"friendly_name": "My First Cycle",
|
"friendly_name": "My First Cycle",
|
||||||
"symbol": "GOOGL",
|
"symbol": "GOOGL",
|
||||||
|
"exchange_id": exchange_id,
|
||||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||||
"status": models.CycleStatus.OPEN,
|
"status": models.CycleStatus.OPEN,
|
||||||
"start_date": datetime.now().date(),
|
"start_date": datetime.now(timezone.utc).date(),
|
||||||
}
|
}
|
||||||
cycle = crud.create_cycle(session, cycle_data)
|
cycle = crud.create_cycle(session, cycle_data)
|
||||||
assert cycle.id is not None
|
assert cycle.id is not None
|
||||||
@@ -467,9 +596,35 @@ def test_create_cycle(session: Session):
|
|||||||
assert actual_cycle.start_date == cycle_data["start_date"]
|
assert actual_cycle.start_date == cycle_data["start_date"]
|
||||||
|
|
||||||
|
|
||||||
def test_update_cycle(session: Session):
|
def test_get_cycle_by_id(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name")
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Cycle to Get")
|
||||||
|
cycle = crud.get_cycle_by_id(session, cycle_id)
|
||||||
|
assert cycle is not None
|
||||||
|
assert cycle.id == cycle_id
|
||||||
|
assert cycle.friendly_name == "Cycle to Get"
|
||||||
|
assert cycle.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycles_by_user_id(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_names = ["Cycle One", "Cycle Two", "Cycle Three"]
|
||||||
|
for name in cycle_names:
|
||||||
|
make_cycle(session, user_id, exchange_id, friendly_name=name)
|
||||||
|
|
||||||
|
cycles = crud.get_cycles_by_user_id(session, user_id)
|
||||||
|
assert len(cycles) == len(cycle_names)
|
||||||
|
fetched_names = {cycle.friendly_name for cycle in cycles}
|
||||||
|
for name in cycle_names:
|
||||||
|
assert name in fetched_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_cycle(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||||
|
|
||||||
update_data = {
|
update_data = {
|
||||||
"friendly_name": "Updated Cycle Name",
|
"friendly_name": "Updated Cycle Name",
|
||||||
@@ -488,16 +643,17 @@ def test_update_cycle(session: Session):
|
|||||||
assert actual_cycle.status == update_data["status"]
|
assert actual_cycle.status == update_data["status"]
|
||||||
|
|
||||||
|
|
||||||
def test_update_cycle_immutable_fields(session: Session):
|
def test_update_cycle_immutable_fields(session: Session) -> None:
|
||||||
user_id = make_user(session)
|
user_id = make_user(session)
|
||||||
cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name")
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||||
|
|
||||||
# Attempt to update immutable fields
|
# Attempt to update immutable fields
|
||||||
update_data = {
|
update_data = {
|
||||||
"id": cycle_id + 1, # Trying to change the ID
|
"id": cycle_id + 1, # Trying to change the ID
|
||||||
"user_id": user_id + 1, # Trying to change the user_id
|
"user_id": user_id + 1, # Trying to change the user_id
|
||||||
"start_date": datetime(2020, 1, 1).date(), # Trying to change start_date
|
"start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date
|
||||||
"created_at": datetime(2020, 1, 1), # Trying to change created_at
|
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||||
"friendly_name": "Valid Update", # Valid field to update
|
"friendly_name": "Valid Update", # Valid field to update
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,7 +667,314 @@ def test_update_cycle_immutable_fields(session: Session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_user(session: Session):
|
# Cycle loans
|
||||||
|
def test_create_cycle_loan_event(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
loan_data = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 100000,
|
||||||
|
"loan_interest_rate_tenth_bps": 5000, # 5%
|
||||||
|
"notes": "Test loan change for the cycle",
|
||||||
|
}
|
||||||
|
|
||||||
|
loan_event = crud.create_cycle_loan_event(session, loan_data)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
assert loan_event.id is not None
|
||||||
|
assert loan_event.cycle_id == cycle_id
|
||||||
|
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||||
|
assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||||
|
assert loan_event.notes == loan_data["notes"]
|
||||||
|
assert loan_event.effective_date == now.date()
|
||||||
|
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
|
||||||
|
|
||||||
|
session.refresh(loan_event)
|
||||||
|
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
|
||||||
|
assert actual_loan_event is not None
|
||||||
|
assert actual_loan_event.cycle_id == cycle_id
|
||||||
|
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||||
|
assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||||
|
assert actual_loan_event.notes == loan_data["notes"]
|
||||||
|
assert actual_loan_event.effective_date == now.date()
|
||||||
|
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
loan_data_1 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 100000,
|
||||||
|
"loan_interest_rate_tenth_bps": 5000,
|
||||||
|
"notes": "First loan event",
|
||||||
|
}
|
||||||
|
yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date()
|
||||||
|
loan_data_2 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 150000,
|
||||||
|
"loan_interest_rate_tenth_bps": 4500,
|
||||||
|
"effective_date": yesterday,
|
||||||
|
"notes": "Second loan event",
|
||||||
|
}
|
||||||
|
|
||||||
|
crud.create_cycle_loan_event(session, loan_data_1)
|
||||||
|
crud.create_cycle_loan_event(session, loan_data_2)
|
||||||
|
|
||||||
|
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
|
||||||
|
assert len(loan_events) == 2
|
||||||
|
notes = [event.notes for event in loan_events]
|
||||||
|
assert loan_events[0].notes == loan_data_2["notes"]
|
||||||
|
assert loan_events[0].effective_date == yesterday
|
||||||
|
assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
loan_data_1 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 100000,
|
||||||
|
"loan_interest_rate_tenth_bps": 5000,
|
||||||
|
"notes": "First loan event",
|
||||||
|
}
|
||||||
|
loan_data_2 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 150000,
|
||||||
|
"loan_interest_rate_tenth_bps": 4500,
|
||||||
|
"notes": "Second loan event",
|
||||||
|
}
|
||||||
|
|
||||||
|
crud.create_cycle_loan_event(session, loan_data_1)
|
||||||
|
crud.create_cycle_loan_event(session, loan_data_2)
|
||||||
|
|
||||||
|
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
|
||||||
|
assert len(loan_events) == 2
|
||||||
|
notes = [event.notes for event in loan_events]
|
||||||
|
assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_cycle_loan_event_single_field(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
loan_data = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"loan_amount_cents": 200000,
|
||||||
|
}
|
||||||
|
|
||||||
|
loan_event = crud.create_cycle_loan_event(session, loan_data)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
assert loan_event.id is not None
|
||||||
|
assert loan_event.cycle_id == cycle_id
|
||||||
|
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||||
|
assert loan_event.loan_interest_rate_tenth_bps is None
|
||||||
|
assert loan_event.notes is None
|
||||||
|
assert loan_event.effective_date == now.date()
|
||||||
|
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
|
||||||
|
|
||||||
|
session.refresh(loan_event)
|
||||||
|
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
|
||||||
|
assert actual_loan_event is not None
|
||||||
|
assert actual_loan_event.cycle_id == cycle_id
|
||||||
|
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||||
|
assert actual_loan_event.loan_interest_rate_tenth_bps is None
|
||||||
|
assert actual_loan_event.notes is None
|
||||||
|
assert actual_loan_event.effective_date == now.date()
|
||||||
|
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_cycle_daily_accrual(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
today = datetime.now(timezone.utc).date()
|
||||||
|
accrual_data = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"accrual_date": today,
|
||||||
|
"accrued_interest_cents": 150,
|
||||||
|
"notes": "Daily interest accrual",
|
||||||
|
}
|
||||||
|
|
||||||
|
accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"])
|
||||||
|
assert accrual.id is not None
|
||||||
|
assert accrual.cycle_id == cycle_id
|
||||||
|
assert accrual.accrual_date == accrual_data["accrual_date"]
|
||||||
|
assert accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
|
||||||
|
|
||||||
|
session.refresh(accrual)
|
||||||
|
actual_accrual = session.get(models.CycleDailyAccrual, accrual.id)
|
||||||
|
assert actual_accrual is not None
|
||||||
|
assert actual_accrual.cycle_id == cycle_id
|
||||||
|
assert actual_accrual.accrual_date == accrual_data["accrual_date"]
|
||||||
|
assert actual_accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
today = datetime.now(timezone.utc).date()
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
|
||||||
|
accrual_data_1 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"accrual_date": yesterday,
|
||||||
|
"accrued_interest_cents": 100,
|
||||||
|
}
|
||||||
|
accrual_data_2 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"accrual_date": today,
|
||||||
|
"accrued_interest_cents": 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||||
|
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||||
|
|
||||||
|
accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id)
|
||||||
|
assert len(accruals) == 2
|
||||||
|
dates = [accrual.accrual_date for accrual in accruals]
|
||||||
|
assert dates == [yesterday, today] # Ordered by accrual_date asc
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id)
|
||||||
|
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||||
|
|
||||||
|
today = datetime.now(timezone.utc).date()
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
|
||||||
|
accrual_data_1 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"accrual_date": yesterday,
|
||||||
|
"accrued_interest_cents": 100,
|
||||||
|
}
|
||||||
|
accrual_data_2 = {
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
"accrual_date": today,
|
||||||
|
"accrued_interest_cents": 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||||
|
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||||
|
|
||||||
|
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today)
|
||||||
|
assert accruals_today is not None
|
||||||
|
assert accruals_today.accrual_date == today
|
||||||
|
assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
|
||||||
|
|
||||||
|
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday)
|
||||||
|
assert accruals_yesterday is not None
|
||||||
|
assert accruals_yesterday.accrual_date == yesterday
|
||||||
|
assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"]
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges
|
||||||
|
def test_create_exchange(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_data = {
|
||||||
|
"name": "NYSE",
|
||||||
|
"notes": "New York Stock Exchange",
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
exchange = crud.create_exchange(session, exchange_data)
|
||||||
|
assert exchange.id is not None
|
||||||
|
assert exchange.name == exchange_data["name"]
|
||||||
|
assert exchange.notes == exchange_data["notes"]
|
||||||
|
assert exchange.user_id == user_id
|
||||||
|
|
||||||
|
session.refresh(exchange)
|
||||||
|
actual_exchange = session.get(models.Exchanges, exchange.id)
|
||||||
|
assert actual_exchange is not None
|
||||||
|
assert actual_exchange.name == exchange_data["name"]
|
||||||
|
assert actual_exchange.notes == exchange_data["notes"]
|
||||||
|
assert actual_exchange.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchange_by_id(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id=user_id, name="LSE")
|
||||||
|
exchange = crud.get_exchange_by_id(session, exchange_id)
|
||||||
|
assert exchange is not None
|
||||||
|
assert exchange.id == exchange_id
|
||||||
|
assert exchange.name == "LSE"
|
||||||
|
assert exchange.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchange_by_name_and_user_id(session: Session) -> None:
|
||||||
|
exchange_name = "TSX"
|
||||||
|
user_id = make_user(session)
|
||||||
|
make_exchange(session, user_id=user_id, name=exchange_name)
|
||||||
|
exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id)
|
||||||
|
assert exchange is not None
|
||||||
|
assert exchange.name == exchange_name
|
||||||
|
assert exchange.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_exchanges(session: Session) -> None:
|
||||||
|
exchange_names = ["NYSE", "NASDAQ", "LSE"]
|
||||||
|
user_id = make_user(session)
|
||||||
|
for name in exchange_names:
|
||||||
|
make_exchange(session, user_id=user_id, name=name)
|
||||||
|
|
||||||
|
exchanges = crud.get_all_exchanges(session)
|
||||||
|
assert len(exchanges) >= 3
|
||||||
|
fetched_names = {ex.name for ex in exchanges}
|
||||||
|
for name in exchange_names:
|
||||||
|
assert name in fetched_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_exchanges_by_user_id(session: Session) -> None:
|
||||||
|
exchange_names = ["NYSE", "NASDAQ"]
|
||||||
|
user_id = make_user(session)
|
||||||
|
for name in exchange_names:
|
||||||
|
make_exchange(session, user_id=user_id, name=name)
|
||||||
|
|
||||||
|
exchanges = crud.get_all_exchanges_by_user_id(session, user_id)
|
||||||
|
assert len(exchanges) == len(exchange_names)
|
||||||
|
fetched_names = {ex.name for ex in exchanges}
|
||||||
|
for name in exchange_names:
|
||||||
|
assert name in fetched_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_exchange(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange")
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated Exchange",
|
||||||
|
"notes": "Updated notes for the exchange",
|
||||||
|
}
|
||||||
|
updated_exchange = crud.update_exchange(session, exchange_id, update_data)
|
||||||
|
assert updated_exchange is not None
|
||||||
|
assert updated_exchange.id == exchange_id
|
||||||
|
assert updated_exchange.name == update_data["name"]
|
||||||
|
assert updated_exchange.notes == update_data["notes"]
|
||||||
|
|
||||||
|
session.refresh(updated_exchange)
|
||||||
|
actual_exchange = session.get(models.Exchanges, exchange_id)
|
||||||
|
assert actual_exchange is not None
|
||||||
|
assert actual_exchange.name == update_data["name"]
|
||||||
|
assert actual_exchange.notes == update_data["notes"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_exchange(session: Session) -> None:
|
||||||
|
user_id = make_user(session)
|
||||||
|
exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange")
|
||||||
|
crud.delete_exchange(session, exchange_id)
|
||||||
|
deleted_exchange = session.get(models.Exchanges, exchange_id)
|
||||||
|
assert deleted_exchange is None
|
||||||
|
|
||||||
|
|
||||||
|
# Users
|
||||||
|
def test_create_user(session: Session) -> None:
|
||||||
user_data = {
|
user_data = {
|
||||||
"username": "newuser",
|
"username": "newuser",
|
||||||
"password_hash": "newhashedpassword",
|
"password_hash": "newhashedpassword",
|
||||||
@@ -528,7 +991,23 @@ def test_create_user(session: Session):
|
|||||||
assert actual_user.password_hash == user_data["password_hash"]
|
assert actual_user.password_hash == user_data["password_hash"]
|
||||||
|
|
||||||
|
|
||||||
def test_update_user(session: Session):
|
def test_get_user_by_id(session: Session) -> None:
|
||||||
|
user_id = make_user(session, username="fetchuser")
|
||||||
|
user = crud.get_user_by_id(session, user_id)
|
||||||
|
assert user is not None
|
||||||
|
assert user.id == user_id
|
||||||
|
assert user.username == "fetchuser"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_by_username(session: Session) -> None:
|
||||||
|
username = "uniqueuser"
|
||||||
|
make_user(session, username=username)
|
||||||
|
user = crud.get_user_by_username(session, username)
|
||||||
|
assert user is not None
|
||||||
|
assert user.username == username
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user(session: Session) -> None:
|
||||||
user_id = make_user(session, username="updatableuser")
|
user_id = make_user(session, username="updatableuser")
|
||||||
|
|
||||||
update_data = {
|
update_data = {
|
||||||
@@ -545,14 +1024,14 @@ def test_update_user(session: Session):
|
|||||||
assert actual_user.password_hash == update_data["password_hash"]
|
assert actual_user.password_hash == update_data["password_hash"]
|
||||||
|
|
||||||
|
|
||||||
def test_update_user_immutable_fields(session: Session):
|
def test_update_user_immutable_fields(session: Session) -> None:
|
||||||
user_id = make_user(session, username="immutableuser")
|
user_id = make_user(session, username="immutableuser")
|
||||||
|
|
||||||
# Attempt to update immutable fields
|
# Attempt to update immutable fields
|
||||||
update_data = {
|
update_data = {
|
||||||
"id": user_id + 1, # Trying to change the ID
|
"id": user_id + 1, # Trying to change the ID
|
||||||
"username": "newusername", # Trying to change the username
|
"username": "newusername", # Trying to change the username
|
||||||
"created_at": datetime(2020, 1, 1), # Trying to change created_at
|
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||||
"password_hash": "validupdate", # Valid field to update
|
"password_hash": "validupdate", # Valid field to update
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -566,7 +1045,7 @@ def test_update_user_immutable_fields(session: Session):
|
|||||||
|
|
||||||
|
|
||||||
# login sessions
|
# login sessions
|
||||||
def test_create_login_session(session: Session):
|
def test_create_login_session(session: Session) -> None:
|
||||||
user_id = make_user(session, username="testuser")
|
user_id = make_user(session, username="testuser")
|
||||||
session_token_hash = "sessiontokenhashed"
|
session_token_hash = "sessiontokenhashed"
|
||||||
login_session = crud.create_login_session(session, user_id, session_token_hash)
|
login_session = crud.create_login_session(session, user_id, session_token_hash)
|
||||||
@@ -575,7 +1054,7 @@ def test_create_login_session(session: Session):
|
|||||||
assert login_session.session_token_hash == session_token_hash
|
assert login_session.session_token_hash == session_token_hash
|
||||||
|
|
||||||
|
|
||||||
def test_create_login_session_with_invalid_user(session: Session):
|
def test_create_login_session_with_invalid_user(session: Session) -> None:
|
||||||
invalid_user_id = 9999 # Assuming this user ID does not exist
|
invalid_user_id = 9999 # Assuming this user ID does not exist
|
||||||
session_token_hash = "sessiontokenhashed"
|
session_token_hash = "sessiontokenhashed"
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
@@ -583,40 +1062,44 @@ def test_create_login_session_with_invalid_user(session: Session):
|
|||||||
assert "user_id does not exist" in str(excinfo.value)
|
assert "user_id does not exist" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_get_login_session_by_token_and_user_id(session: Session):
|
def test_get_login_session_by_token_and_user_id(session: Session) -> None:
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
created_session = make_login_session(session, now)
|
created_session = make_login_session(session, now)
|
||||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
|
fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||||
session, created_session.session_token_hash, created_session.user_id
|
|
||||||
)
|
|
||||||
assert fetched_session is not None
|
assert fetched_session is not None
|
||||||
assert fetched_session.id == created_session.id
|
assert fetched_session.id == created_session.id
|
||||||
assert fetched_session.user_id == created_session.user_id
|
assert fetched_session.user_id == created_session.user_id
|
||||||
assert fetched_session.session_token_hash == created_session.session_token_hash
|
assert fetched_session.session_token_hash == created_session.session_token_hash
|
||||||
|
|
||||||
|
|
||||||
def test_update_login_session(session: Session):
|
def test_get_login_session_by_token(session: Session) -> None:
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
|
created_session = make_login_session(session, now)
|
||||||
|
fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash)
|
||||||
|
assert fetched_session is not None
|
||||||
|
assert fetched_session.id == created_session.id
|
||||||
|
assert fetched_session.user_id == created_session.user_id
|
||||||
|
assert fetched_session.session_token_hash == created_session.session_token_hash
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_login_session(session: Session) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
created_session = make_login_session(session, now)
|
created_session = make_login_session(session, now)
|
||||||
|
|
||||||
update_data = {
|
update_data = {
|
||||||
"last_seen_at": now + timedelta(hours=1),
|
"last_seen_at": now + timedelta(hours=1),
|
||||||
"last_used_ip": "192.168.1.1",
|
"last_used_ip": "192.168.1.1",
|
||||||
}
|
}
|
||||||
updated_session = crud.update_login_session(
|
updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data)
|
||||||
session, created_session.session_token_hash, update_data
|
|
||||||
)
|
|
||||||
assert updated_session is not None
|
assert updated_session is not None
|
||||||
assert updated_session.last_seen_at == update_data["last_seen_at"]
|
assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
|
||||||
assert updated_session.last_used_ip == update_data["last_used_ip"]
|
assert updated_session.last_used_ip == update_data["last_used_ip"]
|
||||||
|
|
||||||
|
|
||||||
def test_delete_login_session(session: Session):
|
def test_delete_login_session(session: Session) -> None:
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
created_session = make_login_session(session, now)
|
created_session = make_login_session(session, now)
|
||||||
|
|
||||||
crud.delete_login_session(session, created_session.session_token_hash)
|
crud.delete_login_session(session, created_session.session_token_hash)
|
||||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
|
deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||||
session, created_session.session_token_hash, created_session.user_id
|
|
||||||
)
|
|
||||||
assert deleted_session is None
|
assert deleted_session is None
|
||||||
|
|||||||
@@ -46,9 +46,8 @@ def database_ctx(db: Database) -> Generator[Database, None, None]:
|
|||||||
|
|
||||||
def test_select_one_executes() -> None:
|
def test_select_one_executes() -> None:
|
||||||
db = create_database(None) # in-memory by default
|
db = create_database(None) # in-memory by default
|
||||||
with database_ctx(db):
|
with database_ctx(db), session_ctx(db) as session:
|
||||||
with session_ctx(db) as session:
|
val = session.exec(text("SELECT 1")).scalar_one()
|
||||||
val = session.exec(text("SELECT 1")).scalar_one()
|
|
||||||
assert int(val) == 1
|
assert int(val) == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -56,9 +55,7 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None:
|
|||||||
db = create_database(None) # in-memory with StaticPool
|
db = create_database(None) # in-memory with StaticPool
|
||||||
with database_ctx(db):
|
with database_ctx(db):
|
||||||
with session_ctx(db) as s1:
|
with session_ctx(db) as s1:
|
||||||
s1.exec(
|
s1.exec(text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);"))
|
||||||
text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);")
|
|
||||||
)
|
|
||||||
s1.exec(text("INSERT INTO t (val) VALUES (:v)").bindparams(v="hello"))
|
s1.exec(text("INSERT INTO t (val) VALUES (:v)").bindparams(v="hello"))
|
||||||
with session_ctx(db) as s2:
|
with session_ctx(db) as s2:
|
||||||
got = s2.exec(text("SELECT val FROM t")).scalar_one()
|
got = s2.exec(text("SELECT val FROM t")).scalar_one()
|
||||||
@@ -67,10 +64,9 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None:
|
|||||||
|
|
||||||
def test_sqlite_pragmas_applied() -> None:
|
def test_sqlite_pragmas_applied() -> None:
|
||||||
db = create_database(None)
|
db = create_database(None)
|
||||||
with database_ctx(db):
|
with database_ctx(db), session_ctx(db) as session:
|
||||||
# PRAGMA returns integer 1 when foreign_keys ON
|
# PRAGMA returns integer 1 when foreign_keys ON
|
||||||
with session_ctx(db) as session:
|
fk = session.exec(text("PRAGMA foreign_keys")).scalar_one()
|
||||||
fk = session.exec(text("PRAGMA foreign_keys")).scalar_one()
|
|
||||||
assert int(fk) == 1
|
assert int(fk) == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -82,16 +78,8 @@ def test_rollback_on_exception() -> None:
|
|||||||
# Create table then insert and raise inside the same session to force rollback
|
# Create table then insert and raise inside the same session to force rollback
|
||||||
with pytest.raises(RuntimeError): # noqa: PT012, SIM117
|
with pytest.raises(RuntimeError): # noqa: PT012, SIM117
|
||||||
with session_ctx(db) as s:
|
with session_ctx(db) as s:
|
||||||
s.exec(
|
s.exec(text("CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);"))
|
||||||
text(
|
s.exec(text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(v="will_rollback"))
|
||||||
"CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
s.exec(
|
|
||||||
text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(
|
|
||||||
v="will_rollback"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# simulate handler error -> should trigger rollback in get_session
|
# simulate handler error -> should trigger rollback in get_session
|
||||||
raise RuntimeError("simulated failure")
|
raise RuntimeError("simulated failure")
|
||||||
|
|
||||||
|
|||||||
@@ -36,33 +36,66 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
"user_id": ("INTEGER", 1, 0),
|
"user_id": ("INTEGER", 1, 0),
|
||||||
"friendly_name": ("TEXT", 0, 0),
|
"friendly_name": ("TEXT", 0, 0),
|
||||||
"symbol": ("TEXT", 1, 0),
|
"symbol": ("TEXT", 1, 0),
|
||||||
|
"exchange_id": ("INTEGER", 1, 0),
|
||||||
"underlying_currency": ("TEXT", 1, 0),
|
"underlying_currency": ("TEXT", 1, 0),
|
||||||
"status": ("TEXT", 1, 0),
|
"status": ("TEXT", 1, 0),
|
||||||
"funding_source": ("TEXT", 0, 0),
|
"funding_source": ("TEXT", 0, 0),
|
||||||
"capital_exposure_cents": ("INTEGER", 0, 0),
|
"capital_exposure_cents": ("INTEGER", 0, 0),
|
||||||
"loan_amount_cents": ("INTEGER", 0, 0),
|
"loan_amount_cents": ("INTEGER", 0, 0),
|
||||||
"loan_interest_rate_bps": ("INTEGER", 0, 0),
|
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
|
||||||
"start_date": ("DATE", 1, 0),
|
"start_date": ("DATE", 1, 0),
|
||||||
"end_date": ("DATE", 0, 0),
|
"end_date": ("DATE", 0, 0),
|
||||||
|
"latest_interest_accrued_date": ("DATE", 0, 0),
|
||||||
|
"total_accrued_amount_cents": ("INTEGER", 1, 0),
|
||||||
|
},
|
||||||
|
"cycle_loan_change_events": {
|
||||||
|
"id": ("INTEGER", 1, 1),
|
||||||
|
"cycle_id": ("INTEGER", 1, 0),
|
||||||
|
"effective_date": ("DATE", 1, 0),
|
||||||
|
"loan_amount_cents": ("INTEGER", 0, 0),
|
||||||
|
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
|
||||||
|
"related_trade_id": ("INTEGER", 0, 0),
|
||||||
|
"notes": ("TEXT", 0, 0),
|
||||||
|
"created_at": ("DATETIME", 1, 0),
|
||||||
|
},
|
||||||
|
"cycle_daily_accrual": {
|
||||||
|
"id": ("INTEGER", 1, 1),
|
||||||
|
"cycle_id": ("INTEGER", 1, 0),
|
||||||
|
"accrual_date": ("DATE", 1, 0),
|
||||||
|
"accrual_amount_cents": ("INTEGER", 1, 0),
|
||||||
|
"created_at": ("DATETIME", 1, 0),
|
||||||
},
|
},
|
||||||
"trades": {
|
"trades": {
|
||||||
"id": ("INTEGER", 1, 1),
|
"id": ("INTEGER", 1, 1),
|
||||||
"user_id": ("INTEGER", 1, 0),
|
"user_id": ("INTEGER", 1, 0),
|
||||||
"friendly_name": ("TEXT", 0, 0),
|
"friendly_name": ("TEXT", 0, 0),
|
||||||
"symbol": ("TEXT", 1, 0),
|
"symbol": ("TEXT", 1, 0),
|
||||||
|
"exchange_id": ("INTEGER", 1, 0),
|
||||||
"underlying_currency": ("TEXT", 1, 0),
|
"underlying_currency": ("TEXT", 1, 0),
|
||||||
"trade_type": ("TEXT", 1, 0),
|
"trade_type": ("TEXT", 1, 0),
|
||||||
"trade_strategy": ("TEXT", 1, 0),
|
"trade_strategy": ("TEXT", 1, 0),
|
||||||
|
"trade_date": ("DATE", 1, 0),
|
||||||
"trade_time_utc": ("DATETIME", 1, 0),
|
"trade_time_utc": ("DATETIME", 1, 0),
|
||||||
"expiry_date": ("DATE", 0, 0),
|
"expiry_date": ("DATE", 0, 0),
|
||||||
"strike_price_cents": ("INTEGER", 0, 0),
|
"strike_price_cents": ("INTEGER", 0, 0),
|
||||||
"quantity": ("INTEGER", 1, 0),
|
"quantity": ("INTEGER", 1, 0),
|
||||||
|
"quantity_multiplier": ("INTEGER", 1, 0),
|
||||||
"price_cents": ("INTEGER", 1, 0),
|
"price_cents": ("INTEGER", 1, 0),
|
||||||
"gross_cash_flow_cents": ("INTEGER", 1, 0),
|
"gross_cash_flow_cents": ("INTEGER", 1, 0),
|
||||||
"commission_cents": ("INTEGER", 1, 0),
|
"commission_cents": ("INTEGER", 1, 0),
|
||||||
"net_cash_flow_cents": ("INTEGER", 1, 0),
|
"net_cash_flow_cents": ("INTEGER", 1, 0),
|
||||||
|
"is_invalidated": ("BOOLEAN", 1, 0),
|
||||||
|
"invalidated_at": ("DATETIME", 0, 0),
|
||||||
|
"replaced_by_trade_id": ("INTEGER", 0, 0),
|
||||||
|
"notes": ("TEXT", 0, 0),
|
||||||
"cycle_id": ("INTEGER", 0, 0),
|
"cycle_id": ("INTEGER", 0, 0),
|
||||||
},
|
},
|
||||||
|
"exchanges": {
|
||||||
|
"id": ("INTEGER", 1, 1),
|
||||||
|
"user_id": ("INTEGER", 1, 0),
|
||||||
|
"name": ("TEXT", 1, 0),
|
||||||
|
"notes": ("TEXT", 0, 0),
|
||||||
|
},
|
||||||
"sessions": {
|
"sessions": {
|
||||||
"id": ("INTEGER", 1, 1),
|
"id": ("INTEGER", 1, 1),
|
||||||
"user_id": ("INTEGER", 1, 0),
|
"user_id": ("INTEGER", 1, 0),
|
||||||
@@ -80,21 +113,35 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
"trades": [
|
"trades": [
|
||||||
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||||
{"table": "users", "from": "user_id", "to": "id"},
|
{"table": "users", "from": "user_id", "to": "id"},
|
||||||
|
{"table": "exchanges", "from": "exchange_id", "to": "id"},
|
||||||
],
|
],
|
||||||
"cycles": [
|
"cycles": [
|
||||||
{"table": "users", "from": "user_id", "to": "id"},
|
{"table": "users", "from": "user_id", "to": "id"},
|
||||||
|
{"table": "exchanges", "from": "exchange_id", "to": "id"},
|
||||||
|
],
|
||||||
|
"cycle_loan_change_events": [
|
||||||
|
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||||
|
{"table": "trades", "from": "related_trade_id", "to": "id"},
|
||||||
|
],
|
||||||
|
"cycle_daily_accrual": [
|
||||||
|
{"table": "cycles", "from": "cycle_id", "to": "id"},
|
||||||
|
],
|
||||||
|
"sessions": [
|
||||||
|
{"table": "users", "from": "user_id", "to": "id"},
|
||||||
|
],
|
||||||
|
"users": [],
|
||||||
|
"exchanges": [
|
||||||
|
{"table": "users", "from": "user_id", "to": "id"},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
# check tables exist
|
# check tables exist
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
text("SELECT name FROM sqlite_master WHERE type='table'")
|
text("SELECT name FROM sqlite_master WHERE type='table'"),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
found_tables = {r[0] for r in rows}
|
found_tables = {r[0] for r in rows}
|
||||||
assert set(expected_schema.keys()).issubset(found_tables), (
|
assert set(expected_schema.keys()).issubset(found_tables), f"missing tables: {set(expected_schema.keys()) - found_tables}"
|
||||||
f"missing tables: {set(expected_schema.keys()) - found_tables}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check user_version
|
# check user_version
|
||||||
uv = conn.execute(text("PRAGMA user_version")).fetchone()
|
uv = conn.execute(text("PRAGMA user_version")).fetchone()
|
||||||
@@ -103,14 +150,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
|
|
||||||
# validate each table columns
|
# validate each table columns
|
||||||
for tbl_name, cols in expected_schema.items():
|
for tbl_name, cols in expected_schema.items():
|
||||||
info_rows = conn.execute(
|
info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall()
|
||||||
text(f"PRAGMA table_info({tbl_name})")
|
|
||||||
).fetchall()
|
|
||||||
# map: name -> (type, notnull, pk)
|
# map: name -> (type, notnull, pk)
|
||||||
actual = {
|
actual = {r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) for r in info_rows}
|
||||||
r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5]))
|
|
||||||
for r in info_rows
|
|
||||||
}
|
|
||||||
for colname, (exp_type, exp_notnull, exp_pk) in cols.items():
|
for colname, (exp_type, exp_notnull, exp_pk) in cols.items():
|
||||||
assert colname in actual, f"{tbl_name}: missing column {colname}"
|
assert colname in actual, f"{tbl_name}: missing column {colname}"
|
||||||
act_type, act_notnull, act_pk = actual[colname]
|
act_type, act_notnull, act_pk = actual[colname]
|
||||||
@@ -122,22 +164,47 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert exp_type in act_base or act_base in exp_type, (
|
assert exp_type in act_base or act_base in exp_type, (
|
||||||
f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}"
|
f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}"
|
||||||
)
|
)
|
||||||
assert act_notnull == exp_notnull, (
|
assert act_notnull == exp_notnull, f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}"
|
||||||
f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}"
|
assert act_pk == exp_pk, f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}"
|
||||||
)
|
|
||||||
assert act_pk == exp_pk, (
|
|
||||||
f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}"
|
|
||||||
)
|
|
||||||
for tbl_name, fks in expected_fks.items():
|
for tbl_name, fks in expected_fks.items():
|
||||||
fk_rows = conn.execute(
|
fk_rows = conn.execute(text(f"PRAGMA foreign_key_list('{tbl_name}')")).fetchall()
|
||||||
text(f"PRAGMA foreign_key_list('{tbl_name}')")
|
|
||||||
).fetchall()
|
|
||||||
# fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match)
|
# fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match)
|
||||||
actual_fk_list = [
|
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
|
||||||
{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows
|
|
||||||
]
|
|
||||||
for efk in fks:
|
for efk in fks:
|
||||||
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"
|
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"
|
||||||
|
|
||||||
|
# check trades.replaced_by_trade_id self-referential FK
|
||||||
|
fk_rows = conn.execute(text("PRAGMA foreign_key_list('trades')")).fetchall()
|
||||||
|
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
|
||||||
|
assert {"table": "trades", "from": "replaced_by_trade_id", "to": "id"} in actual_fk_list, (
|
||||||
|
"missing self FK trades.replaced_by_trade_id -> trades.id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# helper to find unique index on a column
|
||||||
|
def has_unique_index(table: str, column: str) -> bool:
|
||||||
|
idx_rows = conn.execute(text(f"PRAGMA index_list('{table}')")).fetchall()
|
||||||
|
for idx in idx_rows:
|
||||||
|
idx_name = idx[1]
|
||||||
|
is_unique = bool(idx[2])
|
||||||
|
if not is_unique:
|
||||||
|
continue
|
||||||
|
info = conn.execute(text(f"PRAGMA index_info('{idx_name}')")).fetchall()
|
||||||
|
cols = [r[2] for r in info]
|
||||||
|
if column in cols:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert has_unique_index("trades", "friendly_name"), (
|
||||||
|
"expected unique index on trades(friendly_name) per uq_trades_user_friendly_name"
|
||||||
|
)
|
||||||
|
assert has_unique_index("cycles", "friendly_name"), (
|
||||||
|
"expected unique index on cycles(friendly_name) per uq_cycles_user_friendly_name"
|
||||||
|
)
|
||||||
|
assert has_unique_index("exchanges", "name"), "expected unique index on exchanges(name) per uq_exchanges_user_name"
|
||||||
|
assert has_unique_index("sessions", "session_token_hash"), "expected unique index on sessions(session_token_hash)"
|
||||||
|
assert has_unique_index("cycle_loan_change_events", "related_trade_id"), (
|
||||||
|
"expected unique index on cycle_loan_change_events(related_trade_id)"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
engine.dispose()
|
engine.dispose()
|
||||||
SQLModel.metadata.clear()
|
SQLModel.metadata.clear()
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
from app import app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
with TestClient(app) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
def test_home_route(client):
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"message": "Hello"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_about_route(client):
|
|
||||||
response = client.get("/about")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"message": "This is the about page."}
|
|
||||||
24
backend/tests/test_security.py
Normal file
24
backend/tests/test_security.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from trading_journal import security
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_and_verify_password() -> None:
|
||||||
|
plain = "password"
|
||||||
|
hashed = security.hash_password(plain)
|
||||||
|
assert hashed != plain
|
||||||
|
assert security.verify_password(plain, hashed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_session_token() -> None:
|
||||||
|
token1 = security.generate_session_token()
|
||||||
|
token2 = security.generate_session_token()
|
||||||
|
assert token1 != token2
|
||||||
|
assert len(token1) > 0
|
||||||
|
assert len(token2) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_and_verify_session_token_sha256() -> None:
|
||||||
|
token = security.generate_session_token()
|
||||||
|
token_hash = security.hash_session_token_sha256(token)
|
||||||
|
assert token_hash != token
|
||||||
|
assert security.verify_token_sha256(token, token_hash)
|
||||||
|
assert not security.verify_token_sha256(token + "x", token_hash)
|
||||||
1101
backend/tests/test_service.py
Normal file
1101
backend/tests/test_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ def test_default_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
|
|
||||||
s = load_settings()
|
s = load_settings()
|
||||||
assert s.host == "0.0.0.0" # noqa: S104
|
assert s.host == "0.0.0.0" # noqa: S104
|
||||||
assert s.port == 8000 # noqa: PLR2004
|
assert s.port == 8000
|
||||||
assert s.workers == 1
|
assert s.workers == 1
|
||||||
assert s.log_level == "info"
|
assert s.log_level == "info"
|
||||||
|
|
||||||
@@ -26,8 +26,8 @@ def test_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
|
|
||||||
s = load_settings()
|
s = load_settings()
|
||||||
assert s.host == "127.0.0.1"
|
assert s.host == "127.0.0.1"
|
||||||
assert s.port == 9000 # noqa: PLR2004
|
assert s.port == 9000
|
||||||
assert s.workers == 3 # noqa: PLR2004
|
assert s.workers == 3
|
||||||
assert s.log_level == "debug"
|
assert s.log_level == "debug"
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +40,6 @@ def test_yaml_config_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> No
|
|||||||
|
|
||||||
s = load_settings()
|
s = load_settings()
|
||||||
assert s.host == "10.0.0.5"
|
assert s.host == "10.0.0.5"
|
||||||
assert s.port == 8088 # noqa: PLR2004
|
assert s.port == 8088
|
||||||
assert s.workers == 5 # noqa: PLR2004
|
assert s.workers == 5
|
||||||
assert s.log_level == "debug"
|
assert s.log_level == "debug"
|
||||||
|
|||||||
@@ -1,13 +1,26 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from __future__ import annotations
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
|
from datetime import date, datetime, timedelta, timezone
|
||||||
|
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from trading_journal import models
|
from trading_journal import models
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
def _check_enum(enum_cls, value, field_name: str):
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
|
||||||
|
# Generic enum member type
|
||||||
|
T = TypeVar("T", bound="Enum")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_enum(enum_cls: type[T], value: object, field_name: str) -> T:
|
||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError(f"{field_name} is required")
|
raise ValueError(f"{field_name} is required")
|
||||||
# already an enum member
|
# already an enum member
|
||||||
@@ -22,38 +35,56 @@ def _check_enum(enum_cls, value, field_name: str):
|
|||||||
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
|
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_columns(model: type[models.SQLModel]) -> set[str]:
|
||||||
|
tbl = cast("models.SQLModel", model).__table__ # type: ignore[attr-defined]
|
||||||
|
return {c.name for c in tbl.columns}
|
||||||
|
|
||||||
|
|
||||||
|
AnyModel = Any
|
||||||
|
|
||||||
|
|
||||||
|
def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]:
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
return data.model_dump(exclude_unset=True)
|
||||||
|
if hasattr(data, "dict"):
|
||||||
|
return data.dict(exclude_unset=True)
|
||||||
|
return dict(data)
|
||||||
|
|
||||||
|
|
||||||
# Trades
|
# Trades
|
||||||
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
|
||||||
if hasattr(trade_data, "dict"):
|
data = _data_to_dict(trade_data)
|
||||||
data = trade_data.dict(exclude_unset=True)
|
allowed = _allowed_columns(models.Trades)
|
||||||
else:
|
|
||||||
data = dict(trade_data)
|
|
||||||
allowed = {c.name for c in models.Trades.__table__.columns}
|
|
||||||
payload = {k: v for k, v in data.items() if k in allowed}
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
|
cycle_id = payload.get("cycle_id")
|
||||||
if "symbol" not in payload:
|
if "symbol" not in payload:
|
||||||
raise ValueError("symbol is required")
|
raise ValueError("symbol is required")
|
||||||
|
if "exchange_id" not in payload and cycle_id is None:
|
||||||
|
raise ValueError("exchange_id is required when no cycle is attached")
|
||||||
|
# If an exchange_id is provided (and no cycle is attached), ensure the exchange exists
|
||||||
|
# and belongs to the same user as the trade (if user_id is provided).
|
||||||
|
if cycle_id is None and "exchange_id" in payload:
|
||||||
|
ex = session.get(models.Exchanges, payload["exchange_id"])
|
||||||
|
if ex is None:
|
||||||
|
raise ValueError("exchange_id does not exist")
|
||||||
|
user_id = payload.get("user_id")
|
||||||
|
if user_id is not None and ex.user_id != user_id:
|
||||||
|
raise ValueError("exchange.user_id does not match trade.user_id")
|
||||||
if "underlying_currency" not in payload:
|
if "underlying_currency" not in payload:
|
||||||
raise ValueError("underlying_currency is required")
|
raise ValueError("underlying_currency is required")
|
||||||
payload["underlying_currency"] = _check_enum(
|
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
|
||||||
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
|
|
||||||
)
|
|
||||||
if "trade_type" not in payload:
|
if "trade_type" not in payload:
|
||||||
raise ValueError("trade_type is required")
|
raise ValueError("trade_type is required")
|
||||||
payload["trade_type"] = _check_enum(
|
payload["trade_type"] = _check_enum(models.TradeType, payload["trade_type"], "trade_type")
|
||||||
models.TradeType, payload["trade_type"], "trade_type"
|
|
||||||
)
|
|
||||||
if "trade_strategy" not in payload:
|
if "trade_strategy" not in payload:
|
||||||
raise ValueError("trade_strategy is required")
|
raise ValueError("trade_strategy is required")
|
||||||
payload["trade_strategy"] = _check_enum(
|
payload["trade_strategy"] = _check_enum(models.TradeStrategy, payload["trade_strategy"], "trade_strategy")
|
||||||
models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
|
|
||||||
)
|
|
||||||
# trade_time_utc is the creation moment: always set to now (caller shouldn't provide)
|
# trade_time_utc is the creation moment: always set to now (caller shouldn't provide)
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
payload.pop("trade_time_utc", None)
|
payload.pop("trade_time_utc", None)
|
||||||
payload["trade_time_utc"] = now
|
payload["trade_time_utc"] = now
|
||||||
if "trade_date" not in payload or payload.get("trade_date") is None:
|
if "trade_date" not in payload or payload.get("trade_date") is None:
|
||||||
payload["trade_date"] = payload["trade_time_utc"].date()
|
payload["trade_date"] = payload["trade_time_utc"].date()
|
||||||
cycle_id = payload.get("cycle_id")
|
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
if "quantity" not in payload:
|
if "quantity" not in payload:
|
||||||
raise ValueError("quantity is required")
|
raise ValueError("quantity is required")
|
||||||
@@ -61,15 +92,10 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
|||||||
raise ValueError("price_cents is required")
|
raise ValueError("price_cents is required")
|
||||||
if "commission_cents" not in payload:
|
if "commission_cents" not in payload:
|
||||||
payload["commission_cents"] = 0
|
payload["commission_cents"] = 0
|
||||||
quantity: int = payload["quantity"]
|
|
||||||
price_cents: int = payload["price_cents"]
|
|
||||||
commission_cents: int = payload["commission_cents"]
|
|
||||||
if "gross_cash_flow_cents" not in payload:
|
if "gross_cash_flow_cents" not in payload:
|
||||||
payload["gross_cash_flow_cents"] = -quantity * price_cents
|
raise ValueError("gross_cash_flow_cents is required")
|
||||||
if "net_cash_flow_cents" not in payload:
|
if "net_cash_flow_cents" not in payload:
|
||||||
payload["net_cash_flow_cents"] = (
|
raise ValueError("net_cash_flow_cents is required")
|
||||||
payload["gross_cash_flow_cents"] - commission_cents
|
|
||||||
)
|
|
||||||
|
|
||||||
# If no cycle_id provided, create Cycle instance but don't call create_cycle()
|
# If no cycle_id provided, create Cycle instance but don't call create_cycle()
|
||||||
created_cycle = None
|
created_cycle = None
|
||||||
@@ -77,9 +103,9 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
|||||||
c_payload = {
|
c_payload = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"symbol": payload["symbol"],
|
"symbol": payload["symbol"],
|
||||||
|
"exchange_id": payload["exchange_id"],
|
||||||
"underlying_currency": payload["underlying_currency"],
|
"underlying_currency": payload["underlying_currency"],
|
||||||
"friendly_name": "Auto-created Cycle by trade "
|
"friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""),
|
||||||
+ payload.get("friendly_name", ""),
|
|
||||||
"status": models.CycleStatus.OPEN,
|
"status": models.CycleStatus.OPEN,
|
||||||
"start_date": payload["trade_date"],
|
"start_date": payload["trade_date"],
|
||||||
}
|
}
|
||||||
@@ -90,11 +116,13 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
|||||||
# If cycle_id provided, validate existence and ownership
|
# If cycle_id provided, validate existence and ownership
|
||||||
if cycle_id is not None:
|
if cycle_id is not None:
|
||||||
cycle = session.get(models.Cycles, cycle_id)
|
cycle = session.get(models.Cycles, cycle_id)
|
||||||
|
|
||||||
if cycle is None:
|
if cycle is None:
|
||||||
raise ValueError("cycle_id does not exist")
|
raise ValueError("cycle_id does not exist")
|
||||||
else:
|
payload.pop("exchange_id", None) # ignore exchange_id if provided; use cycle's exchange_id
|
||||||
if cycle.user_id != user_id:
|
payload["exchange_id"] = cycle.exchange_id
|
||||||
raise ValueError("cycle.user_id does not match trade.user_id")
|
if cycle.user_id != user_id:
|
||||||
|
raise ValueError("cycle.user_id does not match trade.user_id")
|
||||||
|
|
||||||
# Build trade instance; if we created a Cycle instance, link via relationship so a single flush will persist both and populate ids
|
# Build trade instance; if we created a Cycle instance, link via relationship so a single flush will persist both and populate ids
|
||||||
t_payload = dict(payload)
|
t_payload = dict(payload)
|
||||||
@@ -119,9 +147,7 @@ def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None:
|
|||||||
return session.get(models.Trades, trade_id)
|
return session.get(models.Trades, trade_id)
|
||||||
|
|
||||||
|
|
||||||
def get_trade_by_user_id_and_friendly_name(
|
def get_trade_by_user_id_and_friendly_name(session: Session, user_id: int, friendly_name: str) -> models.Trades | None:
|
||||||
session: Session, user_id: int, friendly_name: str
|
|
||||||
) -> models.Trades | None:
|
|
||||||
statement = select(models.Trades).where(
|
statement = select(models.Trades).where(
|
||||||
models.Trades.user_id == user_id,
|
models.Trades.user_id == user_id,
|
||||||
models.Trades.friendly_name == friendly_name,
|
models.Trades.friendly_name == friendly_name,
|
||||||
@@ -133,7 +159,22 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]
|
|||||||
statement = select(models.Trades).where(
|
statement = select(models.Trades).where(
|
||||||
models.Trades.user_id == user_id,
|
models.Trades.user_id == user_id,
|
||||||
)
|
)
|
||||||
return session.exec(statement).all()
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades:
|
||||||
|
trade: models.Trades | None = session.get(models.Trades, trade_id)
|
||||||
|
if trade is None:
|
||||||
|
raise ValueError("trade_id does not exist")
|
||||||
|
trade.friendly_name = friendly_name
|
||||||
|
session.add(trade)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("update_trade_friendly_name integrity error") from e
|
||||||
|
session.refresh(trade)
|
||||||
|
return trade
|
||||||
|
|
||||||
|
|
||||||
def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades:
|
def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades:
|
||||||
@@ -169,36 +210,33 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades:
|
|||||||
return trade
|
return trade
|
||||||
|
|
||||||
|
|
||||||
def replace_trade(
|
def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
|
||||||
session: Session, old_trade_id: int, new_trade_data: Mapping
|
|
||||||
) -> models.Trades:
|
|
||||||
invalidate_trade(session, old_trade_id)
|
invalidate_trade(session, old_trade_id)
|
||||||
if hasattr(new_trade_data, "dict"):
|
data = _data_to_dict(new_trade_data)
|
||||||
data = new_trade_data.dict(exclude_unset=True)
|
|
||||||
else:
|
|
||||||
data = dict(new_trade_data)
|
|
||||||
data["replaced_by_trade_id"] = old_trade_id
|
data["replaced_by_trade_id"] = old_trade_id
|
||||||
new_trade = create_trade(session, data)
|
return create_trade(session, data)
|
||||||
return new_trade
|
|
||||||
|
|
||||||
|
|
||||||
# Cycles
|
# Cycles
|
||||||
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
|
def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
|
||||||
if hasattr(cycle_data, "dict"):
|
data = _data_to_dict(cycle_data)
|
||||||
data = cycle_data.dict(exclude_unset=True)
|
allowed = _allowed_columns(models.Cycles)
|
||||||
else:
|
|
||||||
data = dict(cycle_data)
|
|
||||||
allowed = {c.name for c in models.Cycles.__table__.columns}
|
|
||||||
payload = {k: v for k, v in data.items() if k in allowed}
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
if "user_id" not in payload:
|
if "user_id" not in payload:
|
||||||
raise ValueError("user_id is required")
|
raise ValueError("user_id is required")
|
||||||
if "symbol" not in payload:
|
if "symbol" not in payload:
|
||||||
raise ValueError("symbol is required")
|
raise ValueError("symbol is required")
|
||||||
|
if "exchange_id" not in payload:
|
||||||
|
raise ValueError("exchange_id is required")
|
||||||
|
# ensure the exchange exists and belongs to the same user
|
||||||
|
ex = session.get(models.Exchanges, payload["exchange_id"])
|
||||||
|
if ex is None:
|
||||||
|
raise ValueError("exchange_id does not exist")
|
||||||
|
if ex.user_id != payload.get("user_id"):
|
||||||
|
raise ValueError("exchange.user_id does not match cycle.user_id")
|
||||||
if "underlying_currency" not in payload:
|
if "underlying_currency" not in payload:
|
||||||
raise ValueError("underlying_currency is required")
|
raise ValueError("underlying_currency is required")
|
||||||
payload["underlying_currency"] = _check_enum(
|
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
|
||||||
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
|
|
||||||
)
|
|
||||||
if "status" not in payload:
|
if "status" not in payload:
|
||||||
raise ValueError("status is required")
|
raise ValueError("status is required")
|
||||||
payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status")
|
payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status")
|
||||||
@@ -216,30 +254,44 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"}
|
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date"}
|
||||||
|
|
||||||
|
|
||||||
def update_cycle(
|
def get_cycle_by_id(session: Session, cycle_id: int) -> models.Cycles | None:
|
||||||
session: Session, cycle_id: int, update_data: Mapping
|
return session.get(models.Cycles, cycle_id)
|
||||||
) -> models.Cycles:
|
|
||||||
|
|
||||||
|
def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles]:
|
||||||
|
statement = select(models.Cycles).where(
|
||||||
|
models.Cycles.user_id == user_id,
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
|
||||||
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||||
if cycle is None:
|
if cycle is None:
|
||||||
raise ValueError("cycle_id does not exist")
|
raise ValueError("cycle_id does not exist")
|
||||||
if hasattr(update_data, "dict"):
|
data = _data_to_dict(update_data)
|
||||||
data = update_data.dict(exclude_unset=True)
|
|
||||||
else:
|
|
||||||
data = dict(update_data)
|
|
||||||
|
|
||||||
allowed = {c.name for c in models.Cycles.__table__.columns}
|
allowed = _allowed_columns(models.Cycles)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if k in IMMUTABLE_CYCLE_FIELDS:
|
if k in IMMUTABLE_CYCLE_FIELDS:
|
||||||
raise ValueError(f"field {k!r} is immutable")
|
raise ValueError(f"field {k!r} is immutable")
|
||||||
if k not in allowed:
|
if k not in allowed:
|
||||||
continue
|
continue
|
||||||
|
# If trying to change exchange_id, ensure the new exchange exists and belongs to
|
||||||
|
# the same user as the cycle.
|
||||||
|
if k == "exchange_id":
|
||||||
|
ex = session.get(models.Exchanges, v)
|
||||||
|
if ex is None:
|
||||||
|
raise ValueError("exchange_id does not exist")
|
||||||
|
if ex.user_id != cycle.user_id:
|
||||||
|
raise ValueError("exchange.user_id does not match cycle.user_id")
|
||||||
if k == "underlying_currency":
|
if k == "underlying_currency":
|
||||||
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency")
|
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901
|
||||||
if k == "status":
|
if k == "status":
|
||||||
v = _check_enum(models.CycleStatus, v, "status")
|
v = _check_enum(models.CycleStatus, v, "status") # noqa: PLW2901
|
||||||
setattr(cycle, k, v)
|
setattr(cycle, k, v)
|
||||||
session.add(cycle)
|
session.add(cycle)
|
||||||
try:
|
try:
|
||||||
@@ -251,16 +303,179 @@ def update_cycle(
|
|||||||
return cycle
|
return cycle
|
||||||
|
|
||||||
|
|
||||||
|
# Cycle loan and interest
|
||||||
|
def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
|
||||||
|
data = _data_to_dict(loan_data)
|
||||||
|
allowed = _allowed_columns(models.CycleLoanChangeEvents)
|
||||||
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
|
if "cycle_id" not in payload:
|
||||||
|
raise ValueError("cycle_id is required")
|
||||||
|
cycle = session.get(models.Cycles, payload["cycle_id"])
|
||||||
|
if cycle is None:
|
||||||
|
raise ValueError("cycle_id does not exist")
|
||||||
|
|
||||||
|
payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date()
|
||||||
|
payload["created_at"] = datetime.now(timezone.utc)
|
||||||
|
cle = models.CycleLoanChangeEvents(**payload)
|
||||||
|
session.add(cle)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("create_cycle_loan_event integrity error") from e
|
||||||
|
session.refresh(cle)
|
||||||
|
return cle
|
||||||
|
|
||||||
|
|
||||||
|
def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]:
|
||||||
|
eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date)
|
||||||
|
id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id)
|
||||||
|
statement = (
|
||||||
|
select(models.CycleLoanChangeEvents)
|
||||||
|
.where(
|
||||||
|
models.CycleLoanChangeEvents.cycle_id == cycle_id,
|
||||||
|
)
|
||||||
|
.order_by(eff_col, id_col.asc())
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual:
|
||||||
|
cycle = session.get(models.Cycles, cycle_id)
|
||||||
|
if cycle is None:
|
||||||
|
raise ValueError("cycle_id does not exist")
|
||||||
|
existing = session.exec(
|
||||||
|
select(models.CycleDailyAccrual).where(
|
||||||
|
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||||
|
models.CycleDailyAccrual.accrual_date == accrual_date,
|
||||||
|
),
|
||||||
|
).first()
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
if accrual_amount_cents < 0:
|
||||||
|
raise ValueError("accrual_amount_cents must be non-negative")
|
||||||
|
row = models.CycleDailyAccrual(
|
||||||
|
cycle_id=cycle_id,
|
||||||
|
accrual_date=accrual_date,
|
||||||
|
accrual_amount_cents=accrual_amount_cents,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("create_cycle_daily_accrual integrity error") from e
|
||||||
|
session.refresh(row)
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]:
|
||||||
|
date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date)
|
||||||
|
statement = (
|
||||||
|
select(models.CycleDailyAccrual)
|
||||||
|
.where(
|
||||||
|
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||||
|
)
|
||||||
|
.order_by(date_col.asc())
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None:
|
||||||
|
statement = select(models.CycleDailyAccrual).where(
|
||||||
|
models.CycleDailyAccrual.cycle_id == cycle_id,
|
||||||
|
models.CycleDailyAccrual.accrual_date == accrual_date,
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges
|
||||||
|
IMMUTABLE_EXCHANGE_FIELDS = {"id"}
|
||||||
|
|
||||||
|
|
||||||
|
def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
|
||||||
|
data = _data_to_dict(exchange_data)
|
||||||
|
allowed = _allowed_columns(models.Exchanges)
|
||||||
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
|
if "name" not in payload:
|
||||||
|
raise ValueError("name is required")
|
||||||
|
|
||||||
|
e = models.Exchanges(**payload)
|
||||||
|
session.add(e)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("create_exchange integrity error") from e
|
||||||
|
session.refresh(e)
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None:
|
||||||
|
return session.get(models.Exchanges, exchange_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None:
|
||||||
|
statement = select(models.Exchanges).where(
|
||||||
|
models.Exchanges.name == name,
|
||||||
|
models.Exchanges.user_id == user_id,
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_exchanges(session: Session) -> list[models.Exchanges]:
|
||||||
|
statement = select(models.Exchanges)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]:
|
||||||
|
statement = select(models.Exchanges).where(
|
||||||
|
models.Exchanges.user_id == user_id,
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
|
||||||
|
def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
|
||||||
|
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||||
|
if exchange is None:
|
||||||
|
raise ValueError("exchange_id does not exist")
|
||||||
|
data = _data_to_dict(update_data)
|
||||||
|
allowed = _allowed_columns(models.Exchanges)
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in IMMUTABLE_EXCHANGE_FIELDS:
|
||||||
|
raise ValueError(f"field {k!r} is immutable")
|
||||||
|
if k in allowed:
|
||||||
|
setattr(exchange, k, v)
|
||||||
|
session.add(exchange)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("update_exchange integrity error") from e
|
||||||
|
session.refresh(exchange)
|
||||||
|
return exchange
|
||||||
|
|
||||||
|
|
||||||
|
def delete_exchange(session: Session, exchange_id: int) -> None:
|
||||||
|
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||||
|
if exchange is None:
|
||||||
|
return
|
||||||
|
session.delete(exchange)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("delete_exchange integrity error") from e
|
||||||
|
|
||||||
|
|
||||||
# Users
|
# Users
|
||||||
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
|
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
|
||||||
|
|
||||||
|
|
||||||
def create_user(session: Session, user_data: Mapping) -> models.Users:
|
def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users:
|
||||||
if hasattr(user_data, "dict"):
|
data = _data_to_dict(user_data)
|
||||||
data = user_data.dict(exclude_unset=True)
|
allowed = _allowed_columns(models.Users)
|
||||||
else:
|
|
||||||
data = dict(user_data)
|
|
||||||
allowed = {c.name for c in models.Users.__table__.columns}
|
|
||||||
payload = {k: v for k, v in data.items() if k in allowed}
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
if "username" not in payload:
|
if "username" not in payload:
|
||||||
raise ValueError("username is required")
|
raise ValueError("username is required")
|
||||||
@@ -278,15 +493,23 @@ def create_user(session: Session, user_data: Mapping) -> models.Users:
|
|||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users:
|
def get_user_by_id(session: Session, user_id: int) -> models.Users | None:
|
||||||
|
return session.get(models.Users, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_by_username(session: Session, username: str) -> models.Users | None:
|
||||||
|
statement = select(models.Users).where(
|
||||||
|
models.Users.username == username,
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users:
|
||||||
user: models.Users | None = session.get(models.Users, user_id)
|
user: models.Users | None = session.get(models.Users, user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise ValueError("user_id does not exist")
|
raise ValueError("user_id does not exist")
|
||||||
if hasattr(update_data, "dict"):
|
data = _data_to_dict(update_data)
|
||||||
data = update_data.dict(exclude_unset=True)
|
allowed = _allowed_columns(models.Users)
|
||||||
else:
|
|
||||||
data = dict(update_data)
|
|
||||||
allowed = {c.name for c in models.Users.__table__.columns}
|
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if k in IMMUTABLE_USER_FIELDS:
|
if k in IMMUTABLE_USER_FIELDS:
|
||||||
raise ValueError(f"field {k!r} is immutable")
|
raise ValueError(f"field {k!r} is immutable")
|
||||||
@@ -315,10 +538,11 @@ def create_login_session(
|
|||||||
user: models.Users | None = session.get(models.Users, user_id)
|
user: models.Users | None = session.get(models.Users, user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise ValueError("user_id does not exist")
|
raise ValueError("user_id does not exist")
|
||||||
|
user_id_val = cast("int", user.id)
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
expires_at = now + timedelta(seconds=session_length_seconds)
|
expires_at = now + timedelta(seconds=session_length_seconds)
|
||||||
s = models.Sessions(
|
s = models.Sessions(
|
||||||
user_id=user.id,
|
user_id=user_id_val,
|
||||||
session_token_hash=session_token_hash,
|
session_token_hash=session_token_hash,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
@@ -337,9 +561,7 @@ def create_login_session(
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def get_login_session_by_token_hash_and_user_id(
|
def get_login_session_by_token_hash_and_user_id(session: Session, session_token_hash: str, user_id: int) -> models.Sessions | None:
|
||||||
session: Session, session_token_hash: str, user_id: int
|
|
||||||
) -> models.Sessions | None:
|
|
||||||
statement = select(models.Sessions).where(
|
statement = select(models.Sessions).where(
|
||||||
models.Sessions.session_token_hash == session_token_hash,
|
models.Sessions.session_token_hash == session_token_hash,
|
||||||
models.Sessions.user_id == user_id,
|
models.Sessions.user_id == user_id,
|
||||||
@@ -349,25 +571,29 @@ def get_login_session_by_token_hash_and_user_id(
|
|||||||
return session.exec(statement).first()
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None:
|
||||||
|
statement = select(models.Sessions).where(
|
||||||
|
models.Sessions.session_token_hash == session_token_hash,
|
||||||
|
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
||||||
|
|
||||||
|
|
||||||
def update_login_session(
|
def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None:
|
||||||
session: Session, session_token_hashed: str, update_session: Mapping
|
|
||||||
) -> models.Sessions | None:
|
|
||||||
login_session: models.Sessions | None = session.exec(
|
login_session: models.Sessions | None = session.exec(
|
||||||
select(models.Sessions).where(
|
select(models.Sessions).where(
|
||||||
models.Sessions.session_token_hash == session_token_hashed,
|
models.Sessions.session_token_hash == session_token_hashed,
|
||||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||||
)
|
),
|
||||||
).first()
|
).first()
|
||||||
if login_session is None:
|
if login_session is None:
|
||||||
return None
|
return None
|
||||||
if hasattr(update_session, "dict"):
|
data = _data_to_dict(update_session)
|
||||||
data = update_session.dict(exclude_unset=True)
|
allowed = _allowed_columns(models.Sessions)
|
||||||
else:
|
|
||||||
data = dict(update_session)
|
|
||||||
allowed = {c.name for c in models.Sessions.__table__.columns}
|
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
|
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
|
||||||
setattr(login_session, k, v)
|
setattr(login_session, k, v)
|
||||||
@@ -385,7 +611,7 @@ def delete_login_session(session: Session, session_token_hash: str) -> None:
|
|||||||
login_session: models.Sessions | None = session.exec(
|
login_session: models.Sessions | None = session.exec(
|
||||||
select(models.Sessions).where(
|
select(models.Sessions).where(
|
||||||
models.Sessions.session_token_hash == session_token_hash,
|
models.Sessions.session_token_hash == session_token_hash,
|
||||||
)
|
),
|
||||||
).first()
|
).first()
|
||||||
if login_session is None:
|
if login_session is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
from sqlmodel import Session, create_engine
|
from sqlmodel import Session, create_engine
|
||||||
|
|
||||||
from trading_journal import db_migration
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from sqlite3 import Connection as DBAPIConnection
|
from sqlite3 import Connection as DBAPIConnection
|
||||||
@@ -24,17 +23,13 @@ class Database:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._database_url = database_url or "sqlite:///:memory:"
|
self._database_url = database_url or "sqlite:///:memory:"
|
||||||
|
|
||||||
default_connect = (
|
default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {}
|
||||||
{"check_same_thread": False, "timeout": 30}
|
|
||||||
if self._database_url.startswith("sqlite")
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
merged_connect = {**default_connect, **(connect_args or {})}
|
merged_connect = {**default_connect, **(connect_args or {})}
|
||||||
|
|
||||||
if self._database_url == "sqlite:///:memory:":
|
if self._database_url == "sqlite:///:memory:":
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Using in-memory SQLite database; all data will be lost when the application stops."
|
"Using in-memory SQLite database; all data will be lost when the application stops.",
|
||||||
)
|
)
|
||||||
self._engine = create_engine(
|
self._engine = create_engine(
|
||||||
self._database_url,
|
self._database_url,
|
||||||
@@ -43,15 +38,11 @@ class Database:
|
|||||||
poolclass=StaticPool,
|
poolclass=StaticPool,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._engine = create_engine(
|
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect)
|
||||||
self._database_url, echo=echo, connect_args=merged_connect
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._database_url.startswith("sqlite"):
|
if self._database_url.startswith("sqlite"):
|
||||||
|
|
||||||
def _enable_sqlite_pragmas(
|
def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None:
|
||||||
dbapi_conn: DBAPIConnection, _connection_record: object
|
|
||||||
) -> None:
|
|
||||||
try:
|
try:
|
||||||
cur = dbapi_conn.cursor()
|
cur = dbapi_conn.cursor()
|
||||||
cur.execute("PRAGMA journal_mode=WAL;")
|
cur.execute("PRAGMA journal_mode=WAL;")
|
||||||
@@ -66,7 +57,7 @@ class Database:
|
|||||||
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
|
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
|
||||||
|
|
||||||
def init_db(self) -> None:
|
def init_db(self) -> None:
|
||||||
db_migration.run_migrations(self._engine)
|
pass
|
||||||
|
|
||||||
def get_session(self) -> Generator[Session, None, None]:
|
def get_session(self) -> Generator[Session, None, None]:
|
||||||
session = Session(self._engine)
|
session = Session(self._engine)
|
||||||
@@ -79,6 +70,18 @@ class Database:
|
|||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session_ctx_manager(self) -> Generator[Session, None, None]:
|
||||||
|
session = Session(self._engine)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
def dispose(self) -> None:
|
def dispose(self) -> None:
|
||||||
self._engine.dispose()
|
self._engine.dispose()
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,13 @@ def _mig_0_1(engine: Engine) -> None:
|
|||||||
SQLModel.metadata.create_all(
|
SQLModel.metadata.create_all(
|
||||||
bind=engine,
|
bind=engine,
|
||||||
tables=[
|
tables=[
|
||||||
models_v1.Trades.__table__,
|
models_v1.Trades.__table__, # type: ignore[attr-defined]
|
||||||
models_v1.Cycles.__table__,
|
models_v1.Cycles.__table__, # type: ignore[attr-defined]
|
||||||
models_v1.Users.__table__,
|
models_v1.Users.__table__, # type: ignore[attr-defined]
|
||||||
models_v1.Sessions.__table__,
|
models_v1.Sessions.__table__, # type: ignore[attr-defined]
|
||||||
|
models_v1.Exchanges.__table__, # type: ignore[attr-defined]
|
||||||
|
models_v1.CycleLoanChangeEvents.__table__, # type: ignore[attr-defined]
|
||||||
|
models_v1.CycleDailyAccrual.__table__, # type: ignore[attr-defined]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,7 +63,7 @@ def run_migrations(engine: Engine, target_version: int | None = None) -> int:
|
|||||||
fn = MIGRATIONS.get(cur_version)
|
fn = MIGRATIONS.get(cur_version)
|
||||||
if fn is None:
|
if fn is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No migration from {cur_version} -> {cur_version + 1}"
|
f"No migration from {cur_version} -> {cur_version + 1}",
|
||||||
)
|
)
|
||||||
# call migration with Engine (fn should use transactions)
|
# call migration with Engine (fn should use transactions)
|
||||||
fn(engine)
|
fn(engine)
|
||||||
|
|||||||
136
backend/trading_journal/dto.py
Normal file
136
backend/trading_journal/dto.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime # noqa: TC003
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(SQLModel):
|
||||||
|
username: str
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserLogin(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(UserBase):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SessionsBase(SQLModel):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRead(SessionsBase):
|
||||||
|
id: int
|
||||||
|
expires_at: datetime
|
||||||
|
last_seen_at: datetime | None
|
||||||
|
last_used_ip: str | None
|
||||||
|
user_agent: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class SessionsCreate(SessionsBase):
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SessionsUpdate(SQLModel):
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
last_seen_at: datetime | None = None
|
||||||
|
last_used_ip: str | None = None
|
||||||
|
user_agent: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesBase(SQLModel):
|
||||||
|
name: str
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesCreate(ExchangesBase):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangesRead(ExchangesBase):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class CycleBase(SQLModel):
|
||||||
|
friendly_name: str | None = None
|
||||||
|
status: str
|
||||||
|
end_date: date | None = None
|
||||||
|
funding_source: str | None = None
|
||||||
|
capital_exposure_cents: int | None = None
|
||||||
|
loan_amount_cents: int | None = None
|
||||||
|
loan_interest_rate_tenth_bps: int | None = None
|
||||||
|
trades: list[TradeRead] | None = None
|
||||||
|
exchange: ExchangesRead | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CycleCreate(CycleBase):
|
||||||
|
user_id: int
|
||||||
|
symbol: str
|
||||||
|
exchange_id: int
|
||||||
|
underlying_currency: UnderlyingCurrency
|
||||||
|
start_date: date
|
||||||
|
|
||||||
|
|
||||||
|
class CycleUpdate(CycleBase):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class CycleRead(CycleCreate):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class TradeBase(SQLModel):
|
||||||
|
friendly_name: str | None = None
|
||||||
|
symbol: str
|
||||||
|
exchange_id: int
|
||||||
|
underlying_currency: UnderlyingCurrency
|
||||||
|
trade_type: TradeType
|
||||||
|
trade_strategy: TradeStrategy
|
||||||
|
trade_date: date
|
||||||
|
quantity: int
|
||||||
|
price_cents: int
|
||||||
|
commission_cents: int
|
||||||
|
notes: str | None = None
|
||||||
|
cycle_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TradeCreate(TradeBase):
|
||||||
|
user_id: int | None = None
|
||||||
|
trade_time_utc: datetime | None = None
|
||||||
|
gross_cash_flow_cents: int | None = None
|
||||||
|
net_cash_flow_cents: int | None = None
|
||||||
|
quantity_multiplier: int = 1
|
||||||
|
expiry_date: date | None = None
|
||||||
|
strike_price_cents: int | None = None
|
||||||
|
is_invalidated: bool = False
|
||||||
|
invalidated_at: datetime | None = None
|
||||||
|
replaced_by_trade_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TradeNoteUpdate(BaseModel):
|
||||||
|
id: int
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TradeFriendlyNameUpdate(BaseModel):
|
||||||
|
id: int
|
||||||
|
friendly_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class TradeRead(TradeCreate):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
SessionsCreate.model_rebuild()
|
||||||
|
CycleBase.model_rebuild()
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
from datetime import date, datetime # noqa: TC003
|
from datetime import date, datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
Column,
|
Column,
|
||||||
Date,
|
Date,
|
||||||
DateTime,
|
DateTime,
|
||||||
Field,
|
Field,
|
||||||
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
SQLModel,
|
||||||
@@ -16,8 +18,10 @@ from sqlmodel import (
|
|||||||
|
|
||||||
class TradeType(str, Enum):
|
class TradeType(str, Enum):
|
||||||
SELL_PUT = "SELL_PUT"
|
SELL_PUT = "SELL_PUT"
|
||||||
|
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
|
||||||
ASSIGNMENT = "ASSIGNMENT"
|
ASSIGNMENT = "ASSIGNMENT"
|
||||||
SELL_CALL = "SELL_CALL"
|
SELL_CALL = "SELL_CALL"
|
||||||
|
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
|
||||||
EXERCISE_CALL = "EXERCISE_CALL"
|
EXERCISE_CALL = "EXERCISE_CALL"
|
||||||
LONG_SPOT = "LONG_SPOT"
|
LONG_SPOT = "LONG_SPOT"
|
||||||
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
||||||
@@ -64,102 +68,132 @@ class FundingSource(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Trades(SQLModel, table=True):
|
class Trades(SQLModel, table=True):
|
||||||
__tablename__ = "trades"
|
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||||
UniqueConstraint(
|
|
||||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||||
friendly_name: str | None = Field(
|
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: UnderlyingCurrency = Field(
|
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||||
sa_column=Column(Text, nullable=False)
|
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||||
)
|
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
trade_time_utc: datetime = Field(
|
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
|
||||||
)
|
|
||||||
expiry_date: date | None = Field(default=None, nullable=True)
|
expiry_date: date | None = Field(default=None, nullable=True)
|
||||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
|
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
is_invalidated: bool = Field(default=False, nullable=False)
|
is_invalidated: bool = Field(default=False, nullable=False)
|
||||||
invalidated_at: datetime | None = Field(
|
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||||
)
|
|
||||||
replaced_by_trade_id: int | None = Field(
|
|
||||||
default=None, foreign_key="trades.id", nullable=True
|
|
||||||
)
|
|
||||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
cycle_id: int | None = Field(
|
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
|
||||||
)
|
|
||||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||||
|
|
||||||
|
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
|
||||||
|
back_populates="trade",
|
||||||
|
sa_relationship_kwargs={"uselist": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Cycles(SQLModel, table=True):
|
class Cycles(SQLModel, table=True):
|
||||||
__tablename__ = "cycles"
|
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||||
UniqueConstraint(
|
|
||||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
friendly_name: str | None = Field(
|
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: UnderlyingCurrency = Field(
|
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||||
sa_column=Column(Text, nullable=False)
|
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||||
)
|
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
|
||||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
|
||||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
|
|
||||||
trades: list["Trades"] = Relationship(back_populates="cycle")
|
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||||
|
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||||
|
|
||||||
|
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
|
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||||
|
|
||||||
|
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||||
|
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
|
||||||
|
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||||
|
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||||
|
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
|
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||||
|
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||||
|
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||||
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
|
|
||||||
|
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||||
|
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||||
|
|
||||||
|
|
||||||
|
class CycleDailyAccrual(SQLModel, table=True):
|
||||||
|
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||||
|
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||||
|
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
|
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
|
|
||||||
|
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||||
|
|
||||||
|
|
||||||
|
class Exchanges(SQLModel, table=True):
|
||||||
|
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||||
|
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
|
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
trades: list["Trades"] = Relationship(back_populates="exchange")
|
||||||
|
cycles: list["Cycles"] = Relationship(back_populates="exchange")
|
||||||
|
user: "Users" = Relationship(back_populates="exchanges")
|
||||||
|
|
||||||
|
|
||||||
class Users(SQLModel, table=True):
|
class Users(SQLModel, table=True):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users" # type: ignore[attr-defined]
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
# unique=True already creates an index; no need to also set index=True
|
# unique=True already creates an index; no need to also set index=True
|
||||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
is_active: bool = Field(default=True, nullable=False)
|
is_active: bool = Field(default=True, nullable=False)
|
||||||
|
sessions: list["Sessions"] = Relationship(back_populates="user")
|
||||||
|
exchanges: list["Exchanges"] = Relationship(back_populates="user")
|
||||||
|
|
||||||
|
|
||||||
class Sessions(SQLModel, table=True):
|
class Sessions(SQLModel, table=True):
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions" # type: ignore[attr-defined]
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||||
)
|
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||||
expires_at: datetime = Field(
|
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
|
||||||
)
|
|
||||||
last_seen_at: datetime | None = Field(
|
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
|
||||||
)
|
|
||||||
last_used_ip: str | None = Field(
|
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
user: "Users" = Relationship(back_populates="sessions")
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from datetime import date, datetime # noqa: TC003
|
from datetime import date, datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
Column,
|
Column,
|
||||||
Date,
|
Date,
|
||||||
DateTime,
|
DateTime,
|
||||||
Field,
|
Field,
|
||||||
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
SQLModel,
|
||||||
@@ -16,8 +18,10 @@ from sqlmodel import (
|
|||||||
|
|
||||||
class TradeType(str, Enum):
|
class TradeType(str, Enum):
|
||||||
SELL_PUT = "SELL_PUT"
|
SELL_PUT = "SELL_PUT"
|
||||||
|
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
|
||||||
ASSIGNMENT = "ASSIGNMENT"
|
ASSIGNMENT = "ASSIGNMENT"
|
||||||
SELL_CALL = "SELL_CALL"
|
SELL_CALL = "SELL_CALL"
|
||||||
|
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
|
||||||
EXERCISE_CALL = "EXERCISE_CALL"
|
EXERCISE_CALL = "EXERCISE_CALL"
|
||||||
LONG_SPOT = "LONG_SPOT"
|
LONG_SPOT = "LONG_SPOT"
|
||||||
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
|
||||||
@@ -64,102 +68,132 @@ class FundingSource(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Trades(SQLModel, table=True):
|
class Trades(SQLModel, table=True):
|
||||||
__tablename__ = "trades"
|
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||||
UniqueConstraint(
|
|
||||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||||
friendly_name: str | None = Field(
|
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: UnderlyingCurrency = Field(
|
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||||
sa_column=Column(Text, nullable=False)
|
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||||
)
|
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
trade_time_utc: datetime = Field(
|
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
|
||||||
)
|
|
||||||
expiry_date: date | None = Field(default=None, nullable=True)
|
expiry_date: date | None = Field(default=None, nullable=True)
|
||||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
|
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
is_invalidated: bool = Field(default=False, nullable=False)
|
is_invalidated: bool = Field(default=False, nullable=False)
|
||||||
invalidated_at: datetime | None = Field(
|
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||||
)
|
|
||||||
replaced_by_trade_id: int | None = Field(
|
|
||||||
default=None, foreign_key="trades.id", nullable=True
|
|
||||||
)
|
|
||||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
cycle_id: int | None = Field(
|
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
|
||||||
)
|
|
||||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||||
|
|
||||||
|
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
|
||||||
|
back_populates="trade",
|
||||||
|
sa_relationship_kwargs={"uselist": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Cycles(SQLModel, table=True):
|
class Cycles(SQLModel, table=True):
|
||||||
__tablename__ = "cycles"
|
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||||
UniqueConstraint(
|
|
||||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
friendly_name: str | None = Field(
|
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: UnderlyingCurrency = Field(
|
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||||
sa_column=Column(Text, nullable=False)
|
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||||
)
|
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
|
||||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
|
||||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
|
|
||||||
trades: list["Trades"] = Relationship(back_populates="cycle")
|
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||||
|
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||||
|
|
||||||
|
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
|
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||||
|
|
||||||
|
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||||
|
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
|
||||||
|
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||||
|
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||||
|
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
|
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||||
|
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||||
|
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||||
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
|
|
||||||
|
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||||
|
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||||
|
|
||||||
|
|
||||||
|
class CycleDailyAccrual(SQLModel, table=True):
|
||||||
|
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||||
|
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||||
|
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
|
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||||
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
|
|
||||||
|
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||||
|
|
||||||
|
|
||||||
|
class Exchanges(SQLModel, table=True):
|
||||||
|
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||||
|
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
|
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
|
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
trades: list["Trades"] = Relationship(back_populates="exchange")
|
||||||
|
cycles: list["Cycles"] = Relationship(back_populates="exchange")
|
||||||
|
user: "Users" = Relationship(back_populates="exchanges")
|
||||||
|
|
||||||
|
|
||||||
class Users(SQLModel, table=True):
|
class Users(SQLModel, table=True):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users" # type: ignore[attr-defined]
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
# unique=True already creates an index; no need to also set index=True
|
# unique=True already creates an index; no need to also set index=True
|
||||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
is_active: bool = Field(default=True, nullable=False)
|
is_active: bool = Field(default=True, nullable=False)
|
||||||
|
sessions: list["Sessions"] = Relationship(back_populates="user")
|
||||||
|
exchanges: list["Exchanges"] = Relationship(back_populates="user")
|
||||||
|
|
||||||
|
|
||||||
class Sessions(SQLModel, table=True):
|
class Sessions(SQLModel, table=True):
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions" # type: ignore[attr-defined]
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||||
)
|
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||||
expires_at: datetime = Field(
|
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
|
||||||
)
|
|
||||||
last_seen_at: datetime | None = Field(
|
|
||||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
|
||||||
)
|
|
||||||
last_used_ip: str | None = Field(
|
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
|
||||||
)
|
|
||||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||||
|
user: "Users" = Relationship(back_populates="sessions")
|
||||||
|
|||||||
51
backend/trading_journal/security.py
Normal file
51
backend/trading_journal/security.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.exceptions import VerifyMismatchError
|
||||||
|
|
||||||
|
import settings
|
||||||
|
|
||||||
|
ph = PasswordHasher()
|
||||||
|
|
||||||
|
# Utility functions for password hashing and verification
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(plain: str) -> str:
|
||||||
|
return ph.hash(plain)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
|
try:
|
||||||
|
return ph.verify(hashed, plain)
|
||||||
|
except VerifyMismatchError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Session token hash
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_token(nbytes: int = 32) -> str:
|
||||||
|
return secrets.token_urlsafe(nbytes)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_session_token_sha256(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def sign_token_hmac(token: str) -> str:
|
||||||
|
if not settings.settings.hmac_key:
|
||||||
|
return token
|
||||||
|
return hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token_sha256(token: str, expected_hash: str) -> bool:
|
||||||
|
return hmac.compare_digest(hash_session_token_sha256(token), expected_hash)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token_hmac(token: str, expected_hmac: str) -> bool:
|
||||||
|
if not settings.settings.hmac_key:
|
||||||
|
return verify_token_sha256(token, expected_hmac)
|
||||||
|
sig = hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||||
|
return hmac.compare_digest(sig, expected_hmac)
|
||||||
364
backend/trading_journal/service.py
Normal file
364
backend/trading_journal/service.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
from fastapi import Request, Response, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
|
|
||||||
|
import settings
|
||||||
|
from trading_journal import crud, security
|
||||||
|
from trading_journal.dto import (
|
||||||
|
CycleBase,
|
||||||
|
CycleCreate,
|
||||||
|
CycleRead,
|
||||||
|
CycleUpdate,
|
||||||
|
ExchangesBase,
|
||||||
|
ExchangesCreate,
|
||||||
|
ExchangesRead,
|
||||||
|
SessionsCreate,
|
||||||
|
SessionsUpdate,
|
||||||
|
TradeCreate,
|
||||||
|
TradeRead,
|
||||||
|
UserCreate,
|
||||||
|
UserLogin,
|
||||||
|
UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
from trading_journal.db import Database
|
||||||
|
from trading_journal.models import Sessions
|
||||||
|
|
||||||
|
|
||||||
|
EXCEPT_PATHS = [
|
||||||
|
f"{settings.settings.api_base}/status",
|
||||||
|
f"{settings.settings.api_base}/register",
|
||||||
|
f"{settings.settings.api_base}/login",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleWare(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911
|
||||||
|
if request.url.path in EXCEPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
token = request.cookies.get("session_token")
|
||||||
|
if not token:
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header and auth_header.startswith("Bearer "):
|
||||||
|
token = auth_header[len("Bearer ") :]
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "Unauthorized"},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
|
||||||
|
if db_factory is None:
|
||||||
|
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
|
||||||
|
try:
|
||||||
|
with db_factory.get_session_ctx_manager() as request_session:
|
||||||
|
hashed_token = security.hash_session_token_sha256(token)
|
||||||
|
request.state.db_session = request_session
|
||||||
|
login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token)
|
||||||
|
if not login_session:
|
||||||
|
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||||
|
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
if session_expires_utc < datetime.now(timezone.utc):
|
||||||
|
crud.delete_login_session(request_session, login_session.session_token_hash)
|
||||||
|
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||||
|
if login_session.user.is_active is False:
|
||||||
|
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||||
|
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
|
||||||
|
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
|
||||||
|
else:
|
||||||
|
updated_expiry = session_expires_utc
|
||||||
|
updated_session: SessionsUpdate = SessionsUpdate(
|
||||||
|
last_seen_at=datetime.now(timezone.utc),
|
||||||
|
last_used_ip=request.client.host if request.client else None,
|
||||||
|
user_agent=request.headers.get("User-Agent"),
|
||||||
|
expires_at=updated_expiry,
|
||||||
|
)
|
||||||
|
user_id = login_session.user_id
|
||||||
|
request.state.user_id = user_id
|
||||||
|
crud.update_login_session(request_session, hashed_token, update_session=updated_session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to authenticate user: \n")
|
||||||
|
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserAlreadyExistsError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangeAlreadyExistsError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangeNotFoundError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CycleNotFoundError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TradeNotFoundError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTradeDataError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCycleDataError(ServiceError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# User service
|
||||||
|
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
||||||
|
if crud.get_user_by_username(db_session, user_in.username):
|
||||||
|
raise UserAlreadyExistsError("username already exists")
|
||||||
|
hashed = security.hash_password(user_in.password)
|
||||||
|
user_data: dict = {
|
||||||
|
"username": user_in.username,
|
||||||
|
"password_hash": hashed,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
user = crud.create_user(db_session, user_data=user_data)
|
||||||
|
try:
|
||||||
|
# prefer pydantic's from_orm if DTO supports orm_mode
|
||||||
|
user = UserRead.model_validate(user)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to convert user to UserRead: ")
|
||||||
|
raise ServiceError("Failed to convert user to UserRead") from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create user:")
|
||||||
|
raise ServiceError("Failed to create user") from e
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None:
|
||||||
|
user = crud.get_user_by_username(db_session, user_in.username)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
user_id_val = cast("int", user.id)
|
||||||
|
|
||||||
|
if not security.verify_password(user_in.password, user.password_hash):
|
||||||
|
return None
|
||||||
|
|
||||||
|
token = security.generate_session_token()
|
||||||
|
token_hashed = security.hash_session_token_sha256(token)
|
||||||
|
try:
|
||||||
|
session = crud.create_login_session(
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id_val,
|
||||||
|
session_token_hash=token_hashed,
|
||||||
|
session_length_seconds=settings.settings.session_expiry_seconds,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create login session: \n")
|
||||||
|
raise ServiceError("Failed to create login session") from e
|
||||||
|
return SessionsCreate.model_validate(session), token
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges service
|
||||||
|
def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate:
|
||||||
|
existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||||
|
if existing_exchange:
|
||||||
|
raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user")
|
||||||
|
exchange_data = ExchangesCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
name=name,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
exchange = crud.create_exchange(db_session, exchange_data=exchange_data)
|
||||||
|
try:
|
||||||
|
exchange_dto = ExchangesCreate.model_validate(exchange)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to convert exchange to ExchangesCreate:")
|
||||||
|
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to create exchange:")
|
||||||
|
raise ServiceError("Failed to create exchange") from e
|
||||||
|
return exchange_dto
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]:
|
||||||
|
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
|
||||||
|
return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
|
||||||
|
|
||||||
|
|
||||||
|
def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase:
|
||||||
|
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
|
||||||
|
if not existing_exchange:
|
||||||
|
raise ExchangeNotFoundError("Exchange not found")
|
||||||
|
if existing_exchange.user_id != user_id:
|
||||||
|
raise ExchangeNotFoundError("Exchange not found")
|
||||||
|
|
||||||
|
if name:
|
||||||
|
other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
|
||||||
|
if other_exchange and other_exchange.id != existing_exchange.id:
|
||||||
|
raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user")
|
||||||
|
|
||||||
|
exchange_data = ExchangesBase(
|
||||||
|
name=name or existing_exchange.name,
|
||||||
|
notes=notes or existing_exchange.notes,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update exchange: \n")
|
||||||
|
raise ServiceError("Failed to update exchange") from e
|
||||||
|
return ExchangesBase.model_validate(exchange)
|
||||||
|
|
||||||
|
|
||||||
|
# Cycle Service
|
||||||
|
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead:
|
||||||
|
raise NotImplementedError("Cycle creation not implemented")
|
||||||
|
cycle_data_dict = cycle_data.model_dump()
|
||||||
|
cycle_data_dict["user_id"] = user_id
|
||||||
|
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
|
||||||
|
created_cycle = crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id)
|
||||||
|
return CycleRead.model_validate(created_cycle)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead:
|
||||||
|
cycle = crud.get_cycle_by_id(db_session, cycle_id)
|
||||||
|
if not cycle:
|
||||||
|
raise CycleNotFoundError("Cycle not found")
|
||||||
|
if cycle.user_id != user_id:
|
||||||
|
raise CycleNotFoundError("Cycle not found")
|
||||||
|
return CycleRead.model_validate(cycle)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleRead]:
|
||||||
|
cycles = crud.get_cycles_by_user_id(db_session, user_id)
|
||||||
|
return [CycleRead.model_validate(cycle) for cycle in cycles]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # noqa: PLR0911
|
||||||
|
if cycle_data.status == "CLOSED" and cycle_data.end_date is None:
|
||||||
|
return False, "end_date is required when status is CLOSED"
|
||||||
|
if cycle_data.status == "OPEN" and cycle_data.end_date is not None:
|
||||||
|
return False, "end_date must be empty when status is OPEN"
|
||||||
|
if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0:
|
||||||
|
return False, "capital_exposure_cents must be non-negative"
|
||||||
|
if (
|
||||||
|
cycle_data.funding_source is not None
|
||||||
|
and cycle_data.funding_source != "CASH"
|
||||||
|
and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None)
|
||||||
|
):
|
||||||
|
return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH"
|
||||||
|
if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0:
|
||||||
|
return False, "loan_amount_cents must be non-negative"
|
||||||
|
if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0:
|
||||||
|
return False, "loan_interest_rate_tenth_bps must be non-negative"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead:
|
||||||
|
is_valid, err_msg = _validate_cycle_update_data(cycle_data)
|
||||||
|
if not is_valid:
|
||||||
|
raise InvalidCycleDataError(err_msg)
|
||||||
|
cycle_id = cast("int", cycle_data.id)
|
||||||
|
existing_cycle = crud.get_cycle_by_id(db_session, cycle_id)
|
||||||
|
if not existing_cycle:
|
||||||
|
raise CycleNotFoundError("Cycle not found")
|
||||||
|
if existing_cycle.user_id != user_id:
|
||||||
|
raise CycleNotFoundError("Cycle not found")
|
||||||
|
|
||||||
|
provided_data_dict = cycle_data.model_dump(exclude_unset=True)
|
||||||
|
cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update cycle: \n")
|
||||||
|
raise ServiceError("Failed to update cycle") from e
|
||||||
|
return CycleRead.model_validate(updated_cycle)
|
||||||
|
|
||||||
|
|
||||||
|
# Trades service
|
||||||
|
def _append_cashflows(trade_data: TradeCreate) -> TradeCreate:
|
||||||
|
sign_multipler: int
|
||||||
|
if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"):
|
||||||
|
sign_multipler = 1
|
||||||
|
else:
|
||||||
|
sign_multipler = -1
|
||||||
|
quantity = trade_data.quantity * trade_data.quantity_multiplier
|
||||||
|
gross_cash_flow_cents = quantity * trade_data.price_cents * sign_multipler
|
||||||
|
net_cash_flow_cents = gross_cash_flow_cents - trade_data.commission_cents
|
||||||
|
trade_data.gross_cash_flow_cents = gross_cash_flow_cents
|
||||||
|
trade_data.net_cash_flow_cents = net_cash_flow_cents
|
||||||
|
return trade_data
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_trade_data(trade_data: TradeCreate) -> bool:
|
||||||
|
return not (
|
||||||
|
trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead:
|
||||||
|
if not _validate_trade_data(trade_data):
|
||||||
|
raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades")
|
||||||
|
trade_data_dict = trade_data.model_dump()
|
||||||
|
trade_data_dict["user_id"] = user_id
|
||||||
|
trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict)
|
||||||
|
trade_data_with_user_id = _append_cashflows(trade_data_with_user_id)
|
||||||
|
created_trade = crud.create_trade(db_session, trade_data=trade_data_with_user_id)
|
||||||
|
return TradeRead.model_validate(created_trade)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead:
|
||||||
|
trade = crud.get_trade_by_id(db_session, trade_id)
|
||||||
|
if not trade:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
if trade.user_id != user_id:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
return TradeRead.model_validate(trade)
|
||||||
|
|
||||||
|
|
||||||
|
def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead:
|
||||||
|
existing_trade = crud.get_trade_by_id(db_session, trade_id)
|
||||||
|
if not existing_trade:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
if existing_trade.user_id != user_id:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
try:
|
||||||
|
updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update trade friendly name: \n")
|
||||||
|
raise ServiceError("Failed to update trade friendly name") from e
|
||||||
|
return TradeRead.model_validate(updated_trade)
|
||||||
|
|
||||||
|
|
||||||
|
def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead:
|
||||||
|
existing_trade = crud.get_trade_by_id(db_session, trade_id)
|
||||||
|
if not existing_trade:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
if existing_trade.user_id != user_id:
|
||||||
|
raise TradeNotFoundError("Trade not found")
|
||||||
|
if note is None:
|
||||||
|
note = ""
|
||||||
|
try:
|
||||||
|
updated_trade = crud.update_trade_note(db_session, trade_id, note)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update trade notes: \n")
|
||||||
|
raise ServiceError("Failed to update trade notes") from e
|
||||||
|
return TradeRead.model_validate(updated_trade)
|
||||||
0
backend/utils/__init__.py
Normal file
0
backend/utils/__init__.py
Normal file
7
backend/utils/db_migration.py
Normal file
7
backend/utils/db_migration.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from sqlmodel import create_engine
|
||||||
|
|
||||||
|
import settings
|
||||||
|
from trading_journal import db_migration
|
||||||
|
|
||||||
|
db_engine = create_engine(settings.settings.database_url, echo=True)
|
||||||
|
db_migration.run_migrations(db_engine)
|
||||||
Reference in New Issue
Block a user