Compare commits

...

9 Commits

Author SHA1 Message Date
5e7d801075 Merge pull request 'feature/api_endpoint' (#5) from feature/api_endpoint into main
All checks were successful
Backend CI / unit-test (push) Successful in 43s
Reviewed-on: #5
2025-10-01 15:55:47 +02:00
94fb4705ff add tests for router and openapi, still need to add routes for update interest
All checks were successful
Backend CI / unit-test (push) Successful in 1m10s
Backend CI / unit-test (pull_request) Successful in 44s
2025-10-01 15:53:48 +02:00
bb87b90285 service layer add all tests for existing code
All checks were successful
Backend CI / unit-test (push) Successful in 40s
2025-09-29 16:48:28 +02:00
5eae75b23e wip service test
Some checks failed
Backend CI / unit-test (push) Failing after 37s
2025-09-26 22:37:26 +02:00
6a5f160d83 add interest accural test, improve migration tests
All checks were successful
Backend CI / unit-test (push) Successful in 37s
2025-09-25 22:16:24 +02:00
27b4adaca4 add interest change tables 2025-09-25 12:08:07 +02:00
e66aab99ea basic api is there
All checks were successful
Backend CI / unit-test (push) Successful in 35s
2025-09-24 21:02:21 +02:00
544f5e8c92 Merge pull request 'add readme' (#4) from feature/readme into main
All checks were successful
Backend CI / unit-test (push) Successful in 34s
Reviewed-on: #4
2025-09-23 10:44:11 +02:00
b6ba108156 add readme
All checks were successful
Backend CI / unit-test (push) Successful in 33s
Backend CI / unit-test (pull_request) Successful in 33s
2025-09-23 10:43:32 +02:00
15 changed files with 2645 additions and 19 deletions

19
LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2025 Tianyu Liu, Studio TJ
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE.

28
README.md Normal file
View File

@@ -0,0 +1,28 @@
# Trading Journal (Work In Progress)
A simple trading journal application (work in progress).
This repository contains the backend of a trading journal designed to help you record and analyse trades. The system is specially designed to support journaling trades for the "wheel" options strategy, but it also supports other trade types such as long/short spot positions, forex, and more.
Important: the project is still under active development. There is a backend in this repo, but the frontend UI has not been implemented yet.
## Key features
- Journal trades with rich metadata (strategy, entry/exit, P/L, notes).
- Built-in support and data model conveniences for the Wheel strategy (puts/calls lifecycle tracking).
- Flexible support for other trade types: long/short spots, forex, futures, etc.
- Backend-first design with tests and migration helpers.
## Repository layout
- `backend/` — Python backend code (API, models, services, migrations, tests).
- `backend/trading_journal/` — core application modules: CRUD, models, DTOs, services, and security.
- `backend/tests/` — unit tests targeting the backend logic and DB layer.
## License
See the `LICENSE` file in the project root for license details.

View File

@@ -52,8 +52,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
await asyncio.to_thread(_db.dispose) await asyncio.to_thread(_db.dispose)
origins = [
"http://127.0.0.1:18881",
]
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(service.AuthMiddleWare) app.add_middleware(
service.AuthMiddleWare,
)
app.state.db_factory = _db app.state.db_factory = _db
@@ -77,7 +83,7 @@ async def register_user(request: Request, user_in: UserCreate) -> Response:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e: except Exception as e:
logger.exception("Failed to register user: \n") logger.exception("Failed to register user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
@app.post(f"{settings.settings.api_base}/login") @app.post(f"{settings.settings.api_base}/login")
@@ -110,7 +116,7 @@ async def login(request: Request, user_in: UserLogin) -> Response:
) )
except Exception as e: except Exception as e:
logger.exception("Failed to login user: \n") logger.exception("Failed to login user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
else: else:
return response return response

554
backend/openapi.yaml Normal file
View File

@@ -0,0 +1,554 @@
openapi: "3.0.3"
info:
title: Trading Journal API
version: "1.0.0"
description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py).
servers:
- url: "http://127.0.0.1:18881{basePath}"
variables:
basePath:
default: "/api/v1"
description: "API base path (matches settings.settings.api_base)"
components:
securitySchemes:
session_cookie:
type: apiKey
in: cookie
name: session_token
schemas:
UserCreate:
$ref: "#/components/schemas/UserCreate_impl"
UserCreate_impl:
type: object
required:
- username
- password
properties:
username:
type: string
is_active:
type: boolean
default: true
password:
type: string
UserLogin:
type: object
required:
- username
- password
properties:
username:
type: string
password:
type: string
UserRead:
type: object
required:
- id
- username
properties:
id:
type: integer
username:
type: string
is_active:
type: boolean
SessionsBase:
type: object
required:
- user_id
properties:
user_id:
type: integer
SessionsCreate:
allOf:
- $ref: "#/components/schemas/SessionsBase"
- type: object
required:
- expires_at
properties:
expires_at:
type: string
format: date-time
ExchangesBase:
type: object
required:
- name
properties:
name:
type: string
notes:
type: string
nullable: true
ExchangesRead:
allOf:
- $ref: "#/components/schemas/ExchangesBase"
- type: object
required:
- id
properties:
id:
type: integer
CycleBase:
type: object
properties:
friendly_name:
type: string
nullable: true
status:
type: string
end_date:
type: string
format: date
nullable: true
funding_source:
type: string
nullable: true
capital_exposure_cents:
type: integer
nullable: true
loan_amount_cents:
type: integer
nullable: true
loan_interest_rate_tenth_bps:
type: integer
nullable: true
trades:
type: array
items:
$ref: "#/components/schemas/TradeRead"
nullable: true
exchange:
$ref: "#/components/schemas/ExchangesRead"
nullable: true
CycleCreate:
allOf:
- $ref: "#/components/schemas/CycleBase"
- type: object
required:
- user_id
- symbol
- exchange_id
- underlying_currency
- start_date
properties:
user_id:
type: integer
symbol:
type: string
exchange_id:
type: integer
underlying_currency:
type: string
start_date:
type: string
format: date
CycleUpdate:
allOf:
- $ref: "#/components/schemas/CycleBase"
- type: object
required:
- id
properties:
id:
type: integer
CycleRead:
allOf:
- $ref: "#/components/schemas/CycleCreate"
- type: object
required:
- id
properties:
id:
type: integer
TradeBase:
type: object
required:
- symbol
- underlying_currency
- trade_type
- trade_strategy
- trade_date
- quantity
- price_cents
- commission_cents
properties:
friendly_name:
type: string
nullable: true
symbol:
type: string
exchange_id:
type: integer
underlying_currency:
type: string
trade_type:
type: string
trade_strategy:
type: string
trade_date:
type: string
format: date
quantity:
type: integer
price_cents:
type: integer
commission_cents:
type: integer
notes:
type: string
nullable: true
cycle_id:
type: integer
nullable: true
TradeCreate:
allOf:
- $ref: "#/components/schemas/TradeBase"
- type: object
properties:
user_id:
type: integer
nullable: true
trade_time_utc:
type: string
format: date-time
nullable: true
gross_cash_flow_cents:
type: integer
nullable: true
net_cash_flow_cents:
type: integer
nullable: true
quantity_multiplier:
type: integer
default: 1
expiry_date:
type: string
format: date
nullable: true
strike_price_cents:
type: integer
nullable: true
is_invalidated:
type: boolean
default: false
invalidated_at:
type: string
format: date-time
nullable: true
replaced_by_trade_id:
type: integer
nullable: true
TradeNoteUpdate:
type: object
required:
- id
properties:
id:
type: integer
notes:
type: string
nullable: true
TradeFriendlyNameUpdate:
type: object
required:
- id
- friendly_name
properties:
id:
type: integer
friendly_name:
type: string
TradeRead:
allOf:
- $ref: "#/components/schemas/TradeCreate"
- type: object
required:
- id
properties:
id:
type: integer
paths:
/status:
get:
summary: "Get API status"
security: [] # no auth required
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
properties:
status:
type: string
/register:
post:
summary: "Register user"
security: [] # no auth required
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UserCreate"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/UserRead"
"400":
description: Bad Request (user exists)
"500":
description: Internal Server Error
/login:
post:
summary: "Login"
security: [] # no auth required
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UserLogin"
responses:
"200":
description: OK (sets session cookie)
content:
application/json:
schema:
$ref: "#/components/schemas/SessionsBase"
headers:
Set-Cookie:
description: session cookie
schema:
type: string
"401":
description: Unauthorized
"500":
description: Internal Server Error
/exchanges:
post:
summary: "Create exchange"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesBase"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesRead"
"400":
description: Bad Request
"401":
description: Unauthorized
get:
summary: "List user exchanges"
security:
- session_cookie: []
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/ExchangesRead"
"401":
description: Unauthorized
/exchanges/{exchange_id}:
patch:
summary: "Update exchange"
security:
- session_cookie: []
parameters:
- name: exchange_id
in: path
required: true
schema:
type: integer
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesBase"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/ExchangesRead"
"404":
description: Not found
"400":
description: Bad request
/cycles:
post:
summary: "Create cycle (currently returns 405 in code)"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CycleBase"
responses:
"405":
description: Method not allowed (app currently returns 405)
patch:
summary: "Update cycle"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CycleUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/CycleRead"
"400":
description: Invalid data
"404":
description: Not found
/cycles/{cycle_id}:
get:
summary: "Get cycle by id"
security:
- session_cookie: []
parameters:
- name: cycle_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/CycleRead"
"404":
description: Not found
/cycles/user/{user_id}:
get:
summary: "Get cycles by user id"
security:
- session_cookie: []
parameters:
- name: user_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/CycleRead"
/trades:
post:
summary: "Create trade"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeCreate"
responses:
"201":
description: Created
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"400":
description: Invalid trade data
"500":
description: Internal Server Error
/trades/{trade_id}:
get:
summary: "Get trade by id"
security:
- session_cookie: []
parameters:
- name: trade_id
in: path
required: true
schema:
type: integer
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found
/trades/friendlyname:
patch:
summary: "Update trade friendly name"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeFriendlyNameUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found
/trades/notes:
patch:
summary: "Update trade notes"
security:
- session_cookie: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/TradeNoteUpdate"
responses:
"200":
description: Updated
content:
application/json:
schema:
$ref: "#/components/schemas/TradeRead"
"404":
description: Not found

View File

@@ -0,0 +1,56 @@
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250908-CA-PUT",
"symbol": "CA",
"exchange_id": 1,
"underlying_currency": "EUR",
"trade_type": "SELL_PUT",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-08",
"quantity": 1,
"quantity_multiplier": 100,
"price_cents": 17,
"expiry_date": "2025-09-09",
"strike_price_cents": 1220,
"commission_cents": 114
}'
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250920-CA-ASSIGN",
"symbol": "CA",
"exchange_id": 1,
"cycle_id": 1,
"underlying_currency": "EUR",
"trade_type": "ASSIGNMENT",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-20",
"quantity": 100,
"quantity_multiplier": 1,
"price_cents": 1220,
"commission_cents": 0
}'
curl --location '127.0.0.1:18881/api/v1/trades' \
--header 'Content-Type: application/json' \
--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \
--data '{
"friendly_name": "20250923-CA-CALL",
"symbol": "CA",
"exchange_id": 1,
"cycle_id": 1,
"underlying_currency": "EUR",
"trade_type": "SELL_CALL",
"trade_strategy": "WHEEL",
"trade_date": "2025-09-23",
"quantity": 1,
"quantity_multiplier": 100,
"price_cents": 31,
"expiry_date": "2025-10-10",
"strike_price_cents": 1200,
"commission_cents": 114
}'

