add crud for exchange
This commit is contained in:
@@ -561,6 +561,79 @@ def test_update_cycle_immutable_fields(session: Session) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges
|
||||||
|
def test_create_exchange(session: Session) -> None:
|
||||||
|
exchange_data = {
|
||||||
|
"name": "NYSE",
|
||||||
|
"notes": "New York Stock Exchange",
|
||||||
|
}
|
||||||
|
exchange = crud.create_exchange(session, exchange_data)
|
||||||
|
assert exchange.id is not None
|
||||||
|
assert exchange.name == exchange_data["name"]
|
||||||
|
assert exchange.notes == exchange_data["notes"]
|
||||||
|
|
||||||
|
session.refresh(exchange)
|
||||||
|
actual_exchange = session.get(models.Exchanges, exchange.id)
|
||||||
|
assert actual_exchange is not None
|
||||||
|
assert actual_exchange.name == exchange_data["name"]
|
||||||
|
assert actual_exchange.notes == exchange_data["notes"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchange_by_id(session: Session) -> None:
|
||||||
|
exchange_id = make_exchange(session, name="LSE")
|
||||||
|
exchange = crud.get_exchange_by_id(session, exchange_id)
|
||||||
|
assert exchange is not None
|
||||||
|
assert exchange.id == exchange_id
|
||||||
|
assert exchange.name == "LSE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_exchange_by_name(session: Session) -> None:
|
||||||
|
exchange_name = "TSX"
|
||||||
|
make_exchange(session, name=exchange_name)
|
||||||
|
exchange = crud.get_exchange_by_name(session, exchange_name)
|
||||||
|
assert exchange is not None
|
||||||
|
assert exchange.name == exchange_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_exchanges(session: Session) -> None:
|
||||||
|
exchange_names = ["NYSE", "NASDAQ", "LSE"]
|
||||||
|
for name in exchange_names:
|
||||||
|
make_exchange(session, name=name)
|
||||||
|
|
||||||
|
exchanges = crud.get_all_exchanges(session)
|
||||||
|
assert len(exchanges) >= 3
|
||||||
|
fetched_names = {ex.name for ex in exchanges}
|
||||||
|
for name in exchange_names:
|
||||||
|
assert name in fetched_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_exchange(session: Session) -> None:
|
||||||
|
exchange_id = make_exchange(session, name="Initial Exchange")
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated Exchange",
|
||||||
|
"notes": "Updated notes for the exchange",
|
||||||
|
}
|
||||||
|
updated_exchange = crud.update_exchange(session, exchange_id, update_data)
|
||||||
|
assert updated_exchange is not None
|
||||||
|
assert updated_exchange.id == exchange_id
|
||||||
|
assert updated_exchange.name == update_data["name"]
|
||||||
|
assert updated_exchange.notes == update_data["notes"]
|
||||||
|
|
||||||
|
session.refresh(updated_exchange)
|
||||||
|
actual_exchange = session.get(models.Exchanges, exchange_id)
|
||||||
|
assert actual_exchange is not None
|
||||||
|
assert actual_exchange.name == update_data["name"]
|
||||||
|
assert actual_exchange.notes == update_data["notes"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_exchange(session: Session) -> None:
|
||||||
|
exchange_id = make_exchange(session, name="Deletable Exchange")
|
||||||
|
crud.delete_exchange(session, exchange_id)
|
||||||
|
deleted_exchange = session.get(models.Exchanges, exchange_id)
|
||||||
|
assert deleted_exchange is None
|
||||||
|
|
||||||
|
|
||||||
|
# Users
|
||||||
def test_create_user(session: Session) -> None:
|
def test_create_user(session: Session) -> None:
|
||||||
user_data = {
|
user_data = {
|
||||||
"username": "newuser",
|
"username": "newuser",
|
||||||
|
|||||||
@@ -245,6 +245,83 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model
|
|||||||
return cycle
|
return cycle
|
||||||
|
|
||||||
|
|
||||||
|
# Exchanges
|
||||||
|
IMMUTABLE_EXCHANGE_FIELDS = {"id"}
|
||||||
|
|
||||||
|
|
||||||
|
def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges:
|
||||||
|
if hasattr(exchange_data, "dict"):
|
||||||
|
data = exchange_data.dict(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
data = dict(exchange_data)
|
||||||
|
allowed = {c.name for c in models.Exchanges.__table__.columns}
|
||||||
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
|
if "name" not in payload:
|
||||||
|
raise ValueError("name is required")
|
||||||
|
|
||||||
|
e = models.Exchanges(**payload)
|
||||||
|
session.add(e)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("create_exchange integrity error") from e
|
||||||
|
session.refresh(e)
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None:
|
||||||
|
return session.get(models.Exchanges, exchange_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None:
|
||||||
|
statement = select(models.Exchanges).where(
|
||||||
|
models.Exchanges.name == name,
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_exchanges(session: Session) -> list[models.Exchanges]:
|
||||||
|
statement = select(models.Exchanges)
|
||||||
|
return session.exec(statement).all()
|
||||||
|
|
||||||
|
|
||||||
|
def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges:
|
||||||
|
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||||
|
if exchange is None:
|
||||||
|
raise ValueError("exchange_id does not exist")
|
||||||
|
if hasattr(update_data, "dict"):
|
||||||
|
data = update_data.dict(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
data = dict(update_data)
|
||||||
|
allowed = {c.name for c in models.Exchanges.__table__.columns}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in IMMUTABLE_EXCHANGE_FIELDS:
|
||||||
|
raise ValueError(f"field {k!r} is immutable")
|
||||||
|
if k in allowed:
|
||||||
|
setattr(exchange, k, v)
|
||||||
|
session.add(exchange)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("update_exchange integrity error") from e
|
||||||
|
session.refresh(exchange)
|
||||||
|
return exchange
|
||||||
|
|
||||||
|
|
||||||
|
def delete_exchange(session: Session, exchange_id: int) -> None:
|
||||||
|
exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id)
|
||||||
|
if exchange is None:
|
||||||
|
return
|
||||||
|
session.delete(exchange)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("delete_exchange integrity error") from e
|
||||||
|
|
||||||
|
|
||||||
# Users
|
# Users
|
||||||
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
|
IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user