Compare commits

..

20 Commits

Author SHA1 Message Date
ef6dacd0bc service with accrual interest and loan update tested
All checks were successful
Backend CI / unit-test (push) Successful in 48s
2025-10-08 12:34:20 +02:00
0ca660f268 wip loan update
Some checks failed
Backend CI / unit-test (push) Failing after 44s
2025-10-03 11:55:30 +02:00
5e7d801075 Merge pull request 'feature/api_endpoint' (#5) from feature/api_endpoint into main
All checks were successful
Backend CI / unit-test (push) Successful in 43s
Reviewed-on: #5
2025-10-01 15:55:47 +02:00
94fb4705ff add tests for router and openapi, still need to add routes for update interest
All checks were successful
Backend CI / unit-test (push) Successful in 1m10s
Backend CI / unit-test (pull_request) Successful in 44s
2025-10-01 15:53:48 +02:00
bb87b90285 service layer add all tests for existing code
All checks were successful
Backend CI / unit-test (push) Successful in 40s
2025-09-29 16:48:28 +02:00
5eae75b23e wip service test
Some checks failed
Backend CI / unit-test (push) Failing after 37s
2025-09-26 22:37:26 +02:00
6a5f160d83 add interest accural test, improve migration tests
All checks were successful
Backend CI / unit-test (push) Successful in 37s
2025-09-25 22:16:24 +02:00
27b4adaca4 add interest change tables 2025-09-25 12:08:07 +02:00
e66aab99ea basic api is there
All checks were successful
Backend CI / unit-test (push) Successful in 35s
2025-09-24 21:02:21 +02:00
80fc405bf6 Almost finish basic functionalities
All checks were successful
Backend CI / unit-test (push) Successful in 36s
2025-09-24 17:33:27 +02:00
cf6c826468 use utils module 2025-09-24 10:44:32 +02:00
a6592bd140 wip 2025-09-23 23:35:15 +02:00
92c4e0d4fc refine type checking
All checks were successful
Backend CI / unit-test (push) Successful in 35s
2025-09-23 17:37:14 +02:00
544f5e8c92 Merge pull request 'add readme' (#4) from feature/readme into main
All checks were successful
Backend CI / unit-test (push) Successful in 34s
Reviewed-on: #4
2025-09-23 10:44:11 +02:00
b6ba108156 add readme
All checks were successful
Backend CI / unit-test (push) Successful in 33s
Backend CI / unit-test (pull_request) Successful in 33s
2025-09-23 10:43:32 +02:00
b68249f9f1 add create get exchange endpoint
All checks were successful
Backend CI / unit-test (push) Successful in 34s
2025-09-22 23:07:28 +02:00
1750401278 several changes:
All checks were successful
Backend CI / unit-test (push) Successful in 34s
* api calls for auth

* exchange now bind to user
2025-09-22 22:51:59 +02:00
466e6ce653 wip user reg
All checks were successful
Backend CI / unit-test (push) Successful in 34s
2025-09-22 17:35:10 +02:00
e70a63e4f9 add security py 2025-09-22 14:54:29 +02:00
76ed38e9af add crud for exchange 2025-09-22 14:39:33 +02:00
29 changed files with 4649 additions and 217 deletions

19
LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2025 Tianyu Liu, Studio TJ
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE.

28
README.md Normal file
View File

@@ -0,0 +1,28 @@
# Trading Journal (Work In Progress)
A simple trading journal application (work in progress).
This repository contains the backend of a trading journal designed to help you record and analyse trades. The system is specially designed to support journaling trades for the "wheel" options strategy, but it also supports other trade types such as long/short spot positions, forex, and more.
Important: the project is still under active development. There is a backend in this repo, but the frontend UI has not been implemented yet.
## Key features
- Journal trades with rich metadata (strategy, entry/exit, P/L, notes).
- Built-in support and data model conveniences for the Wheel strategy (puts/calls lifecycle tracking).
- Flexible support for other trade types: long/short spots, forex, futures, etc.
- Backend-first design with tests and migration helpers.
## Repository layout
- `backend/` — Python backend code (API, models, services, migrations, tests).
- `backend/trading_journal/` — core application modules: CRUD, models, DTOs, services, and security.
- `backend/tests/` — unit tests targeting the backend logic and DB layer.
## License
See the `LICENSE` file in the project root for license details.

2
backend/.gitignore vendored
View File

@@ -15,3 +15,5 @@ __pycache__/
*.db
*.db-shm
*.db-wal
devsettings.yaml

View File

@@ -13,10 +13,14 @@
"app:app",
"--host=0.0.0.0",
"--reload",
"--port=5000"
"--port=18881"
],
"jinja": true,
"autoStartBrowser": true
"autoStartBrowser": false,
"env": {
"CONFIG_FILE": "devsettings.yaml"
},
"console": "integratedTerminal"
}
]
}

View File

@@ -11,5 +11,6 @@
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "standard",
}

View File

