From eb1f8c0e37f95807c943f0ac66ad280779593be9 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 18 Sep 2025 14:26:55 +0200 Subject: [PATCH] db ferst version is done. --- backend/tests/test_crud.py | 178 +++++++++++++++++++++++++-- backend/trading_journal/crud.py | 150 +++++++++++++++++++--- backend/trading_journal/models_v1.py | 8 ++ 3 files changed, 312 insertions(+), 24 deletions(-) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 325e270..3ee2fce 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -74,6 +74,7 @@ def make_trade( commission_cents=500, net_cash_flow_cents=-150500, cycle_id=cycle_id, + notes="Initial test trade", ) session.add(trade) session.commit() @@ -97,9 +98,9 @@ def test_create_trade_success_with_cycle(session: Session): "user_id": user_id, "friendly_name": "Test Trade", "symbol": "AAPL", - "underlying_currency": "USD", - "trade_type": "LONG_SPOT", - "trade_strategy": "SPOT", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, "trade_time_utc": datetime.now(), "quantity": 10, "price_cents": 15000, @@ -137,9 +138,9 @@ def test_create_trade_with_auto_created_cycle(session: Session): "user_id": user_id, "friendly_name": "Test Trade with Auto Cycle", "symbol": "AAPL", - "underlying_currency": "USD", - "trade_type": "LONG_SPOT", - "trade_strategy": "SPOT", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, "trade_time_utc": datetime.now(), "quantity": 5, "price_cents": 15500, @@ -179,9 +180,9 @@ def test_create_trade_missing_required_fields(session: Session): "user_id": user_id, "friendly_name": "Incomplete Trade", "symbol": "AAPL", - "underlying_currency": "USD", - "trade_type": "LONG_SPOT", - "trade_strategy": "SPOT", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, "trade_time_utc": datetime.now(), "quantity": 10, "price_cents": 15000, @@ -338,13 +339,89 @@ def test_get_trades_by_user_id(session: Session): assert friendly_names == {"Trade One", "Trade Two"} +def test_update_trade_note(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + trade_id = make_trade(session, user_id, cycle_id) + + new_note = "This is an updated note." + updated_trade = crud.update_trade_note(session, trade_id, new_note) + assert updated_trade is not None + assert updated_trade.id == trade_id + assert updated_trade.notes == new_note + + session.refresh(updated_trade) + actual_trade = session.get(models.Trades, trade_id) + assert actual_trade is not None + assert actual_trade.notes == new_note + + +def test_invalidate_trade(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + trade_id = make_trade(session, user_id, cycle_id) + + invalidated_trade = crud.invalidate_trade(session, trade_id) + assert invalidated_trade is not None + assert invalidated_trade.id == trade_id + assert invalidated_trade.is_invalidated is True + + session.refresh(invalidated_trade) + actual_trade = session.get(models.Trades, trade_id) + assert actual_trade is not None + assert actual_trade.is_invalidated is True + + +def test_replace_trade(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + old_trade_id = make_trade(session, user_id, cycle_id) + + new_trade_data = { + "user_id": user_id, + "friendly_name": "Replaced Trade", + "symbol": "MSFT", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_time_utc": datetime.now(), + "quantity": 20, + "price_cents": 25000, + } + + new_trade = crud.replace_trade(session, old_trade_id, new_trade_data) + assert new_trade.id is not None + assert new_trade.id != old_trade_id + assert new_trade.user_id == user_id + assert new_trade.symbol == new_trade_data["symbol"] + assert new_trade.quantity == new_trade_data["quantity"] + + # Verify the old trade is invalidated + old_trade = session.get(models.Trades, old_trade_id) + assert old_trade is not None + assert old_trade.is_invalidated is True + + # Verify the new trade exists + session.refresh(new_trade) + actual_new_trade = session.get(models.Trades, new_trade.id) + assert actual_new_trade is not None + assert actual_new_trade.friendly_name == new_trade_data["friendly_name"] + assert actual_new_trade.symbol == new_trade_data["symbol"] + assert actual_new_trade.underlying_currency == new_trade_data["underlying_currency"] + assert actual_new_trade.trade_type == new_trade_data["trade_type"] + assert actual_new_trade.trade_strategy == new_trade_data["trade_strategy"] + assert actual_new_trade.quantity == new_trade_data["quantity"] + assert actual_new_trade.price_cents == new_trade_data["price_cents"] + assert actual_new_trade.replaced_by_trade_id == old_trade_id + + def test_create_cycle(session: Session): user_id = make_user(session) cycle_data = { "user_id": user_id, "friendly_name": "My First Cycle", "symbol": "GOOGL", - "underlying_currency": "USD", + "underlying_currency": models.UnderlyingCurrency.USD, "status": models.CycleStatus.OPEN, "start_date": datetime.now().date(), } @@ -367,6 +444,50 @@ def test_create_cycle(session: Session): assert actual_cycle.start_date == cycle_data["start_date"] +def test_update_cycle(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") + + update_data = { + "friendly_name": "Updated Cycle Name", + "status": models.CycleStatus.CLOSED, + } + updated_cycle = crud.update_cycle(session, cycle_id, update_data) + assert updated_cycle is not None + assert updated_cycle.id == cycle_id + assert updated_cycle.friendly_name == update_data["friendly_name"] + assert updated_cycle.status == update_data["status"] + + session.refresh(updated_cycle) + actual_cycle = session.get(models.Cycles, cycle_id) + assert actual_cycle is not None + assert actual_cycle.friendly_name == update_data["friendly_name"] + assert actual_cycle.status == update_data["status"] + + +def test_update_cycle_immutable_fields(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") + + # Attempt to update immutable fields + update_data = { + "id": cycle_id + 1, # Trying to change the ID + "user_id": user_id + 1, # Trying to change the user_id + "start_date": datetime(2020, 1, 1).date(), # Trying to change start_date + "created_at": datetime(2020, 1, 1), # Trying to change created_at + "friendly_name": "Valid Update", # Valid field to update + } + + with pytest.raises(ValueError) as excinfo: + crud.update_cycle(session, cycle_id, update_data) + assert ( + "field 'id' is immutable" in str(excinfo.value) + or "field 'user_id' is immutable" in str(excinfo.value) + or "field 'start_date' is immutable" in str(excinfo.value) + or "field 'created_at' is immutable" in str(excinfo.value) + ) + + def test_create_user(session: Session): user_data = { "username": "newuser", @@ -382,3 +503,40 @@ def test_create_user(session: Session): assert actual_user is not None assert actual_user.username == user_data["username"] assert actual_user.password_hash == user_data["password_hash"] + + +def test_update_user(session: Session): + user_id = make_user(session, username="updatableuser") + + update_data = { + "password_hash": "updatedhashedpassword", + } + updated_user = crud.update_user(session, user_id, update_data) + assert updated_user is not None + assert updated_user.id == user_id + assert updated_user.password_hash == update_data["password_hash"] + + session.refresh(updated_user) + actual_user = session.get(models.Users, user_id) + assert actual_user is not None + assert actual_user.password_hash == update_data["password_hash"] + + +def test_update_user_immutable_fields(session: Session): + user_id = make_user(session, username="immutableuser") + + # Attempt to update immutable fields + update_data = { + "id": user_id + 1, # Trying to change the ID + "username": "newusername", # Trying to change the username + "created_at": datetime(2020, 1, 1), # Trying to change created_at + "password_hash": "validupdate", # Valid field to update + } + + with pytest.raises(ValueError) as excinfo: + crud.update_user(session, user_id, update_data) + assert ( + "field 'id' is immutable" in str(excinfo.value) + or "field 'username' is immutable" in str(excinfo.value) + or "field 'created_at' is immutable" in str(excinfo.value) + ) diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index b37e065..386c1f4 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -47,6 +47,7 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: payload["trade_strategy"] = _check_enum( models.TradeStrategy, payload["trade_strategy"], "trade_strategy" ) + # trade_time_utc is the creation moment: always set to now (caller shouldn't provide) now = datetime.now(timezone.utc) payload.pop("trade_time_utc", None) payload["trade_time_utc"] = now @@ -69,20 +70,24 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: payload["net_cash_flow_cents"] = ( payload["gross_cash_flow_cents"] - commission_cents ) + + # If no cycle_id provided, create Cycle instance but don't call create_cycle() + created_cycle = None if cycle_id is None: - cycle_id = create_cycle( - session, - { - "user_id": user_id, - "symbol": payload["symbol"], - "underlying_currency": payload["underlying_currency"], - "friendly_name": "Auto-created Cycle by trade " - + payload.get("friendly_name", ""), - "status": models.CycleStatus.OPEN, - "start_date": payload["trade_date"], - }, - ).id - payload["cycle_id"] = cycle_id + c_payload = { + "user_id": user_id, + "symbol": payload["symbol"], + "underlying_currency": payload["underlying_currency"], + "friendly_name": "Auto-created Cycle by trade " + + payload.get("friendly_name", ""), + "status": models.CycleStatus.OPEN, + "start_date": payload["trade_date"], + } + created_cycle = models.Cycles(**c_payload) + session.add(created_cycle) + # do NOT flush here; will flush together with trade below + + # If cycle_id provided, validate existence and ownership if cycle_id is not None: cycle = session.get(models.Cycles, cycle_id) if cycle is None: @@ -90,7 +95,16 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: else: if cycle.user_id != user_id: raise ValueError("cycle.user_id does not match trade.user_id") - t = models.Trades(**payload) + + # Build trade instance; if we created a Cycle instance, link via relationship so a single flush will persist both and populate ids + t_payload = dict(payload) + # remove cycle_id if we're using created_cycle; relationship will set it on flush + if created_cycle is not None: + t_payload.pop("cycle_id", None) + t = models.Trades(**t_payload) + if created_cycle is not None: + t.cycle = created_cycle + session.add(t) try: session.flush() @@ -122,6 +136,52 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] return session.exec(statement).all() +def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades: + trade: models.Trades | None = session.get(models.Trades, trade_id) + if trade is None: + raise ValueError("trade_id does not exist") + trade.notes = note + session.add(trade) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_trade_note integrity error") from e + session.refresh(trade) + return trade + + +def invalidate_trade(session: Session, trade_id: int) -> models.Trades: + trade: models.Trades | None = session.get(models.Trades, trade_id) + if trade is None: + raise ValueError("trade_id does not exist") + if trade.is_invalidated: + raise ValueError("trade is already invalidated") + trade.is_invalidated = True + trade.invalidated_at = datetime.now(timezone.utc) + session.add(trade) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("invalidate_trade integrity error") from e + session.refresh(trade) + return trade + + +def replace_trade( + session: Session, old_trade_id: int, new_trade_data: Mapping +) -> models.Trades: + invalidate_trade(session, old_trade_id) + if hasattr(new_trade_data, "dict"): + data = new_trade_data.dict(exclude_unset=True) + else: + data = dict(new_trade_data) + data["replaced_by_trade_id"] = old_trade_id + new_trade = create_trade(session, data) + return new_trade + + # Cycles def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: if hasattr(cycle_data, "dict"): @@ -156,7 +216,45 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: return c +IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} + + +def update_cycle( + session: Session, cycle_id: int, update_data: Mapping +) -> models.Cycles: + cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) + if cycle is None: + raise ValueError("cycle_id does not exist") + if hasattr(update_data, "dict"): + data = update_data.dict(exclude_unset=True) + else: + data = dict(update_data) + + allowed = {c.name for c in models.Cycles.__table__.columns} + for k, v in data.items(): + if k in IMMUTABLE_CYCLE_FIELDS: + raise ValueError(f"field {k!r} is immutable") + if k not in allowed: + continue + if k == "underlying_currency": + v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") + if k == "status": + v = _check_enum(models.CycleStatus, v, "status") + setattr(cycle, k, v) + session.add(cycle) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_cycle integrity error") from e + session.refresh(cycle) + return cycle + + # Users +IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} + + def create_user(session: Session, user_data: Mapping) -> models.Users: if hasattr(user_data, "dict"): data = user_data.dict(exclude_unset=True) @@ -178,3 +276,27 @@ def create_user(session: Session, user_data: Mapping) -> models.Users: raise ValueError("create_user integrity error") from e session.refresh(u) return u + + +def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: + user: models.Users | None = session.get(models.Users, user_id) + if user is None: + raise ValueError("user_id does not exist") + if hasattr(update_data, "dict"): + data = update_data.dict(exclude_unset=True) + else: + data = dict(update_data) + allowed = {c.name for c in models.Users.__table__.columns} + for k, v in data.items(): + if k in IMMUTABLE_USER_FIELDS: + raise ValueError(f"field {k!r} is immutable") + if k in allowed: + setattr(user, k, v) + session.add(user) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_user integrity error") from e + session.refresh(user) + return user diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index d9d8944..0e5857f 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -94,6 +94,14 @@ class Trades(SQLModel, table=True): gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) + is_invalidated: bool = Field(default=False, nullable=False) + invalidated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True), nullable=True) + ) + replaced_by_trade_id: int | None = Field( + default=None, foreign_key="trades.id", nullable=True + ) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) cycle_id: int | None = Field( default=None, foreign_key="cycles.id", nullable=True, index=True )