Compare commits
2 Commits
0bc85c1faf
...
39fc10572e
| Author | SHA1 | Date | |
|---|---|---|---|
| 39fc10572e | |||
| 2fbf1e9e01 |
125
.github/script/compare_models.py
vendored
Normal file
125
.github/script/compare_models.py
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
import ast
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the "backend" package directory is on sys.path so `import trading_journal` works
|
||||
# Find repo root by walking upwards until we find a "backend" directory.
|
||||
p = Path(__file__).resolve()
|
||||
repo_root = None
|
||||
while True:
|
||||
if (p / "backend").exists():
|
||||
repo_root = p
|
||||
break
|
||||
if p.parent == p:
|
||||
break
|
||||
p = p.parent
|
||||
# fallback: two levels up (covers common .github/script layout)
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
backend_dir = repo_root / "backend"
|
||||
if backend_dir.exists():
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
|
||||
def load_struct(path: Path):
|
||||
src = path.read_text(encoding="utf-8")
|
||||
mod = ast.parse(src)
|
||||
out = {}
|
||||
for node in mod.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
# detect SQLModel table classes:
|
||||
is_table = any(
|
||||
(
|
||||
kw.arg == "table"
|
||||
and isinstance(kw.value, ast.Constant)
|
||||
and kw.value.value is True
|
||||
)
|
||||
for kw in getattr(node, "keywords", [])
|
||||
) or any(
|
||||
getattr(b, "id", None) == "SQLModel"
|
||||
or getattr(getattr(b, "attr", None), "id", None) == "SQLModel"
|
||||
for b in getattr(node, "bases", [])
|
||||
)
|
||||
if not is_table:
|
||||
continue
|
||||
fields = []
|
||||
for item in node.body:
|
||||
# annotated assignment: name: type = value
|
||||
if isinstance(item, ast.AnnAssign) and getattr(item.target, "id", None):
|
||||
name = item.target.id
|
||||
ann = (
|
||||
ast.unparse(item.annotation)
|
||||
if item.annotation is not None
|
||||
else None
|
||||
)
|
||||
val = ast.unparse(item.value) if item.value is not None else None
|
||||
fields.append((name, ann, val))
|
||||
# simple assign: name = value (rare for Field, but include)
|
||||
elif isinstance(item, ast.Assign):
|
||||
for t in item.targets:
|
||||
if getattr(t, "id", None):
|
||||
name = t.id
|
||||
ann = None
|
||||
val = (
|
||||
ast.unparse(item.value) if item.value is not None else None
|
||||
)
|
||||
fields.append((name, ann, val))
|
||||
# sort fields by name for deterministic comparison
|
||||
fields.sort(key=lambda x: x[0])
|
||||
out[node.name] = fields
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) == 1:
|
||||
print(
|
||||
"usage: compare_models.py <live_model_path> [snapshot_model_path]",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
live = Path(sys.argv[1])
|
||||
snap = None
|
||||
if len(sys.argv) >= 3:
|
||||
snap = Path(sys.argv[2])
|
||||
else:
|
||||
# auto-detect snapshot via db_migration.LATEST_VERSION
|
||||
try:
|
||||
import importlib
|
||||
|
||||
dbm = importlib.import_module("trading_journal.db_migration")
|
||||
latest = getattr(dbm, "LATEST_VERSION")
|
||||
snap = Path(live.parent) / f"models_v{latest}.py"
|
||||
except Exception as e:
|
||||
print("failed to determine snapshot path:", e, file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
if not live.exists() or not snap.exists():
|
||||
print(
|
||||
f"file missing: live={live.exists()} snap={snap.exists()}", file=sys.stderr
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
a = load_struct(live)
|
||||
b = load_struct(snap)
|
||||
if a != b:
|
||||
print("models mismatch\n")
|
||||
diff = {
|
||||
"live_only_classes": sorted(set(a) - set(b)),
|
||||
"snapshot_only_classes": sorted(set(b) - set(a)),
|
||||
"mismatched_classes": {},
|
||||
}
|
||||
for cls in set(a) & set(b):
|
||||
if a[cls] != b[cls]:
|
||||
diff["mismatched_classes"][cls] = {"live": a[cls], "snapshot": b[cls]}
|
||||
print(json.dumps(diff, indent=2, ensure_ascii=False))
|
||||
sys.exit(1)
|
||||
print("models match snapshot")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@@ -24,6 +24,10 @@ jobs:
|
||||
- name: Install deps
|
||||
run: pip install -r dev-requirements.txt
|
||||
|
||||
- name: Run models vs snapshot check
|
||||
run: |
|
||||
python .github/scripts/compare_models.py trading_journal/models.py
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -q
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
@@ -41,7 +41,9 @@ def make_user(session: Session, username: str = "testuser") -> int:
|
||||
return user.id
|
||||
|
||||
|
||||
def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||
def make_cycle(
|
||||
session: Session, user_id: int, friendly_name: str = "Test Cycle"
|
||||
) -> int:
|
||||
cycle = models.Cycles(
|
||||
user_id=user_id,
|
||||
friendly_name=friendly_name,
|
||||
@@ -57,7 +59,7 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||
|
||||
|
||||
def make_trade(
|
||||
session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
|
||||
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
|
||||
) -> int:
|
||||
trade = models.Trades(
|
||||
user_id=user_id,
|
||||
@@ -82,7 +84,7 @@ def make_trade(
|
||||
return trade.id
|
||||
|
||||
|
||||
def make_trade_by_trade_data(session, trade_data: dict) -> int:
|
||||
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
|
||||
trade = models.Trades(**trade_data)
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
@@ -90,6 +92,27 @@ def make_trade_by_trade_data(session, trade_data: dict) -> int:
|
||||
return trade.id
|
||||
|
||||
|
||||
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
|
||||
user_id = make_user(session, username="sessionuser")
|
||||
session_token_hash = "uniquesessiontokenhash"
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
login_session = models.Sessions(
|
||||
user_id=user_id,
|
||||
session_token_hash=session_token_hash,
|
||||
created_at=created_at,
|
||||
expires_at=created_at + timedelta(seconds=86400),
|
||||
last_seen_at=None,
|
||||
last_used_ip=None,
|
||||
user_agent=None,
|
||||
device_name=None,
|
||||
)
|
||||
session.add(login_session)
|
||||
session.commit()
|
||||
session.refresh(login_session)
|
||||
return login_session
|
||||
|
||||
|
||||
def test_create_trade_success_with_cycle(session: Session):
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
@@ -540,3 +563,60 @@ def test_update_user_immutable_fields(session: Session):
|
||||
or "field 'username' is immutable" in str(excinfo.value)
|
||||
or "field 'created_at' is immutable" in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
# login sessions
|
||||
def test_create_login_session(session: Session):
|
||||
user_id = make_user(session, username="testuser")
|
||||
session_token_hash = "sessiontokenhashed"
|
||||
login_session = crud.create_login_session(session, user_id, session_token_hash)
|
||||
assert login_session.id is not None
|
||||
assert login_session.user_id == user_id
|
||||
assert login_session.session_token_hash == session_token_hash
|
||||
|
||||
|
||||
def test_create_login_session_with_invalid_user(session: Session):
|
||||
invalid_user_id = 9999 # Assuming this user ID does not exist
|
||||
session_token_hash = "sessiontokenhashed"
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
crud.create_login_session(session, invalid_user_id, session_token_hash)
|
||||
assert "user_id does not exist" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_login_session_by_token_and_user_id(session: Session):
|
||||
now = datetime.now()
|
||||
created_session = make_login_session(session, now)
|
||||
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
assert fetched_session is not None
|
||||
assert fetched_session.id == created_session.id
|
||||
assert fetched_session.user_id == created_session.user_id
|
||||
assert fetched_session.session_token_hash == created_session.session_token_hash
|
||||
|
||||
|
||||
def test_update_login_session(session: Session):
|
||||
now = datetime.now()
|
||||
created_session = make_login_session(session, now)
|
||||
|
||||
update_data = {
|
||||
"last_seen_at": now + timedelta(hours=1),
|
||||
"last_used_ip": "192.168.1.1",
|
||||
}
|
||||
updated_session = crud.update_login_session(
|
||||
session, created_session.session_token_hash, update_data
|
||||
)
|
||||
assert updated_session is not None
|
||||
assert updated_session.last_seen_at == update_data["last_seen_at"]
|
||||
assert updated_session.last_used_ip == update_data["last_used_ip"]
|
||||
|
||||
|
||||
def test_delete_login_session(session: Session):
|
||||
now = datetime.now()
|
||||
created_session = make_login_session(session, now)
|
||||
|
||||
crud.delete_login_session(session, created_session.session_token_hash)
|
||||
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
|
||||
session, created_session.session_token_hash, created_session.user_id
|
||||
)
|
||||
assert deleted_session is None
|
||||
|
||||
@@ -63,6 +63,17 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"net_cash_flow_cents": ("INTEGER", 1, 0),
|
||||
"cycle_id": ("INTEGER", 0, 0),
|
||||
},
|
||||
"sessions": {
|
||||
"id": ("INTEGER", 1, 1),
|
||||
"user_id": ("INTEGER", 1, 0),
|
||||
"session_token_hash": ("TEXT", 1, 0),
|
||||
"created_at": ("DATETIME", 1, 0),
|
||||
"expires_at": ("DATETIME", 1, 0),
|
||||
"last_seen_at": ("DATETIME", 0, 0),
|
||||
"last_used_ip": ("TEXT", 0, 0),
|
||||
"user_agent": ("TEXT", 0, 0),
|
||||
"device_name": ("TEXT", 0, 0),
|
||||
},
|
||||
}
|
||||
|
||||
expected_fks = {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Mapping
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -300,3 +300,98 @@ def update_user(session: Session, user_id: int, update_data: Mapping) -> models.
|
||||
raise ValueError("update_user integrity error") from e
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# Sessions
|
||||
def create_login_session(
|
||||
session: Session,
|
||||
user_id: int,
|
||||
session_token_hash: str,
|
||||
session_length_seconds: int = 86400,
|
||||
last_used_ip: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
device_name: str | None = None,
|
||||
) -> models.Sessions:
|
||||
user: models.Users | None = session.get(models.Users, user_id)
|
||||
if user is None:
|
||||
raise ValueError("user_id does not exist")
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=session_length_seconds)
|
||||
s = models.Sessions(
|
||||
user_id=user.id,
|
||||
session_token_hash=session_token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
last_seen_at=now,
|
||||
last_used_ip=last_used_ip,
|
||||
user_agent=user_agent,
|
||||
device_name=device_name,
|
||||
)
|
||||
session.add(s)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("create_login_session integrity error") from e
|
||||
session.refresh(s)
|
||||
return s
|
||||
|
||||
|
||||
def get_login_session_by_token_hash_and_user_id(
|
||||
session: Session, session_token_hash: str, user_id: int
|
||||
) -> models.Sessions | None:
|
||||
statement = select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
models.Sessions.user_id == user_id,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
||||
|
||||
|
||||
def update_login_session(
|
||||
session: Session, session_token_hashed: str, update_session: Mapping
|
||||
) -> models.Sessions | None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hashed,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
).first()
|
||||
if login_session is None:
|
||||
return None
|
||||
if hasattr(update_session, "dict"):
|
||||
data = update_session.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(update_session)
|
||||
allowed = {c.name for c in models.Sessions.__table__.columns}
|
||||
for k, v in data.items():
|
||||
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
|
||||
setattr(login_session, k, v)
|
||||
session.add(login_session)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("update_login_session integrity error") from e
|
||||
session.refresh(login_session)
|
||||
return login_session
|
||||
|
||||
|
||||
def delete_login_session(session: Session, session_token_hash: str) -> None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
)
|
||||
).first()
|
||||
if login_session is None:
|
||||
return
|
||||
session.delete(login_session)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("delete_login_session integrity error") from e
|
||||
|
||||
@@ -26,6 +26,7 @@ def _mig_0_1(engine: Engine) -> None:
|
||||
models_v1.Trades.__table__,
|
||||
models_v1.Cycles.__table__,
|
||||
models_v1.Users.__table__,
|
||||
models_v1.Sessions.__table__,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
|
||||
@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
|
||||
Reference in New Issue
Block a user