View File

@@ -1,19 +1,405 @@
from collections.abc import Generator from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest import pytest
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import settings import settings
from app import app import trading_journal.service as svc
@pytest.fixture @pytest.fixture
def client() -> Generator[TestClient, None, None]: def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]:
with TestClient(app) as client: class NoAuth:
yield client def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
self.app = app
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
state = scope.get("state")
if state is None:
scope["state"] = SimpleNamespace()
scope["state"]["user_id"] = 1
await self.app(scope, receive, send)
class DeclineAuth:
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
self.app = app
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
path = scope.get("path", "")
# allow public/exempt paths through
if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS:
await self.app(scope, receive, send)
return
# immediately respond 401 for protected paths
resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED)
await resp(scope, receive, send)
def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient:
defaults = {
"register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})),
"authenticate_user_service": MagicMock(
return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"),
),
"create_exchange_service": MagicMock(
return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}),
),
"get_exchanges_by_user_service": MagicMock(return_value=[]),
}
if decline_auth:
monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth)
else:
monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth)
merged = {**defaults, **mocks}
for name, mock in merged.items():
monkeypatch.setattr(svc, name, mock)
import sys
if "app" in sys.modules:
del sys.modules["app"]
from importlib import import_module
app = import_module("app").app # re-import app module
return TestClient(app)
return _factory
def test_get_status(client: TestClient) -> None: def test_get_status(client_factory: Callable[..., TestClient]) -> None:
response = client.get(f"{settings.settings.api_base}/status") client = client_factory()
with client as c:
response = c.get(f"{settings.settings.api_base}/status")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "ok"} assert response.json() == {"status": "ok"}
def test_register_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory() # use defaults
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == 201
def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists")))
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "username already exists"}
def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down")))
with client as c:
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert r.json() == {"detail": "Internal Server Error"}
def test_login_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory() # use defaults
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == 200
assert r.json() == {"user_id": 1}
assert r.cookies.get("session_token") == "token"
def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(authenticate_user_service=MagicMock(return_value=None))
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_401_UNAUTHORIZED
assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"}
def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down")))
with client as c:
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert r.json() == {"detail": "Internal Server Error"}
def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory()
with client as c:
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
assert r.status_code == 201
assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"}
def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists")))
with client as c:
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "exchange already exists"}
def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(decline_auth=True)
with client as c:
r = c.get(f"{settings.settings.api_base}/exchanges")
assert r.status_code == status.HTTP_401_UNAUTHORIZED
assert r.json() == {"detail": "Unauthorized"}
def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory()
with client as c:
r = c.get(f"{settings.settings.api_base}/exchanges")
assert r.status_code == 200
assert r.json() == []
def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_exchanges_service=MagicMock(
return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"})
assert r.status_code == 200
assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}
def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "exchange not found"}
def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_cycle_by_id_service=MagicMock(
return_value=SimpleNamespace(
friendly_name="Cycle 1",
status="active",
id=1,
),
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/1")
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"}
def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/999")
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "cycle not found"}
def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_cycles_by_user_service=MagicMock(
return_value=[
SimpleNamespace(
friendly_name="Cycle 1",
status="active",
id=1,
),
SimpleNamespace(
friendly_name="Cycle 2",
status="completed",
id=2,
),
],
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/cycles/user/1")
assert r.status_code == 200
assert r.json() == [
{"id": 1, "friendly_name": "Cycle 1", "status": "active"},
{"id": 2, "friendly_name": "Cycle 2", "status": "completed"},
]
def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_cycle_service=MagicMock(
return_value=SimpleNamespace(
friendly_name="Updated Cycle",
status="completed",
id=1,
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1})
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"}
def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1})
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "invalid cycle data"}
def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "cycle not found"}
def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
create_trade_service=MagicMock(
return_value=SimpleNamespace(),
),
)
with client as c:
r = c.post(
f"{settings.settings.api_base}/trades",
json={
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 15,
"commission_cents": 100,
"trade_date": "2025-10-01",
},
)
assert r.status_code == 201
def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")),
)
with client as c:
r = c.post(
f"{settings.settings.api_base}/trades",
json={
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 15,
"commission_cents": 100,
"trade_date": "2025-10-01",
},
)
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert r.json() == {"detail": "invalid trade data"}
def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
get_trade_by_id_service=MagicMock(
return_value=SimpleNamespace(
id=1,
cycle_id=1,
exchange_id=1,
symbol="BTCUSD",
underlying_currency="USD",
trade_type="LONG_SPOT",
trade_strategy="FX",
quantity=1,
price_cents=1500,
commission_cents=100,
trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc),
),
),
)
with client as c:
r = c.get(f"{settings.settings.api_base}/trades/1")
assert r.status_code == 200
assert r.json() == {
"id": 1,
"cycle_id": 1,
"exchange_id": 1,
"symbol": "BTCUSD",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "FX",
"quantity": 1,
"price_cents": 1500,
"commission_cents": 100,
"trade_date": "2025-10-01T00:00:00+00:00",
}
def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.get(f"{settings.settings.api_base}/trades/999")
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}
def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_trade_friendly_name_service=MagicMock(
return_value=SimpleNamespace(
id=1,
friendly_name="Updated Trade Name",
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"})
assert r.status_code == 200
assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"}
def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}
def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(
update_trade_note_service=MagicMock(
return_value=SimpleNamespace(
id=1,
note="Updated trade note",
),
),
)
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"})
assert r.status_code == 200
assert r.json() == {"id": 1, "note": "Updated trade note"}
def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None:
client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
with client as c:
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"})
assert r.status_code == status.HTTP_404_NOT_FOUND
assert r.json() == {"detail": "trade not found"}

View File

@@ -137,6 +137,16 @@ def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
return dt.astimezone(timezone.utc) return dt.astimezone(timezone.utc)
def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None:
actual_utc = _ensure_utc_aware(actual)
expected_utc = _ensure_utc_aware(expected)
assert actual_utc is not None
assert expected_utc is not None
delta = abs(actual_utc - expected_utc)
assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
# Trades
def test_create_trade_success_with_cycle(session: Session) -> None: def test_create_trade_success_with_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session, user_id) exchange_id = make_exchange(session, user_id)
@@ -554,6 +564,7 @@ def test_replace_trade(session: Session) -> None:
assert actual_new_trade.replaced_by_trade_id == old_trade_id assert actual_new_trade.replaced_by_trade_id == old_trade_id
# Cycles
def test_create_cycle(session: Session) -> None: def test_create_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session, user_id) exchange_id = make_exchange(session, user_id)
@@ -656,6 +667,216 @@ def test_update_cycle_immutable_fields(session: Session) -> None:
) )
# Cycle loans
def test_create_cycle_loan_event(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000, # 5%
"notes": "Test loan change for the cycle",
}
loan_event = crud.create_cycle_loan_event(session, loan_data)
now = datetime.now(timezone.utc)
assert loan_event.id is not None
assert loan_event.cycle_id == cycle_id
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
assert loan_event.notes == loan_data["notes"]
assert loan_event.effective_date == now.date()
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
session.refresh(loan_event)
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
assert actual_loan_event is not None
assert actual_loan_event.cycle_id == cycle_id
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
assert actual_loan_event.notes == loan_data["notes"]
assert actual_loan_event.effective_date == now.date()
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data_1 = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"notes": "First loan event",
}
yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date()
loan_data_2 = {
"cycle_id": cycle_id,
"loan_amount_cents": 150000,
"loan_interest_rate_tenth_bps": 4500,
"effective_date": yesterday,
"notes": "Second loan event",
}
crud.create_cycle_loan_event(session, loan_data_1)
crud.create_cycle_loan_event(session, loan_data_2)
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
assert len(loan_events) == 2
notes = [event.notes for event in loan_events]
assert loan_events[0].notes == loan_data_2["notes"]
assert loan_events[0].effective_date == yesterday
assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc
def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data_1 = {
"cycle_id": cycle_id,
"loan_amount_cents": 100000,
"loan_interest_rate_tenth_bps": 5000,
"notes": "First loan event",
}
loan_data_2 = {
"cycle_id": cycle_id,
"loan_amount_cents": 150000,
"loan_interest_rate_tenth_bps": 4500,
"notes": "Second loan event",
}
crud.create_cycle_loan_event(session, loan_data_1)
crud.create_cycle_loan_event(session, loan_data_2)
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
assert len(loan_events) == 2
notes = [event.notes for event in loan_events]
assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same
def test_create_cycle_loan_event_single_field(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
loan_data = {
"cycle_id": cycle_id,
"loan_amount_cents": 200000,
}
loan_event = crud.create_cycle_loan_event(session, loan_data)
now = datetime.now(timezone.utc)
assert loan_event.id is not None
assert loan_event.cycle_id == cycle_id
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert loan_event.loan_interest_rate_tenth_bps is None
assert loan_event.notes is None
assert loan_event.effective_date == now.date()
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
session.refresh(loan_event)
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
assert actual_loan_event is not None
assert actual_loan_event.cycle_id == cycle_id
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
assert actual_loan_event.loan_interest_rate_tenth_bps is None
assert actual_loan_event.notes is None
assert actual_loan_event.effective_date == now.date()
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
def test_create_cycle_daily_accrual(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
accrual_data = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
"notes": "Daily interest accrual",
}
accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"])
assert accrual.id is not None
assert accrual.cycle_id == cycle_id
assert accrual.accrual_date == accrual_data["accrual_date"]
assert accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
session.refresh(accrual)
actual_accrual = session.get(models.CycleDailyAccrual, accrual.id)
assert actual_accrual is not None
assert actual_accrual.cycle_id == cycle_id
assert actual_accrual.accrual_date == accrual_data["accrual_date"]
assert actual_accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"]
def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
yesterday = today - timedelta(days=1)
accrual_data_1 = {
"cycle_id": cycle_id,
"accrual_date": yesterday,
"accrued_interest_cents": 100,
}
accrual_data_2 = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
}
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id)
assert len(accruals) == 2
dates = [accrual.accrual_date for accrual in accruals]
assert dates == [yesterday, today] # Ordered by accrual_date asc
def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None:
user_id = make_user(session)
exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id)
today = datetime.now(timezone.utc).date()
yesterday = today - timedelta(days=1)
accrual_data_1 = {
"cycle_id": cycle_id,
"accrual_date": yesterday,
"accrued_interest_cents": 100,
}
accrual_data_2 = {
"cycle_id": cycle_id,
"accrual_date": today,
"accrued_interest_cents": 150,
}
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today)
assert accruals_today is not None
assert accruals_today.accrual_date == today
assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday)
assert accruals_yesterday is not None
assert accruals_yesterday.accrual_date == yesterday
assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"]
# Exchanges # Exchanges
def test_create_exchange(session: Session) -> None: def test_create_exchange(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)

