BrickTracker/bricktracker/socket.py
2025-01-17 11:03:00 +01:00

232 lines
6.4 KiB
Python

import logging
from typing import Any, Final, Tuple
from flask import copy_current_request_context, Flask, request
from flask_socketio import SocketIO
from .configuration_list import BrickConfigurationList
from .login import LoginManager
from .rebrickable_set import RebrickableSet
from .sql import close as sql_close
logger = logging.getLogger(__name__)
# Messages valid through the socket
MESSAGES: Final[dict[str, str]] = {
'ADD_SET': 'add_set',
'COMPLETE': 'complete',
'CONNECT': 'connect',
'DISCONNECT': 'disconnect',
'FAIL': 'fail',
'IMPORT_SET': 'import_set',
'LOAD_SET': 'load_set',
'PROGRESS': 'progress',
'SET_LOADED': 'set_loaded',
}
# Flask socket.io with our extra features
class BrickSocket(object):
app: Flask
socket: SocketIO
threaded: bool
# Progress
progress_message: str
progress_total: int
progress_count: int
def __init__(
self,
app: Flask,
*args,
threaded: bool = True,
**kwargs
):
# Save the app
self.app = app
# Progress
self.progress_message = ''
self.progress_count = 0
self.progress_total = 0
# Save the threaded flag
self.threaded = threaded
# Compute the namespace
self.namespace = '/{namespace}'.format(
namespace=app.config['SOCKET_NAMESPACE'].value
)
# Inject CORS if a domain is defined
if app.config['DOMAIN_NAME'].value != '':
kwargs['cors_allowed_origins'] = app.config['DOMAIN_NAME'].value
# Instantiate the socket
self.socket = SocketIO(
self.app,
*args,
**kwargs,
path=app.config['SOCKET_PATH'].value,
async_mode='eventlet',
)
# Store the socket in the app config
self.app.config['_SOCKET'] = self
# Setup the socket
@self.socket.on(MESSAGES['CONNECT'], namespace=self.namespace)
def connect() -> None:
self.connected()
@self.socket.on(MESSAGES['DISCONNECT'], namespace=self.namespace)
def disconnect() -> None:
self.disconnected()
@self.socket.on(MESSAGES['IMPORT_SET'], namespace=self.namespace)
def import_set(data: dict[str, Any], /) -> None:
# Needs to be authenticated
if LoginManager.is_not_authenticated():
self.fail(message='You need to be authenticated')
return
# Needs the Rebrickable API key
try:
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
except Exception as e:
self.fail(message=str(e))
return
brickset = RebrickableSet(self)
# Start it in a thread if requested
if self.threaded:
@copy_current_request_context
def do_download() -> None:
brickset.download(data)
self.socket.start_background_task(do_download)
else:
brickset.download(data)
@self.socket.on(MESSAGES['LOAD_SET'], namespace=self.namespace)
def load_set(data: dict[str, Any], /) -> None:
# Needs to be authenticated
if LoginManager.is_not_authenticated():
self.fail(message='You need to be authenticated')
return
# Needs the Rebrickable API key
try:
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
except Exception as e:
self.fail(message=str(e))
return
brickset = RebrickableSet(self)
# Start it in a thread if requested
if self.threaded:
@copy_current_request_context
def do_load() -> None:
brickset.load(data)
self.socket.start_background_task(do_load)
else:
brickset.load(data)
# Update the progress auto-incrementing
def auto_progress(
self,
/,
message: str | None = None,
increment_total=False,
) -> None:
# Auto-increment
self.progress_count += 1
if increment_total:
self.progress_total += 1
self.progress(message=message)
# Send a complete
def complete(self, /, **data: Any) -> None:
self.emit('COMPLETE', data)
# Close any dangling connection
sql_close()
# Socket is connected
def connected(self, /) -> Tuple[str, int]:
logger.debug('Socket: client connected')
return '', 301
# Socket is disconnected
def disconnected(self, /) -> None:
logger.debug('Socket: client disconnected')
# Emit a message through the socket
def emit(self, name: str, *arg, all=False) -> None:
# Emit to all sockets
if all:
to = None
else:
# Grab the request SID
# This keeps message isolated between clients (and tabs!)
try:
to = request.sid # type: ignore
except Exception:
logger.debug('Unable to load request.sid')
to = None
logger.debug('Socket: {name}={args} (to: {to})'.format(
name=name,
args=arg,
to=to,
))
self.socket.emit(
MESSAGES[name],
*arg,
namespace=self.namespace,
to=to,
)
# Send a failed
def fail(self, /, **data: Any) -> None:
self.emit('FAIL', data)
# Close any dangling connection
sql_close()
# Update the progress
def progress(self, /, message: str | None = None) -> None:
# Save the las message
if message is not None:
self.progress_message = message
# Prepare data
data: dict[str, Any] = {
'message': self.progress_message,
'count': self.progress_count,
'total': self.progress_total,
}
self.emit('PROGRESS', data)
# Update the progress total only
def update_total(self, total: int, /, add: bool = False) -> None:
if add:
self.progress_total += total
else:
self.progress_total = total
# Update the total
def total_progress(self, total: int, /, add: bool = False) -> None:
self.update_total(total, add=add)
self.progress()