@@ -1,17 +1,47 @@
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from __future__ import annotations
from fastapi import FastAPI, status
import asyncio
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, Response
import settings
from trading_journal import db
from trading_journal.dto import TradeCreate, TradeRead
from trading_journal import db, service
from trading_journal.dto import (
CycleBase,
CycleRead,
CycleUpdate,
ExchangesBase,
ExchangesRead,
SessionsBase,
SessionsCreate,
TradeCreate,
TradeFriendlyNameUpdate,
TradeNoteUpdate,
TradeRead,
UserCreate,
UserLogin,
UserRead,
)
API_BASE = "/api/v1"
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from trading_journal.db import Database
_db = db.create_database(settings.settings.database_url)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
@@ -22,9 +52,273 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
await asyncio.to_thread(_db.dispose)
origins = [
"http://127.0.0.1:18881",
]
app = FastAPI(lifespan=lifespan)
app.add_middleware(
service.AuthMiddleWare,
)
app.state.db_factory = _db
@app.get(f"{API_BASE}/status")
@app.get(f"{settings.settings.api_base}/status")
async def get_status() -> dict[str, str]:
return {"status": "ok"}
@app.post(f"{settings.settings.api_base}/register")
async def register_user(request: Request, user_in: UserCreate) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> UserRead:
with db_factory.get_session_ctx_manager() as db:
return service.register_user_service(db, user_in)
try:
user = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump())
except service.UserAlreadyExistsError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to register user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
@app.post(f"{settings.settings.api_base}/login")
async def login(request: Request, user_in: UserLogin) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> tuple[SessionsCreate, str] | None:
with db_factory.get_session_ctx_manager() as db:
return service.authenticate_user_service(db, user_in)
try:
result = await asyncio.to_thread(sync_work)
if result is None:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Invalid username or password, or user doesn't exist"},
)
session, token = result
session_return = SessionsBase(user_id=session.user_id)
response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump())
expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds())
response.set_cookie(
key="session_token",
value=token,
httponly=True,
secure=True,
samesite="lax",
max_age=expires_sec,
path="/",
)
except Exception as e:
logger.exception("Failed to login user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
else:
return response
# Exchange
@app.post(f"{settings.settings.api_base}/exchanges")
async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> ExchangesBase:
with db_factory.get_session_ctx_manager() as db:
return service.create_exchange_service(db, request.state.user_id, exchange_data.name, exchange_data.notes)
try:
exchange = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=exchange.model_dump())
except service.ExchangeAlreadyExistsError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to create exchange: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.get(f"{settings.settings.api_base}/exchanges")
async def get_exchanges(request: Request) -> list[ExchangesRead]:
db_factory: Database = request.app.state.db_factory
def sync_work() -> list[ExchangesRead]:
with db_factory.get_session_ctx_manager() as db:
return service.get_exchanges_by_user_service(db, request.state.user_id)
try:
return await asyncio.to_thread(sync_work)
except Exception as e:
logger.exception("Failed to get exchanges: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}")
async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> ExchangesBase:
with db_factory.get_session_ctx_manager() as db:
return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes)
try:
exchange = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump())
except service.ExchangeNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except service.ExchangeAlreadyExistsError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to update exchange: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
# Cycle
@app.post(f"{settings.settings.api_base}/cycles")
async def create_cycle(request: Request, cycle_data: CycleBase) -> Response:
return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.")
db_factory: Database = request.app.state.db_factory
def sync_work() -> CycleBase:
with db_factory.get_session_ctx_manager() as db:
return service.create_cycle_service(db, request.state.user_id, cycle_data)
try:
cycle = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle))
except Exception as e:
logger.exception("Failed to create cycle: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.get(f"{settings.settings.api_base}/cycles/{{cycle_id}}")
async def get_cycle_by_id(request: Request, cycle_id: int) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> CycleBase:
with db_factory.get_session_ctx_manager() as db:
return service.get_cycle_by_id_service(db, request.state.user_id, cycle_id)
try:
cycle = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
except service.CycleNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to get cycle by id: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.get(f"{settings.settings.api_base}/cycles/user/{{user_id}}")
async def get_cycles_by_user(request: Request, user_id: int) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> list[CycleRead]:
with db_factory.get_session_ctx_manager() as db:
return service.get_cycles_by_user_service(db, user_id)
try:
cycles = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycles))
except Exception as e:
logger.exception("Failed to get cycles by user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.patch(f"{settings.settings.api_base}/cycles")
async def update_cycle(request: Request, cycle_data: CycleUpdate) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> CycleRead:
with db_factory.get_session_ctx_manager() as db:
return service.update_cycle_service(db, request.state.user_id, cycle_data)
try:
cycle = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle))
except service.InvalidCycleDataError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except service.CycleNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to update cycle: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.post(f"{settings.settings.api_base}/trades")
async def create_trade(request: Request, trade_data: TradeCreate) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> TradeRead:
with db_factory.get_session_ctx_manager() as db:
return service.create_trade_service(db, request.state.user_id, trade_data)
try:
trade = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(trade))
except service.InvalidTradeDataError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to create trade: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.get(f"{settings.settings.api_base}/trades/{{trade_id}}")
async def get_trade_by_id(request: Request, trade_id: int) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> TradeRead:
with db_factory.get_session_ctx_manager() as db:
return service.get_trade_by_id_service(db, request.state.user_id, trade_id)
try:
trade = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
except service.TradeNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to get trade by id: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.patch(f"{settings.settings.api_base}/trades/friendlyname")
async def update_trade_friendly_name(request: Request, friendly_name_update: TradeFriendlyNameUpdate) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> TradeRead:
with db_factory.get_session_ctx_manager() as db:
return service.update_trade_friendly_name_service(
db,
request.state.user_id,
friendly_name_update.id,
friendly_name_update.friendly_name,
)
try:
trade = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
except service.TradeNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to update trade friendly name: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.patch(f"{settings.settings.api_base}/trades/notes")
async def update_trade_note(request: Request, note_update: TradeNoteUpdate) -> Response:
db_factory: Database = request.app.state.db_factory
def sync_work() -> TradeRead:
with db_factory.get_session_ctx_manager() as db:
return service.update_trade_note_service(db, request.state.user_id, note_update.id, note_update.notes)
try:
trade = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade))
except service.TradeNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except Exception as e:
logger.exception("Failed to update trade note: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e

View File

@@ -17,7 +17,7 @@ anyio==4.10.0 \
argon2-cffi==25.1.0 \
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
# via passlib
# via -r requirements.in
argon2-cffi-bindings==25.1.0 \
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
@@ -230,10 +230,6 @@ packaging==25.0 \
--hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \
--hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f
# via pytest
passlib[argon2]==1.7.4 \
--hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \
--hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04
# via -r requirements.in
pluggy==1.6.0 \
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746

554
backend/openapi.yaml Normal file
View File

@@ -0,0 +1,554 @@
openapi: "3.0.3"
info:
title: Trading Journal API
version: "1.0.0"
description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py).
servers:
- url: "http://127.0.0.1:18881{basePath}"
variables:
basePath:
default: "/api/v1"
description: "API base path (matches settings.settings.api_base)"
components:
securitySchemes:
session_cookie:
type: apiKey
in: cookie
name: session_token
schemas:
UserCreate:
$ref: "#/components/schemas/UserCreate_impl"
UserCreate_impl:
type: object
required:
- username
- password
properties:
username:
type: string
is_active:
type: boolean
default: true
password:
type: string
UserLogin:
type: object
required:
- username
- password
properties:
username:
type: string
password:
type: string
UserRead:
type: object
required:
- id
- username
properties:
id:
type: integer
username:
type: string
is_active:
type: boolean
SessionsBase:
type: object
required:
- user_id
properties:
user_id:
type: integer
SessionsCreate:
allOf:
- $ref: "#/components/schemas/SessionsBase"
- type: object
required:
- expires_at
properties:
expires_at:
type: string
format: date-time
ExchangesBase:
type: object
required:
- name
properties:
name:
type: string
notes:
type: string
nullable: true
ExchangesRead:
allOf:
- $ref: "#/components/schemas/ExchangesBase"
- type: object
required:
- id
properties:
id:
type: integer
CycleBase:
type: object
properties:
friendly_name:
type: string
nullable: true
status:
type: string
end_date:
type: string
format: date
nullable: true
funding_source:
type: string
nullable: true
capital_exposure_cents:
type: integer
nullable: true
loan_amount_cents:
type: integer
nullable: true
loan_interest_rate_tenth_bps:
type: integer
nullable: true
trades:
type: array
items:
$ref: "#/components/schemas/TradeRead"
nullable: true
exchange:
$ref: "#/components/schemas/ExchangesRead"
nullable: true
CycleCreate:
allOf:
- $ref: "#/components/schemas/CycleBase"
- type: object
required:
- user_id
- symbol
- exchange_id
- underlying_currency
- start_date
properties:
user_id:
type: integer
symbol:
type: string
exchange_id:
type: integer
underlying_currency:
type: string
start_date:
type: string
format: date
CycleUpdate:
allOf:
- $ref: "#/components/schemas/CycleBase"
- type: object
required:
- id
properties:
id:
type: integer
CycleRead:
allOf:
- $ref: "#/components/schemas/CycleCreate"
- type: object
required:
- id
properties:
id:
type: integer
TradeBase:
type: object
required:
- symbol
- underlying_currency
- trade_type
- trade_strategy
- trade_date
- quantity
- price_cents
- commission_cents
properties:
friendly_name:
type: string
nullable: true
symbol:
type: string
exchange_id:
type: integer
underlying_currency:
type: string
trade_type:
type: string
trade_strategy:
type: string
trade_date:
type: string
format: date
quantity:
type: integer
price_cents:
type: integer
commission_cents:
type: integer
notes:
type: string
nullable: true
cycle_id:
type: integer
nullable: true
TradeCreate:
allOf:
- $ref: "#/components/schemas/TradeBase"
- type: object
properties:
user_id:
type: integer
nullable: true
trade_time_utc:
type: string
format: date-time
nullable: true
gross_cash_flow_cents:
type: integer
nullable: true
net_cash_flow_cents:
type: integer
nullable: true
quantity_multiplier:
type: integer
default: 1
expiry_date:
type: string
format: date
nullable: true
strike_price_cents:
type: integer
nullable: true
is_invalidated:
type: boolean
default: false
invalidated_at:
type: string
format: date-time
nullable: true
replaced_by_trade_id:
type: integer
nullable: true
TradeNoteUpdate:
type: object
required:
- id
properties:
id:
type: integer
notes:
type: string
nullable: true
TradeFriendlyNameUpdate:
type: object
required:
- id
- friendly_name
properties:
id:
type: integer
friendly_name:
type: string
TradeRead:
allOf:
- $ref: "#/components/schemas/TradeCreate"
- type: object
required:
- id
properties:
id:
type: integer
paths:
/status:
get:
summary: "Get API status"
security: [] # no auth required
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
properties:
status:
type: string
/register:
post:
summary: "Register user"
security: [] # no auth required
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UserCreate"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/UserRead"
"400":
description: Bad Request (user exists)
"500":
description: Internal Server Error
/login:
post:
summary: "Login"
security: [] # no auth required
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UserLogin"
responses:
"200":
description: OK (sets session cookie)
content:
application/json:
schema:
$ref: "#/components/schemas/SessionsBase"
headers:
Set-Cookie:
description: session cookie
schema:
type: string
"401":
description: Unauthorized
"500":
description: Internal Server Error
/exchanges:
post:
summary: "Create exchange"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesBase"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesRead"
"400":
description: Bad Request
"401":
description: Unauthorized
get:
summary: "List user exchanges"
security:
- session_cookie: []
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/ExchangesRead"
"401":
description: Unauthorized
/exchanges/{exchange_id}:
patch:
summary: "Update exchange"
security:
- session_cookie: []
parameters:
- name: exchange_id
in: path
required: true
schema:
type: integer
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesBase"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesRead"
"404":
description: Not found
"400":
description: Bad request
/cycles:
post:
summary: "Create cycle (currently returns 405 in code)"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CycleBase"
responses:
"405":
description: Method not allowed (app currently returns 405)
patch:
summary: "Update cycle"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CycleUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/CycleRead"
"400":
description: Invalid data
"404":
description: Not found
/cycles/{cycle_id}:
get:
summary: "Get cycle by id"
security:
- session_cookie: []
parameters:
- name: cycle_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/CycleRead"
"404":
description: Not found
/cycles/user/{user_id}:
get:
summary: "Get cycles by user id"
security:
- session_cookie: []
parameters:
- name: user_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/CycleRead"
/trades:
post:
summary: "Create trade"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeCreate"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"400":
description: Invalid trade data
"500":
description: Internal Server Error
/trades/{trade_id}:
get:
summary: "Get trade by id"
security:
- session_cookie: []
parameters:
- name: trade_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found
/trades/friendlyname:
patch:
summary: "Update trade friendly name"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeFriendlyNameUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found
/trades/notes:
patch:
summary: "Update trade notes"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeNoteUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found

View File

@@ -4,4 +4,4 @@ httpx
pyyaml
pydantic-settings
sqlmodel
passlib[argon2]
argon2-cffi

View File

@@ -17,7 +17,7 @@ anyio==4.10.0 \
argon2-cffi==25.1.0 \
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
# via passlib
# via -r requirements.in
argon2-cffi-bindings==25.1.0 \
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
@@ -222,10 +222,6 @@ idna==3.10 \
# via
# anyio
# httpx
passlib[argon2]==1.7.4 \
--hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \
--hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04
# via -r requirements.in
pycparser==2.23 \
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934

View File

@@ -24,3 +24,4 @@ ignore = [
[lint.extend-per-file-ignores]
"test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"]
"models*.py" = ["FA102"]
"dto.py" = ["TC001", "TC003"]

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
@@ -13,6 +15,9 @@ class Settings(BaseSettings):
workers: int = 1
log_level: str = "info"
database_url: str = "sqlite:///:memory:"
api_base: str = "/api/v1"
session_expiry_seconds: int = 3600 * 24 * 7 # 7 days
hmac_key: str | None = None
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")

View File

@@ -0,0 +1,56 @@
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250908-CA-PUT",
"symbol": "CA",
"exchange_id": 1,
"underlying_currency": "EUR",
"trade_type": "SELL_PUT",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-08",
"quantity": 1,
"quantity_multiplier": 100,
"price_cents": 17,
"expiry_date": "2025-09-09",
"strike_price_cents": 1220,
"commission_cents": 114
}'
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250920-CA-ASSIGN",
"symbol": "CA",
"exchange_id": 1,
"cycle_id": 1,
"underlying_currency": "EUR",
"trade_type": "ASSIGNMENT",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-20",
"quantity": 100,
"quantity_multiplier": 1,
"price_cents": 1220,
"commission_cents": 0
}'
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250923-CA-CALL",
"symbol": "CA",
"exchange_id": 1,
"cycle_id": 1,
"underlying_currency": "EUR",
"trade_type": "SELL_CALL",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-23",
"quantity": 1,
"quantity_multiplier": 100,
"price_cents": 31,
"expiry_date": "2025-10-10",
"strike_price_cents": 1200,
"commission_cents": 114
}'

View File

