Compare commits
8 Commits
6a5f160d83
...
feature/lo
| Author | SHA1 | Date | |
|---|---|---|---|
| ef6dacd0bc | |||
| 0ca660f268 | |||
| 5e7d801075 | |||
| 94fb4705ff | |||
| bb87b90285 | |||
| 5eae75b23e | |||
| 544f5e8c92 | |||
| b6ba108156 |
19
LICENSE
Normal file
19
LICENSE
Normal file
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2025 Tianyu Liu, Studio TJ
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
28
README.md
Normal file
28
README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
|
||||
# Trading Journal (Work In Progress)
|
||||
|
||||
A simple trading journal application (work in progress).
|
||||
|
||||
This repository contains the backend of a trading journal designed to help you record and analyse trades. The system is specially designed to support journaling trades for the "wheel" options strategy, but it also supports other trade types such as long/short spot positions, forex, and more.
|
||||
|
||||
Important: the project is still under active development. There is a backend in this repo, but the frontend UI has not been implemented yet.
|
||||
|
||||
## Key features
|
||||
|
||||
- Journal trades with rich metadata (strategy, entry/exit, P/L, notes).
|
||||
- Built-in support and data model conveniences for the Wheel strategy (puts/calls lifecycle tracking).
|
||||
- Flexible support for other trade types: long/short spots, forex, futures, etc.
|
||||
- Backend-first design with tests and migration helpers.
|
||||
|
||||
## Repository layout
|
||||
|
||||
- `backend/` — Python backend code (API, models, services, migrations, tests).
|
||||
- `backend/trading_journal/` — core application modules: CRUD, models, DTOs, services, and security.
|
||||
- `backend/tests/` — unit tests targeting the backend logic and DB layer.
|
||||
|
||||
|
||||
## License
|
||||
|
||||
See the `LICENSE` file in the project root for license details.
|
||||
|
||||
|
||||
@@ -52,8 +52,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
await asyncio.to_thread(_db.dispose)
|
||||
|
||||
|
||||
origins = [
|
||||
"http://127.0.0.1:18881",
|
||||
]
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(service.AuthMiddleWare)
|
||||
app.add_middleware(
|
||||
service.AuthMiddleWare,
|
||||
)
|
||||
app.state.db_factory = _db
|
||||
|
||||
|
||||
@@ -77,7 +83,7 @@ async def register_user(request: Request, user_in: UserCreate) -> Response:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to register user: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||
|
||||
|
||||
@app.post(f"{settings.settings.api_base}/login")
|
||||
@@ -110,7 +116,7 @@ async def login(request: Request, user_in: UserLogin) -> Response:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to login user: \n")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
554
backend/openapi.yaml
Normal file
554
backend/openapi.yaml
Normal file
@@ -0,0 +1,554 @@
|
||||
openapi: "3.0.3"
|
||||
info:
|
||||
title: Trading Journal API
|
||||
version: "1.0.0"
|
||||
description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py).
|
||||
servers:
|
||||
- url: "http://127.0.0.1:18881{basePath}"
|
||||
variables:
|
||||
basePath:
|
||||
default: "/api/v1"
|
||||
description: "API base path (matches settings.settings.api_base)"
|
||||
components:
|
||||
securitySchemes:
|
||||
session_cookie:
|
||||
type: apiKey
|
||||
in: cookie
|
||||
name: session_token
|
||||
schemas:
|
||||
UserCreate:
|
||||
$ref: "#/components/schemas/UserCreate_impl"
|
||||
UserCreate_impl:
|
||||
type: object
|
||||
required:
|
||||
- username
|
||||
- password
|
||||
properties:
|
||||
username:
|
||||
type: string
|
||||
is_active:
|
||||
type: boolean
|
||||
default: true
|
||||
password:
|
||||
type: string
|
||||
UserLogin:
|
||||
type: object
|
||||
required:
|
||||
- username
|
||||
- password
|
||||
properties:
|
||||
username:
|
||||
type: string
|
||||
password:
|
||||
type: string
|
||||
UserRead:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- username
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
username:
|
||||
type: string
|
||||
is_active:
|
||||
type: boolean
|
||||
SessionsBase:
|
||||
type: object
|
||||
required:
|
||||
- user_id
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
SessionsCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/SessionsBase"
|
||||
- type: object
|
||||
required:
|
||||
- expires_at
|
||||
properties:
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
ExchangesBase:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
ExchangesRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/ExchangesBase"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
CycleBase:
|
||||
type: object
|
||||
properties:
|
||||
friendly_name:
|
||||
type: string
|
||||
nullable: true
|
||||
status:
|
||||
type: string
|
||||
end_date:
|
||||
type: string
|
||||
format: date
|
||||
nullable: true
|
||||
funding_source:
|
||||
type: string
|
||||
nullable: true
|
||||
capital_exposure_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
loan_amount_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
loan_interest_rate_tenth_bps:
|
||||
type: integer
|
||||
nullable: true
|
||||
trades:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
nullable: true
|
||||
exchange:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
nullable: true
|
||||
CycleCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleBase"
|
||||
- type: object
|
||||
required:
|
||||
- user_id
|
||||
- symbol
|
||||
- exchange_id
|
||||
- underlying_currency
|
||||
- start_date
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
symbol:
|
||||
type: string
|
||||
exchange_id:
|
||||
type: integer
|
||||
underlying_currency:
|
||||
type: string
|
||||
start_date:
|
||||
type: string
|
||||
format: date
|
||||
CycleUpdate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleBase"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
CycleRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/CycleCreate"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
TradeBase:
|
||||
type: object
|
||||
required:
|
||||
- symbol
|
||||
- underlying_currency
|
||||
- trade_type
|
||||
- trade_strategy
|
||||
- trade_date
|
||||
- quantity
|
||||
- price_cents
|
||||
- commission_cents
|
||||
properties:
|
||||
friendly_name:
|
||||
type: string
|
||||
nullable: true
|
||||
symbol:
|
||||
type: string
|
||||
exchange_id:
|
||||
type: integer
|
||||
underlying_currency:
|
||||
type: string
|
||||
trade_type:
|
||||
type: string
|
||||
trade_strategy:
|
||||
type: string
|
||||
trade_date:
|
||||
type: string
|
||||
format: date
|
||||
quantity:
|
||||
type: integer
|
||||
price_cents:
|
||||
type: integer
|
||||
commission_cents:
|
||||
type: integer
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
cycle_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
TradeCreate:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/TradeBase"
|
||||
- type: object
|
||||
properties:
|
||||
user_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
trade_time_utc:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
gross_cash_flow_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
net_cash_flow_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
quantity_multiplier:
|
||||
type: integer
|
||||
default: 1
|
||||
expiry_date:
|
||||
type: string
|
||||
format: date
|
||||
nullable: true
|
||||
strike_price_cents:
|
||||
type: integer
|
||||
nullable: true
|
||||
is_invalidated:
|
||||
type: boolean
|
||||
default: false
|
||||
invalidated_at:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
replaced_by_trade_id:
|
||||
type: integer
|
||||
nullable: true
|
||||
TradeNoteUpdate:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
notes:
|
||||
type: string
|
||||
nullable: true
|
||||
TradeFriendlyNameUpdate:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- friendly_name
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
friendly_name:
|
||||
type: string
|
||||
TradeRead:
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/TradeCreate"
|
||||
- type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
paths:
|
||||
/status:
|
||||
get:
|
||||
summary: "Get API status"
|
||||
security: [] # no auth required
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
/register:
|
||||
post:
|
||||
summary: "Register user"
|
||||
security: [] # no auth required
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserCreate"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserRead"
|
||||
"400":
|
||||
description: Bad Request (user exists)
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/login:
|
||||
post:
|
||||
summary: "Login"
|
||||
security: [] # no auth required
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UserLogin"
|
||||
responses:
|
||||
"200":
|
||||
description: OK (sets session cookie)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/SessionsBase"
|
||||
headers:
|
||||
Set-Cookie:
|
||||
description: session cookie
|
||||
schema:
|
||||
type: string
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/exchanges:
|
||||
post:
|
||||
summary: "Create exchange"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesBase"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"400":
|
||||
description: Bad Request
|
||||
"401":
|
||||
description: Unauthorized
|
||||
get:
|
||||
summary: "List user exchanges"
|
||||
security:
|
||||
- session_cookie: []
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"401":
|
||||
description: Unauthorized
|
||||
/exchanges/{exchange_id}:
|
||||
patch:
|
||||
summary: "Update exchange"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: exchange_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesBase"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ExchangesRead"
|
||||
"404":
|
||||
description: Not found
|
||||
"400":
|
||||
description: Bad request
|
||||
/cycles:
|
||||
post:
|
||||
summary: "Create cycle (currently returns 405 in code)"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleBase"
|
||||
responses:
|
||||
"405":
|
||||
description: Method not allowed (app currently returns 405)
|
||||
patch:
|
||||
summary: "Update cycle"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
"400":
|
||||
description: Invalid data
|
||||
"404":
|
||||
description: Not found
|
||||
/cycles/{cycle_id}:
|
||||
get:
|
||||
summary: "Get cycle by id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: cycle_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/cycles/user/{user_id}:
|
||||
get:
|
||||
summary: "Get cycles by user id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: user_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CycleRead"
|
||||
/trades:
|
||||
post:
|
||||
summary: "Create trade"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeCreate"
|
||||
responses:
|
||||
"201":
|
||||
description: Created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"400":
|
||||
description: Invalid trade data
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
/trades/{trade_id}:
|
||||
get:
|
||||
summary: "Get trade by id"
|
||||
security:
|
||||
- session_cookie: []
|
||||
parameters:
|
||||
- name: trade_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/trades/friendlyname:
|
||||
patch:
|
||||
summary: "Update trade friendly name"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeFriendlyNameUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
/trades/notes:
|
||||
patch:
|
||||
summary: "Update trade notes"
|
||||
security:
|
||||
- session_cookie: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeNoteUpdate"
|
||||
responses:
|
||||
"200":
|
||||
description: Updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/TradeRead"
|
||||
"404":
|
||||
description: Not found
|
||||
@@ -24,3 +24,4 @@ ignore = [
|
||||
[lint.extend-per-file-ignores]
|
||||
"test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"]
|
||||
"models*.py" = ["FA102"]
|
||||
"dto.py" = ["TC001", "TC003"]
|
||||
|
||||
@@ -1,19 +1,405 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import settings
|
||||
from app import app
|
||||
import trading_journal.service as svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]:
|
||||
class NoAuth:
|
||||
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||
state = scope.get("state")
|
||||
if state is None:
|
||||
scope["state"] = SimpleNamespace()
|
||||
scope["state"]["user_id"] = 1
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
class DeclineAuth:
|
||||
def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None: # noqa: ANN001
|
||||
if scope.get("type") != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
path = scope.get("path", "")
|
||||
# allow public/exempt paths through
|
||||
if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
# immediately respond 401 for protected paths
|
||||
resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
await resp(scope, receive, send)
|
||||
|
||||
def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient:
|
||||
defaults = {
|
||||
"register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})),
|
||||
"authenticate_user_service": MagicMock(
|
||||
return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"),
|
||||
),
|
||||
"create_exchange_service": MagicMock(
|
||||
return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}),
|
||||
),
|
||||
"get_exchanges_by_user_service": MagicMock(return_value=[]),
|
||||
}
|
||||
|
||||
if decline_auth:
|
||||
monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth)
|
||||
else:
|
||||
monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth)
|
||||
merged = {**defaults, **mocks}
|
||||
for name, mock in merged.items():
|
||||
monkeypatch.setattr(svc, name, mock)
|
||||
import sys
|
||||
|
||||
if "app" in sys.modules:
|
||||
del sys.modules["app"]
|
||||
from importlib import import_module
|
||||
|
||||
app = import_module("app").app # re-import app module
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
def test_get_status(client: TestClient) -> None:
|
||||
response = client.get(f"{settings.settings.api_base}/status")
|
||||
def test_get_status(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
response = c.get(f"{settings.settings.api_base}/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_register_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory() # use defaults
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == 201
|
||||
|
||||
|
||||
def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "username already exists"}
|
||||
|
||||
|
||||
def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert r.json() == {"detail": "Internal Server Error"}
|
||||
|
||||
|
||||
def test_login_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory() # use defaults
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"user_id": 1}
|
||||
assert r.cookies.get("session_token") == "token"
|
||||
|
||||
|
||||
def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(authenticate_user_service=MagicMock(return_value=None))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"}
|
||||
|
||||
|
||||
def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"})
|
||||
assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert r.json() == {"detail": "Internal Server Error"}
|
||||
|
||||
|
||||
def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"}
|
||||
|
||||
|
||||
def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists")))
|
||||
with client as c:
|
||||
r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "exchange already exists"}
|
||||
|
||||
|
||||
def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(decline_auth=True)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||
assert r.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert r.json() == {"detail": "Unauthorized"}
|
||||
|
||||
|
||||
def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory()
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/exchanges")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_exchanges_service=MagicMock(
|
||||
return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}
|
||||
|
||||
|
||||
def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "exchange not found"}
|
||||
|
||||
|
||||
def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_cycle_by_id_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
friendly_name="Cycle 1",
|
||||
status="active",
|
||||
id=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"}
|
||||
|
||||
|
||||
def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/999")
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "cycle not found"}
|
||||
|
||||
|
||||
def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_cycles_by_user_service=MagicMock(
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
friendly_name="Cycle 1",
|
||||
status="active",
|
||||
id=1,
|
||||
),
|
||||
SimpleNamespace(
|
||||
friendly_name="Cycle 2",
|
||||
status="completed",
|
||||
id=2,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/cycles/user/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == [
|
||||
{"id": 1, "friendly_name": "Cycle 1", "status": "active"},
|
||||
{"id": 2, "friendly_name": "Cycle 2", "status": "completed"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_cycle_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
friendly_name="Updated Cycle",
|
||||
status="completed",
|
||||
id=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"}
|
||||
|
||||
|
||||
def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1})
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "invalid cycle data"}
|
||||
|
||||
|
||||
def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "cycle not found"}
|
||||
|
||||
|
||||
def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
create_trade_service=MagicMock(
|
||||
return_value=SimpleNamespace(),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.post(
|
||||
f"{settings.settings.api_base}/trades",
|
||||
json={
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 15,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01",
|
||||
},
|
||||
)
|
||||
assert r.status_code == 201
|
||||
|
||||
|
||||
def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")),
|
||||
)
|
||||
with client as c:
|
||||
r = c.post(
|
||||
f"{settings.settings.api_base}/trades",
|
||||
json={
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 15,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01",
|
||||
},
|
||||
)
|
||||
assert r.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert r.json() == {"detail": "invalid trade data"}
|
||||
|
||||
|
||||
def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
get_trade_by_id_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
cycle_id=1,
|
||||
exchange_id=1,
|
||||
symbol="BTCUSD",
|
||||
underlying_currency="USD",
|
||||
trade_type="LONG_SPOT",
|
||||
trade_strategy="FX",
|
||||
quantity=1,
|
||||
price_cents=1500,
|
||||
commission_cents=100,
|
||||
trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc),
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/trades/1")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {
|
||||
"id": 1,
|
||||
"cycle_id": 1,
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTCUSD",
|
||||
"underlying_currency": "USD",
|
||||
"trade_type": "LONG_SPOT",
|
||||
"trade_strategy": "FX",
|
||||
"quantity": 1,
|
||||
"price_cents": 1500,
|
||||
"commission_cents": 100,
|
||||
"trade_date": "2025-10-01T00:00:00+00:00",
|
||||
}
|
||||
|
||||
|
||||
def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.get(f"{settings.settings.api_base}/trades/999")
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
|
||||
|
||||
def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_trade_friendly_name_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
friendly_name="Updated Trade Name",
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"}
|
||||
|
||||
|
||||
def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
|
||||
|
||||
def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(
|
||||
update_trade_note_service=MagicMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=1,
|
||||
note="Updated trade note",
|
||||
),
|
||||
),
|
||||
)
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"id": 1, "note": "Updated trade note"}
|
||||
|
||||
|
||||
def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None:
|
||||
client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found")))
|
||||
with client as c:
|
||||
r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"})
|
||||
assert r.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert r.json() == {"detail": "trade not found"}
|
||||
|
||||
@@ -56,7 +56,9 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
|
||||
return cast("int", exchange.id)
|
||||
|
||||
|
||||
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||
def make_cycle(
|
||||
session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle"
|
||||
) -> int:
|
||||
cycle = models.Cycles(
|
||||
user_id=user_id,
|
||||
friendly_name=friendly_name,
|
||||
@@ -72,7 +74,9 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name:
|
||||
return cast("int", cycle.id)
|
||||
|
||||
|
||||
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
|
||||
def make_trade(
|
||||
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
|
||||
) -> int:
|
||||
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||
assert cycle is not None
|
||||
exchange_id = cycle.exchange_id
|
||||
@@ -137,13 +141,17 @@ def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None:
|
||||
def _validate_timestamp(
|
||||
actual: datetime, expected: datetime, tolerance: timedelta
|
||||
) -> None:
|
||||
actual_utc = _ensure_utc_aware(actual)
|
||||
expected_utc = _ensure_utc_aware(expected)
|
||||
assert actual_utc is not None
|
||||
assert expected_utc is not None
|
||||
delta = abs(actual_utc - expected_utc)
|
||||
assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
|
||||
assert delta <= tolerance, (
|
||||
f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}"
|
||||
)
|
||||
|
||||
|
||||
# Trades
|
||||
@@ -470,7 +478,9 @@ def test_update_trade_friendly_name(session: Session) -> None:
|
||||
trade_id = make_trade(session, user_id, cycle_id)
|
||||
|
||||
new_friendly_name = "Updated Trade Name"
|
||||
updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name)
|
||||
updated_trade = crud.update_trade_friendly_name(
|
||||
session, trade_id, new_friendly_name
|
||||
)
|
||||
assert updated_trade is not None
|
||||
assert updated_trade.id == trade_id
|
||||
assert updated_trade.friendly_name == new_friendly_name
|
||||
@@ -624,7 +634,9 @@ def test_get_cycles_by_user_id(session: Session) -> None:
|
||||
def test_update_cycle(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||
cycle_id = make_cycle(
|
||||
session, user_id, exchange_id, friendly_name="Initial Cycle Name"
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"friendly_name": "Updated Cycle Name",
|
||||
@@ -646,14 +658,20 @@ def test_update_cycle(session: Session) -> None:
|
||||
def test_update_cycle_immutable_fields(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
|
||||
cycle_id = make_cycle(
|
||||
session, user_id, exchange_id, friendly_name="Initial Cycle Name"
|
||||
)
|
||||
|
||||
# Attempt to update immutable fields
|
||||
update_data = {
|
||||
"id": cycle_id + 1, # Trying to change the ID
|
||||
"user_id": user_id + 1, # Trying to change the user_id
|
||||
"start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date
|
||||
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||
"start_date": datetime(
|
||||
2020, 1, 1, tzinfo=timezone.utc
|
||||
).date(), # Trying to change start_date
|
||||
"created_at": datetime(
|
||||
2020, 1, 1, tzinfo=timezone.utc
|
||||
), # Trying to change created_at
|
||||
"friendly_name": "Valid Update", # Valid field to update
|
||||
}
|
||||
|
||||
@@ -685,7 +703,10 @@ def test_create_cycle_loan_event(session: Session) -> None:
|
||||
assert loan_event.id is not None
|
||||
assert loan_event.cycle_id == cycle_id
|
||||
assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||
assert (
|
||||
loan_event.loan_interest_rate_tenth_bps
|
||||
== loan_data["loan_interest_rate_tenth_bps"]
|
||||
)
|
||||
assert loan_event.notes == loan_data["notes"]
|
||||
assert loan_event.effective_date == now.date()
|
||||
_validate_timestamp(loan_event.created_at, now, timedelta(seconds=1))
|
||||
@@ -695,12 +716,41 @@ def test_create_cycle_loan_event(session: Session) -> None:
|
||||
assert actual_loan_event is not None
|
||||
assert actual_loan_event.cycle_id == cycle_id
|
||||
assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"]
|
||||
assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"]
|
||||
assert (
|
||||
actual_loan_event.loan_interest_rate_tenth_bps
|
||||
== loan_data["loan_interest_rate_tenth_bps"]
|
||||
)
|
||||
assert actual_loan_event.notes == loan_data["notes"]
|
||||
assert actual_loan_event.effective_date == now.date()
|
||||
_validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1))
|
||||
|
||||
|
||||
def test_create_cycle_loan_event_same_date_error(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data_1 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000,
|
||||
"effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(),
|
||||
"notes": "First loan event",
|
||||
}
|
||||
loan_data_2 = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 150000,
|
||||
"loan_interest_rate_tenth_bps": 4500,
|
||||
"effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(),
|
||||
"notes": "Second loan event same date",
|
||||
}
|
||||
|
||||
crud.create_cycle_loan_event(session, loan_data_1)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
crud.create_cycle_loan_event(session, loan_data_2)
|
||||
assert "create_cycle_loan_event integrity error" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
@@ -729,34 +779,77 @@ def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None:
|
||||
notes = [event.notes for event in loan_events]
|
||||
assert loan_events[0].notes == loan_data_2["notes"]
|
||||
assert loan_events[0].effective_date == yesterday
|
||||
assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc
|
||||
assert notes == [
|
||||
"Second loan event",
|
||||
"First loan event",
|
||||
] # Ordered by effective_date desc
|
||||
|
||||
|
||||
def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None:
|
||||
def test_get_cycle_loan_event_by_cycle_id_and_effective_date(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data_1 = {
|
||||
effective_date = datetime(2023, 1, 1, tzinfo=timezone.utc).date()
|
||||
loan_data = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000,
|
||||
"notes": "First loan event",
|
||||
"effective_date": effective_date,
|
||||
"notes": "Loan event for specific date",
|
||||
}
|
||||
loan_data_2 = {
|
||||
|
||||
crud.create_cycle_loan_event(session, loan_data)
|
||||
loan_event = crud.get_loan_event_by_cycle_id_and_effective_date(
|
||||
session, cycle_id, effective_date
|
||||
)
|
||||
assert loan_event is not None
|
||||
assert loan_event.cycle_id == cycle_id
|
||||
assert loan_event.effective_date == effective_date
|
||||
assert loan_event.notes == loan_data["notes"]
|
||||
|
||||
|
||||
def test_update_cycle_loan_event(session: Session) -> None:
|
||||
user_id = make_user(session)
|
||||
exchange_id = make_exchange(session, user_id)
|
||||
cycle_id = make_cycle(session, user_id, exchange_id)
|
||||
|
||||
loan_data = {
|
||||
"cycle_id": cycle_id,
|
||||
"loan_amount_cents": 150000,
|
||||
"loan_interest_rate_tenth_bps": 4500,
|
||||
"notes": "Second loan event",
|
||||
"loan_amount_cents": 100000,
|
||||
"loan_interest_rate_tenth_bps": 5000,
|
||||
"notes": "Initial loan event",
|
||||
}
|
||||
|
||||
crud.create_cycle_loan_event(session, loan_data_1)
|
||||
crud.create_cycle_loan_event(session, loan_data_2)
|
||||
loan_event = crud.create_cycle_loan_event(session, loan_data)
|
||||
assert loan_event is not None
|
||||
|
||||
loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id)
|
||||
assert len(loan_events) == 2
|
||||
notes = [event.notes for event in loan_events]
|
||||
assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same
|
||||
update_data = {
|
||||
"loan_amount_cents": 120000,
|
||||
"loan_interest_rate_tenth_bps": 4500,
|
||||
"notes": "Updated loan event",
|
||||
}
|
||||
event_id = loan_event.id or 0
|
||||
|
||||
updated_loan_event = crud.update_cycle_loan_event(session, event_id, update_data)
|
||||
assert updated_loan_event is not None
|
||||
assert updated_loan_event.id == loan_event.id
|
||||
assert updated_loan_event.loan_amount_cents == update_data["loan_amount_cents"]
|
||||
assert (
|
||||
updated_loan_event.loan_interest_rate_tenth_bps
|
||||
== update_data["loan_interest_rate_tenth_bps"]
|
||||
)
|
||||
assert updated_loan_event.notes == update_data["notes"]
|
||||
|
||||
session.refresh(updated_loan_event)
|
||||
actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id)
|
||||
assert actual_loan_event is not None
|
||||
assert actual_loan_event.loan_amount_cents == update_data["loan_amount_cents"]
|
||||
assert (
|
||||
actual_loan_event.loan_interest_rate_tenth_bps
|
||||
== update_data["loan_interest_rate_tenth_bps"]
|
||||
)
|
||||
assert actual_loan_event.notes == update_data["notes"]
|
||||
|
||||
|
||||
def test_create_cycle_loan_event_single_field(session: Session) -> None:
|
||||
@@ -802,7 +895,12 @@ def test_create_cycle_daily_accrual(session: Session) -> None:
|
||||
"notes": "Daily interest accrual",
|
||||
}
|
||||
|
||||
accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"])
|
||||
accrual = crud.create_cycle_daily_accrual(
|
||||
session,
|
||||
cycle_id,
|
||||
accrual_data["accrual_date"],
|
||||
accrual_data["accrued_interest_cents"],
|
||||
)
|
||||
assert accrual.id is not None
|
||||
assert accrual.cycle_id == cycle_id
|
||||
assert accrual.accrual_date == accrual_data["accrual_date"]
|
||||
@@ -835,8 +933,18 @@ def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None:
|
||||
"accrued_interest_cents": 150,
|
||||
}
|
||||
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(
|
||||
session,
|
||||
cycle_id,
|
||||
accrual_data_1["accrual_date"],
|
||||
accrual_data_1["accrued_interest_cents"],
|
||||
)
|
||||
crud.create_cycle_daily_accrual(
|
||||
session,
|
||||
cycle_id,
|
||||
accrual_data_2["accrual_date"],
|
||||
accrual_data_2["accrued_interest_cents"],
|
||||
)
|
||||
|
||||
accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id)
|
||||
assert len(accruals) == 2
|
||||
@@ -863,18 +971,37 @@ def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None
|
||||
"accrued_interest_cents": 150,
|
||||
}
|
||||
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"])
|
||||
crud.create_cycle_daily_accrual(
|
||||
session,
|
||||
cycle_id,
|
||||
accrual_data_1["accrual_date"],
|
||||
accrual_data_1["accrued_interest_cents"],
|
||||
)
|
||||
crud.create_cycle_daily_accrual(
|
||||
session,
|
||||
cycle_id,
|
||||
accrual_data_2["accrual_date"],
|
||||
accrual_data_2["accrued_interest_cents"],
|
||||
)
|
||||
|
||||
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today)
|
||||
accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(
|
||||
session, cycle_id, today
|
||||
)
|
||||
assert accruals_today is not None
|
||||
assert accruals_today.accrual_date == today
|
||||
assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
|
||||
assert (
|
||||
accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"]
|
||||
)
|
||||
|
||||
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday)
|
||||
accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(
|
||||
session, cycle_id, yesterday
|
||||
)
|
||||
assert accruals_yesterday is not None
|
||||
assert accruals_yesterday.accrual_date == yesterday
|
||||
assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"]
|
||||
assert (
|
||||
accruals_yesterday.accrual_amount_cents
|
||||
== accrual_data_1["accrued_interest_cents"]
|
||||
)
|
||||
|
||||
|
||||
# Exchanges
|
||||
@@ -1031,7 +1158,9 @@ def test_update_user_immutable_fields(session: Session) -> None:
|
||||
update_data = {
|
||||
"id": user_id + 1, # Trying to change the ID
|
||||
"username": "newusername", # Trying to change the username
|
||||
"created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at
|
||||
"created_at": datetime(
|
||||
2020, 1, 1, tzinfo=timezone.utc
|
||||
), # Trying to change created_at
|
||||
"password_hash": "validupdate", # Valid field to update
|
||||
}
|
||||
|
||||
@@ -1065,7 +1194,9 @@ def test_create_login_session_with_invalid_user(session: Session) -> None:
|
||||
def test_get_login_session_by_token_and_user_id(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
assert fetched_session is not None
|
||||
assert fetched_session.id == created_session.id
|
||||
assert fetched_session.user_id == created_session.user_id
|
||||
@@ -1075,7 +1206,9 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None:
|
||||
def test_get_login_session_by_token(session: Session) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
created_session = make_login_session(session, now)
|
||||
fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash)
|
||||
fetched_session = crud.get_login_session_by_token_hash(
|
||||
session, created_session.session_token_hash
|
||||
)
|
||||
assert fetched_session is not None
|
||||
assert fetched_session.id == created_session.id
|
||||
assert fetched_session.user_id == created_session.user_id
|
||||
@@ -1090,9 +1223,13 @@ def test_update_login_session(session: Session) -> None:
|
||||
"last_seen_at": now + timedelta(hours=1),
|
||||
"last_used_ip": "192.168.1.1",
|
||||
}
|
||||
updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data)
|
||||
updated_session = crud.update_login_session(
|
||||
session, created_session.session_token_hash, update_data
|
||||
)
|
||||
assert updated_session is not None
|
||||
assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
|
||||
assert (
|
||||
_ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"]
|
||||
)
|
||||
assert updated_session.last_used_ip == update_data["last_used_ip"]
|
||||
|
||||
|
||||
@@ -1101,5 +1238,7 @@ def test_delete_login_session(session: Session) -> None:
|
||||
created_session = make_login_session(session, now)
|
||||
|
||||
crud.delete_login_session(session, created_session.session_token_hash)
|
||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id)
|
||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
assert deleted_session is None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -340,6 +340,37 @@ def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def get_loan_event_by_cycle_id_and_effective_date(session: Session, cycle_id: int, effective_date: date) -> models.CycleLoanChangeEvents | None:
|
||||
statement = select(models.CycleLoanChangeEvents).where(
|
||||
models.CycleLoanChangeEvents.cycle_id == cycle_id,
|
||||
models.CycleLoanChangeEvents.effective_date == effective_date,
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def update_cycle_loan_event(session: Session, event_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents:
|
||||
event: models.CycleLoanChangeEvents | None = session.get(models.CycleLoanChangeEvents, event_id)
|
||||
if event is None:
|
||||
raise ValueError("event_id does not exist")
|
||||
data = _data_to_dict(update_data)
|
||||
|
||||
allowed = _allowed_columns(models.CycleLoanChangeEvents)
|
||||
for k, v in data.items():
|
||||
if k in {"id", "cycle_id", "effective_date", "created_at"}:
|
||||
raise ValueError(f"field {k!r} is immutable")
|
||||
if k not in allowed:
|
||||
continue
|
||||
setattr(event, k, v)
|
||||
session.add(event)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("update_cycle_loan_event integrity error") from e
|
||||
session.refresh(event)
|
||||
return event
|
||||
|
||||
|
||||
def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual:
|
||||
cycle = session.get(models.Cycles, cycle_id)
|
||||
if cycle is None:
|
||||
@@ -583,7 +614,11 @@ def get_login_session_by_token_hash(session: Session, session_token_hash: str) -
|
||||
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
||||
|
||||
|
||||
def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None:
|
||||
def update_login_session(
|
||||
session: Session,
|
||||
session_token_hashed: str,
|
||||
update_session: Mapping[str, Any] | BaseModel,
|
||||
) -> models.Sessions | None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hashed,
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime # noqa: TC003
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001
|
||||
from trading_journal.models import (
|
||||
TradeStrategy,
|
||||
TradeType,
|
||||
UnderlyingCurrency,
|
||||
)
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
@@ -90,6 +94,32 @@ class CycleRead(CycleCreate):
|
||||
id: int
|
||||
|
||||
|
||||
class CycleLoanChangeEventBase(SQLModel):
|
||||
cycle_id: int
|
||||
effective_date: date
|
||||
loan_amount_cents: int | None = None
|
||||
loan_interest_rate_tenth_bps: int | None = None
|
||||
related_trade_id: int | None = None
|
||||
notes: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CycleLoanChangeEventCreate(CycleLoanChangeEventBase):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CycleLoanChangeEventRead(CycleLoanChangeEventBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CycleInterestAccrualBase(SQLModel):
|
||||
cycle_id: int
|
||||
accrual_date: date
|
||||
accrual_amount_cents: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TradeBase(SQLModel):
|
||||
friendly_name: str | None = None
|
||||
symbol: str
|
||||
|
||||
@@ -69,33 +69,51 @@ class FundingSource(str, Enum):
|
||||
|
||||
class Trades(SQLModel, table=True):
|
||||
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
trade_time_utc: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expiry_date: date | None = Field(default=None, nullable=True)
|
||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||
quantity_multiplier: int = Field(
|
||||
sa_column=Column(Integer, nullable=False), default=1
|
||||
)
|
||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
is_invalidated: bool = Field(default=False, nullable=False)
|
||||
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||
invalidated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
replaced_by_trade_id: int | None = Field(
|
||||
default=None, foreign_key="trades.id", nullable=True
|
||||
)
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||
cycle_id: int | None = Field(
|
||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||
|
||||
@@ -107,15 +125,23 @@ class Trades(SQLModel, table=True):
|
||||
|
||||
class Cycles(SQLModel, table=True):
|
||||
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||
@@ -127,23 +153,51 @@ class Cycles(SQLModel, table=True):
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||
|
||||
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
latest_interest_accrued_date: date | None = Field(
|
||||
default=None, sa_column=Column(Date, nullable=True)
|
||||
)
|
||||
total_accrued_amount_cents: int = Field(
|
||||
default=0, sa_column=Column(Integer, nullable=False)
|
||||
)
|
||||
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(
|
||||
back_populates="cycle"
|
||||
)
|
||||
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||
|
||||
|
||||
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
cycle_id: int = Field(
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
ForeignKey("cycles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
)
|
||||
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||
loan_amount_cents: int | None = Field(
|
||||
default=None, sa_column=Column(Integer, nullable=True)
|
||||
)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(
|
||||
default=None, sa_column=Column(Integer, nullable=True)
|
||||
)
|
||||
related_trade_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True),
|
||||
) # Not used for now.
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||
@@ -151,20 +205,35 @@ class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
|
||||
class CycleDailyAccrual(SQLModel, table=True):
|
||||
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
cycle_id: int = Field(
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
ForeignKey("cycles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
)
|
||||
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||
|
||||
|
||||
class Exchanges(SQLModel, table=True):
|
||||
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),
|
||||
)
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||
@@ -190,10 +259,18 @@ class Sessions(SQLModel, table=True):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user: "Users" = Relationship(back_populates="sessions")
|
||||
|
||||
@@ -69,33 +69,51 @@ class FundingSource(str, Enum):
|
||||
|
||||
class Trades(SQLModel, table=True):
|
||||
__tablename__ = "trades" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_trades_user_friendly_name"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="trades")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
|
||||
trade_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
trade_time_utc: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expiry_date: date | None = Field(default=None, nullable=True)
|
||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||
quantity: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1)
|
||||
quantity_multiplier: int = Field(
|
||||
sa_column=Column(Integer, nullable=False), default=1
|
||||
)
|
||||
price_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
is_invalidated: bool = Field(default=False, nullable=False)
|
||||
invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True)
|
||||
invalidated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
replaced_by_trade_id: int | None = Field(
|
||||
default=None, foreign_key="trades.id", nullable=True
|
||||
)
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True)
|
||||
cycle_id: int | None = Field(
|
||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||
|
||||
@@ -107,15 +125,23 @@ class Trades(SQLModel, table=True):
|
||||
|
||||
class Cycles(SQLModel, table=True):
|
||||
__tablename__ = "cycles" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id", "friendly_name", name="uq_cycles_user_friendly_name"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
friendly_name: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||
exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True)
|
||||
exchange: "Exchanges" = Relationship(back_populates="cycles")
|
||||
underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False))
|
||||
underlying_currency: UnderlyingCurrency = Field(
|
||||
sa_column=Column(Text, nullable=False)
|
||||
)
|
||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||
@@ -127,23 +153,51 @@ class Cycles(SQLModel, table=True):
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True)
|
||||
|
||||
latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||
total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
latest_interest_accrued_date: date | None = Field(
|
||||
default=None, sa_column=Column(Date, nullable=True)
|
||||
)
|
||||
total_accrued_amount_cents: int = Field(
|
||||
default=0, sa_column=Column(Integer, nullable=False)
|
||||
)
|
||||
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle")
|
||||
loan_change_events: list["CycleLoanChangeEvents"] = Relationship(
|
||||
back_populates="cycle"
|
||||
)
|
||||
daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle")
|
||||
|
||||
|
||||
class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
__tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined]
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
cycle_id: int = Field(
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
ForeignKey("cycles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
)
|
||||
effective_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True))
|
||||
related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True))
|
||||
loan_amount_cents: int | None = Field(
|
||||
default=None, sa_column=Column(Integer, nullable=True)
|
||||
)
|
||||
loan_interest_rate_tenth_bps: int | None = Field(
|
||||
default=None, sa_column=Column(Integer, nullable=True)
|
||||
)
|
||||
related_trade_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True),
|
||||
) # Not used for now.
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="loan_change_events")
|
||||
trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event")
|
||||
@@ -151,20 +205,35 @@ class CycleLoanChangeEvents(SQLModel, table=True):
|
||||
|
||||
class CycleDailyAccrual(SQLModel, table=True):
|
||||
__tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True))
|
||||
cycle_id: int = Field(
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
ForeignKey("cycles.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
)
|
||||
accrual_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||
accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
|
||||
cycle: "Cycles" = Relationship(back_populates="daily_accruals")
|
||||
|
||||
|
||||
class Exchanges(SQLModel, table=True):
|
||||
__tablename__ = "exchanges" # type: ignore[attr-defined]
|
||||
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),
|
||||
)
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
name: str = Field(sa_column=Column(Text, nullable=False))
|
||||
@@ -190,10 +259,18 @@ class Sessions(SQLModel, table=True):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True))
|
||||
last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
|
||||
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
user: "Users" = Relationship(back_populates="sessions")
|
||||
|
||||
@@ -13,6 +13,7 @@ from trading_journal import crud, security
|
||||
from trading_journal.dto import (
|
||||
CycleBase,
|
||||
CycleCreate,
|
||||
CycleLoanChangeEventBase,
|
||||
CycleRead,
|
||||
CycleUpdate,
|
||||
ExchangesBase,
|
||||
@@ -26,6 +27,17 @@ from trading_journal.dto import (
|
||||
UserLogin,
|
||||
UserRead,
|
||||
)
|
||||
from trading_journal.service_error import (
|
||||
CycleLoanEventExistsError,
|
||||
CycleNotFoundError,
|
||||
ExchangeAlreadyExistsError,
|
||||
ExchangeNotFoundError,
|
||||
InvalidCycleDataError,
|
||||
InvalidTradeDataError,
|
||||
ServiceError,
|
||||
TradeNotFoundError,
|
||||
UserAlreadyExistsError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel import Session
|
||||
@@ -62,20 +74,32 @@ class AuthMiddleWare(BaseHTTPMiddleware):
|
||||
|
||||
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
|
||||
if db_factory is None:
|
||||
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "db factory not configured"},
|
||||
)
|
||||
try:
|
||||
with db_factory.get_session_ctx_manager() as request_session:
|
||||
hashed_token = security.hash_session_token_sha256(token)
|
||||
request.state.db_session = request_session
|
||||
login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token)
|
||||
if not login_session:
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Unauthorized"},
|
||||
)
|
||||
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
|
||||
if session_expires_utc < datetime.now(timezone.utc):
|
||||
crud.delete_login_session(request_session, login_session.session_token_hash)
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Unauthorized"},
|
||||
)
|
||||
if login_session.user.is_active is False:
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Unauthorized"},
|
||||
)
|
||||
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
|
||||
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
|
||||
else:
|
||||
@@ -91,43 +115,14 @@ class AuthMiddleWare(BaseHTTPMiddleware):
|
||||
crud.update_login_session(request_session, hashed_token, update_session=updated_session)
|
||||
except Exception:
|
||||
logger.exception("Failed to authenticate user: \n")
|
||||
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class ServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CycleNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class TradeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTradeDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCycleDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
# User service
|
||||
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
|
||||
if crud.get_user_by_username(db_session, user_in.username):
|
||||
@@ -203,7 +198,13 @@ def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[Exc
|
||||
return [ExchangesRead.model_validate(exchange) for exchange in exchanges]
|
||||
|
||||
|
||||
def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase:
|
||||
def update_exchanges_service(
|
||||
db_session: Session,
|
||||
user_id: int,
|
||||
exchange_id: int,
|
||||
name: str | None,
|
||||
notes: str | None,
|
||||
) -> ExchangesBase:
|
||||
existing_exchange = crud.get_exchange_by_id(db_session, exchange_id)
|
||||
if not existing_exchange:
|
||||
raise ExchangeNotFoundError("Exchange not found")
|
||||
@@ -229,6 +230,7 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int
|
||||
|
||||
# Cycle Service
|
||||
def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead:
|
||||
raise NotImplementedError("Cycle creation not implemented")
|
||||
cycle_data_dict = cycle_data.model_dump()
|
||||
cycle_data_dict["user_id"] = user_id
|
||||
cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)
|
||||
@@ -262,7 +264,10 @@ def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: #
|
||||
and cycle_data.funding_source != "CASH"
|
||||
and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None)
|
||||
):
|
||||
return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH"
|
||||
return (
|
||||
False,
|
||||
"loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH",
|
||||
)
|
||||
if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0:
|
||||
return False, "loan_amount_cents must be non-negative"
|
||||
if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0:
|
||||
@@ -270,6 +275,31 @@ def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: #
|
||||
return True, ""
|
||||
|
||||
|
||||
def _create_cycle_loan_event(
|
||||
db_session: Session,
|
||||
cycle_id: int,
|
||||
loan_amount_cents: int | None,
|
||||
loan_interest_rate_tenth_bps: int | None,
|
||||
) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.date()
|
||||
existing_loan_event = crud.get_loan_event_by_cycle_id_and_effective_date(db_session, cycle_id, today)
|
||||
if existing_loan_event:
|
||||
raise CycleLoanEventExistsError("A loan event with the same effective_date already exists for this cycle.")
|
||||
loan_event_data = CycleLoanChangeEventBase(
|
||||
cycle_id=cycle_id,
|
||||
effective_date=today,
|
||||
loan_amount_cents=loan_amount_cents,
|
||||
loan_interest_rate_tenth_bps=loan_interest_rate_tenth_bps,
|
||||
created_at=now,
|
||||
)
|
||||
try:
|
||||
crud.create_cycle_loan_event(db_session, loan_event_data)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create cycle loan event: \n")
|
||||
raise ServiceError("Failed to create cycle loan event") from e
|
||||
|
||||
|
||||
def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead:
|
||||
is_valid, err_msg = _validate_cycle_update_data(cycle_data)
|
||||
if not is_valid:
|
||||
@@ -280,6 +310,13 @@ def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpd
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
if existing_cycle.user_id != user_id:
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
if cycle_data.loan_amount_cents is not None or cycle_data.loan_interest_rate_tenth_bps is not None:
|
||||
_create_cycle_loan_event(
|
||||
db_session,
|
||||
cycle_id,
|
||||
cycle_data.loan_amount_cents,
|
||||
cycle_data.loan_interest_rate_tenth_bps,
|
||||
)
|
||||
|
||||
provided_data_dict = cycle_data.model_dump(exclude_unset=True)
|
||||
cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict)
|
||||
@@ -292,10 +329,42 @@ def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpd
|
||||
return CycleRead.model_validate(updated_cycle)
|
||||
|
||||
|
||||
def accrual_interest_service(db_session: Session, cycle_id: int) -> None:
|
||||
cycle = crud.get_cycle_by_id(db_session, cycle_id)
|
||||
if not cycle:
|
||||
logger.exception("Cycle not found for interest accrual")
|
||||
raise CycleNotFoundError("Cycle not found")
|
||||
if cycle.loan_amount_cents is None or cycle.loan_interest_rate_tenth_bps is None:
|
||||
logger.info("Cycle has no loan, skipping interest accrual")
|
||||
return
|
||||
today = datetime.now(timezone.utc).date()
|
||||
amount_cents = round(cycle.loan_amount_cents * cycle.loan_interest_rate_tenth_bps / 100000 / 365)
|
||||
try:
|
||||
crud.create_cycle_daily_accrual(
|
||||
db_session,
|
||||
cycle_id=cycle_id,
|
||||
accrual_date=today,
|
||||
accrual_amount_cents=amount_cents,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create cycle interest accrual: \n")
|
||||
raise ServiceError("Failed to create cycle interest accrual") from e
|
||||
|
||||
|
||||
def flush_interest_accruals_service(db_session: Session) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# Trades service
|
||||
def _append_cashflows(trade_data: TradeCreate) -> TradeCreate:
|
||||
sign_multipler: int
|
||||
if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"):
|
||||
if trade_data.trade_type in (
|
||||
"SELL_PUT",
|
||||
"SELL_CALL",
|
||||
"EXERCISE_CALL",
|
||||
"CLOSE_LONG_SPOT",
|
||||
"SHORT_SPOT",
|
||||
):
|
||||
sign_multipler = 1
|
||||
else:
|
||||
sign_multipler = -1
|
||||
|
||||
34
backend/trading_journal/service_error.py
Normal file
34
backend/trading_journal/service_error.py
Normal file
@@ -0,0 +1,34 @@
|
||||
class ServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeAlreadyExistsError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CycleNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class TradeNotFoundError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTradeDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCycleDataError(ServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CycleLoanEventExistsError(ServiceError):
|
||||
pass
|
||||
Reference in New Issue
Block a user