Files
trading-journal/backend/trading_journal/db.py

85 lines
2.7 KiB
Python
Raw Normal View History

2025-09-12 21:19:36 +00:00
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from sqlalchemy import event
from sqlalchemy.pool import StaticPool
2025-09-13 12:58:46 +02:00
from sqlmodel import Session, create_engine
2025-09-13 18:46:16 +02:00
from trading_journal import db_migration
2025-09-12 21:19:36 +00:00
if TYPE_CHECKING:
from collections.abc import Generator
2025-09-13 12:58:46 +02:00
from sqlite3 import Connection as DBAPIConnection
2025-09-12 21:19:36 +00:00
class Database:
2025-09-14 15:40:11 +02:00
def __init__(
self,
database_url: str | None = None,
*,
echo: bool = False,
connect_args: dict | None = None,
) -> None:
2025-09-12 21:19:36 +00:00
self._database_url = database_url or "sqlite:///:memory:"
2025-09-19 15:30:41 +02:00
default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {}
2025-09-12 21:19:36 +00:00
merged_connect = {**default_connect, **(connect_args or {})}
if self._database_url == "sqlite:///:memory:":
logger = logging.getLogger(__name__)
2025-09-14 15:40:11 +02:00
logger.warning(
2025-09-19 15:30:41 +02:00
"Using in-memory SQLite database; all data will be lost when the application stops.",
2025-09-14 15:40:11 +02:00
)
self._engine = create_engine(
self._database_url,
echo=echo,
connect_args=merged_connect,
poolclass=StaticPool,
)
2025-09-12 21:19:36 +00:00
else:
2025-09-19 15:30:41 +02:00
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect)
2025-09-12 21:19:36 +00:00
if self._database_url.startswith("sqlite"):
2025-09-19 15:30:41 +02:00
def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None:
2025-09-12 21:19:36 +00:00
try:
cur = dbapi_conn.cursor()
2025-09-13 18:46:16 +02:00
cur.execute("PRAGMA journal_mode=WAL;")
cur.execute("PRAGMA synchronous=NORMAL;")
2025-09-12 21:19:36 +00:00
cur.execute("PRAGMA foreign_keys=ON;")
cur.execute("PRAGMA busy_timeout=30000;")
cur.close()
except Exception:
logger = logging.getLogger(__name__)
logger.exception("Failed to set sqlite pragmas on new connection: ")
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
def init_db(self) -> None:
2025-09-13 18:46:16 +02:00
db_migration.run_migrations(self._engine)
2025-09-12 21:19:36 +00:00
def get_session(self) -> Generator[Session, None, None]:
session = Session(self._engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def dispose(self) -> None:
self._engine.dispose()
2025-09-14 15:40:11 +02:00
def create_database(
database_url: str | None = None,
*,
echo: bool = False,
connect_args: dict | None = None,
) -> Database:
2025-09-12 21:19:36 +00:00
return Database(database_url, echo=echo, connect_args=connect_args)