forked from FrederikBaerentsen/BrickTracker
Massive rewrite
This commit is contained in:
231
bricktracker/socket.py
Normal file
231
bricktracker/socket.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import logging
|
||||
from typing import Any, Final, Tuple
|
||||
|
||||
from flask import copy_current_request_context, Flask, request
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
from .configuration_list import BrickConfigurationList
|
||||
from .login import LoginManager
|
||||
from .rebrickable_set import RebrickableSet
|
||||
from .sql import close as sql_close
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Messages valid through the socket
|
||||
MESSAGES: Final[dict[str, str]] = {
|
||||
'ADD_SET': 'add_set',
|
||||
'COMPLETE': 'complete',
|
||||
'CONNECT': 'connect',
|
||||
'DISCONNECT': 'disconnect',
|
||||
'FAIL': 'fail',
|
||||
'IMPORT_SET': 'import_set',
|
||||
'LOAD_SET': 'load_set',
|
||||
'PROGRESS': 'progress',
|
||||
'SET_LOADED': 'set_loaded',
|
||||
}
|
||||
|
||||
|
||||
# Flask socket.io with our extra features
|
||||
class BrickSocket(object):
|
||||
app: Flask
|
||||
socket: SocketIO
|
||||
threaded: bool
|
||||
|
||||
# Progress
|
||||
progress_message: str
|
||||
progress_total: int
|
||||
progress_count: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Flask,
|
||||
*args,
|
||||
threaded: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
# Save the app
|
||||
self.app = app
|
||||
|
||||
# Progress
|
||||
self.progress_message = ''
|
||||
self.progress_count = 0
|
||||
self.progress_total = 0
|
||||
|
||||
# Save the threaded flag
|
||||
self.threaded = threaded
|
||||
|
||||
# Compute the namespace
|
||||
self.namespace = '/{namespace}'.format(
|
||||
namespace=app.config['SOCKET_NAMESPACE'].value
|
||||
)
|
||||
|
||||
# Inject CORS if a domain is defined
|
||||
if app.config['DOMAIN_NAME'].value != '':
|
||||
kwargs['cors_allowed_origins'] = app.config['DOMAIN_NAME'].value
|
||||
|
||||
# Instantiate the socket
|
||||
self.socket = SocketIO(
|
||||
self.app,
|
||||
*args,
|
||||
**kwargs,
|
||||
path=app.config['SOCKET_PATH'].value,
|
||||
async_mode='eventlet',
|
||||
)
|
||||
|
||||
# Store the socket in the app config
|
||||
self.app.config['_SOCKET'] = self
|
||||
|
||||
# Setup the socket
|
||||
@self.socket.on(MESSAGES['CONNECT'], namespace=self.namespace)
|
||||
def connect() -> None:
|
||||
self.connected()
|
||||
|
||||
@self.socket.on(MESSAGES['DISCONNECT'], namespace=self.namespace)
|
||||
def disconnect() -> None:
|
||||
self.disconnected()
|
||||
|
||||
@self.socket.on(MESSAGES['IMPORT_SET'], namespace=self.namespace)
|
||||
def import_set(data: dict[str, Any], /) -> None:
|
||||
# Needs to be authenticated
|
||||
if LoginManager.is_not_authenticated():
|
||||
self.fail(message='You need to be authenticated')
|
||||
return
|
||||
|
||||
# Needs the Rebrickable API key
|
||||
try:
|
||||
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
|
||||
except Exception as e:
|
||||
self.fail(message=str(e))
|
||||
return
|
||||
|
||||
brickset = RebrickableSet(self)
|
||||
|
||||
# Start it in a thread if requested
|
||||
if self.threaded:
|
||||
@copy_current_request_context
|
||||
def do_download() -> None:
|
||||
brickset.download(data)
|
||||
|
||||
self.socket.start_background_task(do_download)
|
||||
else:
|
||||
brickset.download(data)
|
||||
|
||||
@self.socket.on(MESSAGES['LOAD_SET'], namespace=self.namespace)
|
||||
def load_set(data: dict[str, Any], /) -> None:
|
||||
# Needs to be authenticated
|
||||
if LoginManager.is_not_authenticated():
|
||||
self.fail(message='You need to be authenticated')
|
||||
return
|
||||
|
||||
# Needs the Rebrickable API key
|
||||
try:
|
||||
BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501
|
||||
except Exception as e:
|
||||
self.fail(message=str(e))
|
||||
return
|
||||
|
||||
brickset = RebrickableSet(self)
|
||||
|
||||
# Start it in a thread if requested
|
||||
if self.threaded:
|
||||
@copy_current_request_context
|
||||
def do_load() -> None:
|
||||
brickset.load(data)
|
||||
|
||||
self.socket.start_background_task(do_load)
|
||||
else:
|
||||
brickset.load(data)
|
||||
|
||||
# Update the progress auto-incrementing
|
||||
def auto_progress(
|
||||
self,
|
||||
/,
|
||||
message: str | None = None,
|
||||
increment_total=False,
|
||||
) -> None:
|
||||
# Auto-increment
|
||||
self.progress_count += 1
|
||||
|
||||
if increment_total:
|
||||
self.progress_total += 1
|
||||
|
||||
self.progress(message=message)
|
||||
|
||||
# Send a complete
|
||||
def complete(self, /, **data: Any) -> None:
|
||||
self.emit('COMPLETE', data)
|
||||
|
||||
# Close any dangling connection
|
||||
sql_close()
|
||||
|
||||
# Socket is connected
|
||||
def connected(self, /) -> Tuple[str, int]:
|
||||
logger.debug('Socket: client connected')
|
||||
|
||||
return '', 301
|
||||
|
||||
# Socket is disconnected
|
||||
def disconnected(self, /) -> None:
|
||||
logger.debug('Socket: client disconnected')
|
||||
|
||||
# Emit a message through the socket
|
||||
def emit(self, name: str, *arg, all=False) -> None:
|
||||
# Emit to all sockets
|
||||
if all:
|
||||
to = None
|
||||
else:
|
||||
# Grab the request SID
|
||||
# This keeps message isolated between clients (and tabs!)
|
||||
try:
|
||||
to = request.sid # type: ignore
|
||||
except Exception:
|
||||
logger.debug('Unable to load request.sid')
|
||||
to = None
|
||||
|
||||
logger.debug('Socket: {name}={args} (to: {to})'.format(
|
||||
name=name,
|
||||
args=arg,
|
||||
to=to,
|
||||
))
|
||||
|
||||
self.socket.emit(
|
||||
MESSAGES[name],
|
||||
*arg,
|
||||
namespace=self.namespace,
|
||||
to=to,
|
||||
)
|
||||
|
||||
# Send a failed
|
||||
def fail(self, /, **data: Any) -> None:
|
||||
self.emit('FAIL', data)
|
||||
|
||||
# Close any dangling connection
|
||||
sql_close()
|
||||
|
||||
# Update the progress
|
||||
def progress(self, /, message: str | None = None) -> None:
|
||||
# Save the las message
|
||||
if message is not None:
|
||||
self.progress_message = message
|
||||
|
||||
# Prepare data
|
||||
data: dict[str, Any] = {
|
||||
'message': self.progress_message,
|
||||
'count': self.progress_count,
|
||||
'total': self.progress_total,
|
||||
}
|
||||
|
||||
self.emit('PROGRESS', data)
|
||||
|
||||
# Update the progress total only
|
||||
def update_total(self, total: int, /, add: bool = False) -> None:
|
||||
if add:
|
||||
self.progress_total += total
|
||||
else:
|
||||
self.progress_total = total
|
||||
|
||||
# Update the total
|
||||
def total_progress(self, total: int, /, add: bool = False) -> None:
|
||||
self.update_total(total, add=add)
|
||||
|
||||
self.progress()
|
||||
Reference in New Issue
Block a user