Add session db

This commit is contained in:
2025-09-19 14:06:32 +02:00
parent 0bc85c1faf
commit 2fbf1e9e01
6 changed files with 234 additions and 5 deletions

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from datetime import datetime
from datetime import datetime, timedelta, timezone
import pytest
from sqlalchemy import create_engine
@@ -41,7 +41,9 @@ def make_user(session: Session, username: str = "testuser") -> int:
return user.id
def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
def make_cycle(
session: Session, user_id: int, friendly_name: str = "Test Cycle"
) -> int:
cycle = models.Cycles(
user_id=user_id,
friendly_name=friendly_name,
@@ -57,7 +59,7 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
def make_trade(
session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
) -> int:
trade = models.Trades(
user_id=user_id,
@@ -82,7 +84,7 @@ def make_trade(
return trade.id
def make_trade_by_trade_data(session, trade_data: dict) -> int:
def make_trade_by_trade_data(session: Session, trade_data: dict) -> int:
trade = models.Trades(**trade_data)
session.add(trade)
session.commit()
@@ -90,6 +92,27 @@ def make_trade_by_trade_data(session, trade_data: dict) -> int:
return trade.id
def make_login_session(session: Session, created_at: datetime) -> models.Sessions:
user_id = make_user(session, username="sessionuser")
session_token_hash = "uniquesessiontokenhash"
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
login_session = models.Sessions(
user_id=user_id,
session_token_hash=session_token_hash,
created_at=created_at,
expires_at=created_at + timedelta(seconds=86400),
last_seen_at=None,
last_used_ip=None,
user_agent=None,
device_name=None,
)
session.add(login_session)
session.commit()
session.refresh(login_session)
return login_session
def test_create_trade_success_with_cycle(session: Session):
user_id = make_user(session)
cycle_id = make_cycle(session, user_id)
@@ -540,3 +563,60 @@ def test_update_user_immutable_fields(session: Session):
or "field 'username' is immutable" in str(excinfo.value)
or "field 'created_at' is immutable" in str(excinfo.value)
)
# login sessions
def test_create_login_session(session: Session):
user_id = make_user(session, username="testuser")
session_token_hash = "sessiontokenhashed"
login_session = crud.create_login_session(session, user_id, session_token_hash)
assert login_session.id is not None
assert login_session.user_id == user_id
assert login_session.session_token_hash == session_token_hash
def test_create_login_session_with_invalid_user(session: Session):
invalid_user_id = 9999 # Assuming this user ID does not exist
session_token_hash = "sessiontokenhashed"
with pytest.raises(ValueError) as excinfo:
crud.create_login_session(session, invalid_user_id, session_token_hash)
assert "user_id does not exist" in str(excinfo.value)
def test_get_login_session_by_token_and_user_id(session: Session):
now = datetime.now()
created_session = make_login_session(session, now)
fetched_session = crud.get_login_session_by_token_hash_and_user_id(
session, created_session.session_token_hash, created_session.user_id
)
assert fetched_session is not None
assert fetched_session.id == created_session.id
assert fetched_session.user_id == created_session.user_id
assert fetched_session.session_token_hash == created_session.session_token_hash
def test_update_login_session(session: Session):
now = datetime.now()
created_session = make_login_session(session, now)
update_data = {
"last_seen_at": now + timedelta(hours=1),
"last_used_ip": "192.168.1.1",
}
updated_session = crud.update_login_session(
session, created_session.session_token_hash, update_data
)
assert updated_session is not None
assert updated_session.last_seen_at == update_data["last_seen_at"]
assert updated_session.last_used_ip == update_data["last_used_ip"]
def test_delete_login_session(session: Session):
now = datetime.now()
created_session = make_login_session(session, now)
crud.delete_login_session(session, created_session.session_token_hash)
deleted_session = crud.get_login_session_by_token_hash_and_user_id(
session, created_session.session_token_hash, created_session.user_id
)
assert deleted_session is None

View File

@@ -63,6 +63,17 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"net_cash_flow_cents": ("INTEGER", 1, 0),
"cycle_id": ("INTEGER", 0, 0),
},
"sessions": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
"session_token_hash": ("TEXT", 1, 0),
"created_at": ("DATETIME", 1, 0),
"expires_at": ("DATETIME", 1, 0),
"last_seen_at": ("DATETIME", 0, 0),
"last_used_ip": ("TEXT", 0, 0),
"user_agent": ("TEXT", 0, 0),
"device_name": ("TEXT", 0, 0),
},
}
expected_fks = {

View File

@@ -1,4 +1,4 @@
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Mapping
from sqlalchemy.exc import IntegrityError
@@ -300,3 +300,98 @@ def update_user(session: Session, user_id: int, update_data: Mapping) -> models.
raise ValueError("update_user integrity error") from e
session.refresh(user)
return user
# Sessions
def create_login_session(
session: Session,
user_id: int,
session_token_hash: str,
session_length_seconds: int = 86400,
last_used_ip: str | None = None,
user_agent: str | None = None,
device_name: str | None = None,
) -> models.Sessions:
user: models.Users | None = session.get(models.Users, user_id)
if user is None:
raise ValueError("user_id does not exist")
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=session_length_seconds)
s = models.Sessions(
user_id=user.id,
session_token_hash=session_token_hash,
created_at=now,
expires_at=expires_at,
last_seen_at=now,
last_used_ip=last_used_ip,
user_agent=user_agent,
device_name=device_name,
)
session.add(s)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_login_session integrity error") from e
session.refresh(s)
return s
def get_login_session_by_token_hash_and_user_id(
session: Session, session_token_hash: str, user_id: int
) -> models.Sessions | None:
statement = select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hash,
models.Sessions.user_id == user_id,
models.Sessions.expires_at > datetime.now(timezone.utc),
)
return session.exec(statement).first()
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
def update_login_session(
session: Session, session_token_hashed: str, update_session: Mapping
) -> models.Sessions | None:
login_session: models.Sessions | None = session.exec(
select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hashed,
models.Sessions.expires_at > datetime.now(timezone.utc),
)
).first()
if login_session is None:
return None
if hasattr(update_session, "dict"):
data = update_session.dict(exclude_unset=True)
else:
data = dict(update_session)
allowed = {c.name for c in models.Sessions.__table__.columns}
for k, v in data.items():
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
setattr(login_session, k, v)
session.add(login_session)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("update_login_session integrity error") from e
session.refresh(login_session)
return login_session
def delete_login_session(session: Session, session_token_hash: str) -> None:
login_session: models.Sessions | None = session.exec(
select(models.Sessions).where(
models.Sessions.session_token_hash == session_token_hash,
)
).first()
if login_session is None:
return
session.delete(login_session)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("delete_login_session integrity error") from e

View File

@@ -26,6 +26,7 @@ def _mig_0_1(engine: Engine) -> None:
models_v1.Trades.__table__,
models_v1.Cycles.__table__,
models_v1.Users.__table__,
models_v1.Sessions.__table__,
],
)

View File

@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)
class Sessions(SQLModel, table=True):
__tablename__ = "sessions"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expires_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
)
last_seen_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True)
)
last_used_ip: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))

View File

@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)
class Sessions(SQLModel, table=True):
__tablename__ = "sessions"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expires_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
)
last_seen_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True)
)
last_used_ip: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))