View File

@@ -45,6 +45,25 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0), "loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
"start_date": ("DATE", 1, 0), "start_date": ("DATE", 1, 0),
"end_date": ("DATE", 0, 0), "end_date": ("DATE", 0, 0),
"latest_interest_accrued_date": ("DATE", 0, 0),
"total_accrued_amount_cents": ("INTEGER", 1, 0),
},
"cycle_loan_change_events": {
"id": ("INTEGER", 1, 1),
"cycle_id": ("INTEGER", 1, 0),
"effective_date": ("DATE", 1, 0),
"loan_amount_cents": ("INTEGER", 0, 0),
"loan_interest_rate_tenth_bps": ("INTEGER", 0, 0),
"related_trade_id": ("INTEGER", 0, 0),
"notes": ("TEXT", 0, 0),
"created_at": ("DATETIME", 1, 0),
},
"cycle_daily_accrual": {
"id": ("INTEGER", 1, 1),
"cycle_id": ("INTEGER", 1, 0),
"accrual_date": ("DATE", 1, 0),
"accrual_amount_cents": ("INTEGER", 1, 0),
"created_at": ("DATETIME", 1, 0),
}, },
"trades": { "trades": {
"id": ("INTEGER", 1, 1), "id": ("INTEGER", 1, 1),
@@ -100,6 +119,13 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
{"table": "users", "from": "user_id", "to": "id"}, {"table": "users", "from": "user_id", "to": "id"},
{"table": "exchanges", "from": "exchange_id", "to": "id"}, {"table": "exchanges", "from": "exchange_id", "to": "id"},
], ],
"cycle_loan_change_events": [
{"table": "cycles", "from": "cycle_id", "to": "id"},
{"table": "trades", "from": "related_trade_id", "to": "id"},
],
"cycle_daily_accrual": [
{"table": "cycles", "from": "cycle_id", "to": "id"},
],
"sessions": [ "sessions": [
{"table": "users", "from": "user_id", "to": "id"}, {"table": "users", "from": "user_id", "to": "id"},
], ],
@@ -146,6 +172,39 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows] actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
for efk in fks: for efk in fks:
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}" assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"
# check trades.replaced_by_trade_id self-referential FK
fk_rows = conn.execute(text("PRAGMA foreign_key_list('trades')")).fetchall()
actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows]
assert {"table": "trades", "from": "replaced_by_trade_id", "to": "id"} in actual_fk_list, (
"missing self FK trades.replaced_by_trade_id -> trades.id"
)
# helper to find unique index on a column
def has_unique_index(table: str, column: str) -> bool:
idx_rows = conn.execute(text(f"PRAGMA index_list('{table}')")).fetchall()
for idx in idx_rows:
idx_name = idx[1]
is_unique = bool(idx[2])
if not is_unique:
continue
info = conn.execute(text(f"PRAGMA index_info('{idx_name}')")).fetchall()
cols = [r[2] for r in info]
if column in cols:
return True
return False
assert has_unique_index("trades", "friendly_name"), (
"expected unique index on trades(friendly_name) per uq_trades_user_friendly_name"
)
assert has_unique_index("cycles", "friendly_name"), (
"expected unique index on cycles(friendly_name) per uq_cycles_user_friendly_name"
)
assert has_unique_index("exchanges", "name"), "expected unique index on exchanges(name) per uq_exchanges_user_name"
assert has_unique_index("sessions", "session_token_hash"), "expected unique index on sessions(session_token_hash)"
assert has_unique_index("cycle_loan_change_events", "related_trade_id"), (
"expected unique index on cycle_loan_change_events(related_trade_id)"
)
finally: finally:
engine.dispose() engine.dispose()
SQLModel.metadata.clear() SQLModel.metadata.clear()

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
from pydantic import BaseModel from pydantic import BaseModel
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import Enum
from sqlalchemy.sql.elements import ColumnElement
# Generic enum member type # Generic enum member type
T = TypeVar("T", bound="Enum") T = TypeVar("T", bound="Enum")
@@ -301,6 +303,93 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any]
return cycle return cycle
# Cycle loan and interest
def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
data = _data_to_dict(loan_data)
allowed = _allowed_columns(models.CycleLoanChangeEvents)
payload = {k: v for k, v in data.items() if k in allowed}
if "cycle_id" not in payload:
raise ValueError("cycle_id is required")
cycle = session.get(models.Cycles, payload["cycle_id"])
if cycle is None:
raise ValueError("cycle_id does not exist")
payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date()
payload["created_at"] = datetime.now(timezone.utc)
cle = models.CycleLoanChangeEvents(**payload)
session.add(cle)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_cycle_loan_event integrity error") from e
session.refresh(cle)
return cle
def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]:
eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date)
id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id)
statement = (
select(models.CycleLoanChangeEvents)
.where(
models.CycleLoanChangeEvents.cycle_id == cycle_id,
)
.order_by(eff_col, id_col.asc())
)
return list(session.exec(statement).all())
def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual:
cycle = session.get(models.Cycles, cycle_id)
if cycle is None:
raise ValueError("cycle_id does not exist")
existing = session.exec(
select(models.CycleDailyAccrual).where(
models.CycleDailyAccrual.cycle_id == cycle_id,
models.CycleDailyAccrual.accrual_date == accrual_date,
),
).first()
if existing:
return existing
if accrual_amount_cents < 0:
raise ValueError("accrual_amount_cents must be non-negative")
row = models.CycleDailyAccrual(
cycle_id=cycle_id,
accrual_date=accrual_date,
accrual_amount_cents=accrual_amount_cents,
created_at=datetime.now(timezone.utc),
)
session.add(row)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_cycle_daily_accrual integrity error") from e
session.refresh(row)
return row
def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]:
date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date)
statement = (
select(models.CycleDailyAccrual)
.where(
models.CycleDailyAccrual.cycle_id == cycle_id,
)
.order_by(date_col.asc())
)
return list(session.exec(statement).all())
def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None:
statement = select(models.CycleDailyAccrual).where(
models.CycleDailyAccrual.cycle_id == cycle_id,
models.CycleDailyAccrual.accrual_date == accrual_date,
)
return session.exec(statement).first()
# Exchanges # Exchanges
IMMUTABLE_EXCHANGE_FIELDS = {"id"} IMMUTABLE_EXCHANGE_FIELDS = {"id"}

