feature/api_endpoint #5

Merged
tliu93 merged 18 commits from feature/api_endpoint into main 2025-10-01 15:55:47 +02:00
8 changed files with 132 additions and 103 deletions
Showing only changes of commit 92c4e0d4fc - Show all commits

View File

@@ -11,5 +11,6 @@
"tests" "tests"
], ],
"python.testing.unittestEnabled": false, "python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true "python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "standard",
} }

View File

@@ -2,18 +2,22 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request, status from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, Response
import settings import settings
from trading_journal import db, service from trading_journal import db, service
from trading_journal.db import Database
from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from trading_journal.db import Database
_db = db.create_database(settings.settings.database_url) _db = db.create_database(settings.settings.database_url)
logging.basicConfig( logging.basicConfig(
@@ -43,7 +47,7 @@ async def get_status() -> dict[str, str]:
@app.post(f"{settings.settings.api_base}/register") @app.post(f"{settings.settings.api_base}/register")
async def register_user(request: Request, user_in: UserCreate) -> UserRead: async def register_user(request: Request, user_in: UserCreate) -> Response:
db_factory: Database = request.app.state.db_factory db_factory: Database = request.app.state.db_factory
def sync_work() -> UserRead: def sync_work() -> UserRead:
@@ -61,7 +65,7 @@ async def register_user(request: Request, user_in: UserCreate) -> UserRead:
@app.post(f"{settings.settings.api_base}/login") @app.post(f"{settings.settings.api_base}/login")
async def login(request: Request, user_in: UserLogin) -> SessionsBase: async def login(request: Request, user_in: UserLogin) -> Response:
db_factory: Database = request.app.state.db_factory db_factory: Database = request.app.state.db_factory
def sync_work() -> tuple[SessionsCreate, str] | None: def sync_work() -> tuple[SessionsCreate, str] | None:
@@ -97,7 +101,7 @@ async def login(request: Request, user_in: UserLogin) -> SessionsBase:
# Exchange # Exchange
@app.post(f"{settings.settings.api_base}/exchanges") @app.post(f"{settings.settings.api_base}/exchanges")
async def create_exchange(request: Request, exchange_data: ExchangesBase) -> dict: async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Response:
db_factory: Database = request.app.state.db_factory db_factory: Database = request.app.state.db_factory
def sync_work() -> ExchangesBase: def sync_work() -> ExchangesBase:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
@@ -45,7 +45,7 @@ def make_user(session: Session, username: str = "testuser") -> int:
session.add(user) session.add(user)
session.commit() session.commit()
session.refresh(user) session.refresh(user)
return user.id return cast("int", user.id)
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
@@ -53,7 +53,7 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
session.add(exchange) session.add(exchange)
session.commit() session.commit()
session.refresh(exchange) session.refresh(exchange)
return exchange.id return cast("int", exchange.id)
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int: def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
@@ -65,15 +65,16 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name:
underlying_currency=models.UnderlyingCurrency.USD, underlying_currency=models.UnderlyingCurrency.USD,
status=models.CycleStatus.OPEN, status=models.CycleStatus.OPEN,
start_date=datetime.now(timezone.utc).date(), start_date=datetime.now(timezone.utc).date(),
) ) # type: ignore[arg-type]
session.add(cycle) session.add(cycle)
session.commit() session.commit()
session.refresh(cycle) session.refresh(cycle)
return cycle.id return cast("int", cycle.id)
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int: def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
cycle: models.Cycles = session.get(models.Cycles, cycle_id) cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
assert cycle is not None
exchange_id = cycle.exchange_id exchange_id = cycle.exchange_id
trade = models.Trades( trade = models.Trades(
user_id=user_id, user_id=user_id,
@@ -96,7 +97,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str
session.add(trade) session.add(trade)
session.commit() session.commit()
session.refresh(trade) session.refresh(trade)
return trade.id return cast("int", trade.id)
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int: def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
@@ -104,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
session.add(trade) session.add(trade)
session.commit() session.commit()
session.refresh(trade) session.refresh(trade)
return trade.id return cast("int", trade.id)
def make_login_session(session: Session, created_at: datetime) -> models.Sessions: def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
@@ -128,7 +129,7 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
return login_session return login_session
def _ensure_utc_aware(dt: datetime) -> datetime | None: def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
if dt is None: if dt is None:
return None return None
if dt.tzinfo is None: if dt.tzinfo is None:
@@ -219,7 +220,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
assert auto_cycle.symbol == trade_data["symbol"] assert auto_cycle.symbol == trade_data["symbol"]
assert auto_cycle.underlying_currency == trade_data["underlying_currency"] assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
assert auto_cycle.status == models.CycleStatus.OPEN assert auto_cycle.status == models.CycleStatus.OPEN
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
def test_create_trade_missing_required_fields(session: Session) -> None: def test_create_trade_missing_required_fields(session: Session) -> None:

View File

@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, TypeVar, cast
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select from sqlmodel import Session, select
@@ -10,9 +11,14 @@ from trading_journal import models
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum
def _check_enum(enum_cls: any, value: any, field_name: str) -> any: # Generic enum member type
T = TypeVar("T", bound="Enum")
def _check_enum(enum_cls: type[T], value: object, field_name: str) -> T:
if value is None: if value is None:
raise ValueError(f"{field_name} is required") raise ValueError(f"{field_name} is required")
# already an enum member # already an enum member
@@ -27,19 +33,41 @@ def _check_enum(enum_cls: any, value: any, field_name: str) -> any:
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}") raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
def _allowed_columns(model: type[models.SQLModel]) -> set[str]:
tbl = cast("models.SQLModel", model).__table__ # type: ignore[attr-defined]
return {c.name for c in tbl.columns}
AnyModel = Any
def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]:
if isinstance(data, BaseModel):
return data.model_dump(exclude_unset=True)
if hasattr(data, "dict"):
return data.dict(exclude_unset=True)
return dict(data)
# Trades # Trades
def create_trade(session: Session, trade_data: Mapping) -> models.Trades: def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
if hasattr(trade_data, "dict"): data = _data_to_dict(trade_data)
data = trade_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Trades)
else:
data = dict(trade_data)
allowed = {c.name for c in models.Trades.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed} payload = {k: v for k, v in data.items() if k in allowed}
cycle_id = payload.get("cycle_id") cycle_id = payload.get("cycle_id")
if "symbol" not in payload: if "symbol" not in payload:
raise ValueError("symbol is required") raise ValueError("symbol is required")
if "exchange_id" not in payload and cycle_id is None: if "exchange_id" not in payload and cycle_id is None:
raise ValueError("exchange_id is required when no cycle is attached") raise ValueError("exchange_id is required when no cycle is attached")
# If an exchange_id is provided (and no cycle is attached), ensure the exchange exists
# and belongs to the same user as the trade (if user_id is provided).
if cycle_id is None and "exchange_id" in payload:
ex = session.get(models.Exchanges, payload["exchange_id"])
if ex is None:
raise ValueError("exchange_id does not exist")
user_id = payload.get("user_id")
if user_id is not None and ex.user_id != user_id:
raise ValueError("exchange.user_id does not match trade.user_id")
if "underlying_currency" not in payload: if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required") raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
@@ -132,7 +160,7 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]
statement = select(models.Trades).where( statement = select(models.Trades).where(
models.Trades.user_id == user_id, models.Trades.user_id == user_id,
) )
return session.exec(statement).all() return list(session.exec(statement).all())
def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades: def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades:
@@ -168,23 +196,17 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades:
return trade return trade
def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping) -> models.Trades: def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades:
invalidate_trade(session, old_trade_id) invalidate_trade(session, old_trade_id)
if hasattr(new_trade_data, "dict"): data = _data_to_dict(new_trade_data)
data = new_trade_data.dict(exclude_unset=True)
else:
data = dict(new_trade_data)
data["replaced_by_trade_id"] = old_trade_id data["replaced_by_trade_id"] = old_trade_id
return create_trade(session, data) return create_trade(session, data)
# Cycles # Cycles
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
if hasattr(cycle_data, "dict"): data = _data_to_dict(cycle_data)
data = cycle_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Cycles)
else:
data = dict(cycle_data)
allowed = {c.name for c in models.Cycles.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed} payload = {k: v for k, v in data.items() if k in allowed}
if "user_id" not in payload: if "user_id" not in payload:
raise ValueError("user_id is required") raise ValueError("user_id is required")
@@ -192,6 +214,12 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
raise ValueError("symbol is required") raise ValueError("symbol is required")
if "exchange_id" not in payload: if "exchange_id" not in payload:
raise ValueError("exchange_id is required") raise ValueError("exchange_id is required")
# ensure the exchange exists and belongs to the same user
ex = session.get(models.Exchanges, payload["exchange_id"])
if ex is None:
raise ValueError("exchange_id does not exist")
if ex.user_id != payload.get("user_id"):
raise ValueError("exchange.user_id does not match cycle.user_id")
if "underlying_currency" not in payload: if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required") raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency")
@@ -215,21 +243,26 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"}
def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> models.Cycles: def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles:
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
if cycle is None: if cycle is None:
raise ValueError("cycle_id does not exist") raise ValueError("cycle_id does not exist")
if hasattr(update_data, "dict"): data = _data_to_dict(update_data)
data = update_data.dict(exclude_unset=True)
else:
data = dict(update_data)
allowed = {c.name for c in models.Cycles.__table__.columns} allowed = _allowed_columns(models.Cycles)
for k, v in data.items(): for k, v in data.items():
if k in IMMUTABLE_CYCLE_FIELDS: if k in IMMUTABLE_CYCLE_FIELDS:
raise ValueError(f"field {k!r} is immutable") raise ValueError(f"field {k!r} is immutable")
if k not in allowed: if k not in allowed:
continue continue
# If trying to change exchange_id, ensure the new exchange exists and belongs to
# the same user as the cycle.
if k == "exchange_id":
ex = session.get(models.Exchanges, v)
if ex is None:
raise ValueError("exchange_id does not exist")
if ex.user_id != cycle.user_id:
raise ValueError("exchange.user_id does not match cycle.user_id")
if k == "underlying_currency": if k == "underlying_currency":
v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901 v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901
if k == "status": if k == "status":
@@ -249,12 +282,9 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model
IMMUTABLE_EXCHANGE_FIELDS = {"id"} IMMUTABLE_EXCHANGE_FIELDS = {"id"}
def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges: def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
if hasattr(exchange_data, "dict"): data = _data_to_dict(exchange_data)
data = exchange_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Exchanges)
else:
data = dict(exchange_data)
allowed = {c.name for c in models.Exchanges.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed} payload = {k: v for k, v in data.items() if k in allowed}
if "name" not in payload: if "name" not in payload:
raise ValueError("name is required") raise ValueError("name is required")
@@ -284,25 +314,22 @@ def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int)
def get_all_exchanges(session: Session) -> list[models.Exchanges]: def get_all_exchanges(session: Session) -> list[models.Exchanges]:
statement = select(models.Exchanges) statement = select(models.Exchanges)
return session.exec(statement).all() return list(session.exec(statement).all())
def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]:
statement = select(models.Exchanges).where( statement = select(models.Exchanges).where(
models.Exchanges.user_id == user_id, models.Exchanges.user_id == user_id,
) )
return session.exec(statement).all() return list(session.exec(statement).all())
def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None: if exchange is None:
raise ValueError("exchange_id does not exist") raise ValueError("exchange_id does not exist")
if hasattr(update_data, "dict"): data = _data_to_dict(update_data)
data = update_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Exchanges)
else:
data = dict(update_data)
allowed = {c.name for c in models.Exchanges.__table__.columns}
for k, v in data.items(): for k, v in data.items():
if k in IMMUTABLE_EXCHANGE_FIELDS: if k in IMMUTABLE_EXCHANGE_FIELDS:
raise ValueError(f"field {k!r} is immutable") raise ValueError(f"field {k!r} is immutable")
@@ -334,12 +361,9 @@ def delete_exchange(session: Session, exchange_id: int) -> None:
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
def create_user(session: Session, user_data: Mapping) -> models.Users: def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users:
if hasattr(user_data, "dict"): data = _data_to_dict(user_data)
data = user_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Users)
else:
data = dict(user_data)
allowed = {c.name for c in models.Users.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed} payload = {k: v for k, v in data.items() if k in allowed}
if "username" not in payload: if "username" not in payload:
raise ValueError("username is required") raise ValueError("username is required")
@@ -368,15 +392,12 @@ def get_user_by_username(session: Session, username: str) -> models.Users | None
return session.exec(statement).first() return session.exec(statement).first()
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users:
user: models.Users | None = session.get(models.Users, user_id) user: models.Users | None = session.get(models.Users, user_id)
if user is None: if user is None:
raise ValueError("user_id does not exist") raise ValueError("user_id does not exist")
if hasattr(update_data, "dict"): data = _data_to_dict(update_data)
data = update_data.dict(exclude_unset=True) allowed = _allowed_columns(models.Users)
else:
data = dict(update_data)
allowed = {c.name for c in models.Users.__table__.columns}
for k, v in data.items(): for k, v in data.items():
if k in IMMUTABLE_USER_FIELDS: if k in IMMUTABLE_USER_FIELDS:
raise ValueError(f"field {k!r} is immutable") raise ValueError(f"field {k!r} is immutable")
@@ -405,10 +426,11 @@ def create_login_session(
user: models.Users | None = session.get(models.Users, user_id) user: models.Users | None = session.get(models.Users, user_id)
if user is None: if user is None:
raise ValueError("user_id does not exist") raise ValueError("user_id does not exist")
user_id_val = cast("int", user.id)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=session_length_seconds) expires_at = now + timedelta(seconds=session_length_seconds)
s = models.Sessions( s = models.Sessions(
user_id=user.id, user_id=user_id_val,
session_token_hash=session_token_hash, session_token_hash=session_token_hash,
created_at=now, created_at=now,
expires_at=expires_at, expires_at=expires_at,
@@ -449,7 +471,7 @@ def get_login_session_by_token_hash(session: Session, session_token_hash: str) -
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping) -> models.Sessions | None: def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None:
login_session: models.Sessions | None = session.exec( login_session: models.Sessions | None = session.exec(
select(models.Sessions).where( select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hashed, models.Sessions.session_token_hash == session_token_hashed,
@@ -458,11 +480,8 @@ def update_login_session(session: Session, session_token_hashed: str, update_ses
).first() ).first()
if login_session is None: if login_session is None:
return None return None
if hasattr(update_session, "dict"): data = _data_to_dict(update_session)
data = update_session.dict(exclude_unset=True) allowed = _allowed_columns(models.Sessions)
else:
data = dict(update_session)
allowed = {c.name for c in models.Sessions.__table__.columns}
for k, v in data.items(): for k, v in data.items():
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS: if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
setattr(login_session, k, v) setattr(login_session, k, v)

View File

@@ -23,11 +23,11 @@ def _mig_0_1(engine: Engine) -> None:
SQLModel.metadata.create_all( SQLModel.metadata.create_all(
bind=engine, bind=engine,
tables=[ tables=[
models_v1.Trades.__table__, models_v1.Trades.__table__, # type: ignore[attr-defined]
models_v1.Cycles.__table__, models_v1.Cycles.__table__, # type: ignore[attr-defined]
models_v1.Users.__table__, models_v1.Users.__table__, # type: ignore[attr-defined]
models_v1.Sessions.__table__, models_v1.Sessions.__table__, # type: ignore[attr-defined]
models_v1.Exchanges.__table__, models_v1.Exchanges.__table__, # type: ignore[attr-defined]
], ],
) )

View File

@@ -64,7 +64,7 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True): class Trades(SQLModel, table=True):
__tablename__ = "trades" __tablename__ = "trades" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
@@ -95,7 +95,7 @@ class Trades(SQLModel, table=True):
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" __tablename__ = "cycles" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
@@ -116,7 +116,7 @@ class Cycles(SQLModel, table=True):
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" __tablename__ = "exchanges" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
@@ -128,7 +128,7 @@ class Exchanges(SQLModel, table=True):
class Users(SQLModel, table=True): class Users(SQLModel, table=True):
__tablename__ = "users" __tablename__ = "users" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True # unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
@@ -139,7 +139,7 @@ class Users(SQLModel, table=True):
class Sessions(SQLModel, table=True): class Sessions(SQLModel, table=True):
__tablename__ = "sessions" __tablename__ = "sessions" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))

