BrickTracker/bricktracker/sql.py

377 lines
11 KiB
Python
Raw Normal View History

2025-01-17 11:03:00 +01:00
import logging
import os
import sqlite3
from typing import Any, Final, Tuple
2025-01-17 11:03:00 +01:00
from flask import current_app, g
from jinja2 import Environment, FileSystemLoader
from werkzeug.datastructures import FileStorage
from .exceptions import DatabaseException
from .sql_counter import BrickCounter
from .sql_migration_list import BrickSQLMigrationList
from .sql_stats import BrickSQLStats
from .version import __database_version__
2025-01-17 11:03:00 +01:00
logger = logging.getLogger(__name__)
G_CONNECTION: Final[str] = 'database_connection'
G_ENVIRONMENT: Final[str] = 'database_environment'
G_DEFER: Final[str] = 'database_defer'
G_STATS: Final[str] = 'database_stats'
COUNTERS: Final[list[BrickCounter]] = [
BrickCounter('Sets', 'sets', icon='hashtag'),
BrickCounter('Minifigures', 'minifigures', icon='group-line'),
BrickCounter('Parts', 'inventory', icon='shapes-line'),
BrickCounter('Missing', 'missing', icon='error-warning-line'),
BrickCounter('Wishlist', 'wishlist', icon='gift-line'),
]
2025-01-17 11:03:00 +01:00
# SQLite3 client with our extra features
class BrickSQL(object):
connection: sqlite3.Connection
cursor: sqlite3.Cursor
stats: BrickSQLStats
version: int
2025-01-17 11:03:00 +01:00
def __init__(self, /, *, failsafe: bool = False):
2025-01-17 11:03:00 +01:00
# Instantiate the database connection in the Flask
# application context so that it can be used by all
# requests without re-opening connections
connection = getattr(g, G_CONNECTION, None)
2025-01-17 11:03:00 +01:00
# Grab the existing connection if it exists
if connection is not None:
self.connection = connection
self.stats = getattr(g, G_STATS, BrickSQLStats())
# Grab a cursor
self.cursor = self.connection.cursor()
2025-01-17 11:03:00 +01:00
else:
# Instantiate the stats
self.stats = BrickSQLStats()
# Stats: connect
self.stats.connect += 1
logger.debug('SQLite3: connect')
self.connection = sqlite3.connect(
current_app.config['DATABASE_PATH']
2025-01-17 11:03:00 +01:00
)
# Setup the row factory to get pseudo-dicts rather than tuples
self.connection.row_factory = sqlite3.Row
# Grab a cursor
self.cursor = self.connection.cursor()
# Grab the version and check
try:
version = self.fetchone('schema/get_version')
if version is None:
raise Exception('version is None')
self.version = version[0]
except Exception as e:
self.version = 0
raise DatabaseException('Could not get the database version: {error}'.format( # noqa: E501
error=str(e)
))
2025-01-17 11:03:00 +01:00
if self.upgrade_too_far():
raise DatabaseException('Your database version ({version}) is too far ahead for this version of the application. Expected at most {required}'.format( # noqa: E501
version=self.version,
required=__database_version__,
))
# Debug: Attach the debugger
# Uncomment manually because this is ultra verbose
# self.connection.set_trace_callback(print)
# Save the connection globally for later use
setattr(g, G_CONNECTION, self.connection)
setattr(g, G_STATS, self.stats)
if not failsafe:
if self.upgrade_needed():
raise DatabaseException('Your database need to be upgraded from version {version} to version {required}'.format( # noqa: E501
version=self.version,
required=__database_version__,
))
2025-01-17 11:03:00 +01:00
# Clear the defer stack
def clear_defer(self, /) -> None:
setattr(g, G_DEFER, [])
2025-01-17 11:03:00 +01:00
# Shorthand to commit
def commit(self, /) -> None:
# Stats: commit
self.stats.commit += 1
# Process the defered stack
for item in self.get_defer():
self.raw_execute(item[0], item[1])
self.clear_defer()
logger.debug('SQLite3: commit')
return self.connection.commit()
# Count the database records
def count_records(self) -> list[BrickCounter]:
for counter in COUNTERS:
# Failsafe this one
try:
record = self.fetchone('schema/count', table=counter.table)
if record is not None:
counter.count = record['count']
except Exception:
pass
return COUNTERS
2025-01-17 11:03:00 +01:00
# Defer a call to execute
def defer(self, query: str, parameters: dict[str, Any], /):
defer = self.get_defer()
logger.debug('SQLite3: defer execute')
# Add the query and parameters to the defer stack
defer.append((query, parameters))
# Save the defer stack
setattr(g, G_DEFER, defer)
2025-01-17 11:03:00 +01:00
# Shorthand to execute, returning number of affected rows
def execute(
self,
query: str,
/,
*,
2025-01-17 11:03:00 +01:00
parameters: dict[str, Any] = {},
defer: bool = False,
**context: Any,
2025-01-17 11:03:00 +01:00
) -> Tuple[int, str]:
# Stats: execute
self.stats.execute += 1
# Load the query
query = self.load_query(query, **context)
# Defer
if defer:
self.defer(query, parameters)
return -1, query
else:
result = self.raw_execute(query, parameters)
# Stats: changed
if result.rowcount > 0:
self.stats.changed += result.rowcount
return result.rowcount, query
# Shorthand to executescript
def executescript(self, query: str, /, **context: Any) -> None:
2025-01-17 11:03:00 +01:00
# Load the query
query = self.load_query(query, **context)
# Stats: executescript
self.stats.executescript += 1
logger.debug('SQLite3: executescript')
self.cursor.executescript(query)
# Shorthand to execute and commit
def execute_and_commit(
self,
query: str,
/,
*,
2025-01-17 11:03:00 +01:00
parameters: dict[str, Any] = {},
**context: Any,
2025-01-17 11:03:00 +01:00
) -> Tuple[int, str]:
rows, query = self.execute(query, parameters=parameters, **context)
self.commit()
return rows, query
# Shorthand to execute and fetchall
def fetchall(
self,
query: str,
/,
*,
2025-01-17 11:03:00 +01:00
parameters: dict[str, Any] = {},
**context: Any,
2025-01-17 11:03:00 +01:00
) -> list[sqlite3.Row]:
_, query = self.execute(query, parameters=parameters, **context)
# Stats: fetchall
self.stats.fetchall += 1
2025-01-17 15:34:49 +01:00
logger.debug('SQLite3: fetchall')
2025-01-17 11:03:00 +01:00
records = self.cursor.fetchall()
# Stats: fetched
self.stats.fetched += len(records)
return records
# Shorthand to execute and fetchone
def fetchone(
self,
query: str,
/,
*,
2025-01-17 11:03:00 +01:00
parameters: dict[str, Any] = {},
**context: Any,
2025-01-17 11:03:00 +01:00
) -> sqlite3.Row | None:
_, query = self.execute(query, parameters=parameters, **context)
# Stats: fetchone
self.stats.fetchone += 1
2025-01-17 15:34:49 +01:00
logger.debug('SQLite3: fetchone')
2025-01-17 11:03:00 +01:00
record = self.cursor.fetchone()
# Stats: fetched
if record is not None:
self.stats.fetched += len(record)
return record
# Grab the defer stack
def get_defer(self, /) -> list[Tuple[str, dict[str, Any]]]:
defer: list[Tuple[str, dict[str, Any]]] = getattr(g, G_DEFER, [])
2025-01-17 11:03:00 +01:00
return defer
# Load a query by name
def load_query(self, name: str, /, **context: Any) -> str:
2025-01-17 11:03:00 +01:00
# Grab the existing environment if it exists
environment = getattr(g, G_ENVIRONMENT, None)
2025-01-17 11:03:00 +01:00
# Instantiate Jinja environment for SQL files
if environment is None:
environment = Environment(
loader=FileSystemLoader(
os.path.join(os.path.dirname(__file__), 'sql/')
)
)
# Save the environment globally for later use
setattr(g, G_ENVIRONMENT, environment)
2025-01-17 11:03:00 +01:00
# Grab the template
logger.debug('SQLite: loading {name} (context: {context})'.format(
name=name,
context=context,
))
template = environment.get_template('{name}.sql'.format(
name=name,
))
return template.render(**context)
# Raw execute the query without any options
def raw_execute(
self,
query: str,
parameters: dict[str, Any],
/
2025-01-17 11:03:00 +01:00
) -> sqlite3.Cursor:
logger.debug('SQLite3: execute: {query}'.format(
query=BrickSQL.clean_query(query)
))
return self.cursor.execute(query, parameters)
# Upgrade the database
def upgrade(self) -> None:
if self.upgrade_needed():
for pending in BrickSQLMigrationList().pending(self.version):
logger.info('Applying migration {version}'.format(
version=pending.version)
)
self.executescript(pending.get_query())
self.execute('schema/set_version', version=pending.version)
# Tells whether the database needs upgrade
def upgrade_needed(self) -> bool:
return self.version < __database_version__
# Tells whether the database is too far
def upgrade_too_far(self) -> bool:
return self.version > __database_version__
2025-01-17 11:03:00 +01:00
# Clean the query for debugging
@staticmethod
def clean_query(query: str, /) -> str:
cleaned: list[str] = []
for line in query.splitlines():
# Keep the non-comment side
line, sep, comment = line.partition('--')
# Clean the non-comment side
line = line.strip()
if line:
cleaned.append(line)
return ' '.join(cleaned)
# Delete the database
@staticmethod
def delete() -> None:
os.remove(current_app.config['DATABASE_PATH'])
2025-01-17 11:03:00 +01:00
# Info
logger.info('The database has been deleted')
# Drop the database
@staticmethod
def drop() -> None:
BrickSQL().executescript('schema/drop')
# Info
logger.info('The database has been dropped')
# Initialize the database
@staticmethod
def initialize() -> None:
BrickSQL().executescript('migrations/init')
# Info
logger.info('The database has been initialized')
# Replace the database with a new file
@staticmethod
def upload(file: FileStorage, /) -> None:
file.save(current_app.config['DATABASE_PATH'])
2025-01-17 11:03:00 +01:00
# Info
logger.info('The database has been imported using file {file}'.format(
file=file.filename
))
# Close all existing SQLite3 connections
def close() -> None:
connection: sqlite3.Connection | None = getattr(g, G_CONNECTION, None)
2025-01-17 11:03:00 +01:00
if connection is not None:
2025-01-17 11:03:00 +01:00
logger.debug('SQLite3: close')
connection.close()
2025-01-17 11:03:00 +01:00
# Remove the database from the context
delattr(g, G_CONNECTION)