Compare commits

...

3 Commits

Author SHA1 Message Date
466e6ce653 wip user reg
All checks were successful
Backend CI / unit-test (push) Successful in 34s
2025-09-22 17:35:10 +02:00
e70a63e4f9 add security py 2025-09-22 14:54:29 +02:00
76ed38e9af add crud for exchange 2025-09-22 14:39:33 +02:00
12 changed files with 384 additions and 24 deletions

View File

@@ -2,13 +2,12 @@ import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, status from fastapi import FastAPI, HTTPException, Request
import settings import settings
from trading_journal import db from trading_journal import db, service
from trading_journal.dto import TradeCreate, TradeRead from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead
API_BASE = "/api/v1"
_db = db.create_database(settings.settings.database_url) _db = db.create_database(settings.settings.database_url)
@@ -23,8 +22,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(service.AuthMiddleWare)
app.state.db_factory = _db
@app.get(f"{API_BASE}/status") @app.get(f"{settings.settings.api_base}/status")
async def get_status() -> dict[str, str]: async def get_status() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}
@app.post(f"{settings.settings.api_base}/register")
async def register_user(request: Request, user_in: UserCreate) -> UserRead:
db_factory: Database = request.app.state.db_factory
def sync_work() -> UserRead:
with db_factory.get_session_ctx_manager() as db:
return service.register_user_service(db, user_in)
try:
return await asyncio.to_thread(sync_work)
except service.UserAlreadyExistsError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e
@app.get(f"{settings.settings.api_base}/trades")
async def get_trades() -> dict[str, str]:
return {"trades": []}

View File

@@ -17,7 +17,7 @@ anyio==4.10.0 \
argon2-cffi==25.1.0 \ argon2-cffi==25.1.0 \
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
# via passlib # via -r requirements.in
argon2-cffi-bindings==25.1.0 \ argon2-cffi-bindings==25.1.0 \
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
@@ -230,10 +230,6 @@ packaging==25.0 \
--hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \
--hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f
# via pytest # via pytest
passlib[argon2]==1.7.4 \
--hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \
--hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04
# via -r requirements.in
pluggy==1.6.0 \ pluggy==1.6.0 \
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746

View File

@@ -4,4 +4,4 @@ httpx
pyyaml pyyaml
pydantic-settings pydantic-settings
sqlmodel sqlmodel
passlib[argon2] argon2-cffi

View File

@@ -17,7 +17,7 @@ anyio==4.10.0 \
argon2-cffi==25.1.0 \ argon2-cffi==25.1.0 \
--hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \
--hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741
# via passlib # via -r requirements.in
argon2-cffi-bindings==25.1.0 \ argon2-cffi-bindings==25.1.0 \
--hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \
--hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \
@@ -222,10 +222,6 @@ idna==3.10 \
# via # via
# anyio # anyio
# httpx # httpx
passlib[argon2]==1.7.4 \
--hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \
--hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04
# via -r requirements.in
pycparser==2.23 \ pycparser==2.23 \
--hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \ --hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \
--hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934 --hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -13,6 +15,8 @@ class Settings(BaseSettings):
workers: int = 1 workers: int = 1
log_level: str = "info" log_level: str = "info"
database_url: str = "sqlite:///:memory:" database_url: str = "sqlite:///:memory:"
api_base: str = "/api/v1"
hmac_key: str | None = None
model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")

View File

@@ -3,7 +3,8 @@ from collections.abc import Generator
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app import API_BASE, app import settings
from app import app
@pytest.fixture @pytest.fixture
@@ -13,6 +14,6 @@ def client() -> Generator[TestClient, None, None]:
def test_get_status(client: TestClient) -> None: def test_get_status(client: TestClient) -> None:
response = client.get(f"{API_BASE}/status") response = client.get(f"{settings.settings.api_base}/status")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "ok"} assert response.json() == {"status": "ok"}

View File