View File

@@ -64,7 +64,7 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True): class Trades(SQLModel, table=True):
__tablename__ = "trades" __tablename__ = "trades" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
@@ -95,7 +95,7 @@ class Trades(SQLModel, table=True):
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" __tablename__ = "cycles" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
@@ -116,7 +116,7 @@ class Cycles(SQLModel, table=True):
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" __tablename__ = "exchanges" # type: ignore[attr-defined]
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
@@ -128,7 +128,7 @@ class Exchanges(SQLModel, table=True):
class Users(SQLModel, table=True): class Users(SQLModel, table=True):
__tablename__ = "users" __tablename__ = "users" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True # unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
@@ -139,7 +139,7 @@ class Users(SQLModel, table=True):
class Sessions(SQLModel, table=True): class Sessions(SQLModel, table=True):
__tablename__ = "sessions" __tablename__ = "sessions" # type: ignore[attr-defined]
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))

View File

@@ -2,21 +2,24 @@ from __future__ import annotations
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Callable from typing import TYPE_CHECKING, cast
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlmodel import Session from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.base import BaseHTTPMiddleware
import settings import settings
from trading_journal import crud, security from trading_journal import crud, security
from trading_journal.db import Database
from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
from trading_journal.models import Sessions
SessionsCreate.model_rebuild() SessionsCreate.model_rebuild()
if TYPE_CHECKING:
from sqlmodel import Session
from trading_journal.db import Database
from trading_journal.models import Sessions
EXCEPT_PATHS = [ EXCEPT_PATHS = [
f"{settings.settings.api_base}/status", f"{settings.settings.api_base}/status",
@@ -28,7 +31,7 @@ logger = logging.getLogger(__name__)
class AuthMiddleWare(BaseHTTPMiddleware): class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911 async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911
if request.url.path in EXCEPT_PATHS: if request.url.path in EXCEPT_PATHS:
return await call_next(request) return await call_next(request)
@@ -51,12 +54,12 @@ class AuthMiddleWare(BaseHTTPMiddleware):
with db_factory.get_session_ctx_manager() as request_session: with db_factory.get_session_ctx_manager() as request_session:
hashed_token = security.hash_session_token_sha256(token) hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token)
if not login_session: if not login_session:
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
if session_expires_utc < datetime.now(timezone.utc): if session_expires_utc < datetime.now(timezone.utc):
crud.delete_login_session(request.state.db_session, login_session) crud.delete_login_session(request_session, login_session.session_token_hash)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
if login_session.user.is_active is False: if login_session.user.is_active is False:
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
@@ -72,7 +75,7 @@ class AuthMiddleWare(BaseHTTPMiddleware):
) )
user_id = login_session.user_id user_id = login_session.user_id
request.state.user_id = user_id request.state.user_id = user_id
crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session) crud.update_login_session(request_session, hashed_token, update_session=updated_session)
except Exception: except Exception:
logger.exception("Failed to authenticate user: \n") logger.exception("Failed to authenticate user: \n")
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
@@ -106,7 +109,7 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
# prefer pydantic's from_orm if DTO supports orm_mode # prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user) user = UserRead.model_validate(user)
except Exception as e: except Exception as e:
logger.exception("Failed to convert user to UserRead: %s", e) logger.exception("Failed to convert user to UserRead: ")
raise ServiceError("Failed to convert user to UserRead") from e raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e: except Exception as e:
logger.exception("Failed to create user:") logger.exception("Failed to create user:")
@@ -118,6 +121,7 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[
user = crud.get_user_by_username(db_session, user_in.username) user = crud.get_user_by_username(db_session, user_in.username)
if not user: if not user:
return None return None
user_id_val = cast("int", user.id)
if not security.verify_password(user_in.password, user.password_hash): if not security.verify_password(user_in.password, user.password_hash):
return None return None
@@ -127,7 +131,7 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[
try: try:
session = crud.create_login_session( session = crud.create_login_session(
session=db_session, session=db_session,
user_id=user.id, user_id=user_id_val,
session_token_hash=token_hashed, session_token_hash=token_hashed,
session_length_seconds=settings.settings.session_expiry_seconds, session_length_seconds=settings.settings.session_expiry_seconds,
) )