From 616232b76d856acb5b2c9cc5e47cc782bd880fa7 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sat, 13 Sep 2025 21:14:14 +0200 Subject: [PATCH] add migration and enable ci --- .github/workflows/backend-ci.yml | 29 +++++++++++ backend/ruff.toml | 12 ++++- backend/tests/test_db.py | 2 +- backend/tests/test_db_migration.py | 77 ++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/backend-ci.yml create mode 100644 backend/tests/test_db_migration.py diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml new file mode 100644 index 0000000..c00d4cd --- /dev/null +++ b/.github/workflows/backend-ci.yml @@ -0,0 +1,29 @@ +name: Backend CI + +on: + push: + pull_request: + workflow_dispatch: + +jobs: + unit-test: + runs-on: ubuntu-latest + defaults: + run: + working-directory: backend + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install deps + run: pip install -r dev-requirements.txt + + - name: Run tests + run: | + pytest -q \ No newline at end of file diff --git a/backend/ruff.toml b/backend/ruff.toml index 4bcadc2..c5a0394 100644 --- a/backend/ruff.toml +++ b/backend/ruff.toml @@ -4,7 +4,17 @@ line-length = 144 [lint] select = ["ALL"] fixable = ["UP034", "I001"] -ignore = ["T201", "D", "ANN101", "TD002", "TD003", "TRY003", "EM102"] +ignore = [ + "T201", + "D", + "ANN101", + "TD002", + "TD003", + "TRY003", + "EM101", + "EM102", + "PLC0405", +] [lint.extend-per-file-ignores] "test*.py" = ["S101"] diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index 9990959..6d5c505 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -64,7 +64,7 @@ def test_rollback_on_exception() -> None: s.exec(text("CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);")) s.exec(text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(v="will_rollback")) # simulate handler error -> should trigger rollback in get_session - raise RuntimeError("simulated failure") # noqa: TRY003, EM101 + raise RuntimeError("simulated failure") # New session should not see the inserted row with session_ctx(db) as s2: diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py new file mode 100644 index 0000000..f68bdc4 --- /dev/null +++ b/backend/tests/test_db_migration.py @@ -0,0 +1,77 @@ +import pytest +from sqlalchemy import text +from sqlalchemy.pool import StaticPool +from sqlmodel import create_engine + +from trading_journal import db_migration + + +def _base_type_of(compiled: str) -> str: + """Return base type name (e.g. VARCHAR from VARCHAR(13)), upper-cased.""" + return compiled.split("(")[0].strip().upper() + + +def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: + # in-memory engine that preserves the same connection (StaticPool) + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + + # ensure target is the LATEST_VERSION we expect for the test + monkeypatch.setattr(db_migration, "LATEST_VERSION", 1) + + # run real migrations (will import trading_journal.models_v1 inside _mig_0_1) + final_version = db_migration.run_migrations(engine) + assert final_version == 1 + + # import snapshot models to validate schema + from trading_journal import models_v1 + + expected_tables = { + "trades": models_v1.Trades.__table__, + "cycles": models_v1.Cycles.__table__, + } + + with engine.connect() as conn: + # check tables exist + rows = conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ).fetchall() + found_tables = {r[0] for r in rows} + assert set(expected_tables.keys()).issubset(found_tables), ( + f"missing tables: {set(expected_tables.keys()) - found_tables}" + ) + + # check user_version + uv = conn.execute(text("PRAGMA user_version")).fetchone() + assert uv is not None + assert int(uv[0]) == 1 + + # validate columns and (base) types for each expected table + dialect = conn.dialect + for tbl_name, table in expected_tables.items(): + info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall() + # build mapping: column name -> declared type (upper) + actual_cols = {r[1]: (r[2] or "").upper() for r in info_rows} + for col in table.columns: + assert col.name in actual_cols, ( + f"column {col.name} missing in table {tbl_name}" + ) + # compile expected type against this dialect + try: + compiled = col.type.compile( + dialect=dialect + ) # e.g. VARCHAR(13), DATETIME + except Exception: + compiled = str(col.type) + expected_base = _base_type_of(compiled) + actual_type = actual_cols[col.name] + actual_base = _base_type_of(actual_type) if actual_type else "" + # accept either direction (some dialect vs sqlite naming differences) + assert (expected_base in actual_base) or ( + actual_base in expected_base + ), ( + f"type mismatch for {tbl_name}.{col.name}: expected {expected_base}, got {actual_base}" + )