several changes:
All checks were successful
Backend CI / unit-test (push) Successful in 34s

* api calls for auth

* exchange now bind to user
This commit is contained in:
2025-09-22 22:51:59 +02:00
parent 466e6ce653
commit 1750401278
14 changed files with 259 additions and 45 deletions

4
backend/.gitignore vendored
View File

@@ -14,4 +14,6 @@ __pycache__/
*.db *.db
*.db-shm *.db-shm
*.db-wal *.db-wal
devsettings.yaml

View File

@@ -1,16 +1,27 @@
from __future__ import annotations
import asyncio import asyncio
import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse
import settings import settings
from trading_journal import db, service from trading_journal import db, service
from trading_journal.db import Database from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead from trading_journal.dto import SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead
_db = db.create_database(settings.settings.database_url) _db = db.create_database(settings.settings.database_url)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
@@ -40,13 +51,57 @@ async def register_user(request: Request, user_in: UserCreate) -> UserRead:
return service.register_user_service(db, user_in) return service.register_user_service(db, user_in)
try: try:
return await asyncio.to_thread(sync_work) user = await asyncio.to_thread(sync_work)
return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump())
except service.UserAlreadyExistsError as e: except service.UserAlreadyExistsError as e:
raise HTTPException(status_code=400, detail=str(e)) from e raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e logger.exception("Failed to register user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
@app.post(f"{settings.settings.api_base}/login")
async def login(request: Request, user_in: UserLogin) -> SessionsBase:
db_factory: Database = request.app.state.db_factory
def sync_work() -> tuple[SessionsCreate, str] | None:
with db_factory.get_session_ctx_manager() as db:
return service.authenticate_user_service(db, user_in)
try:
result = await asyncio.to_thread(sync_work)
if result is None:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Invalid username or password, or user doesn't exist"},
)
session, token = result
session_return = SessionsBase(user_id=session.user_id)
response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump())
expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds())
response.set_cookie(
key="session_token",
value=token,
httponly=True,
secure=True,
samesite="lax",
max_age=expires_sec,
path="/",
)
except Exception as e:
logger.exception("Failed to login user: \n")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e
else:
return response
# Exchange
# @app.post(f"{settings.settings.api_base}/exchanges")
# async def create_exchange(request: Request, name: str, notes: str | None) -> dict:
@app.get(f"{settings.settings.api_base}/trades") @app.get(f"{settings.settings.api_base}/trades")
async def get_trades() -> dict[str, str]: async def get_trades(request: Request) -> list:
return {"trades": []} db_factory: Database = request.app.state.db_factory
with db_factory.get_session_ctx_manager() as db:
return service.get_trades_service(db, request.state.user_id)

View File

@@ -16,6 +16,7 @@ class Settings(BaseSettings):
log_level: str = "info" log_level: str = "info"
database_url: str = "sqlite:///:memory:" database_url: str = "sqlite:///:memory:"
api_base: str = "/api/v1" api_base: str = "/api/v1"
session_expiry_seconds: int = 3600 * 24 * 7 # 7 days
hmac_key: str | None = None hmac_key: str | None = None
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")

View File