@@ -1,18 +1,405 @@
from collections.abc import Generator
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from app import API_BASE, app
import settings
import trading_journal.service as svc
@pytest.fixture
def client() -> Generator[TestClient, None, None]:
with TestClient(app) as client:
yield client
def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]:
class NoAuth:
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
self.app = app
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
state = scope.get("state")
if state is None:
scope["state"] = SimpleNamespace()
scope["state"]["user_id"] = 1
await self.app(scope, receive, send)
class DeclineAuth:
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
self.app = app
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
path = scope.get("path", "")
# allow public/exempt paths through
if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS:
await self.app(scope, receive, send)
return
# immediately respond 401 for protected paths
resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED)
await resp(scope, receive, send)
def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient:
defaults = {
"register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})),
"authenticate_user_service": MagicMock(
return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"),
),
"create_exchange_service": MagicMock(
return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}),
),
"get_exchanges_by_user_service": MagicMock(return_value=[]),
}
if decline_auth:
monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth)
else:
monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth)
merged = {**defaults, **mocks}
for name, mock in merged.items():
monkeypatch.setattr(svc, name, mock)
import sys
if "app" in sys.modules:
del sys.modules["app"]
from importlib import import_module
app = import_module("app").app # re-import app module
return TestClient(app)
return _factory
def test_get_status(client: TestClient) -> None:
response = client.get(f"{API_BASE}/status")
def test_get_status(client_factory: Callable[..., TestClient]) -> None:
client = client_factory()
with client as c:
response = c.get(f"{settings.settings.api_base}/status")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_register_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory() # use defaults
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == 201
def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists")))
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "username already exists"}
def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down")))
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert r.json() == {"detail": "Internal Server Error"}
def test_login_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory() # use defaults
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == 200
assert r.json() == {"user_id": 1}
assert r.cookies.get("session_token") == "token"
def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(authenticate_user_service=MagicMock(return_value=None))
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_401_UNAUTHORIZED
assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"}
def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down")))
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert r.json() == {"detail": "Internal Server Error"}
def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory()
with client as c:
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
assert r.status_code == 201
assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"}
def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists")))
with client as c:
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "exchange already exists"}
def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(decline_auth=True)
with client as c:
r = c.get(f"{settings.settings.api_base}/exchanges")
assert r.status_code == status.HTTP_401_UNAUTHORIZED
assert r.json() == {"detail": "Unauthorized"}
def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory()
with client as c:
r = c.get(f"{settings.settings.api_base}/exchanges")
assert r.status_code == 200
assert r.json() == []
def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_exchanges_service=MagicMock(
return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"})
assert r.status_code == 200
assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}
def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "exchange not found"}
def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_cycle_by_id_service=MagicMock(
return_value=SimpleNamespace(
friendly_name="Cycle 1",
status="active",
id=1,
),
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/1")
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"}
def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/999")
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "cycle not found"}
def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_cycles_by_user_service=MagicMock(
return_value=[
SimpleNamespace(
friendly_name="Cycle 1",
status="active",
id=1,
),
SimpleNamespace(
friendly_name="Cycle 2",
status="completed",
id=2,
),
],
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/user/1")
assert r.status_code == 200
assert r.json() == [
{"id": 1, "friendly_name": "Cycle 1", "status": "active"},
{"id": 2, "friendly_name": "Cycle 2", "status": "completed"},
]
def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_cycle_service=MagicMock(
return_value=SimpleNamespace(
friendly_name="Updated Cycle",
status="completed",
id=1,
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1})
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"}
def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "invalid cycle data"}
def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "cycle not found"}
def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
create_trade_service=MagicMock(
return_value=SimpleNamespace(),
),
)
with client as c:
r = c.post(
f"{settings.settings.api_base}/trades",
json={
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 15,
"commission_cents": 100,
"trade_date": "2025-10-01",
},
)
assert r.status_code == 201
def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")),
)
with client as c:
r = c.post(
f"{settings.settings.api_base}/trades",
json={
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 15,
"commission_cents": 100,
"trade_date": "2025-10-01",
},
)
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "invalid trade data"}
def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_trade_by_id_service=MagicMock(
return_value=SimpleNamespace(
id=1,
cycle_id=1,
exchange_id=1,
symbol="BTCUSD",
underlying_currency="USD",
trade_type="LONG_SPOT",
trade_strategy="FX",
quantity=1,
price_cents=1500,
commission_cents=100,
trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc),
),
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/trades/1")
assert r.status_code == 200
assert r.json() == {
"id": 1,
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 1500,
"commission_cents": 100,
"trade_date": "2025-10-01T00:00:00+00:00",
}
def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.get(f"{settings.settings.api_base}/trades/999")
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}
def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_trade_friendly_name_service=MagicMock(
return_value=SimpleNamespace(
id=1,
friendly_name="Updated Trade Name",
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"})
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"}
def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}
def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_trade_note_service=MagicMock(
return_value=SimpleNamespace(
id=1,
note="Updated trade note",
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"})
assert r.status_code == 200
assert r.json() == {"id": 1, "note": "Updated trade note"}
def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
import pytest
from sqlalchemy import create_engine
@@ -45,18 +45,20 @@ def make_user(session: Session, username: str = "testuser") -> int:
session.add(user)
session.commit()
session.refresh(user)
return user.id
return cast("int", user.id)
def make_exchange(session: Session, name: str = "NASDAQ") -> int:
exchange = models.Exchanges(name=name, notes="Test exchange")
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange")
session.add(exchange)
session.commit()
session.refresh(exchange)
return exchange.id
return cast("int", exchange.id)
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
def make_cycle(
session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle"
) -> int:
cycle = models.Cycles(
user_id=user_id,
friendly_name=friendly_name,
@@ -65,15 +67,18 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name:
underlying_currency=models.UnderlyingCurrency.USD,
status=models.CycleStatus.OPEN,
start_date=datetime.now(timezone.utc).date(),
)
) # type: ignore[arg-type]
session.add(cycle)
session.commit()
session.refresh(cycle)
return cycle.id
return cast("int", cycle.id)
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
cycle: models.Cycles = session.get(models.Cycles, cycle_id)
def make_trade(
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
) -> int:
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
assert cycle is not None
exchange_id = cycle.exchange_id
trade = models.Trades(
user_id=user_id,
@@ -96,7 +101,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
return cast("int", trade.id)
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
@@ -104,7 +109,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
return cast("int", trade.id)
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
@@ -128,7 +133,7 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
return login_session
def _ensure_utc_aware(dt: datetime) -> datetime | None:
def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
if dt is None:
return None
if dt.tzinfo is None:
@@ -136,9 +141,23 @@ def _ensure_utc_aware(dt: datetime) -> datetime | None:
return dt.astimezone(timezone.utc)
def _validate_timestamp(
actual: datetime, expected: datetime, tolerance: timedelta
) -> None:
actual_utc = _ensure_utc_aware(actual)
expected_utc = _ensure_utc_aware(expected)
assert actual_utc is not None
assert expected_utc is not None
delta = abs(actual_utc - expected_utc)
assert delta <= tolerance, (
f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
)
# Trades
def test_create_trade_success_with_cycle(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_data = {
@@ -171,6 +190,51 @@ def test_create_trade_success_with_cycle(session: Session) -> None:
assert actual_trade.trade_type == trade_data["trade_type"]
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
assert actual_trade.quantity == trade_data["quantity"]
assert actual_trade.quantity_multiplier == 1
assert actual_trade.price_cents == trade_data["price_cents"]
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
assert actual_trade.commission_cents == trade_data["commission_cents"]
assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"]
assert actual_trade.cycle_id == trade_data["cycle_id"]
def test_create_trade_with_custom_multipler(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_data = {
"user_id": user_id,
"friendly_name": "Test Trade with Multiplier",
"symbol": "AAPL",
"underlying_currency": models.UnderlyingCurrency.USD,
"trade_type": models.TradeType.LONG_SPOT,
"trade_strategy": models.TradeStrategy.SPOT,
"trade_time_utc": datetime.now(timezone.utc),
"quantity": 10,
"quantity_multiplier": 100,
"price_cents": 15000,
"gross_cash_flow_cents": -1500000,
"commission_cents": 50000,
"net_cash_flow_cents": -1550000,
"cycle_id": cycle_id,
}
trade = crud.create_trade(session, trade_data)
assert trade.id is not None
assert trade.user_id == user_id
assert trade.cycle_id == cycle_id
session.refresh(trade)
actual_trade = session.get(models.Trades, trade.id)
assert actual_trade is not None
assert actual_trade.friendly_name == trade_data["friendly_name"]
assert actual_trade.symbol == trade_data["symbol"]
assert actual_trade.underlying_currency == trade_data["underlying_currency"]
assert actual_trade.trade_type == trade_data["trade_type"]
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
assert actual_trade.quantity == trade_data["quantity"]
assert actual_trade.quantity_multiplier == trade_data["quantity_multiplier"]
assert actual_trade.price_cents == trade_data["price_cents"]
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
assert actual_trade.commission_cents == trade_data["commission_cents"]
@@ -180,7 +244,7 @@ def test_create_trade_success_with_cycle(session: Session) -> None:
def test_create_trade_with_auto_created_cycle(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
trade_data = {
"user_id": user_id,
@@ -193,6 +257,9 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
"trade_time_utc": datetime.now(timezone.utc),
"quantity": 5,
"price_cents": 15500,
"gross_cash_flow_cents": -77500,
"commission_cents": 300,
"net_cash_flow_cents": -77800,
}
trade = crud.create_trade(session, trade_data)
@@ -219,12 +286,12 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
assert auto_cycle.symbol == trade_data["symbol"]
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
assert auto_cycle.status == models.CycleStatus.OPEN
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
def test_create_trade_missing_required_fields(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
base_trade_data = {
"user_id": user_id,
@@ -291,7 +358,7 @@ def test_create_trade_missing_required_fields(session: Session) -> None:
def test_get_trade_by_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_data = {
"user_id": user_id,
@@ -330,7 +397,7 @@ def test_get_trade_by_id(session: Session) -> None:
def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
friendly_name = "Unique Trade Name"
trade_data = {
@@ -359,7 +426,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
def test_get_trades_by_user_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_data_1 = {
"user_id": user_id,
@@ -404,9 +471,29 @@ def test_get_trades_by_user_id(session: Session) -> None:
assert friendly_names == {"Trade One", "Trade Two"}
def test_update_trade_friendly_name(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_id = make_trade(session, user_id, cycle_id)
new_friendly_name = "Updated Trade Name"
updated_trade = crud.update_trade_friendly_name(
session, trade_id, new_friendly_name
)
assert updated_trade is not None
assert updated_trade.id == trade_id
assert updated_trade.friendly_name == new_friendly_name
session.refresh(updated_trade)
actual_trade = session.get(models.Trades, trade_id)
assert actual_trade is not None
assert actual_trade.friendly_name == new_friendly_name
def test_update_trade_note(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_id = make_trade(session, user_id, cycle_id)
@@ -424,7 +511,7 @@ def test_update_trade_note(session: Session) -> None:
def test_invalidate_trade(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
trade_id = make_trade(session, user_id, cycle_id)
@@ -441,7 +528,7 @@ def test_invalidate_trade(session: Session) -> None:
def test_replace_trade(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
old_trade_id = make_trade(session, user_id, cycle_id)
@@ -456,6 +543,9 @@ def test_replace_trade(session: Session) -> None:
"trade_time_utc": datetime.now(timezone.utc),
"quantity": 20,
"price_cents": 25000,
"gross_cash_flow_cents": -500000,
"commission_cents": 1000,
"net_cash_flow_cents": -501000,
}
new_trade = crud.replace_trade(session, old_trade_id, new_trade_data)
@@ -484,9 +574,10 @@ def test_replace_trade(session: Session) -> None:
assert actual_new_trade.replaced_by_trade_id == old_trade_id
# Cycles
def test_create_cycle(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
exchange_id = make_exchange(session, user_id)
cycle_data = {
"user_id": user_id,
"friendly_name": "My First Cycle",
@@ -515,10 +606,37 @@ def test_create_cycle(session: Session) -> None:
assert actual_cycle.start_date == cycle_data["start_date"]
def test_get_cycle_by_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Cycle to Get")
cycle = crud.get_cycle_by_id(session, cycle_id)
assert cycle is not None
assert cycle.id == cycle_id
assert cycle.friendly_name == "Cycle to Get"
assert cycle.user_id == user_id
def test_get_cycles_by_user_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_names = ["Cycle One", "Cycle Two", "Cycle Three"]
for name in cycle_names:
make_cycle(session, user_id, exchange_id, friendly_name=name)
cycles = crud.get_cycles_by_user_id(session, user_id)
assert len(cycles) == len(cycle_names)
fetched_names = {cycle.friendly_name for cycle in cycles}
for name in cycle_names:
assert name in fetched_names
def test_update_cycle(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(
session, user_id, exchange_id, friendly_name="Initial Cycle Name"
)
update_data = {
"friendly_name": "Updated Cycle Name",
@@ -539,15 +657,21 @@ def test_update_cycle(session: Session) -> None:
def test_update_cycle_immutable_fields(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session)
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(
session, user_id, exchange_id, friendly_name="Initial Cycle Name"
)
# Attempt to update immutable fields
update_data = {
"id": cycle_id + 1, # Trying to change the ID
"user_id": user_id + 1, # Trying to change the user_id
"start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
"start_date": datetime(
2020, 1, 1, tzinfo=timezone.utc
).date(), # Trying to change start_date
"created_at": datetime(
2020, 1, 1, tzinfo=timezone.utc
), # Trying to change created_at
"friendly_name": "Valid Update", # Valid field to update
}
@@ -561,6 +685,422 @@ def test_update_cycle_immutable_fields(session: Session) -> None:
)
# Cycle loans
def test_create_cycle_loan_event(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000, # 5%
"notes": "Test loan change for the cycle",
}
loan_event = crud.create_cycle_loan_event(session, loan_data)
now = datetime.now(timezone.utc)
assert loan_event.id is not None
assert loan_event.cycle_id == cycle_id
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert (
loan_event.loan_interest_rate_tenth_bps
== loan_data["loan_interest_rate_tenth_bps"]
)
assert loan_event.notes == loan_data["notes"]
assert loan_event.effective_date == now.date()
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
session.refresh(loan_event)
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
assert actual_loan_event is not None
assert actual_loan_event.cycle_id == cycle_id
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert (
actual_loan_event.loan_interest_rate_tenth_bps
== loan_data["loan_interest_rate_tenth_bps"]
)
assert actual_loan_event.notes == loan_data["notes"]
assert actual_loan_event.effective_date == now.date()
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
def test_create_cycle_loan_event_same_date_error(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data_1 = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(),
"notes": "First loan event",
}
loan_data_2 = {
"cycle_id": cycle_id,
"loan_amount_cents": 150000,
"loan_interest_rate_tenth_bps": 4500,
"effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(),
"notes": "Second loan event same date",
}
crud.create_cycle_loan_event(session, loan_data_1)
with pytest.raises(ValueError) as excinfo:
crud.create_cycle_loan_event(session, loan_data_2)
assert "create_cycle_loan_event integrity error" in str(excinfo.value)
def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data_1 = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"notes": "First loan event",
}
yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date()
loan_data_2 = {
"cycle_id": cycle_id,
"loan_amount_cents": 150000,
"loan_interest_rate_tenth_bps": 4500,
"effective_date": yesterday,
"notes": "Second loan event",
}
crud.create_cycle_loan_event(session, loan_data_1)
crud.create_cycle_loan_event(session, loan_data_2)
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
assert len(loan_events) == 2
notes = [event.notes for event in loan_events]
assert loan_events[0].notes == loan_data_2["notes"]
assert loan_events[0].effective_date == yesterday
assert notes == [
"Second loan event",
"First loan event",
] # Ordered by effective_date desc
def test_get_cycle_loan_event_by_cycle_id_and_effective_date(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
effective_date = datetime(2023, 1, 1, tzinfo=timezone.utc).date()
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"effective_date": effective_date,
"notes": "Loan event for specific date",
}
crud.create_cycle_loan_event(session, loan_data)
loan_event = crud.get_loan_event_by_cycle_id_and_effective_date(
session, cycle_id, effective_date
)
assert loan_event is not None
assert loan_event.cycle_id == cycle_id
assert loan_event.effective_date == effective_date
assert loan_event.notes == loan_data["notes"]
def test_update_cycle_loan_event(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"notes": "Initial loan event",
}
loan_event = crud.create_cycle_loan_event(session, loan_data)
assert loan_event is not None
update_data = {
"loan_amount_cents": 120000,
"loan_interest_rate_tenth_bps": 4500,
"notes": "Updated loan event",
}
event_id = loan_event.id or 0
updated_loan_event = crud.update_cycle_loan_event(session, event_id, update_data)
assert updated_loan_event is not None
assert updated_loan_event.id == loan_event.id
assert updated_loan_event.loan_amount_cents == update_data["loan_amount_cents"]
assert (
updated_loan_event.loan_interest_rate_tenth_bps
== update_data["loan_interest_rate_tenth_bps"]
)
assert updated_loan_event.notes == update_data["notes"]
session.refresh(updated_loan_event)
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
assert actual_loan_event is not None
assert actual_loan_event.loan_amount_cents == update_data["loan_amount_cents"]
assert (
actual_loan_event.loan_interest_rate_tenth_bps
== update_data["loan_interest_rate_tenth_bps"]
)
assert actual_loan_event.notes == update_data["notes"]
def test_create_cycle_loan_event_single_field(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 200000,
}
loan_event = crud.create_cycle_loan_event(session, loan_data)
now = datetime.now(timezone.utc)
assert loan_event.id is not None
assert loan_event.cycle_id == cycle_id
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert loan_event.loan_interest_rate_tenth_bps is None
assert loan_event.notes is None
assert loan_event.effective_date == now.date()
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
session.refresh(loan_event)
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
assert actual_loan_event is not None
assert actual_loan_event.cycle_id == cycle_id
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert actual_loan_event.loan_interest_rate_tenth_bps is None
assert actual_loan_event.notes is None
assert actual_loan_event.effective_date == now.date()
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
def test_create_cycle_daily_accrual(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
accrual_data = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
"notes": "Daily interest accrual",
}
accrual = crud.create_cycle_daily_accrual(
session,
cycle_id,
accrual_data["accrual_date"],
accrual_data["accrued_interest_cents"],
)
assert accrual.id is not None
assert accrual.cycle_id == cycle_id
assert accrual.accrual_date == accrual_data["accrual_date"]
assert accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
session.refresh(accrual)
actual_accrual = session.get(models.CycleDailyAccrual, accrual.id)
assert actual_accrual is not None
assert actual_accrual.cycle_id == cycle_id
assert actual_accrual.accrual_date == accrual_data["accrual_date"]
assert actual_accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
yesterday = today - timedelta(days=1)
accrual_data_1 = {
"cycle_id": cycle_id,
"accrual_date": yesterday,
"accrued_interest_cents": 100,
}
accrual_data_2 = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
}
crud.create_cycle_daily_accrual(
session,
cycle_id,
accrual_data_1["accrual_date"],
accrual_data_1["accrued_interest_cents"],
)
crud.create_cycle_daily_accrual(
session,
cycle_id,
accrual_data_2["accrual_date"],
accrual_data_2["accrued_interest_cents"],
)
accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id)
assert len(accruals) == 2
dates = [accrual.accrual_date for accrual in accruals]
assert dates == [yesterday, today] # Ordered by accrual_date asc
def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
yesterday = today - timedelta(days=1)
accrual_data_1 = {
"cycle_id": cycle_id,
"accrual_date": yesterday,
"accrued_interest_cents": 100,
}
accrual_data_2 = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
}
crud.create_cycle_daily_accrual(
session,
cycle_id,
accrual_data_1["accrual_date"],
accrual_data_1["accrued_interest_cents"],
)
crud.create_cycle_daily_accrual(
session,
cycle_id,
accrual_data_2["accrual_date"],
accrual_data_2["accrued_interest_cents"],
)
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(
session, cycle_id, today
)
assert accruals_today is not None
assert accruals_today.accrual_date == today
assert (
accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
)
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(
session, cycle_id, yesterday
)
assert accruals_yesterday is not None
assert accruals_yesterday.accrual_date == yesterday
assert (
accruals_yesterday.accrual_amount_cents
== accrual_data_1["accrued_interest_cents"]
)
# Exchanges
def test_create_exchange(session: Session) -> None:
user_id = make_user(session)
exchange_data = {
"name": "NYSE",
"notes": "New York Stock Exchange",
"user_id": user_id,
}
exchange = crud.create_exchange(session, exchange_data)
assert exchange.id is not None
assert exchange.name == exchange_data["name"]
assert exchange.notes == exchange_data["notes"]
assert exchange.user_id == user_id
session.refresh(exchange)
actual_exchange = session.get(models.Exchanges, exchange.id)
assert actual_exchange is not None
assert actual_exchange.name == exchange_data["name"]
assert actual_exchange.notes == exchange_data["notes"]
assert actual_exchange.user_id == user_id
def test_get_exchange_by_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="LSE")
exchange = crud.get_exchange_by_id(session, exchange_id)
assert exchange is not None
assert exchange.id == exchange_id
assert exchange.name == "LSE"
assert exchange.user_id == user_id
def test_get_exchange_by_name_and_user_id(session: Session) -> None:
exchange_name = "TSX"
user_id = make_user(session)
make_exchange(session, user_id=user_id, name=exchange_name)
exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id)
assert exchange is not None
assert exchange.name == exchange_name
assert exchange.user_id == user_id
def test_get_all_exchanges(session: Session) -> None:
exchange_names = ["NYSE", "NASDAQ", "LSE"]
user_id = make_user(session)
for name in exchange_names:
make_exchange(session, user_id=user_id, name=name)
exchanges = crud.get_all_exchanges(session)
assert len(exchanges) >= 3
fetched_names = {ex.name for ex in exchanges}
for name in exchange_names:
assert name in fetched_names
def test_get_all_exchanges_by_user_id(session: Session) -> None:
exchange_names = ["NYSE", "NASDAQ"]
user_id = make_user(session)
for name in exchange_names:
make_exchange(session, user_id=user_id, name=name)
exchanges = crud.get_all_exchanges_by_user_id(session, user_id)
assert len(exchanges) == len(exchange_names)
fetched_names = {ex.name for ex in exchanges}
for name in exchange_names:
assert name in fetched_names
def test_update_exchange(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange")
update_data = {
"name": "Updated Exchange",
"notes": "Updated notes for the exchange",
}
updated_exchange = crud.update_exchange(session, exchange_id, update_data)
assert updated_exchange is not None
assert updated_exchange.id == exchange_id
assert updated_exchange.name == update_data["name"]
assert updated_exchange.notes == update_data["notes"]
session.refresh(updated_exchange)
actual_exchange = session.get(models.Exchanges, exchange_id)
assert actual_exchange is not None
assert actual_exchange.name == update_data["name"]
assert actual_exchange.notes == update_data["notes"]
def test_delete_exchange(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange")
crud.delete_exchange(session, exchange_id)
deleted_exchange = session.get(models.Exchanges, exchange_id)
assert deleted_exchange is None
# Users
def test_create_user(session: Session) -> None:
user_data = {
"username": "newuser",
@@ -578,6 +1118,22 @@ def test_create_user(session: Session) -> None:
assert actual_user.password_hash == user_data["password_hash"]
def test_get_user_by_id(session: Session) -> None:
user_id = make_user(session, username="fetchuser")
user = crud.get_user_by_id(session, user_id)
assert user is not None
assert user.id == user_id
assert user.username == "fetchuser"
def test_get_user_by_username(session: Session) -> None:
username = "uniqueuser"
make_user(session, username=username)
user = crud.get_user_by_username(session, username)
assert user is not None
assert user.username == username
def test_update_user(session: Session) -> None:
user_id = make_user(session, username="updatableuser")
@@ -602,7 +1158,9 @@ def test_update_user_immutable_fields(session: Session) -> None:
update_data = {
"id": user_id + 1, # Trying to change the ID
"username": "newusername", # Trying to change the username
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
"created_at": datetime(
2020, 1, 1, tzinfo=timezone.utc
), # Trying to change created_at
"password_hash": "validupdate", # Valid field to update
}
@@ -636,7 +1194,21 @@ def test_create_login_session_with_invalid_user(session: Session) -> None:
def test_get_login_session_by_token_and_user_id(session: Session) -> None:
now = datetime.now(timezone.utc)
created_session = make_login_session(session, now)
fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
session, created_session.session_token_hash, created_session.user_id
)
assert fetched_session is not None
assert fetched_session.id == created_session.id
assert fetched_session.user_id == created_session.user_id
assert fetched_session.session_token_hash == created_session.session_token_hash
def test_get_login_session_by_token(session: Session) -> None:
now = datetime.now(timezone.utc)
created_session = make_login_session(session, now)
fetched_session = crud.get_login_session_by_token_hash(
session, created_session.session_token_hash
)
assert fetched_session is not None
assert fetched_session.id == created_session.id
assert fetched_session.user_id == created_session.user_id
@@ -651,9 +1223,13 @@ def test_update_login_session(session: Session) -> None:
"last_seen_at": now + timedelta(hours=1),
"last_used_ip": "192.168.1.1",
}
updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data)
updated_session = crud.update_login_session(
session, created_session.session_token_hash, update_data
)
assert updated_session is not None
assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
assert (
_ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
)
assert updated_session.last_used_ip == update_data["last_used_ip"]
@@ -662,5 +1238,7 @@ def test_delete_login_session(session: Session) -> None:
created_session = make_login_session(session, now)
crud.delete_login_session(session, created_session.session_token_hash)
deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
session, created_session.session_token_hash, created_session.user_id
)
assert deleted_session is None

View File

@@ -42,9 +42,28 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"funding_source": ("TEXT", 0, 0),
"capital_exposure_cents": ("INTEGER", 0, 0),
"loan_amount_cents": ("INTEGER", 0, 0),
"loan_interest_rate_bps": ("INTEGER", 0, 0),
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
"start_date": ("DATE", 1, 0),
"end_date": ("DATE", 0, 0),
"latest_interest_accrued_date": ("DATE", 0, 0),
"total_accrued_amount_cents": ("INTEGER", 1, 0),
},
"cycle_loan_change_events": {
"id": ("INTEGER", 1, 1),
"cycle_id": ("INTEGER", 1, 0),
"effective_date": ("DATE", 1, 0),
"loan_amount_cents": ("INTEGER", 0, 0),
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
"related_trade_id": ("INTEGER", 0, 0),
"notes": ("TEXT", 0, 0),
"created_at": ("DATETIME", 1, 0),
},
"cycle_daily_accrual": {
"id": ("INTEGER", 1, 1),
"cycle_id": ("INTEGER", 1, 0),
"accrual_date": ("DATE", 1, 0),
"accrual_amount_cents": ("INTEGER", 1, 0),
"created_at": ("DATETIME", 1, 0),
},
"trades": {
"id": ("INTEGER", 1, 1),
@@ -60,6 +79,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"expiry_date": ("DATE", 0, 0),
"strike_price_cents": ("INTEGER", 0, 0),
"quantity": ("INTEGER", 1, 0),
"quantity_multiplier": ("INTEGER", 1, 0),
"price_cents": ("INTEGER", 1, 0),
"gross_cash_flow_cents": ("INTEGER", 1, 0),
"commission_cents": ("INTEGER", 1, 0),
@@ -70,6 +90,12 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"notes": ("TEXT", 0, 0),
"cycle_id": ("INTEGER", 0, 0),
},
"exchanges": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
"name": ("TEXT", 1, 0),
"notes": ("TEXT", 0, 0),
},
"sessions": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
@@ -93,11 +119,20 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
{"table": "users", "from": "user_id", "to": "id"},
{"table": "exchanges", "from": "exchange_id", "to": "id"},
],
"cycle_loan_change_events": [
{"table": "cycles", "from": "cycle_id", "to": "id"},
{"table": "trades", "from": "related_trade_id", "to": "id"},
],
"cycle_daily_accrual": [
{"table": "cycles", "from": "cycle_id", "to": "id"},
],
"sessions": [
{"table": "users", "from": "user_id", "to": "id"},
],
"users": [],
"exchanges": [],
"exchanges": [
{"table": "users", "from": "user_id", "to": "id"},
],
}
with engine.connect() as conn:
@@ -137,6 +172,39 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
for efk in fks:
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"
# check trades.replaced_by_trade_id self-referential FK
fk_rows = conn.execute(text("PRAGMA foreign_key_list('trades')")).fetchall()
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
assert {"table": "trades", "from": "replaced_by_trade_id", "to": "id"} in actual_fk_list, (
"missing self FK trades.replaced_by_trade_id -> trades.id"
)
# helper to find unique index on a column
def has_unique_index(table: str, column: str) -> bool:
idx_rows = conn.execute(text(f"PRAGMA index_list('{table}')")).fetchall()
for idx in idx_rows:
idx_name = idx[1]
is_unique = bool(idx[2])
if not is_unique:
continue
info = conn.execute(text(f"PRAGMA index_info('{idx_name}')")).fetchall()
cols = [r[2] for r in info]
if column in cols:
return True
return False
assert has_unique_index("trades", "friendly_name"), (
"expected unique index on trades(friendly_name) per uq_trades_user_friendly_name"
)
assert has_unique_index("cycles", "friendly_name"), (
"expected unique index on cycles(friendly_name) per uq_cycles_user_friendly_name"
)
assert has_unique_index("exchanges", "name"), "expected unique index on exchanges(name) per uq_exchanges_user_name"
assert has_unique_index("sessions", "session_token_hash"), "expected unique index on sessions(session_token_hash)"
assert has_unique_index("cycle_loan_change_events", "related_trade_id"), (
"expected unique index on cycle_loan_change_events(related_trade_id)"
)
finally:
engine.dispose()
SQLModel.metadata.clear()

View File

@@ -1,4 +1,24 @@
from trading_journal import security
def test_hash_password() -> None:
def test_hash_and_verify_password() -> None:
plain = "password"
hashed = security.hash_password(plain)
assert hashed != plain
assert security.verify_password(plain, hashed)
def test_generate_session_token() -> None:
token1 = security.generate_session_token()
token2 = security.generate_session_token()
assert token1 != token2
assert len(token1) > 0
assert len(token2) > 0
def test_hash_and_verify_session_token_sha256() -> None:
token = security.generate_session_token()
token_hash = security.hash_session_token_sha256(token)
assert token_hash != token
assert security.verify_token_sha256(token, token_hash)
assert not security.verify_token_sha256(token + "x", token_hash)

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from datetime import date, datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, TypeVar, cast
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select
@@ -10,9 +11,16 @@ from trading_journal import models
if TYPE_CHECKING:
from collections.abc import Mapping
from enum import Enum
from sqlalchemy.sql.elements import ColumnElement
def _check_enum(enum_cls: any, value: any, field_name: str) -> any:
# Generic enum member type
T = TypeVar("T", bound="Enum")
def _check_enum(enum_cls: type[T], value: object, field_name: str) -> T:
if value is None:
raise ValueError(f"{field_name} is required")
# already an enum member
@@ -27,19 +35,41 @@ def _check_enum(enum_cls: any, value: any, field_name: str) -> any:
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
def _allowed_columns(model: type[models.SQLModel]) -> set[str]:
tbl = cast("models.SQLModel", model).__table__ # type: ignore[attr-defined]
return {c.name for c in tbl.columns}
AnyModel = Any
def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]:
if isinstance(data, BaseModel):
return data.model_dump(exclude_unset=True)
if hasattr(data, "dict"):
return data.dict(exclude_unset=True)
return dict(data)
# Trades
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
if hasattr(trade_data, "dict"):
data = trade_data.dict(exclude_unset=True)
else:
data = dict(trade_data)
allowed = {c.name for c in models.Trades.__table__.columns}
def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
data = _data_to_dict(trade_data)
allowed = _allowed_columns(models.Trades)
payload = {k: v for k, v in data.items() if k in allowed}
cycle_id = payload.get("cycle_id")
if "symbol" not in payload:
raise ValueError("symbol is required")
if "exchange_id" not in payload and cycle_id is None:
raise ValueError("exchange_id is required when no cycle is attached")
# If an exchange_id is provided (and no cycle is attached), ensure the exchange exists
# and belongs to the same user as the trade (if user_id is provided).
if cycle_id is None and "exchange_id" in payload:
ex = session.get(models.Exchanges, payload["exchange_id"])
if ex is None:
raise ValueError("exchange_id does not exist")
user_id = payload.get("user_id")
if user_id is not None and ex.user_id != user_id:
raise ValueError("exchange.user_id does not match trade.user_id")
if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
@@ -62,13 +92,10 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
raise ValueError("price_cents is required")
if "commission_cents" not in payload:
payload["commission_cents"] = 0
quantity: int = payload["quantity"]
price_cents: int = payload["price_cents"]
commission_cents: int = payload["commission_cents"]
if "gross_cash_flow_cents" not in payload:
payload["gross_cash_flow_cents"] = -quantity * price_cents
raise ValueError("gross_cash_flow_cents is required")
if "net_cash_flow_cents" not in payload:
payload["net_cash_flow_cents"] = payload["gross_cash_flow_cents"] - commission_cents
raise ValueError("net_cash_flow_cents is required")
# If no cycle_id provided, create Cycle instance but don't call create_cycle()
created_cycle = None
@@ -132,7 +159,22 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]
statement = select(models.Trades).where(
models.Trades.user_id == user_id,
)
return session.exec(statement).all()
return list(session.exec(statement).all())
def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades:
trade: models.Trades | None = session.get(models.Trades, trade_id)
if trade is None:
raise ValueError("trade_id does not exist")
trade.friendly_name = friendly_name
session.add(trade)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("update_trade_friendly_name integrity error") from e
session.refresh(trade)
return trade
def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades:
@@ -168,23 +210,17 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades:
return trade
def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping) -> models.Trades:
def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
invalidate_trade(session, old_trade_id)
if hasattr(new_trade_data, "dict"):
data = new_trade_data.dict(exclude_unset=True)
else:
data = dict(new_trade_data)
data = _data_to_dict(new_trade_data)
data["replaced_by_trade_id"] = old_trade_id
return create_trade(session, data)
# Cycles
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
if hasattr(cycle_data, "dict"):
data = cycle_data.dict(exclude_unset=True)
else:
data = dict(cycle_data)
allowed = {c.name for c in models.Cycles.__table__.columns}
def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
data = _data_to_dict(cycle_data)
allowed = _allowed_columns(models.Cycles)
payload = {k: v for k, v in data.items() if k in allowed}
if "user_id" not in payload:
raise ValueError("user_id is required")
@@ -192,6 +228,12 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
raise ValueError("symbol is required")
if "exchange_id" not in payload:
raise ValueError("exchange_id is required")
# ensure the exchange exists and belongs to the same user
ex = session.get(models.Exchanges, payload["exchange_id"])
if ex is None:
raise ValueError("exchange_id does not exist")
if ex.user_id != payload.get("user_id"):
raise ValueError("exchange.user_id does not match cycle.user_id")
if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
@@ -212,24 +254,40 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
return c
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"}
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date"}
def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> models.Cycles:
def get_cycle_by_id(session: Session, cycle_id: int) -> models.Cycles | None:
return session.get(models.Cycles, cycle_id)
def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles]:
statement = select(models.Cycles).where(
models.Cycles.user_id == user_id,
)
return list(session.exec(statement).all())
def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
if cycle is None:
raise ValueError("cycle_id does not exist")
if hasattr(update_data, "dict"):
data = update_data.dict(exclude_unset=True)
else:
data = dict(update_data)
data = _data_to_dict(update_data)
allowed = {c.name for c in models.Cycles.__table__.columns}
allowed = _allowed_columns(models.Cycles)
for k, v in data.items():
if k in IMMUTABLE_CYCLE_FIELDS:
raise ValueError(f"field {k!r} is immutable")
if k not in allowed:
continue
# If trying to change exchange_id, ensure the new exchange exists and belongs to
# the same user as the cycle.
if k == "exchange_id":
ex = session.get(models.Exchanges, v)
if ex is None:
raise ValueError("exchange_id does not exist")
if ex.user_id != cycle.user_id:
raise ValueError("exchange.user_id does not match cycle.user_id")
if k == "underlying_currency":
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901
if k == "status":
@@ -245,16 +303,210 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model
return cycle
# Cycle loan and interest
def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
data = _data_to_dict(loan_data)
allowed = _allowed_columns(models.CycleLoanChangeEvents)
payload = {k: v for k, v in data.items() if k in allowed}
if "cycle_id" not in payload:
raise ValueError("cycle_id is required")
cycle = session.get(models.Cycles, payload["cycle_id"])
if cycle is None:
raise ValueError("cycle_id does not exist")
payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date()
payload["created_at"] = datetime.now(timezone.utc)
cle = models.CycleLoanChangeEvents(**payload)
session.add(cle)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_cycle_loan_event integrity error") from e
session.refresh(cle)
return cle
def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]:
eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date)
id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id)
statement = (
select(models.CycleLoanChangeEvents)
.where(
models.CycleLoanChangeEvents.cycle_id == cycle_id,
)
.order_by(eff_col, id_col.asc())
)
return list(session.exec(statement).all())
def get_loan_event_by_cycle_id_and_effective_date(session: Session, cycle_id: int, effective_date: date) -> models.CycleLoanChangeEvents | None:
statement = select(models.CycleLoanChangeEvents).where(
models.CycleLoanChangeEvents.cycle_id == cycle_id,
models.CycleLoanChangeEvents.effective_date == effective_date,
)
return session.exec(statement).first()
def update_cycle_loan_event(session: Session, event_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
event: models.CycleLoanChangeEvents | None = session.get(models.CycleLoanChangeEvents, event_id)
if event is None:
raise ValueError("event_id does not exist")
data = _data_to_dict(update_data)
allowed = _allowed_columns(models.CycleLoanChangeEvents)
for k, v in data.items():
if k in {"id", "cycle_id", "effective_date", "created_at"}:
raise ValueError(f"field {k!r} is immutable")
if k not in allowed:
continue
setattr(event, k, v)
session.add(event)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("update_cycle_loan_event integrity error") from e
session.refresh(event)
return event
def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual:
cycle = session.get(models.Cycles, cycle_id)
if cycle is None:
raise ValueError("cycle_id does not exist")
existing = session.exec(
select(models.CycleDailyAccrual).where(
models.CycleDailyAccrual.cycle_id == cycle_id,
models.CycleDailyAccrual.accrual_date == accrual_date,
),
).first()
if existing:
return existing
if accrual_amount_cents < 0:
raise ValueError("accrual_amount_cents must be non-negative")
row = models.CycleDailyAccrual(
cycle_id=cycle_id,
accrual_date=accrual_date,
accrual_amount_cents=accrual_amount_cents,
created_at=datetime.now(timezone.utc),
)
session.add(row)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_cycle_daily_accrual integrity error") from e
session.refresh(row)
return row
def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]:
date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date)
statement = (
select(models.CycleDailyAccrual)
.where(
models.CycleDailyAccrual.cycle_id == cycle_id,
)
.order_by(date_col.asc())
)
return list(session.exec(statement).all())
def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None:
statement = select(models.CycleDailyAccrual).where(
models.CycleDailyAccrual.cycle_id == cycle_id,
models.CycleDailyAccrual.accrual_date == accrual_date,
)
return session.exec(statement).first()
# Exchanges
IMMUTABLE_EXCHANGE_FIELDS = {"id"}
def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
data = _data_to_dict(exchange_data)
allowed = _allowed_columns(models.Exchanges)
payload = {k: v for k, v in data.items() if k in allowed}
if "name" not in payload:
raise ValueError("name is required")
e = models.Exchanges(**payload)
session.add(e)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_exchange integrity error") from e
session.refresh(e)
return e
def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None:
return session.get(models.Exchanges, exchange_id)
def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None:
statement = select(models.Exchanges).where(
models.Exchanges.name == name,
models.Exchanges.user_id == user_id,
)
return session.exec(statement).first()
def get_all_exchanges(session: Session) -> list[models.Exchanges]:
statement = select(models.Exchanges)
return list(session.exec(statement).all())
def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]:
statement = select(models.Exchanges).where(
models.Exchanges.user_id == user_id,
)
return list(session.exec(statement).all())
def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None:
raise ValueError("exchange_id does not exist")
data = _data_to_dict(update_data)
allowed = _allowed_columns(models.Exchanges)
for k, v in data.items():
if k in IMMUTABLE_EXCHANGE_FIELDS:
raise ValueError(f"field {k!r} is immutable")
if k in allowed:
setattr(exchange, k, v)
session.add(exchange)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("update_exchange integrity error") from e
session.refresh(exchange)
return exchange
def delete_exchange(session: Session, exchange_id: int) -> None:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None:
return
session.delete(exchange)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("delete_exchange integrity error") from e
# Users
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
def create_user(session: Session, user_data: Mapping) -> models.Users:
if hasattr(user_data, "dict"):
data = user_data.dict(exclude_unset=True)
else:
data = dict(user_data)
allowed = {c.name for c in models.Users.__table__.columns}
def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users:
data = _data_to_dict(user_data)
allowed = _allowed_columns(models.Users)
payload = {k: v for k, v in data.items() if k in allowed}
if "username" not in payload:
raise ValueError("username is required")
@@ -272,15 +524,23 @@ def create_user(session: Session, user_data: Mapping) -> models.Users:
return u
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users:
def get_user_by_id(session: Session, user_id: int) -> models.Users | None:
return session.get(models.Users, user_id)
def get_user_by_username(session: Session, username: str) -> models.Users | None:
statement = select(models.Users).where(
models.Users.username == username,
)
return session.exec(statement).first()
def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users:
user: models.Users | None = session.get(models.Users, user_id)
if user is None:
raise ValueError("user_id does not exist")
if hasattr(update_data, "dict"):
data = update_data.dict(exclude_unset=True)
else:
data = dict(update_data)
allowed = {c.name for c in models.Users.__table__.columns}
data = _data_to_dict(update_data)
allowed = _allowed_columns(models.Users)
for k, v in data.items():
if k in IMMUTABLE_USER_FIELDS:
raise ValueError(f"field {k!r} is immutable")
@@ -309,10 +569,11 @@ def create_login_session(
user: models.Users | None = session.get(models.Users, user_id)
if user is None:
raise ValueError("user_id does not exist")
user_id_val = cast("int", user.id)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=session_length_seconds)
s = models.Sessions(
user_id=user.id,
user_id=user_id_val,
session_token_hash=session_token_hash,
created_at=now,
expires_at=expires_at,
@@ -341,10 +602,23 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_
return session.exec(statement).first()
def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None:
statement = select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hash,
models.Sessions.expires_at > datetime.now(timezone.utc),
)
return session.exec(statement).first()
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping) -> models.Sessions | None:
def update_login_session(
session: Session,
session_token_hashed: str,
update_session: Mapping[str, Any] | BaseModel,
) -> models.Sessions | None:
login_session: models.Sessions | None = session.exec(
select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hashed,
@@ -353,11 +627,8 @@ def update_login_session(session: Session, session_token_hashed: str, update_ses
).first()
if login_session is None:
return None
if hasattr(update_session, "dict"):
data = update_session.dict(exclude_unset=True)
else:
data = dict(update_session)
allowed = {c.name for c in models.Sessions.__table__.columns}
data = _data_to_dict(update_session)
allowed = _allowed_columns(models.Sessions)
for k, v in data.items():
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
setattr(login_session, k, v)

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING
from sqlalchemy import event
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, create_engine
from trading_journal import db_migration
if TYPE_CHECKING:
from collections.abc import Generator
from sqlite3 import Connection as DBAPIConnection
@@ -58,7 +57,6 @@ class Database:
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
def init_db(self) -> None:
# db_migration.run_migrations(self._engine)
pass
def get_session(self) -> Generator[Session, None, None]:
@@ -72,6 +70,18 @@ class Database:
finally:
session.close()
@contextmanager
def get_session_ctx_manager(self) -> Generator[Session, None, None]:
session = Session(self._engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def dispose(self) -> None:
self._engine.dispose()

View File

@@ -23,10 +23,13 @@ def _mig_0_1(engine: Engine) -> None:
SQLModel.metadata.create_all(
bind=engine,
tables=[
models_v1.Trades.__table__,
models_v1.Cycles.__table__,
models_v1.Users.__table__,
models_v1.Sessions.__table__,
models_v1.Trades.__table__, # type: ignore[attr-defined]
models_v1.Cycles.__table__, # type: ignore[attr-defined]
models_v1.Users.__table__, # type: ignore[attr-defined]
models_v1.Sessions.__table__, # type: ignore[attr-defined]
models_v1.Exchanges.__table__, # type: ignore[attr-defined]
models_v1.CycleLoanChangeEvents.__table__, # type: ignore[attr-defined]
models_v1.CycleDailyAccrual.__table__, # type: ignore[attr-defined]
],
)

View File

@@ -1,46 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import date, datetime
from pydantic import BaseModel
from sqlmodel import SQLModel
if TYPE_CHECKING:
from datetime import date, datetime
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
class TradeBase(SQLModel):
user_id: int
friendly_name: str | None
symbol: str
exchange: str
underlying_currency: UnderlyingCurrency
trade_type: TradeType
trade_strategy: TradeStrategy
trade_date: date
trade_time_utc: datetime
quantity: int
price_cents: int
gross_cash_flow_cents: int
commission_cents: int
net_cash_flow_cents: int
notes: str | None
cycle_id: int | None = None
class TradeCreate(TradeBase):
expiry_date: date | None = None
strike_price_cents: int | None = None
is_invalidated: bool = False
invalidated_at: datetime | None = None
replaced_by_trade_id: int | None = None
class TradeRead(TradeBase):
id: int
is_invalidated: bool
invalidated_at: datetime | None
from trading_journal.models import (
TradeStrategy,
TradeType,
UnderlyingCurrency,
)
class UserBase(SQLModel):
@@ -52,5 +21,146 @@ class UserCreate(UserBase):
password: str
class UserLogin(BaseModel):
username: str
password: str
class UserRead(UserBase):
id: int
class SessionsBase(SQLModel):
user_id: int
class SessionRead(SessionsBase):
id: int
expires_at: datetime
last_seen_at: datetime | None
last_used_ip: str | None
user_agent: str | None
class SessionsCreate(SessionsBase):
expires_at: datetime
class SessionsUpdate(SQLModel):
expires_at: datetime | None = None
last_seen_at: datetime | None = None
last_used_ip: str | None = None
user_agent: str | None = None
class ExchangesBase(SQLModel):
name: str
notes: str | None = None
class ExchangesCreate(ExchangesBase):
user_id: int
class ExchangesRead(ExchangesBase):
id: int
class CycleBase(SQLModel):
friendly_name: str | None = None
status: str
end_date: date | None = None
funding_source: str | None = None
capital_exposure_cents: int | None = None
loan_amount_cents: int | None = None
loan_interest_rate_tenth_bps: int | None = None
trades: list[TradeRead] | None = None
exchange: ExchangesRead | None = None
class CycleCreate(CycleBase):
user_id: int
symbol: str
exchange_id: int
underlying_currency: UnderlyingCurrency
start_date: date
class CycleUpdate(CycleBase):
id: int
class CycleRead(CycleCreate):
id: int
class CycleLoanChangeEventBase(SQLModel):
cycle_id: int
effective_date: date
loan_amount_cents: int | None = None
loan_interest_rate_tenth_bps: int | None = None
related_trade_id: int | None = None
notes: str | None = None
created_at: datetime
class CycleLoanChangeEventCreate(CycleLoanChangeEventBase):
created_at: datetime
class CycleLoanChangeEventRead(CycleLoanChangeEventBase):
id: int
created_at: datetime
class CycleInterestAccrualBase(SQLModel):
cycle_id: int
accrual_date: date
accrual_amount_cents: int
created_at: datetime
class TradeBase(SQLModel):
friendly_name: str | None = None
symbol: str
exchange_id: int
underlying_currency: UnderlyingCurrency
trade_type: TradeType
trade_strategy: TradeStrategy
trade_date: date
quantity: int
price_cents: int
commission_cents: int
notes: str | None = None
cycle_id: int | None = None
class TradeCreate(TradeBase):
user_id: int | None = None
trade_time_utc: datetime | None = None
gross_cash_flow_cents: int | None = None
net_cash_flow_cents: int | None = None
quantity_multiplier: int = 1
expiry_date: date | None = None
strike_price_cents: int | None = None
is_invalidated: bool = False
invalidated_at: datetime | None = None
replaced_by_trade_id: int | None = None
class TradeNoteUpdate(BaseModel):
id: int
notes: str | None = None
class TradeFriendlyNameUpdate(BaseModel):
id: int
friendly_name: str
class TradeRead(TradeCreate):
id: int
SessionsCreate.model_rebuild()
CycleBase.model_rebuild()

View File

@@ -1,11 +1,13 @@
from datetime import date, datetime
from enum import Enum
from typing import Optional
from sqlmodel import (
Column,
Date,
DateTime,
Field,
ForeignKey,
Integer,
Relationship,
SQLModel,
@@ -16,8 +18,10 @@ from sqlmodel import (
class TradeType(str, Enum):
SELL_PUT = "SELL_PUT"
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL"
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
@@ -64,83 +68,209 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True):
__tablename__ = "trades"
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
__tablename__ = "trades" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
),
)
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
exchange: "Exchanges" = Relationship(back_populates="trades")
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_date: date = Field(sa_column=Column(Date, nullable=False))
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expiry_date: date | None = Field(default=None, nullable=True)
strike_price_cents: int | None = Field(default=None, nullable=True)
quantity: int = Field(sa_column=Column(Integer, nullable=False))
quantity_multiplier: int = Field(
sa_column=Column(Integer, nullable=False), default=1
)
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
is_invalidated: bool = Field(default=False, nullable=False)
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
invalidated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
)
replaced_by_trade_id: int | None = Field(
default=None, foreign_key="trades.id", nullable=True
)
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True
)
cycle: "Cycles" = Relationship(back_populates="trades")
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
back_populates="trade",
sa_relationship_kwargs={"uselist": False},
)
class Cycles(SQLModel, table=True):
__tablename__ = "cycles"
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
__tablename__ = "cycles" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
),
)
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
exchange: "Exchanges" = Relationship(back_populates="cycles")
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list["Trades"] = Relationship(back_populates="cycle")
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
latest_interest_accrued_date: date | None = Field(
default=None, sa_column=Column(Date, nullable=True)
)
total_accrued_amount_cents: int = Field(
default=0, sa_column=Column(Integer, nullable=False)
)
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(
back_populates="cycle"
)
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
class CycleLoanChangeEvents(SQLModel, table=True):
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date"
),
)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(
sa_column=Column(
Integer,
ForeignKey("cycles.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
)
effective_date: date = Field(sa_column=Column(Date, nullable=False))
loan_amount_cents: int | None = Field(
default=None, sa_column=Column(Integer, nullable=True)
)
loan_interest_rate_tenth_bps: int | None = Field(
default=None, sa_column=Column(Integer, nullable=True)
)
related_trade_id: int | None = Field(
default=None,
sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True),
) # Not used for now.
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
class CycleDailyAccrual(SQLModel, table=True):
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"
),
)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(
sa_column=Column(
Integer,
ForeignKey("cycles.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
)
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges"
__tablename__ = "exchanges" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),
)
id: int | None = Field(default=None, primary_key=True)
name: str = Field(sa_column=Column(Text, nullable=False, unique=True))
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
name: str = Field(sa_column=Column(Text, nullable=False))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
trades: list["Trades"] = Relationship(back_populates="exchange")
cycles: list["Cycles"] = Relationship(back_populates="exchange")
user: "Users" = Relationship(back_populates="exchanges")
class Users(SQLModel, table=True):
__tablename__ = "users"
__tablename__ = "users" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)
sessions: list["Sessions"] = Relationship(back_populates="user")
exchanges: list["Exchanges"] = Relationship(back_populates="user")
class Sessions(SQLModel, table=True):
__tablename__ = "sessions"
__tablename__ = "sessions" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expires_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
)
last_seen_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True)
)
last_used_ip: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user: "Users" = Relationship(back_populates="sessions")

