refine type checking
All checks were successful
Backend CI / unit-test (push) Successful in 35s

This commit is contained in:
2025-09-23 17:37:14 +02:00
parent b68249f9f1
commit 92c4e0d4fc
8 changed files with 132 additions and 103 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
import pytest
from sqlalchemy import create_engine
@@ -45,7 +45,7 @@ def make_user(session: Session, username: str = "testuser") -> int:
session.add(user)
session.commit()
session.refresh(user)
return user.id
return cast("int", user.id)
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
@@ -53,7 +53,7 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
session.add(exchange)
session.commit()
session.refresh(exchange)
return exchange.id
return cast("int", exchange.id)
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
@@ -65,15 +65,16 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name:
underlying_currency=models.UnderlyingCurrency.USD,
status=models.CycleStatus.OPEN,
start_date=datetime.now(timezone.utc).date(),
)
) # type: ignore[arg-type]
session.add(cycle)
session.commit()
session.refresh(cycle)
return cycle.id
return cast("int", cycle.id)
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
cycle: models.Cycles = session.get(models.Cycles, cycle_id)
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
assert cycle is not None
exchange_id = cycle.exchange_id
trade = models.Trades(
user_id=user_id,
@@ -96,7 +97,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
return cast("int", trade.id)
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
@@ -104,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
return cast("int", trade.id)
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
@@ -128,7 +129,7 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
return login_session
def _ensure_utc_aware(dt: datetime) -> datetime | None:
def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
if dt is None:
return None
if dt.tzinfo is None:
@@ -219,7 +220,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
assert auto_cycle.symbol == trade_data["symbol"]
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
assert auto_cycle.status == models.CycleStatus.OPEN
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
def test_create_trade_missing_required_fields(session: Session) -> None: