forked from FrederikBaerentsen/BrickTracker
232 lines
6.4 KiB
Python
232 lines
6.4 KiB
Python
import logging
|
|
from typing import Any, Final, Tuple
|
|
|
|
from flask import copy_current_request_context, Flask, request
|
|
from flask_socketio import SocketIO
|
|
|
|
from .configuration_list import BrickConfigurationList
|
|
from .login import LoginManager
|
|
from .rebrickable_set import RebrickableSet
|
|
from .sql import close as sql_close
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Messages valid through the socket
|
|
MESSAGES: Final[dict[str, str]] = {
|
|
'ADD_SET': 'add_set',
|
|
'COMPLETE': 'complete',
|
|
'CONNECT': 'connect',
|
|
'DISCONNECT': 'disconnect',
|
|
'FAIL': 'fail',
|
|
'IMPORT_SET': 'import_set',
|
|
'LOAD_SET': 'load_set',
|
|
'PROGRESS': 'progress',
|
|
'SET_LOADED': 'set_loaded',
|
|
}
|
|
|
|
|
|
# Flask socket.io with our extra features
|
|
class BrickSocket(object):
|
|
app: Flask
|
|
socket: SocketIO
|
|
threaded: bool
|
|
|
|
# Progress
|
|
progress_message: str
|
|
progress_total: int
|
|
progress_count: int
|
|
|
|
def __init__(
|
|
self,
|
|
app: Flask,
|
|
*args,
|
|
threaded: bool = True,
|
|
**kwargs
|
|
):
|
|
# Save the app
|
|
self.app = app
|
|
|
|
# Progress
|
|
self.progress_message = ''
|
|
self.progress_count = 0
|
|
self.progress_total = 0
|
|
|
|
# Save the threaded flag
|
|
self.threaded = threaded
|
|
|
|
# Compute the namespace
|
|
self.namespace = '/{namespace}'.format(
|
|
namespace=app.config['SOCKET_NAMESPACE'].value
|
|
)
|
|
|
|
# Inject CORS if a domain is defined
|
|
if app.config['DOMAIN_NAME'].value != '':
|
|
kwargs['cors_allowed_origins'] = app.config['DOMAIN_NAME'].value
|
|
|
|
# Instantiate the socket
|
|
self.socket = SocketIO(
|
|
self.app,
|
|
*args,
|
|
**kwargs,
|
|
path=app.config['SOCKET_PATH'].value,
|
|
async_mode='eventlet',
|
|
)
|
|
|
|
# Store the socket in the app config
|
|
self.app.config['_SOCKET'] = self
|
|
|
|
# Setup the socket
|
|
@self.socket.on(MESSAGES['CONNECT'], namespace=self.namespace)
|
|
def connect() -> None:
|
|
self.connected()
|
|
|
|
@self.socket.on(MESSAGES['DISCONNECT'], namespace=self.namespace)
|
|
def disconnect() -> None:
|
|
self.disconnected()
|
|
|
|
@self.socket.on(MESSAGES['IMPORT_SET'], namespace=self.namespace)
|
|
def import_set(data: dict[str, Any], /) -> None:
|
|
# Needs to be authenticated
|
|
if LoginManager.is_not_authenticated():
|
|
self.fail(message='You need to be authenticated')
|
|
return
|
|
|
|
# Needs the Rebrickable API key
|
|
try:
|
|
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
|
|
except Exception as e:
|
|
self.fail(message=str(e))
|
|
return
|
|
|
|
brickset = RebrickableSet(self)
|
|
|
|
# Start it in a thread if requested
|
|
if self.threaded:
|
|
@copy_current_request_context
|
|
def do_download() -> None:
|
|
brickset.download(data)
|
|
|
|
self.socket.start_background_task(do_download)
|
|
else:
|
|
brickset.download(data)
|
|
|
|
@self.socket.on(MESSAGES['LOAD_SET'], namespace=self.namespace)
|
|
def load_set(data: dict[str, Any], /) -> None:
|
|
# Needs to be authenticated
|
|
if LoginManager.is_not_authenticated():
|
|
self.fail(message='You need to be authenticated')
|
|
return
|
|
|
|
# Needs the Rebrickable API key
|
|
try:
|
|
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
|
|
except Exception as e:
|
|
self.fail(message=str(e))
|
|
return
|
|
|
|
brickset = RebrickableSet(self)
|
|
|
|
# Start it in a thread if requested
|
|
if self.threaded:
|
|
@copy_current_request_context
|
|
def do_load() -> None:
|
|
brickset.load(data)
|
|
|
|
self.socket.start_background_task(do_load)
|
|
else:
|
|
brickset.load(data)
|
|
|
|
# Update the progress auto-incrementing
|
|
def auto_progress(
|
|
self,
|
|
/,
|
|
message: str | None = None,
|
|
increment_total=False,
|
|
) -> None:
|
|
# Auto-increment
|
|
self.progress_count += 1
|
|
|
|
if increment_total:
|
|
self.progress_total += 1
|
|
|
|
self.progress(message=message)
|
|
|
|
# Send a complete
|
|
def complete(self, /, **data: Any) -> None:
|
|
self.emit('COMPLETE', data)
|
|
|
|
# Close any dangling connection
|
|
sql_close()
|
|
|
|
# Socket is connected
|
|
def connected(self, /) -> Tuple[str, int]:
|
|
logger.debug('Socket: client connected')
|
|
|
|
return '', 301
|
|
|
|
# Socket is disconnected
|
|
def disconnected(self, /) -> None:
|
|
logger.debug('Socket: client disconnected')
|
|
|
|
# Emit a message through the socket
|
|
def emit(self, name: str, *arg, all=False) -> None:
|
|
# Emit to all sockets
|
|
if all:
|
|
to = None
|
|
else:
|
|
# Grab the request SID
|
|
# This keeps message isolated between clients (and tabs!)
|
|
try:
|
|
to = request.sid # type: ignore
|
|
except Exception:
|
|
logger.debug('Unable to load request.sid')
|
|
to = None
|
|
|
|
logger.debug('Socket: {name}={args} (to: {to})'.format(
|
|
name=name,
|
|
args=arg,
|
|
to=to,
|
|
))
|
|
|
|
self.socket.emit(
|
|
MESSAGES[name],
|
|
*arg,
|
|
namespace=self.namespace,
|
|
to=to,
|
|
)
|
|
|
|
# Send a failed
|
|
def fail(self, /, **data: Any) -> None:
|
|
self.emit('FAIL', data)
|
|
|
|
# Close any dangling connection
|
|
sql_close()
|
|
|
|
# Update the progress
|
|
def progress(self, /, message: str | None = None) -> None:
|
|
# Save the las message
|
|
if message is not None:
|
|
self.progress_message = message
|
|
|
|
# Prepare data
|
|
data: dict[str, Any] = {
|
|
'message': self.progress_message,
|
|
'count': self.progress_count,
|
|
'total': self.progress_total,
|
|
}
|
|
|
|
self.emit('PROGRESS', data)
|
|
|
|
# Update the progress total only
|
|
def update_total(self, total: int, /, add: bool = False) -> None:
|
|
if add:
|
|
self.progress_total += total
|
|
else:
|
|
self.progress_total = total
|
|
|
|
# Update the total
|
|
def total_progress(self, total: int, /, add: bool = False) -> None:
|
|
self.update_total(total, add=add)
|
|
|
|
self.progress()
|