View File

@@ -1,11 +1,13 @@
from datetime import date, datetime
from enum import Enum
from typing import Optional
from sqlmodel import (
Column,
Date,
DateTime,
Field,
ForeignKey,
Integer,
Relationship,
SQLModel,
@@ -16,8 +18,10 @@ from sqlmodel import (
class TradeType(str, Enum):
SELL_PUT = "SELL_PUT"
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL"
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
@@ -64,83 +68,209 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True):
__tablename__ = "trades"
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
__tablename__ = "trades" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
),
)
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
exchange: "Exchanges" = Relationship(back_populates="trades")
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_date: date = Field(sa_column=Column(Date, nullable=False))
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expiry_date: date | None = Field(default=None, nullable=True)
strike_price_cents: int | None = Field(default=None, nullable=True)
quantity: int = Field(sa_column=Column(Integer, nullable=False))
quantity_multiplier: int = Field(
sa_column=Column(Integer, nullable=False), default=1
)
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
is_invalidated: bool = Field(default=False, nullable=False)
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
invalidated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
)
replaced_by_trade_id: int | None = Field(
default=None, foreign_key="trades.id", nullable=True
)
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True
)
cycle: "Cycles" = Relationship(back_populates="trades")
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
back_populates="trade",
sa_relationship_kwargs={"uselist": False},
)
class Cycles(SQLModel, table=True):
__tablename__ = "cycles"
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
__tablename__ = "cycles" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
),
)
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
exchange: "Exchanges" = Relationship(back_populates="cycles")
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list["Trades"] = Relationship(back_populates="cycle")
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
latest_interest_accrued_date: date | None = Field(
default=None, sa_column=Column(Date, nullable=True)
)
total_accrued_amount_cents: int = Field(
default=0, sa_column=Column(Integer, nullable=False)
)
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(
back_populates="cycle"
)
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
class CycleLoanChangeEvents(SQLModel, table=True):
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date"
),
)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(
sa_column=Column(
Integer,
ForeignKey("cycles.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
)
effective_date: date = Field(sa_column=Column(Date, nullable=False))
loan_amount_cents: int | None = Field(
default=None, sa_column=Column(Integer, nullable=True)
)
loan_interest_rate_tenth_bps: int | None = Field(
default=None, sa_column=Column(Integer, nullable=True)
)
related_trade_id: int | None = Field(
default=None,
sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True),
) # Not used for now.
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
class CycleDailyAccrual(SQLModel, table=True):
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint(
"cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"
),
)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(
sa_column=Column(
Integer,
ForeignKey("cycles.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
)
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges"
__tablename__ = "exchanges" # type: ignore[attr-defined]
__table_args__ = (
UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),
)
id: int | None = Field(default=None, primary_key=True)
name: str = Field(sa_column=Column(Text, nullable=False, unique=True))
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
name: str = Field(sa_column=Column(Text, nullable=False))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
trades: list["Trades"] = Relationship(back_populates="exchange")
cycles: list["Cycles"] = Relationship(back_populates="exchange")
user: "Users" = Relationship(back_populates="exchanges")
class Users(SQLModel, table=True):
__tablename__ = "users"
__tablename__ = "users" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)
sessions: list["Sessions"] = Relationship(back_populates="user")
exchanges: list["Exchanges"] = Relationship(back_populates="user")
class Sessions(SQLModel, table=True):
__tablename__ = "sessions"
__tablename__ = "sessions" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expires_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
)
last_seen_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True)
)
last_used_ip: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user: "Users" = Relationship(back_populates="sessions")

