import logging
import os
import sqlite3
from typing import Any, Final, Tuple

from flask import current_app, g
from jinja2 import Environment, FileSystemLoader
from werkzeug.datastructures import FileStorage

from .exceptions import DatabaseException
from .sql_counter import BrickCounter
from .sql_migration_list import BrickSQLMigrationList
from .sql_stats import BrickSQLStats
from .version import __database_version__

logger = logging.getLogger(__name__)

G_CONNECTION: Final[str] = 'database_connection'
G_ENVIRONMENT: Final[str] = 'database_environment'
G_DEFER: Final[str] = 'database_defer'
G_STATS: Final[str] = 'database_stats'


# SQLite3 client with our extra features
class BrickSQL(object):
    connection: sqlite3.Connection
    cursor: sqlite3.Cursor
    stats: BrickSQLStats
    version: int

    def __init__(self, /, *, failsafe: bool = False):
        # Instantiate the database connection in the Flask
        # application context so that it can be used by all
        # requests without re-opening connections
        connection = getattr(g, G_CONNECTION, None)

        # Grab the existing connection if it exists
        if connection is not None:
            self.connection = connection
            self.stats = getattr(g, G_STATS, BrickSQLStats())

            # Grab a cursor
            self.cursor = self.connection.cursor()
        else:
            # Instantiate the stats
            self.stats = BrickSQLStats()

            # Stats: connect
            self.stats.connect += 1

            logger.debug('SQLite3: connect')
            self.connection = sqlite3.connect(
                current_app.config['DATABASE_PATH']
            )

            # Setup the row factory to get pseudo-dicts rather than tuples
            self.connection.row_factory = sqlite3.Row

            # Grab a cursor
            self.cursor = self.connection.cursor()

            # Grab the version and check
            try:
                version = self.fetchone('schema/get_version')

                if version is None:
                    raise Exception('version is None')

                self.version = version[0]
            except Exception as e:
                self.version = 0

                raise DatabaseException('Could not get the database version: {error}'.format(  # noqa: E501
                    error=str(e)
                ))

            if self.upgrade_too_far():
                raise DatabaseException('Your database version ({version}) is too far ahead for this version of the application. Expected at most {required}'.format(  # noqa: E501
                    version=self.version,
                    required=__database_version__,
                ))

            # Debug: Attach the debugger
            # Uncomment manually because this is ultra verbose
            # self.connection.set_trace_callback(print)

            # Save the connection globally for later use
            setattr(g, G_CONNECTION, self.connection)
            setattr(g, G_STATS, self.stats)

            if not failsafe:
                if self.upgrade_needed():
                    raise DatabaseException('Your database need to be upgraded from version {version} to version {required}'.format(  # noqa: E501
                        version=self.version,
                        required=__database_version__,
                    ))

    # Clear the defer stack
    def clear_defer(self, /) -> None:
        setattr(g, G_DEFER, [])

    # Shorthand to commit
    def commit(self, /) -> None:
        # Stats: commit
        self.stats.commit += 1

        # Process the defered stack
        for item in self.get_defer():
            self.raw_execute(item[0], item[1])

        self.clear_defer()

        logger.debug('SQLite3: commit')
        return self.connection.commit()

    # Count the database records
    def count_records(self) -> list[BrickCounter]:
        counters: list[BrickCounter] = []

        # Get all tables
        for table in self.fetchall('schema/tables'):
            counter = BrickCounter(table['name'])

            # Failsafe this one
            try:
                record = self.fetchone('schema/count', table=counter.table)

                if record is not None:
                    counter.count = record['count']
            except Exception:
                pass

            counters.append(counter)

        return counters

    # Defer a call to execute
    def defer(self, query: str, parameters: dict[str, Any], /):
        defer = self.get_defer()

        logger.debug('SQLite3: defer execute')

        # Add the query and parameters to the defer stack
        defer.append((query, parameters))

        # Save the defer stack
        setattr(g, G_DEFER, defer)

    # Shorthand to execute, returning number of affected rows
    def execute(
        self,
        query: str,
        /,
        *,
        parameters: dict[str, Any] = {},
        defer: bool = False,
        **context: Any,
    ) -> Tuple[int, str]:
        # Stats: execute
        self.stats.execute += 1

        # Load the query
        query = self.load_query(query, **context)

        # Defer
        if defer:
            self.defer(query, parameters)

            return -1, query
        else:
            result = self.raw_execute(query, parameters)

            # Stats: changed
            if result.rowcount > 0:
                self.stats.changed += result.rowcount

            return result.rowcount, query

    # Shorthand to executescript
    def executescript(self, query: str, /, **context: Any) -> None:
        # Load the query
        query = self.load_query(query, **context)

        # Stats: executescript
        self.stats.executescript += 1

        logger.debug('SQLite3: executescript')
        self.cursor.executescript(query)

    # Shorthand to execute and commit
    def execute_and_commit(
        self,
        query: str,
        /,
        *,
        parameters: dict[str, Any] = {},
        **context: Any,
    ) -> Tuple[int, str]:
        rows, query = self.execute(query, parameters=parameters, **context)
        self.commit()

        return rows, query

    # Shorthand to execute and fetchall
    def fetchall(
        self,
        query: str,
        /,
        *,
        parameters: dict[str, Any] = {},
        **context: Any,
    ) -> list[sqlite3.Row]:
        _, query = self.execute(query, parameters=parameters, **context)

        # Stats: fetchall
        self.stats.fetchall += 1

        logger.debug('SQLite3: fetchall')
        records = self.cursor.fetchall()

        # Stats: fetched
        self.stats.fetched += len(records)

        return records

    # Shorthand to execute and fetchone
    def fetchone(
        self,
        query: str,
        /,
        *,
        parameters: dict[str, Any] = {},
        **context: Any,
    ) -> sqlite3.Row | None:
        _, query = self.execute(query, parameters=parameters, **context)

        # Stats: fetchone
        self.stats.fetchone += 1

        logger.debug('SQLite3: fetchone')
        record = self.cursor.fetchone()

        # Stats: fetched
        if record is not None:
            self.stats.fetched += len(record)

        return record

    # Grab the defer stack
    def get_defer(self, /) -> list[Tuple[str, dict[str, Any]]]:
        defer: list[Tuple[str, dict[str, Any]]] = getattr(g, G_DEFER, [])

        return defer

    # Load a query by name
    def load_query(self, name: str, /, **context: Any) -> str:
        # Grab the existing environment if it exists
        environment = getattr(g, G_ENVIRONMENT, None)

        # Instantiate Jinja environment for SQL files
        if environment is None:
            logger.debug('SQLite3: instantiating the Jinja loader')
            environment = Environment(
                loader=FileSystemLoader(
                    os.path.join(os.path.dirname(__file__), 'sql/')
                )
            )

            # Save the environment globally for later use
            setattr(g, G_ENVIRONMENT, environment)

        # Grab the template
        logger.debug('SQLite3: loading {name} (context: {context})'.format(
            name=name,
            context=context,
        ))
        template = environment.get_template('{name}.sql'.format(
            name=name,
        ))

        return template.render(**context)

    # Raw execute the query without any options
    def raw_execute(
        self,
        query: str,
        parameters: dict[str, Any],
        /
    ) -> sqlite3.Cursor:
        logger.debug('SQLite3: execute: {query}'.format(
            query=BrickSQL.clean_query(query)
        ))

        return self.cursor.execute(query, parameters)

    # Upgrade the database
    def upgrade(self) -> None:
        if self.upgrade_needed():
            for pending in BrickSQLMigrationList().pending(self.version):
                logger.info('Applying migration {version}'.format(
                    version=pending.version)
                )

                self.executescript(pending.get_query())
                self.execute('schema/set_version', version=pending.version)

    # Tells whether the database needs upgrade
    def upgrade_needed(self) -> bool:
        return self.version < __database_version__

    # Tells whether the database is too far
    def upgrade_too_far(self) -> bool:
        return self.version > __database_version__

    # Clean the query for debugging
    @staticmethod
    def clean_query(query: str, /) -> str:
        cleaned: list[str] = []

        for line in query.splitlines():
            # Keep the non-comment side
            line, sep, comment = line.partition('--')

            # Clean the non-comment side
            line = line.strip()

            if line:
                cleaned.append(line)

        return ' '.join(cleaned)

    # Delete the database
    @staticmethod
    def delete() -> None:
        os.remove(current_app.config['DATABASE_PATH'])

        # Info
        logger.info('The database has been deleted')

    # Drop the database
    @staticmethod
    def drop() -> None:
        BrickSQL().executescript('schema/drop')

        # Info
        logger.info('The database has been dropped')

    # Replace the database with a new file
    @staticmethod
    def upload(file: FileStorage, /) -> None:
        file.save(current_app.config['DATABASE_PATH'])

        # Info
        logger.info('The database has been imported using file {file}'.format(
            file=file.filename
        ))


# Close all existing SQLite3 connections
def close() -> None:
    connection: sqlite3.Connection | None = getattr(g, G_CONNECTION, None)

    if connection is not None:
        logger.debug('SQLite3: close')
        connection.close()

        # Remove the database from the context
        delattr(g, G_CONNECTION)