@@ -48,8 +48,8 @@ def make_user(session: Session, username: str = "testuser") -> int:
return user.id return user.id
def make_exchange(session: Session, name: str = "NASDAQ") -> int: def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
exchange = models.Exchanges(name=name, notes="Test exchange") exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange")
session.add(exchange) session.add(exchange)
session.commit() session.commit()
session.refresh(exchange) session.refresh(exchange)
@@ -138,7 +138,7 @@ def _ensure_utc_aware(dt: datetime) -> datetime | None:
def test_create_trade_success_with_cycle(session: Session) -> None: def test_create_trade_success_with_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
trade_data = { trade_data = {
@@ -180,7 +180,7 @@ def test_create_trade_success_with_cycle(session: Session) -> None:
def test_create_trade_with_auto_created_cycle(session: Session) -> None: def test_create_trade_with_auto_created_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
trade_data = { trade_data = {
"user_id": user_id, "user_id": user_id,
@@ -224,7 +224,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
def test_create_trade_missing_required_fields(session: Session) -> None: def test_create_trade_missing_required_fields(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
base_trade_data = { base_trade_data = {
"user_id": user_id, "user_id": user_id,
@@ -291,7 +291,7 @@ def test_create_trade_missing_required_fields(session: Session) -> None:
def test_get_trade_by_id(session: Session) -> None: def test_get_trade_by_id(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
trade_data = { trade_data = {
"user_id": user_id, "user_id": user_id,
@@ -330,7 +330,7 @@ def test_get_trade_by_id(session: Session) -> None:
def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
friendly_name = "Unique Trade Name" friendly_name = "Unique Trade Name"
trade_data = { trade_data = {
@@ -359,7 +359,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None:
def test_get_trades_by_user_id(session: Session) -> None: def test_get_trades_by_user_id(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
trade_data_1 = { trade_data_1 = {
"user_id": user_id, "user_id": user_id,
@@ -406,7 +406,7 @@ def test_get_trades_by_user_id(session: Session) -> None:
def test_update_trade_note(session: Session) -> None: def test_update_trade_note(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
trade_id = make_trade(session, user_id, cycle_id) trade_id = make_trade(session, user_id, cycle_id)
@@ -424,7 +424,7 @@ def test_update_trade_note(session: Session) -> None:
def test_invalidate_trade(session: Session) -> None: def test_invalidate_trade(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
trade_id = make_trade(session, user_id, cycle_id) trade_id = make_trade(session, user_id, cycle_id)
@@ -441,7 +441,7 @@ def test_invalidate_trade(session: Session) -> None:
def test_replace_trade(session: Session) -> None: def test_replace_trade(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id) cycle_id = make_cycle(session, user_id, exchange_id)
old_trade_id = make_trade(session, user_id, cycle_id) old_trade_id = make_trade(session, user_id, cycle_id)
@@ -486,7 +486,7 @@ def test_replace_trade(session: Session) -> None:
def test_create_cycle(session: Session) -> None: def test_create_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_data = { cycle_data = {
"user_id": user_id, "user_id": user_id,
"friendly_name": "My First Cycle", "friendly_name": "My First Cycle",
@@ -517,7 +517,7 @@ def test_create_cycle(session: Session) -> None:
def test_update_cycle(session: Session) -> None: def test_update_cycle(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
update_data = { update_data = {
@@ -539,7 +539,7 @@ def test_update_cycle(session: Session) -> None:
def test_update_cycle_immutable_fields(session: Session) -> None: def test_update_cycle_immutable_fields(session: Session) -> None:
user_id = make_user(session) user_id = make_user(session)
exchange_id = make_exchange(session) exchange_id = make_exchange(session, user_id)
cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name")
# Attempt to update immutable fields # Attempt to update immutable fields
@@ -563,42 +563,51 @@ def test_update_cycle_immutable_fields(session: Session) -> None:
# Exchanges # Exchanges
def test_create_exchange(session: Session) -> None: def test_create_exchange(session: Session) -> None:
user_id = make_user(session)
exchange_data = { exchange_data = {
"name": "NYSE", "name": "NYSE",
"notes": "New York Stock Exchange", "notes": "New York Stock Exchange",
"user_id": user_id,
} }
exchange = crud.create_exchange(session, exchange_data) exchange = crud.create_exchange(session, exchange_data)
assert exchange.id is not None assert exchange.id is not None
assert exchange.name == exchange_data["name"] assert exchange.name == exchange_data["name"]
assert exchange.notes == exchange_data["notes"] assert exchange.notes == exchange_data["notes"]
assert exchange.user_id == user_id
session.refresh(exchange) session.refresh(exchange)
actual_exchange = session.get(models.Exchanges, exchange.id) actual_exchange = session.get(models.Exchanges, exchange.id)
assert actual_exchange is not None assert actual_exchange is not None
assert actual_exchange.name == exchange_data["name"] assert actual_exchange.name == exchange_data["name"]
assert actual_exchange.notes == exchange_data["notes"] assert actual_exchange.notes == exchange_data["notes"]
assert actual_exchange.user_id == user_id
def test_get_exchange_by_id(session: Session) -> None: def test_get_exchange_by_id(session: Session) -> None:
exchange_id = make_exchange(session, name="LSE") user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="LSE")
exchange = crud.get_exchange_by_id(session, exchange_id) exchange = crud.get_exchange_by_id(session, exchange_id)
assert exchange is not None assert exchange is not None
assert exchange.id == exchange_id assert exchange.id == exchange_id
assert exchange.name == "LSE" assert exchange.name == "LSE"
assert exchange.user_id == user_id
def test_get_exchange_by_name(session: Session) -> None: def test_get_exchange_by_name_and_user_id(session: Session) -> None:
exchange_name = "TSX" exchange_name = "TSX"
make_exchange(session, name=exchange_name) user_id = make_user(session)
exchange = crud.get_exchange_by_name(session, exchange_name) make_exchange(session, user_id=user_id, name=exchange_name)
exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id)
assert exchange is not None assert exchange is not None
assert exchange.name == exchange_name assert exchange.name == exchange_name
assert exchange.user_id == user_id
def test_get_all_exchanges(session: Session) -> None: def test_get_all_exchanges(session: Session) -> None:
exchange_names = ["NYSE", "NASDAQ", "LSE"] exchange_names = ["NYSE", "NASDAQ", "LSE"]
user_id = make_user(session)
for name in exchange_names: for name in exchange_names:
make_exchange(session, name=name) make_exchange(session, user_id=user_id, name=name)
exchanges = crud.get_all_exchanges(session) exchanges = crud.get_all_exchanges(session)
assert len(exchanges) >= 3 assert len(exchanges) >= 3
@@ -607,8 +616,22 @@ def test_get_all_exchanges(session: Session) -> None:
assert name in fetched_names assert name in fetched_names
def test_get_all_exchanges_by_user_id(session: Session) -> None:
exchange_names = ["NYSE", "NASDAQ"]
user_id = make_user(session)
for name in exchange_names:
make_exchange(session, user_id=user_id, name=name)
exchanges = crud.get_all_exchanges_by_user_id(session, user_id)
assert len(exchanges) == len(exchange_names)
fetched_names = {ex.name for ex in exchanges}
for name in exchange_names:
assert name in fetched_names
def test_update_exchange(session: Session) -> None: def test_update_exchange(session: Session) -> None:
exchange_id = make_exchange(session, name="Initial Exchange") user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange")
update_data = { update_data = {
"name": "Updated Exchange", "name": "Updated Exchange",
"notes": "Updated notes for the exchange", "notes": "Updated notes for the exchange",
@@ -627,7 +650,8 @@ def test_update_exchange(session: Session) -> None:
def test_delete_exchange(session: Session) -> None: def test_delete_exchange(session: Session) -> None:
exchange_id = make_exchange(session, name="Deletable Exchange") user_id = make_user(session)
exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange")
crud.delete_exchange(session, exchange_id) crud.delete_exchange(session, exchange_id)
deleted_exchange = session.get(models.Exchanges, exchange_id) deleted_exchange = session.get(models.Exchanges, exchange_id)
assert deleted_exchange is None assert deleted_exchange is None

View File

@@ -70,6 +70,12 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"notes": ("TEXT", 0, 0), "notes": ("TEXT", 0, 0),
"cycle_id": ("INTEGER", 0, 0), "cycle_id": ("INTEGER", 0, 0),
}, },
"exchanges": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
"name": ("TEXT", 1, 0),
"notes": ("TEXT", 0, 0),
},
"sessions": { "sessions": {
"id": ("INTEGER", 1, 1), "id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0), "user_id": ("INTEGER", 1, 0),
@@ -97,7 +103,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
{"table": "users", "from": "user_id", "to": "id"}, {"table": "users", "from": "user_id", "to": "id"},
], ],
"users": [], "users": [],
"exchanges": [], "exchanges": [
{"table": "users", "from": "user_id", "to": "id"},
],
} }
with engine.connect() as conn: with engine.connect() as conn:

View File

@@ -274,9 +274,10 @@ def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges |
return session.get(models.Exchanges, exchange_id) return session.get(models.Exchanges, exchange_id)
def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None: def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None:
statement = select(models.Exchanges).where( statement = select(models.Exchanges).where(
models.Exchanges.name == name, models.Exchanges.name == name,
models.Exchanges.user_id == user_id,
) )
return session.exec(statement).first() return session.exec(statement).first()
@@ -286,6 +287,13 @@ def get_all_exchanges(session: Session) -> list[models.Exchanges]:
return session.exec(statement).all() return session.exec(statement).all()
def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]:
statement = select(models.Exchanges).where(
models.Exchanges.user_id == user_id,
)
return session.exec(statement).all()
def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None: if exchange is None:

View File

@@ -8,8 +8,6 @@ from sqlalchemy import event
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlmodel import Session, create_engine from sqlmodel import Session, create_engine
from trading_journal import db_migration
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
from sqlite3 import Connection as DBAPIConnection from sqlite3 import Connection as DBAPIConnection
@@ -59,7 +57,6 @@ class Database:
event.listen(self._engine, "connect", _enable_sqlite_pragmas) event.listen(self._engine, "connect", _enable_sqlite_pragmas)
def init_db(self) -> None: def init_db(self) -> None:
# db_migration.run_migrations(self._engine)
pass pass
def get_session(self) -> Generator[Session, None, None]: def get_session(self) -> Generator[Session, None, None]:
@@ -74,7 +71,7 @@ class Database:
session.close() session.close()
@contextmanager @contextmanager
def get_session_ctx_manager(self) -> Session: def get_session_ctx_manager(self) -> Generator[Session, None, None]:
session = Session(self._engine) session = Session(self._engine)
try: try:
yield session yield session

View File

@@ -27,6 +27,7 @@ def _mig_0_1(engine: Engine) -> None:
models_v1.Cycles.__table__, models_v1.Cycles.__table__,
models_v1.Users.__table__, models_v1.Users.__table__,
models_v1.Sessions.__table__, models_v1.Sessions.__table__,
models_v1.Exchanges.__table__,
], ],
) )

View File

@@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime # noqa: TC003
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from pydantic import BaseModel
from sqlmodel import SQLModel from sqlmodel import SQLModel
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import date, datetime
from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency
@@ -52,5 +52,33 @@ class UserCreate(UserBase):
password: str password: str
class UserLogin(BaseModel):
username: str
password: str
class UserRead(UserBase): class UserRead(UserBase):
id: int id: int
class SessionsBase(SQLModel):
user_id: int
class SessionRead(SessionsBase):
id: int
expires_at: datetime
last_seen_at: datetime | None
last_used_ip: str | None
user_agent: str | None
class SessionsCreate(SessionsBase):
expires_at: datetime
class SessionsUpdate(SQLModel):
expires_at: datetime | None = None
last_seen_at: datetime | None = None
last_used_ip: str | None = None
user_agent: str | None = None

View File

@@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True):
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" __tablename__ = "exchanges"
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
name: str = Field(sa_column=Column(Text, nullable=False))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
trades: list["Trades"] = Relationship(back_populates="exchange") trades: list["Trades"] = Relationship(back_populates="exchange")
cycles: list["Cycles"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange")
user: "Users" = Relationship(back_populates="exchanges")
class Users(SQLModel, table=True): class Users(SQLModel, table=True):
@@ -131,6 +134,8 @@ class Users(SQLModel, table=True):
username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False)) password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False) is_active: bool = Field(default=True, nullable=False)
sessions: list["Sessions"] = Relationship(back_populates="user")
exchanges: list["Exchanges"] = Relationship(back_populates="user")
class Sessions(SQLModel, table=True): class Sessions(SQLModel, table=True):
@@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True):
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user: "Users" = Relationship(back_populates="sessions")

View File

@@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True):
class Exchanges(SQLModel, table=True): class Exchanges(SQLModel, table=True):
__tablename__ = "exchanges" __tablename__ = "exchanges"
__table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),)
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
name: str = Field(sa_column=Column(Text, nullable=False))
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
trades: list["Trades"] = Relationship(back_populates="exchange") trades: list["Trades"] = Relationship(back_populates="exchange")
cycles: list["Cycles"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange")
user: "Users" = Relationship(back_populates="exchanges")
class Users(SQLModel, table=True): class Users(SQLModel, table=True):
@@ -131,6 +134,8 @@ class Users(SQLModel, table=True):
username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False)) password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False) is_active: bool = Field(default=True, nullable=False)
sessions: list["Sessions"] = Relationship(back_populates="user")
exchanges: list["Exchanges"] = Relationship(back_populates="user")
class Sessions(SQLModel, table=True): class Sessions(SQLModel, table=True):
@@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True):
last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
user: "Users" = Relationship(back_populates="sessions")

View File

@@ -1,3 +1,7 @@
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Callable from typing import Callable
from fastapi import Request, Response, status from fastapi import Request, Response, status
@@ -8,17 +12,23 @@ from starlette.middleware.base import BaseHTTPMiddleware
import settings import settings
from trading_journal import crud, security from trading_journal import crud, security
from trading_journal.db import Database from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
from trading_journal.models import Sessions from trading_journal.models import Sessions
SessionsCreate.model_rebuild()
EXCEPT_PATHS = [ EXCEPT_PATHS = [
f"{settings.settings.api_base}/status", f"{settings.settings.api_base}/status",
f"{settings.settings.api_base}/register", f"{settings.settings.api_base}/register",
f"{settings.settings.api_base}/login",
] ]
logger = logging.getLogger(__name__)
class AuthMiddleWare(BaseHTTPMiddleware): class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911
if request.url.path in EXCEPT_PATHS: if request.url.path in EXCEPT_PATHS:
return await call_next(request) return await call_next(request)
@@ -42,10 +52,32 @@ class AuthMiddleWare(BaseHTTPMiddleware):
hashed_token = security.hash_session_token_sha256(token) hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token)
except Exception: # noqa: BLE001 if not login_session:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"}) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
if session_expires_utc < datetime.now(timezone.utc):
crud.delete_login_session(request.state.db_session, login_session)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
if login_session.user.is_active is False:
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
else:
updated_expiry = session_expires_utc
updated_session: SessionsUpdate = SessionsUpdate(
last_seen_at=datetime.now(timezone.utc),
last_used_ip=request.client.host if request.client else None,
user_agent=request.headers.get("User-Agent"),
expires_at=updated_expiry,
)
user_id = login_session.user_id
request.state.user_id = user_id
crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session)
except Exception:
logger.exception("Failed to authenticate user: \n")
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
return None return await call_next(request)
class ServiceError(Exception): class ServiceError(Exception):
@@ -60,13 +92,46 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username): if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists") raise UserAlreadyExistsError("username already exists")
hashed = security.hash_password(user_in.password) hashed = security.hash_password(user_in.password)
user_data: dict = {
"username": user_in.username,
"password_hash": hashed,
}
try: try:
user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed) user = crud.create_user(db_session, user_data=user_data)
try: try:
# prefer pydantic's from_orm if DTO supports orm_mode # prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user) user = UserRead.model_validate(user)
except Exception as e: except Exception as e:
logger.exception("Failed to convert user to UserRead: %s", e)
raise ServiceError("Failed to convert user to UserRead") from e raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e: except Exception as e:
logger.exception("Failed to create user:")
raise ServiceError("Failed to create user") from e raise ServiceError("Failed to create user") from e
return user return user
def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None:
user = crud.get_user_by_username(db_session, user_in.username)
if not user:
return None
if not security.verify_password(user_in.password, user.password_hash):
return None
token = security.generate_session_token()
token_hashed = security.hash_session_token_sha256(token)
try:
session = crud.create_login_session(
session=db_session,
user_id=user.id,
session_token_hash=token_hashed,
session_length_seconds=settings.settings.session_expiry_seconds,
)
except Exception as e:
logger.exception("Failed to create login session: \n")
raise ServiceError("Failed to create login session") from e
return SessionsCreate.model_validate(session), token
def get_trades_service(db_session: Session, user_id: int) -> list:
return crud.get_trades_by_user_id(db_session, user_id)

View File

View File

@@ -0,0 +1,13 @@
import sys
from pathlib import Path
from sqlmodel import create_engine
project_parent = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(project_parent))
import settings # noqa: E402
from trading_journal import db_migration # noqa: E402
db_engine = create_engine(settings.settings.database_url, echo=True)
db_migration.run_migrations(db_engine)