diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 5eac231..1ae5a55 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -561,6 +561,79 @@ def test_update_cycle_immutable_fields(session: Session) -> None: ) +# Exchanges +def test_create_exchange(session: Session) -> None: + exchange_data = { + "name": "NYSE", + "notes": "New York Stock Exchange", + } + exchange = crud.create_exchange(session, exchange_data) + assert exchange.id is not None + assert exchange.name == exchange_data["name"] + assert exchange.notes == exchange_data["notes"] + + session.refresh(exchange) + actual_exchange = session.get(models.Exchanges, exchange.id) + assert actual_exchange is not None + assert actual_exchange.name == exchange_data["name"] + assert actual_exchange.notes == exchange_data["notes"] + + +def test_get_exchange_by_id(session: Session) -> None: + exchange_id = make_exchange(session, name="LSE") + exchange = crud.get_exchange_by_id(session, exchange_id) + assert exchange is not None + assert exchange.id == exchange_id + assert exchange.name == "LSE" + + +def test_get_exchange_by_name(session: Session) -> None: + exchange_name = "TSX" + make_exchange(session, name=exchange_name) + exchange = crud.get_exchange_by_name(session, exchange_name) + assert exchange is not None + assert exchange.name == exchange_name + + +def test_get_all_exchanges(session: Session) -> None: + exchange_names = ["NYSE", "NASDAQ", "LSE"] + for name in exchange_names: + make_exchange(session, name=name) + + exchanges = crud.get_all_exchanges(session) + assert len(exchanges) >= 3 + fetched_names = {ex.name for ex in exchanges} + for name in exchange_names: + assert name in fetched_names + + +def test_update_exchange(session: Session) -> None: + exchange_id = make_exchange(session, name="Initial Exchange") + update_data = { + "name": "Updated Exchange", + "notes": "Updated notes for the exchange", + } + updated_exchange = crud.update_exchange(session, exchange_id, update_data) + assert updated_exchange is not None + assert updated_exchange.id == exchange_id + assert updated_exchange.name == update_data["name"] + assert updated_exchange.notes == update_data["notes"] + + session.refresh(updated_exchange) + actual_exchange = session.get(models.Exchanges, exchange_id) + assert actual_exchange is not None + assert actual_exchange.name == update_data["name"] + assert actual_exchange.notes == update_data["notes"] + + +def test_delete_exchange(session: Session) -> None: + exchange_id = make_exchange(session, name="Deletable Exchange") + crud.delete_exchange(session, exchange_id) + deleted_exchange = session.get(models.Exchanges, exchange_id) + assert deleted_exchange is None + + +# Users def test_create_user(session: Session) -> None: user_data = { "username": "newuser", diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 5ce6fb4..9e998bd 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -245,6 +245,83 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model return cycle +# Exchanges +IMMUTABLE_EXCHANGE_FIELDS = {"id"} + + +def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges: + if hasattr(exchange_data, "dict"): + data = exchange_data.dict(exclude_unset=True) + else: + data = dict(exchange_data) + allowed = {c.name for c in models.Exchanges.__table__.columns} + payload = {k: v for k, v in data.items() if k in allowed} + if "name" not in payload: + raise ValueError("name is required") + + e = models.Exchanges(**payload) + session.add(e) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_exchange integrity error") from e + session.refresh(e) + return e + + +def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None: + return session.get(models.Exchanges, exchange_id) + + +def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None: + statement = select(models.Exchanges).where( + models.Exchanges.name == name, + ) + return session.exec(statement).first() + + +def get_all_exchanges(session: Session) -> list[models.Exchanges]: + statement = select(models.Exchanges) + return session.exec(statement).all() + + +def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: + exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) + if exchange is None: + raise ValueError("exchange_id does not exist") + if hasattr(update_data, "dict"): + data = update_data.dict(exclude_unset=True) + else: + data = dict(update_data) + allowed = {c.name for c in models.Exchanges.__table__.columns} + for k, v in data.items(): + if k in IMMUTABLE_EXCHANGE_FIELDS: + raise ValueError(f"field {k!r} is immutable") + if k in allowed: + setattr(exchange, k, v) + session.add(exchange) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_exchange integrity error") from e + session.refresh(exchange) + return exchange + + +def delete_exchange(session: Session, exchange_id: int) -> None: + exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) + if exchange is None: + return + session.delete(exchange) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("delete_exchange integrity error") from e + + # Users IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}