This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
@@ -45,7 +45,7 @@ def make_user(session: Session, username: str = "testuser") -> int:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user.id
|
||||
return cast("int", user.id)
|
||||
|
||||
|
||||
def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
|
||||
@@ -53,7 +53,7 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int:
|
||||
session.add(exchange)
|
||||
session.commit()
|
||||
session.refresh(exchange)
|
||||
return exchange.id
|
||||
return cast("int", exchange.id)
|
||||
|
||||
|
||||
def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||
@@ -65,15 +65,16 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name:
|
||||
underlying_currency=models.UnderlyingCurrency.USD,
|
||||
status=models.CycleStatus.OPEN,
|
||||
start_date=datetime.now(timezone.utc).date(),
|
||||
)
|
||||
) # type: ignore[arg-type]
|
||||
session.add(cycle)
|
||||
session.commit()
|
||||
session.refresh(cycle)
|
||||
return cycle.id
|
||||
return cast("int", cycle.id)
|
||||
|
||||
|
||||
def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int:
|
||||
cycle: models.Cycles = session.get(models.Cycles, cycle_id)
|
||||
cycle: models.Cycles | None = session.get(models.Cycles, cycle_id)
|
||||
assert cycle is not None
|
||||
exchange_id = cycle.exchange_id
|
||||
trade = models.Trades(
|
||||
user_id=user_id,
|
||||
@@ -96,7 +97,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
session.refresh(trade)
|
||||
return trade.id
|
||||
return cast("int", trade.id)
|
||||
|
||||
|
||||
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||
@@ -104,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
session.refresh(trade)
|
||||
return trade.id
|
||||
return cast("int", trade.id)
|
||||
|
||||
|
||||
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
|
||||
@@ -128,7 +129,7 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session
|
||||
return login_session
|
||||
|
||||
|
||||
def _ensure_utc_aware(dt: datetime) -> datetime | None:
|
||||
def _ensure_utc_aware(dt: datetime | None) -> datetime | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
@@ -219,7 +220,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None:
|
||||
assert auto_cycle.symbol == trade_data["symbol"]
|
||||
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
|
||||
assert auto_cycle.status == models.CycleStatus.OPEN
|
||||
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
|
||||
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_create_trade_missing_required_fields(session: Session) -> None:
|
||||
|
||||
Reference in New Issue
Block a user