View File

@@ -1,11 +1,51 @@
from passlib.context import CryptContext
import hashlib
import hmac
import secrets
pwd_ctx = CryptContext(schemes=["argon2"], deprecated="auto")
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
import settings
ph = PasswordHasher()
# Utility functions for password hashing and verification
def hash_password(plain: str) -> str:
return pwd_ctx.hash(plain)
return ph.hash(plain)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_ctx.verify(plain, hashed)
try:
return ph.verify(hashed, plain)
except VerifyMismatchError:
return False
# Session token hash
def generate_session_token(nbytes: int = 32) -> str:
return secrets.token_urlsafe(nbytes)
def hash_session_token_sha256(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def sign_token_hmac(token: str) -> str:
if not settings.settings.hmac_key:
return token
return hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
def verify_token_sha256(token: str, expected_hash: str) -> bool:
return hmac.compare_digest(hash_session_token_sha256(token), expected_hash)
def verify_token_hmac(token: str, expected_hmac: str) -> bool:
if not settings.settings.hmac_key:
return verify_token_sha256(token, expected_hmac)
sig = hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
return hmac.compare_digest(sig, expected_hmac)

View File

@@ -0,0 +1,432 @@
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, cast
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
import settings
from trading_journal import crud, security
from trading_journal.dto import (
CycleBase,
CycleCreate,
CycleLoanChangeEventBase,
CycleRead,
CycleUpdate,
ExchangesBase,
ExchangesCreate,
ExchangesRead,
SessionsCreate,
SessionsUpdate,
TradeCreate,
TradeRead,
UserCreate,
UserLogin,
UserRead,
)
from trading_journal.service_error import (
CycleLoanEventExistsError,
CycleNotFoundError,
ExchangeAlreadyExistsError,
ExchangeNotFoundError,
InvalidCycleDataError,
InvalidTradeDataError,
ServiceError,
TradeNotFoundError,
UserAlreadyExistsError,
)
if TYPE_CHECKING:
from sqlmodel import Session
from trading_journal.db import Database
from trading_journal.models import Sessions
EXCEPT_PATHS = [
f"{settings.settings.api_base}/status",
f"{settings.settings.api_base}/register",
f"{settings.settings.api_base}/login",
]
logger = logging.getLogger(__name__)
class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911
if request.url.path in EXCEPT_PATHS:
return await call_next(request)
token = request.cookies.get("session_token")
if not token:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :]
if not token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
if db_factory is None:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "db factory not configured"},
)
try:
with db_factory.get_session_ctx_manager() as request_session:
hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token)
if not login_session:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
if session_expires_utc < datetime.now(timezone.utc):
crud.delete_login_session(request_session, login_session.session_token_hash)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
if login_session.user.is_active is False:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
else:
updated_expiry = session_expires_utc
updated_session: SessionsUpdate = SessionsUpdate(
last_seen_at=datetime.now(timezone.utc),
last_used_ip=request.client.host if request.client else None,
user_agent=request.headers.get("User-Agent"),
expires_at=updated_expiry,
)
user_id = login_session.user_id
request.state.user_id = user_id
crud.update_login_session(request_session, hashed_token, update_session=updated_session)
except Exception:
logger.exception("Failed to authenticate user: \n")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error"},
)
return await call_next(request)
# User service
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists")
hashed = security.hash_password(user_in.password)
user_data: dict = {
"username": user_in.username,
"password_hash": hashed,
}
try:
user = crud.create_user(db_session, user_data=user_data)
try:
# prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user)
except Exception as e:
logger.exception("Failed to convert user to UserRead: ")
raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e:
logger.exception("Failed to create user:")
raise ServiceError("Failed to create user") from e
return user
def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None:
user = crud.get_user_by_username(db_session, user_in.username)
if not user:
return None
user_id_val = cast("int", user.id)
if not security.verify_password(user_in.password, user.password_hash):
return None
token = security.generate_session_token()
token_hashed = security.hash_session_token_sha256(token)
try:
session = crud.create_login_session(
session=db_session,
user_id=user_id_val,
session_token_hash=token_hashed,
session_length_seconds=settings.settings.session_expiry_seconds,
)
except Exception as e:
logger.exception("Failed to create login session: \n")
raise ServiceError("Failed to create login session") from e
return SessionsCreate.model_validate(session), token
# Exchanges service
def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate:
existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
if existing_exchange:
raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user")
exchange_data = ExchangesCreate(
user_id=user_id,
name=name,
notes=notes,
)
try:
exchange = crud.create_exchange(db_session, exchange_data=exchange_data)
try:
exchange_dto = ExchangesCreate.model_validate(exchange)
except Exception as e:
logger.exception("Failed to convert exchange to ExchangesCreate:")
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
except Exception as e:
logger.exception("Failed to create exchange:")
raise ServiceError("Failed to create exchange") from e
return exchange_dto
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]:
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
def update_exchanges_service(
db_session: Session,
user_id: int,
exchange_id: int,
name: str | None,
notes: str | None,
) -> ExchangesBase:
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
if not existing_exchange:
raise ExchangeNotFoundError("Exchange not found")
if existing_exchange.user_id != user_id:
raise ExchangeNotFoundError("Exchange not found")
if name:
other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
if other_exchange and other_exchange.id != existing_exchange.id:
raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user")
exchange_data = ExchangesBase(
name=name or existing_exchange.name,
notes=notes or existing_exchange.notes,
)
try:
exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data)
except Exception as e:
logger.exception("Failed to update exchange: \n")
raise ServiceError("Failed to update exchange") from e
return ExchangesBase.model_validate(exchange)
# Cycle Service
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead:
raise NotImplementedError("Cycle creation not implemented")
cycle_data_dict = cycle_data.model_dump()
cycle_data_dict["user_id"] = user_id
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
created_cycle = crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id)
return CycleRead.model_validate(created_cycle)
def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead:
cycle = crud.get_cycle_by_id(db_session, cycle_id)
if not cycle:
raise CycleNotFoundError("Cycle not found")
if cycle.user_id != user_id:
raise CycleNotFoundError("Cycle not found")
return CycleRead.model_validate(cycle)
def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleRead]:
cycles = crud.get_cycles_by_user_id(db_session, user_id)
return [CycleRead.model_validate(cycle) for cycle in cycles]
def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # noqa: PLR0911
if cycle_data.status == "CLOSED" and cycle_data.end_date is None:
return False, "end_date is required when status is CLOSED"
if cycle_data.status == "OPEN" and cycle_data.end_date is not None:
return False, "end_date must be empty when status is OPEN"
if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0:
return False, "capital_exposure_cents must be non-negative"
if (
cycle_data.funding_source is not None
and cycle_data.funding_source != "CASH"
and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None)
):
return (
False,
"loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH",
)
if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0:
return False, "loan_amount_cents must be non-negative"
if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0:
return False, "loan_interest_rate_tenth_bps must be non-negative"
return True, ""
def _create_cycle_loan_event(
db_session: Session,
cycle_id: int,
loan_amount_cents: int | None,
loan_interest_rate_tenth_bps: int | None,
) -> None:
now = datetime.now(timezone.utc)
today = now.date()
existing_loan_event = crud.get_loan_event_by_cycle_id_and_effective_date(db_session, cycle_id, today)
if existing_loan_event:
raise CycleLoanEventExistsError("A loan event with the same effective_date already exists for this cycle.")
loan_event_data = CycleLoanChangeEventBase(
cycle_id=cycle_id,
effective_date=today,
loan_amount_cents=loan_amount_cents,
loan_interest_rate_tenth_bps=loan_interest_rate_tenth_bps,
created_at=now,
)
try:
crud.create_cycle_loan_event(db_session, loan_event_data)
except Exception as e:
logger.exception("Failed to create cycle loan event: \n")
raise ServiceError("Failed to create cycle loan event") from e
def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead:
is_valid, err_msg = _validate_cycle_update_data(cycle_data)
if not is_valid:
raise InvalidCycleDataError(err_msg)
cycle_id = cast("int", cycle_data.id)
existing_cycle = crud.get_cycle_by_id(db_session, cycle_id)
if not existing_cycle:
raise CycleNotFoundError("Cycle not found")
if existing_cycle.user_id != user_id:
raise CycleNotFoundError("Cycle not found")
if cycle_data.loan_amount_cents is not None or cycle_data.loan_interest_rate_tenth_bps is not None:
_create_cycle_loan_event(
db_session,
cycle_id,
cycle_data.loan_amount_cents,
cycle_data.loan_interest_rate_tenth_bps,
)
provided_data_dict = cycle_data.model_dump(exclude_unset=True)
cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict)
try:
updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id)
except Exception as e:
logger.exception("Failed to update cycle: \n")
raise ServiceError("Failed to update cycle") from e
return CycleRead.model_validate(updated_cycle)
def accrual_interest_service(db_session: Session, cycle_id: int) -> None:
cycle = crud.get_cycle_by_id(db_session, cycle_id)
if not cycle:
logger.exception("Cycle not found for interest accrual")
raise CycleNotFoundError("Cycle not found")
if cycle.loan_amount_cents is None or cycle.loan_interest_rate_tenth_bps is None:
logger.info("Cycle has no loan, skipping interest accrual")
return
today = datetime.now(timezone.utc).date()
amount_cents = round(cycle.loan_amount_cents * cycle.loan_interest_rate_tenth_bps / 100000 / 365)
try:
crud.create_cycle_daily_accrual(
db_session,
cycle_id=cycle_id,
accrual_date=today,
accrual_amount_cents=amount_cents,
)
except Exception as e:
logger.exception("Failed to create cycle interest accrual: \n")
raise ServiceError("Failed to create cycle interest accrual") from e
def flush_interest_accruals_service(db_session: Session) -> None:
pass
# Trades service
def _append_cashflows(trade_data: TradeCreate) -> TradeCreate:
sign_multipler: int
if trade_data.trade_type in (
"SELL_PUT",
"SELL_CALL",
"EXERCISE_CALL",
"CLOSE_LONG_SPOT",
"SHORT_SPOT",
):
sign_multipler = 1
else:
sign_multipler = -1
quantity = trade_data.quantity * trade_data.quantity_multiplier
gross_cash_flow_cents = quantity * trade_data.price_cents * sign_multipler
net_cash_flow_cents = gross_cash_flow_cents - trade_data.commission_cents
trade_data.gross_cash_flow_cents = gross_cash_flow_cents
trade_data.net_cash_flow_cents = net_cash_flow_cents
return trade_data
def _validate_trade_data(trade_data: TradeCreate) -> bool:
return not (
trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None)
)
def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead:
if not _validate_trade_data(trade_data):
raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades")
trade_data_dict = trade_data.model_dump()
trade_data_dict["user_id"] = user_id
trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict)
trade_data_with_user_id = _append_cashflows(trade_data_with_user_id)
created_trade = crud.create_trade(db_session, trade_data=trade_data_with_user_id)
return TradeRead.model_validate(created_trade)
def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead:
trade = crud.get_trade_by_id(db_session, trade_id)
if not trade:
raise TradeNotFoundError("Trade not found")
if trade.user_id != user_id:
raise TradeNotFoundError("Trade not found")
return TradeRead.model_validate(trade)
def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead:
existing_trade = crud.get_trade_by_id(db_session, trade_id)
if not existing_trade:
raise TradeNotFoundError("Trade not found")
if existing_trade.user_id != user_id:
raise TradeNotFoundError("Trade not found")
try:
updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name)
except Exception as e:
logger.exception("Failed to update trade friendly name: \n")
raise ServiceError("Failed to update trade friendly name") from e
return TradeRead.model_validate(updated_trade)
def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead:
existing_trade = crud.get_trade_by_id(db_session, trade_id)
if not existing_trade:
raise TradeNotFoundError("Trade not found")
if existing_trade.user_id != user_id:
raise TradeNotFoundError("Trade not found")
if note is None:
note = ""
try:
updated_trade = crud.update_trade_note(db_session, trade_id, note)
except Exception as e:
logger.exception("Failed to update trade notes: \n")
raise ServiceError("Failed to update trade notes") from e
return TradeRead.model_validate(updated_trade)

View File

@@ -0,0 +1,34 @@
class ServiceError(Exception):
pass
class UserAlreadyExistsError(ServiceError):
pass
class ExchangeAlreadyExistsError(ServiceError):
pass
class ExchangeNotFoundError(ServiceError):
pass
class CycleNotFoundError(ServiceError):
pass
class TradeNotFoundError(ServiceError):
pass
class InvalidTradeDataError(ServiceError):
pass
class InvalidCycleDataError(ServiceError):
pass
class CycleLoanEventExistsError(ServiceError):
pass

View File

View File

@@ -0,0 +1,7 @@
from sqlmodel import create_engine
import settings
from trading_journal import db_migration
db_engine = create_engine(settings.settings.database_url, echo=True)
db_migration.run_migrations(db_engine)