View File

@@ -28,6 +28,8 @@ def _mig_0_1(engine: Engine) -> None:
models_v1.Users.__table__, # type: ignore[attr-defined] models_v1.Users.__table__, # type: ignore[attr-defined]
models_v1.Sessions.__table__, # type: ignore[attr-defined] models_v1.Sessions.__table__, # type: ignore[attr-defined]
models_v1.Exchanges.__table__, # type: ignore[attr-defined] models_v1.Exchanges.__table__, # type: ignore[attr-defined]
models_v1.CycleLoanChangeEvents.__table__, # type: ignore[attr-defined]
models_v1.CycleDailyAccrual.__table__, # type: ignore[attr-defined]
], ],
) )

View File

@@ -69,7 +69,7 @@ class CycleBase(SQLModel):
funding_source: str | None = None funding_source: str | None = None
capital_exposure_cents: int | None = None capital_exposure_cents: int | None = None
loan_amount_cents: int | None = None loan_amount_cents: int | None = None
loan_interest_rate_bps: int | None = None loan_interest_rate_tenth_bps: int | None = None
trades: list[TradeRead] | None = None trades: list[TradeRead] | None = None
exchange: ExchangesRead | None = None exchange: ExchangesRead | None = None

View File

@@ -1,11 +1,13 @@
from datetime import date, datetime from datetime import date, datetime
from enum import Enum from enum import Enum
from typing import Optional
from sqlmodel import ( from sqlmodel import (
Column, Column,
Date, Date,
DateTime, DateTime,
Field, Field,
ForeignKey,
Integer, Integer,
Relationship, Relationship,
SQLModel, SQLModel,
@@ -16,8 +18,10 @@ from sqlmodel import (
class TradeType(str, Enum): class TradeType(str, Enum):
SELL_PUT = "SELL_PUT" SELL_PUT = "SELL_PUT"
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
ASSIGNMENT = "ASSIGNMENT" ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL" SELL_CALL = "SELL_CALL"
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL" EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT" LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
@@ -92,8 +96,14 @@ class Trades(SQLModel, table=True):
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
cycle: "Cycles" = Relationship(back_populates="trades") cycle: "Cycles" = Relationship(back_populates="trades")
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
back_populates="trade",
sa_relationship_kwargs={"uselist": False},
)
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" # type: ignore[attr-defined] __tablename__ = "cycles" # type: ignore[attr-defined]
@@ -109,12 +119,48 @@ class Cycles(SQLModel, table=True):
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True) capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False)) start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list["Trades"] = Relationship(back_populates="cycle") trades: list["Trades"] = Relationship(back_populates="cycle")
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
class CycleLoanChangeEvents(SQLModel, table=True):
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
effective_date: date = Field(sa_column=Column(Date, nullable=False))
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
class CycleDailyAccrual(SQLModel, table=True):
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" # type: ignore[attr-defined] __tablename__ = "exchanges" # type: ignore[attr-defined]