@@ -561,6 +561,79 @@ def test_update_cycle_immutable_fields(session: Session) -> None:
) )
# Exchanges
def test_create_exchange(session: Session) -> None:
exchange_data = {
"name": "NYSE",
"notes": "New York Stock Exchange",
}
exchange = crud.create_exchange(session, exchange_data)
assert exchange.id is not None
assert exchange.name == exchange_data["name"]
assert exchange.notes == exchange_data["notes"]
session.refresh(exchange)
actual_exchange = session.get(models.Exchanges, exchange.id)
assert actual_exchange is not None
assert actual_exchange.name == exchange_data["name"]
assert actual_exchange.notes == exchange_data["notes"]
def test_get_exchange_by_id(session: Session) -> None:
exchange_id = make_exchange(session, name="LSE")
exchange = crud.get_exchange_by_id(session, exchange_id)
assert exchange is not None
assert exchange.id == exchange_id
assert exchange.name == "LSE"
def test_get_exchange_by_name(session: Session) -> None:
exchange_name = "TSX"
make_exchange(session, name=exchange_name)
exchange = crud.get_exchange_by_name(session, exchange_name)
assert exchange is not None
assert exchange.name == exchange_name
def test_get_all_exchanges(session: Session) -> None:
exchange_names = ["NYSE", "NASDAQ", "LSE"]
for name in exchange_names:
make_exchange(session, name=name)
exchanges = crud.get_all_exchanges(session)
assert len(exchanges) >= 3
fetched_names = {ex.name for ex in exchanges}
for name in exchange_names:
assert name in fetched_names
def test_update_exchange(session: Session) -> None:
exchange_id = make_exchange(session, name="Initial Exchange")
update_data = {
"name": "Updated Exchange",
"notes": "Updated notes for the exchange",
}
updated_exchange = crud.update_exchange(session, exchange_id, update_data)
assert updated_exchange is not None
assert updated_exchange.id == exchange_id
assert updated_exchange.name == update_data["name"]
assert updated_exchange.notes == update_data["notes"]
session.refresh(updated_exchange)
actual_exchange = session.get(models.Exchanges, exchange_id)
assert actual_exchange is not None
assert actual_exchange.name == update_data["name"]
assert actual_exchange.notes == update_data["notes"]
def test_delete_exchange(session: Session) -> None:
exchange_id = make_exchange(session, name="Deletable Exchange")
crud.delete_exchange(session, exchange_id)
deleted_exchange = session.get(models.Exchanges, exchange_id)
assert deleted_exchange is None
# Users
def test_create_user(session: Session) -> None: def test_create_user(session: Session) -> None:
user_data = { user_data = {
"username": "newuser", "username": "newuser",
@@ -578,6 +651,22 @@ def test_create_user(session: Session) -> None:
assert actual_user.password_hash == user_data["password_hash"] assert actual_user.password_hash == user_data["password_hash"]
def test_get_user_by_id(session: Session) -> None:
user_id = make_user(session, username="fetchuser")
user = crud.get_user_by_id(session, user_id)
assert user is not None
assert user.id == user_id
assert user.username == "fetchuser"
def test_get_user_by_username(session: Session) -> None:
username = "uniqueuser"
make_user(session, username=username)
user = crud.get_user_by_username(session, username)
assert user is not None
assert user.username == username
def test_update_user(session: Session) -> None: def test_update_user(session: Session) -> None:
user_id = make_user(session, username="updatableuser") user_id = make_user(session, username="updatableuser")
@@ -643,6 +732,16 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None:
assert fetched_session.session_token_hash == created_session.session_token_hash assert fetched_session.session_token_hash == created_session.session_token_hash
def test_get_login_session_by_token(session: Session) -> None:
now = datetime.now(timezone.utc)
created_session = make_login_session(session, now)
fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash)
assert fetched_session is not None
assert fetched_session.id == created_session.id
assert fetched_session.user_id == created_session.user_id
assert fetched_session.session_token_hash == created_session.session_token_hash
def test_update_login_session(session: Session) -> None: def test_update_login_session(session: Session) -> None:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
created_session = make_login_session(session, now) created_session = make_login_session(session, now)

View File

@@ -1,4 +1,24 @@
from trading_journal import security from trading_journal import security
def test_hash_password() -> None:
def test_hash_and_verify_password() -> None:
plain = "password" plain = "password"
hashed = security.hash_password(plain)
assert hashed != plain
assert security.verify_password(plain, hashed)
def test_generate_session_token() -> None:
token1 = security.generate_session_token()
token2 = security.generate_session_token()
assert token1 != token2
assert len(token1) > 0
assert len(token2) > 0
def test_hash_and_verify_session_token_sha256() -> None:
token = security.generate_session_token()
token_hash = security.hash_session_token_sha256(token)
assert token_hash != token
assert security.verify_token_sha256(token, token_hash)
assert not security.verify_token_sha256(token + "x", token_hash)

View File

@@ -245,6 +245,83 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model
return cycle return cycle
# Exchanges
IMMUTABLE_EXCHANGE_FIELDS = {"id"}
def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges:
if hasattr(exchange_data, "dict"):
data = exchange_data.dict(exclude_unset=True)
else:
data = dict(exchange_data)
allowed = {c.name for c in models.Exchanges.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed}
if "name" not in payload:
raise ValueError("name is required")
e = models.Exchanges(**payload)
session.add(e)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_exchange integrity error") from e
session.refresh(e)
return e
def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None:
return session.get(models.Exchanges, exchange_id)
def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None:
statement = select(models.Exchanges).where(
models.Exchanges.name == name,
)
return session.exec(statement).first()
def get_all_exchanges(session: Session) -> list[models.Exchanges]:
statement = select(models.Exchanges)
return session.exec(statement).all()
def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None:
raise ValueError("exchange_id does not exist")
if hasattr(update_data, "dict"):
data = update_data.dict(exclude_unset=True)
else:
data = dict(update_data)
allowed = {c.name for c in models.Exchanges.__table__.columns}
for k, v in data.items():
if k in IMMUTABLE_EXCHANGE_FIELDS:
raise ValueError(f"field {k!r} is immutable")
if k in allowed:
setattr(exchange, k, v)
session.add(exchange)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("update_exchange integrity error") from e
session.refresh(exchange)
return exchange
def delete_exchange(session: Session, exchange_id: int) -> None:
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
if exchange is None:
return
session.delete(exchange)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("delete_exchange integrity error") from e
# Users # Users
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
@@ -272,6 +349,17 @@ def create_user(session: Session, user_data: Mapping) -> models.Users:
return u return u
def get_user_by_id(session: Session, user_id: int) -> models.Users | None:
return session.get(models.Users, user_id)
def get_user_by_username(session: Session, username: str) -> models.Users | None:
statement = select(models.Users).where(
models.Users.username == username,
)
return session.exec(statement).first()
def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users:
user: models.Users | None = session.get(models.Users, user_id) user: models.Users | None = session.get(models.Users, user_id)
if user is None: if user is None:
@@ -341,6 +429,15 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_
return session.exec(statement).first() return session.exec(statement).first()
def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None:
statement = select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hash,
models.Sessions.expires_at > datetime.now(timezone.utc),
)
return session.exec(statement).first()
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import event from sqlalchemy import event
@@ -72,6 +73,18 @@ class Database:
finally: finally:
session.close() session.close()
@contextmanager
def get_session_ctx_manager(self) -> Session:
session = Session(self._engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def dispose(self) -> None: def dispose(self) -> None:
self._engine.dispose() self._engine.dispose()

View File

@@ -1,11 +1,51 @@
from passlib.context import CryptContext import hashlib
import hmac
import secrets
pwd_ctx = CryptContext(schemes=["argon2"], deprecated="auto") from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
import settings
ph = PasswordHasher()
# Utility functions for password hashing and verification
def hash_password(plain: str) -> str: def hash_password(plain: str) -> str:
return pwd_ctx.hash(plain) return ph.hash(plain)
def verify_password(plain: str, hashed: str) -> bool: def verify_password(plain: str, hashed: str) -> bool:
return pwd_ctx.verify(plain, hashed) try:
return ph.verify(hashed, plain)
except VerifyMismatchError:
return False
# Session token hash
def generate_session_token(nbytes: int = 32) -> str:
return secrets.token_urlsafe(nbytes)
def hash_session_token_sha256(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def sign_token_hmac(token: str) -> str:
if not settings.settings.hmac_key:
return token
return hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
def verify_token_sha256(token: str, expected_hash: str) -> bool:
return hmac.compare_digest(hash_session_token_sha256(token), expected_hash)
def verify_token_hmac(token: str, expected_hmac: str) -> bool:
if not settings.settings.hmac_key:
return verify_token_sha256(token, expected_hmac)
sig = hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
return hmac.compare_digest(sig, expected_hmac)

View File

@@ -0,0 +1,72 @@
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from sqlmodel import Session
from starlette.middleware.base import BaseHTTPMiddleware
import settings
from trading_journal import crud, security
from trading_journal.db import Database
from trading_journal.dto import UserCreate, UserRead
from trading_journal.models import Sessions
EXCEPT_PATHS = [
f"{settings.settings.api_base}/status",
f"{settings.settings.api_base}/register",
]
class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response:
if request.url.path in EXCEPT_PATHS:
return await call_next(request)
token = request.cookies.get("session_token")
if not token:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :]
if not token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
if db_factory is None:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
try:
with db_factory.get_session_ctx_manager() as request_session:
hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token)
except Exception: # noqa: BLE001
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"})
return None
class ServiceError(Exception):
pass
class UserAlreadyExistsError(ServiceError):
pass
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists")
hashed = security.hash_password(user_in.password)
try:
user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed)
try:
# prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user)
except Exception as e:
raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e:
raise ServiceError("Failed to create user") from e
return user