Add session db
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Mapping
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -300,3 +300,98 @@ def update_user(session: Session, user_id: int, update_data: Mapping) -> models.
|
||||
raise ValueError("update_user integrity error") from e
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# Sessions
|
||||
def create_login_session(
|
||||
session: Session,
|
||||
user_id: int,
|
||||
session_token_hash: str,
|
||||
session_length_seconds: int = 86400,
|
||||
last_used_ip: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
device_name: str | None = None,
|
||||
) -> models.Sessions:
|
||||
user: models.Users | None = session.get(models.Users, user_id)
|
||||
if user is None:
|
||||
raise ValueError("user_id does not exist")
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=session_length_seconds)
|
||||
s = models.Sessions(
|
||||
user_id=user.id,
|
||||
session_token_hash=session_token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
last_seen_at=now,
|
||||
last_used_ip=last_used_ip,
|
||||
user_agent=user_agent,
|
||||
device_name=device_name,
|
||||
)
|
||||
session.add(s)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("create_login_session integrity error") from e
|
||||
session.refresh(s)
|
||||
return s
|
||||
|
||||
|
||||
def get_login_session_by_token_hash_and_user_id(
|
||||
session: Session, session_token_hash: str, user_id: int
|
||||
) -> models.Sessions | None:
|
||||
statement = select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
models.Sessions.user_id == user_id,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"}
|
||||
|
||||
|
||||
def update_login_session(
|
||||
session: Session, session_token_hashed: str, update_session: Mapping
|
||||
) -> models.Sessions | None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hashed,
|
||||
models.Sessions.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
).first()
|
||||
if login_session is None:
|
||||
return None
|
||||
if hasattr(update_session, "dict"):
|
||||
data = update_session.dict(exclude_unset=True)
|
||||
else:
|
||||
data = dict(update_session)
|
||||
allowed = {c.name for c in models.Sessions.__table__.columns}
|
||||
for k, v in data.items():
|
||||
if k in allowed and k not in IMMUTABLE_SESSION_FIELDS:
|
||||
setattr(login_session, k, v)
|
||||
session.add(login_session)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("update_login_session integrity error") from e
|
||||
session.refresh(login_session)
|
||||
return login_session
|
||||
|
||||
|
||||
def delete_login_session(session: Session, session_token_hash: str) -> None:
|
||||
login_session: models.Sessions | None = session.exec(
|
||||
select(models.Sessions).where(
|
||||
models.Sessions.session_token_hash == session_token_hash,
|
||||
)
|
||||
).first()
|
||||
if login_session is None:
|
||||
return
|
||||
session.delete(login_session)
|
||||
try:
|
||||
session.flush()
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise ValueError("delete_login_session integrity error") from e
|
||||
|
||||
@@ -26,6 +26,7 @@ def _mig_0_1(engine: Engine) -> None:
|
||||
models_v1.Trades.__table__,
|
||||
models_v1.Cycles.__table__,
|
||||
models_v1.Users.__table__,
|
||||
models_v1.Sessions.__table__,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
|
||||
@@ -142,3 +142,24 @@ class Users(SQLModel, table=True):
|
||||
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
password_hash: str = Field(sa_column=Column(Text, nullable=False))
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
|
||||
|
||||
class Sessions(SQLModel, table=True):
|
||||
__tablename__ = "sessions"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
|
||||
session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False)
|
||||
)
|
||||
expires_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
)
|
||||
last_seen_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
last_used_ip: str | None = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
|
||||
Reference in New Issue
Block a user