View File

@@ -1,11 +1,13 @@
from datetime import date, datetime from datetime import date, datetime
from enum import Enum from enum import Enum
from typing import Optional
from sqlmodel import ( from sqlmodel import (
Column, Column,
Date, Date,
DateTime, DateTime,
Field, Field,
ForeignKey,
Integer, Integer,
Relationship, Relationship,
SQLModel, SQLModel,
@@ -16,8 +18,10 @@ from sqlmodel import (
class TradeType(str, Enum): class TradeType(str, Enum):
SELL_PUT = "SELL_PUT" SELL_PUT = "SELL_PUT"
CLOSE_SELL_PUT = "CLOSE_SELL_PUT"
ASSIGNMENT = "ASSIGNMENT" ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL" SELL_CALL = "SELL_CALL"
CLOSE_SELL_CALL = "CLOSE_SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL" EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT" LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
@@ -92,8 +96,14 @@ class Trades(SQLModel, table=True):
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
cycle: "Cycles" = Relationship(back_populates="trades") cycle: "Cycles" = Relationship(back_populates="trades")
related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship(
back_populates="trade",
sa_relationship_kwargs={"uselist": False},
)
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" # type: ignore[attr-defined] __tablename__ = "cycles" # type: ignore[attr-defined]
@@ -109,12 +119,48 @@ class Cycles(SQLModel, table=True):
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True) capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False)) start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list["Trades"] = Relationship(back_populates="cycle") trades: list["Trades"] = Relationship(back_populates="cycle")
loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
class CycleLoanChangeEvents(SQLModel, table=True):
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
effective_date: date = Field(sa_column=Column(Date, nullable=False))
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
class CycleDailyAccrual(SQLModel, table=True):
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
id: int | None = Field(default=None, primary_key=True)
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" # type: ignore[attr-defined] __tablename__ = "exchanges" # type: ignore[attr-defined]

