From 2fbf1e9e0137bdb542769beec9650db8014acdd1 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 14:06:32 +0200 Subject: [PATCH 1/5] Add session db --- backend/tests/test_crud.py | 88 +++++++++++++++++++++- backend/tests/test_db_migration.py | 11 +++ backend/trading_journal/crud.py | 97 ++++++++++++++++++++++++- backend/trading_journal/db_migration.py | 1 + backend/trading_journal/models.py | 21 ++++++ backend/trading_journal/models_v1.py | 21 ++++++ 6 files changed, 234 insertions(+), 5 deletions(-) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 3ee2fce..b542d90 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import datetime +from datetime import datetime, timedelta, timezone import pytest from sqlalchemy import create_engine @@ -41,7 +41,9 @@ def make_user(session: Session, username: str = "testuser") -> int: return user.id -def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int: +def make_cycle( + session: Session, user_id: int, friendly_name: str = "Test Cycle" +) -> int: cycle = models.Cycles( user_id=user_id, friendly_name=friendly_name, @@ -57,7 +59,7 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int: def make_trade( - session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade" + session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade" ) -> int: trade = models.Trades( user_id=user_id, @@ -82,7 +84,7 @@ def make_trade( return trade.id -def make_trade_by_trade_data(session, trade_data: dict) -> int: +def make_trade_by_trade_data(session: Session, trade_data: dict) -> int: trade = models.Trades(**trade_data) session.add(trade) session.commit() @@ -90,6 +92,27 @@ def make_trade_by_trade_data(session, trade_data: dict) -> int: return trade.id +def make_login_session(session: Session, created_at: datetime) -> models.Sessions: + user_id = make_user(session, username="sessionuser") + session_token_hash = "uniquesessiontokenhash" + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + login_session = models.Sessions( + user_id=user_id, + session_token_hash=session_token_hash, + created_at=created_at, + expires_at=created_at + timedelta(seconds=86400), + last_seen_at=None, + last_used_ip=None, + user_agent=None, + device_name=None, + ) + session.add(login_session) + session.commit() + session.refresh(login_session) + return login_session + + def test_create_trade_success_with_cycle(session: Session): user_id = make_user(session) cycle_id = make_cycle(session, user_id) @@ -540,3 +563,60 @@ def test_update_user_immutable_fields(session: Session): or "field 'username' is immutable" in str(excinfo.value) or "field 'created_at' is immutable" in str(excinfo.value) ) + + +# login sessions +def test_create_login_session(session: Session): + user_id = make_user(session, username="testuser") + session_token_hash = "sessiontokenhashed" + login_session = crud.create_login_session(session, user_id, session_token_hash) + assert login_session.id is not None + assert login_session.user_id == user_id + assert login_session.session_token_hash == session_token_hash + + +def test_create_login_session_with_invalid_user(session: Session): + invalid_user_id = 9999 # Assuming this user ID does not exist + session_token_hash = "sessiontokenhashed" + with pytest.raises(ValueError) as excinfo: + crud.create_login_session(session, invalid_user_id, session_token_hash) + assert "user_id does not exist" in str(excinfo.value) + + +def test_get_login_session_by_token_and_user_id(session: Session): + now = datetime.now() + created_session = make_login_session(session, now) + fetched_session = crud.get_login_session_by_token_hash_and_user_id( + session, created_session.session_token_hash, created_session.user_id + ) + assert fetched_session is not None + assert fetched_session.id == created_session.id + assert fetched_session.user_id == created_session.user_id + assert fetched_session.session_token_hash == created_session.session_token_hash + + +def test_update_login_session(session: Session): + now = datetime.now() + created_session = make_login_session(session, now) + + update_data = { + "last_seen_at": now + timedelta(hours=1), + "last_used_ip": "192.168.1.1", + } + updated_session = crud.update_login_session( + session, created_session.session_token_hash, update_data + ) + assert updated_session is not None + assert updated_session.last_seen_at == update_data["last_seen_at"] + assert updated_session.last_used_ip == update_data["last_used_ip"] + + +def test_delete_login_session(session: Session): + now = datetime.now() + created_session = make_login_session(session, now) + + crud.delete_login_session(session, created_session.session_token_hash) + deleted_session = crud.get_login_session_by_token_hash_and_user_id( + session, created_session.session_token_hash, created_session.user_id + ) + assert deleted_session is None diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index e1c8850..d274e6a 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -63,6 +63,17 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "net_cash_flow_cents": ("INTEGER", 1, 0), "cycle_id": ("INTEGER", 0, 0), }, + "sessions": { + "id": ("INTEGER", 1, 1), + "user_id": ("INTEGER", 1, 0), + "session_token_hash": ("TEXT", 1, 0), + "created_at": ("DATETIME", 1, 0), + "expires_at": ("DATETIME", 1, 0), + "last_seen_at": ("DATETIME", 0, 0), + "last_used_ip": ("TEXT", 0, 0), + "user_agent": ("TEXT", 0, 0), + "device_name": ("TEXT", 0, 0), + }, } expected_fks = { diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 386c1f4..5f051d8 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Mapping from sqlalchemy.exc import IntegrityError @@ -300,3 +300,98 @@ def update_user(session: Session, user_id: int, update_data: Mapping) -> models. raise ValueError("update_user integrity error") from e session.refresh(user) return user + + +# Sessions +def create_login_session( + session: Session, + user_id: int, + session_token_hash: str, + session_length_seconds: int = 86400, + last_used_ip: str | None = None, + user_agent: str | None = None, + device_name: str | None = None, +) -> models.Sessions: + user: models.Users | None = session.get(models.Users, user_id) + if user is None: + raise ValueError("user_id does not exist") + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=session_length_seconds) + s = models.Sessions( + user_id=user.id, + session_token_hash=session_token_hash, + created_at=now, + expires_at=expires_at, + last_seen_at=now, + last_used_ip=last_used_ip, + user_agent=user_agent, + device_name=device_name, + ) + session.add(s) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_login_session integrity error") from e + session.refresh(s) + return s + + +def get_login_session_by_token_hash_and_user_id( + session: Session, session_token_hash: str, user_id: int +) -> models.Sessions | None: + statement = select(models.Sessions).where( + models.Sessions.session_token_hash == session_token_hash, + models.Sessions.user_id == user_id, + models.Sessions.expires_at > datetime.now(timezone.utc), + ) + + return session.exec(statement).first() + + +IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} + + +def update_login_session( + session: Session, session_token_hashed: str, update_session: Mapping +) -> models.Sessions | None: + login_session: models.Sessions | None = session.exec( + select(models.Sessions).where( + models.Sessions.session_token_hash == session_token_hashed, + models.Sessions.expires_at > datetime.now(timezone.utc), + ) + ).first() + if login_session is None: + return None + if hasattr(update_session, "dict"): + data = update_session.dict(exclude_unset=True) + else: + data = dict(update_session) + allowed = {c.name for c in models.Sessions.__table__.columns} + for k, v in data.items(): + if k in allowed and k not in IMMUTABLE_SESSION_FIELDS: + setattr(login_session, k, v) + session.add(login_session) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_login_session integrity error") from e + session.refresh(login_session) + return login_session + + +def delete_login_session(session: Session, session_token_hash: str) -> None: + login_session: models.Sessions | None = session.exec( + select(models.Sessions).where( + models.Sessions.session_token_hash == session_token_hash, + ) + ).first() + if login_session is None: + return + session.delete(login_session) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("delete_login_session integrity error") from e diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index c59e3b0..d55a6a0 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -26,6 +26,7 @@ def _mig_0_1(engine: Engine) -> None: models_v1.Trades.__table__, models_v1.Cycles.__table__, models_v1.Users.__table__, + models_v1.Sessions.__table__, ], ) diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 0e5857f..17de397 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -142,3 +142,24 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + + +class Sessions(SQLModel, table=True): + __tablename__ = "sessions" + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) + expires_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) + ) + last_seen_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True), nullable=True) + ) + last_used_ip: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 0e5857f..17de397 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -142,3 +142,24 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + + +class Sessions(SQLModel, table=True): + __tablename__ = "sessions" + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) + expires_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) + ) + last_seen_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True), nullable=True) + ) + last_used_ip: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) From 39fc10572ee2fc22e023e08e56bebb394866b927 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 14:13:05 +0200 Subject: [PATCH 2/5] add ci to compare models --- .github/script/compare_models.py | 125 +++++++++++++++++++++++++++++++ .github/workflows/backend-ci.yml | 4 + 2 files changed, 129 insertions(+) create mode 100644 .github/script/compare_models.py diff --git a/.github/script/compare_models.py b/.github/script/compare_models.py new file mode 100644 index 0000000..805853a --- /dev/null +++ b/.github/script/compare_models.py @@ -0,0 +1,125 @@ +import ast +import json +import sys +from pathlib import Path + +# Ensure the "backend" package directory is on sys.path so `import trading_journal` works +# Find repo root by walking upwards until we find a "backend" directory. +p = Path(__file__).resolve() +repo_root = None +while True: + if (p / "backend").exists(): + repo_root = p + break + if p.parent == p: + break + p = p.parent +# fallback: two levels up (covers common .github/script layout) +if repo_root is None: + repo_root = Path(__file__).resolve().parents[2] + +backend_dir = repo_root / "backend" +if backend_dir.exists(): + sys.path.insert(0, str(backend_dir)) + + +def load_struct(path: Path): + src = path.read_text(encoding="utf-8") + mod = ast.parse(src) + out = {} + for node in mod.body: + if not isinstance(node, ast.ClassDef): + continue + # detect SQLModel table classes: + is_table = any( + ( + kw.arg == "table" + and isinstance(kw.value, ast.Constant) + and kw.value.value is True + ) + for kw in getattr(node, "keywords", []) + ) or any( + getattr(b, "id", None) == "SQLModel" + or getattr(getattr(b, "attr", None), "id", None) == "SQLModel" + for b in getattr(node, "bases", []) + ) + if not is_table: + continue + fields = [] + for item in node.body: + # annotated assignment: name: type = value + if isinstance(item, ast.AnnAssign) and getattr(item.target, "id", None): + name = item.target.id + ann = ( + ast.unparse(item.annotation) + if item.annotation is not None + else None + ) + val = ast.unparse(item.value) if item.value is not None else None + fields.append((name, ann, val)) + # simple assign: name = value (rare for Field, but include) + elif isinstance(item, ast.Assign): + for t in item.targets: + if getattr(t, "id", None): + name = t.id + ann = None + val = ( + ast.unparse(item.value) if item.value is not None else None + ) + fields.append((name, ann, val)) + # sort fields by name for deterministic comparison + fields.sort(key=lambda x: x[0]) + out[node.name] = fields + return out + + +def main(): + if len(sys.argv) == 1: + print( + "usage: compare_models.py [snapshot_model_path]", + file=sys.stderr, + ) + sys.exit(2) + + live = Path(sys.argv[1]) + snap = None + if len(sys.argv) >= 3: + snap = Path(sys.argv[2]) + else: + # auto-detect snapshot via db_migration.LATEST_VERSION + try: + import importlib + + dbm = importlib.import_module("trading_journal.db_migration") + latest = getattr(dbm, "LATEST_VERSION") + snap = Path(live.parent) / f"models_v{latest}.py" + except Exception as e: + print("failed to determine snapshot path:", e, file=sys.stderr) + sys.exit(2) + + if not live.exists() or not snap.exists(): + print( + f"file missing: live={live.exists()} snap={snap.exists()}", file=sys.stderr + ) + sys.exit(2) + + a = load_struct(live) + b = load_struct(snap) + if a != b: + print("models mismatch\n") + diff = { + "live_only_classes": sorted(set(a) - set(b)), + "snapshot_only_classes": sorted(set(b) - set(a)), + "mismatched_classes": {}, + } + for cls in set(a) & set(b): + if a[cls] != b[cls]: + diff["mismatched_classes"][cls] = {"live": a[cls], "snapshot": b[cls]} + print(json.dumps(diff, indent=2, ensure_ascii=False)) + sys.exit(1) + print("models match snapshot") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index c00d4cd..e910897 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -24,6 +24,10 @@ jobs: - name: Install deps run: pip install -r dev-requirements.txt + - name: Run models vs snapshot check + run: | + python .github/scripts/compare_models.py trading_journal/models.py + - name: Run tests run: | pytest -q \ No newline at end of file From afd342b31f79dcb74671b4249e59f3abf572d637 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 14:14:51 +0200 Subject: [PATCH 3/5] cwd for compare model --- .github/workflows/backend-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index e910897..003e4a5 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -25,8 +25,9 @@ jobs: run: pip install -r dev-requirements.txt - name: Run models vs snapshot check + working-directory: .. run: | - python .github/scripts/compare_models.py trading_journal/models.py + python .github/scripts/compare_models.py backend/trading_journal/models.py - name: Run tests run: | From bc264c801425c54450089b409fda3d5fa314e044 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 14:16:25 +0200 Subject: [PATCH 4/5] use workspace --- .github/workflows/backend-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 003e4a5..5555997 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -25,7 +25,7 @@ jobs: run: pip install -r dev-requirements.txt - name: Run models vs snapshot check - working-directory: .. + working-directory: ${{ github.workspace }} run: | python .github/scripts/compare_models.py backend/trading_journal/models.py From 9f3010d3007cc3254451780f70abadb5aeb15c35 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 14:18:21 +0200 Subject: [PATCH 5/5] fix path typo --- .github/workflows/backend-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 5555997..e7ddfb5 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -27,7 +27,7 @@ jobs: - name: Run models vs snapshot check working-directory: ${{ github.workspace }} run: | - python .github/scripts/compare_models.py backend/trading_journal/models.py + python .github/script/compare_models.py backend/trading_journal/models.py - name: Run tests run: |