Files
trading-journal/backend/trading_journal/service.py
Tianyu Liu b68249f9f1
All checks were successful
Backend CI / unit-test (push) Successful in 34s
add create get exchange endpoint
2025-09-22 23:07:28 +02:00

170 lines
7.0 KiB
Python

from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from sqlmodel import Session
from starlette.middleware.base import BaseHTTPMiddleware
import settings
from trading_journal import crud, security
from trading_journal.db import Database
from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead
from trading_journal.models import Sessions
SessionsCreate.model_rebuild()
EXCEPT_PATHS = [
f"{settings.settings.api_base}/status",
f"{settings.settings.api_base}/register",
f"{settings.settings.api_base}/login",
]
logger = logging.getLogger(__name__)
class AuthMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911
if request.url.path in EXCEPT_PATHS:
return await call_next(request)
token = request.cookies.get("session_token")
if not token:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :]
if not token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Unauthorized"},
)
db_factory: Database | None = getattr(request.app.state, "db_factory", None)
if db_factory is None:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"})
try:
with db_factory.get_session_ctx_manager() as request_session:
hashed_token = security.hash_session_token_sha256(token)
request.state.db_session = request_session
login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token)
if not login_session:
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc)
if session_expires_utc < datetime.now(timezone.utc):
crud.delete_login_session(request.state.db_session, login_session)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
if login_session.user.is_active is False:
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"})
if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600):
updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds)
else:
updated_expiry = session_expires_utc
updated_session: SessionsUpdate = SessionsUpdate(
last_seen_at=datetime.now(timezone.utc),
last_used_ip=request.client.host if request.client else None,
user_agent=request.headers.get("User-Agent"),
expires_at=updated_expiry,
)
user_id = login_session.user_id
request.state.user_id = user_id
crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session)
except Exception:
logger.exception("Failed to authenticate user: \n")
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"})
return await call_next(request)
class ServiceError(Exception):
pass
class UserAlreadyExistsError(ServiceError):
pass
class ExchangeAlreadyExistsError(ServiceError):
pass
def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead:
if crud.get_user_by_username(db_session, user_in.username):
raise UserAlreadyExistsError("username already exists")
hashed = security.hash_password(user_in.password)
user_data: dict = {
"username": user_in.username,
"password_hash": hashed,
}
try:
user = crud.create_user(db_session, user_data=user_data)
try:
# prefer pydantic's from_orm if DTO supports orm_mode
user = UserRead.model_validate(user)
except Exception as e:
logger.exception("Failed to convert user to UserRead: %s", e)
raise ServiceError("Failed to convert user to UserRead") from e
except Exception as e:
logger.exception("Failed to create user:")
raise ServiceError("Failed to create user") from e
return user
def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None:
user = crud.get_user_by_username(db_session, user_in.username)
if not user:
return None
if not security.verify_password(user_in.password, user.password_hash):
return None
token = security.generate_session_token()
token_hashed = security.hash_session_token_sha256(token)
try:
session = crud.create_login_session(
session=db_session,
user_id=user.id,
session_token_hash=token_hashed,
session_length_seconds=settings.settings.session_expiry_seconds,
)
except Exception as e:
logger.exception("Failed to create login session: \n")
raise ServiceError("Failed to create login session") from e
return SessionsCreate.model_validate(session), token
# Exchanges service
def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate:
existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id)
if existing_exchange:
raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user")
exchange_data = ExchangesCreate(
user_id=user_id,
name=name,
notes=notes,
)
try:
exchange = crud.create_exchange(db_session, exchange_data=exchange_data)
try:
exchange_dto = ExchangesCreate.model_validate(exchange)
except Exception as e:
logger.exception("Failed to convert exchange to ExchangesCreate: ")
raise ServiceError("Failed to convert exchange to ExchangesCreate") from e
except Exception as e:
logger.exception("Failed to create exchange:")
raise ServiceError("Failed to create exchange") from e
return exchange_dto
def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]:
exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id)
return [ExchangesBase.model_validate(exchange) for exchange in exchanges]
def get_trades_service(db_session: Session, user_id: int) -> list:
return crud.get_trades_by_user_id(db_session, user_id)