View File

@@ -229,6 +229,7 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int
# Cycle Service # Cycle Service
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead:
raise NotImplementedError("Cycle creation not implemented")
cycle_data_dict = cycle_data.model_dump() cycle_data_dict = cycle_data.model_dump()
cycle_data_dict["user_id"] = user_id cycle_data_dict["user_id"] = user_id
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
@@ -250,11 +251,23 @@ def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleR
return [CycleRead.model_validate(cycle) for cycle in cycles] return [CycleRead.model_validate(cycle) for cycle in cycles]
def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # noqa: PLR0911
if cycle_data.status == "CLOSED" and cycle_data.end_date is None: if cycle_data.status == "CLOSED" and cycle_data.end_date is None:
return False, "end_date is required when status is CLOSED" return False, "end_date is required when status is CLOSED"
if cycle_data.status == "OPEN" and cycle_data.end_date is not None: if cycle_data.status == "OPEN" and cycle_data.end_date is not None:
return False, "end_date must be empty when status is OPEN" return False, "end_date must be empty when status is OPEN"
if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0:
return False, "capital_exposure_cents must be non-negative"
if (
cycle_data.funding_source is not None
and cycle_data.funding_source != "CASH"
and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None)
):
return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH"
if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0:
return False, "loan_amount_cents must be non-negative"
if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0:
return False, "loan_interest_rate_tenth_bps must be non-negative"
return True, "" return True, ""