From 901add5218b3e05a7e0e18f57a91efd78416ded8 Mon Sep 17 00:00:00 2001 From: Gregoo Date: Mon, 27 Jan 2025 14:15:07 +0100 Subject: [PATCH] Provide decorator for socket actions, for repetitive tasks like checking if authenticated or ready for Rebrickable actions --- bricktracker/socket.py | 74 +++---------------------- bricktracker/socket_decorator.py | 93 ++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 66 deletions(-) create mode 100644 bricktracker/socket_decorator.py diff --git a/bricktracker/socket.py b/bricktracker/socket.py index c7215ae..3fbd6cd 100644 --- a/bricktracker/socket.py +++ b/bricktracker/socket.py @@ -1,14 +1,13 @@ import logging from typing import Any, Final, Tuple -from flask import copy_current_request_context, Flask, request +from flask import Flask, request from flask_socketio import SocketIO -from .configuration_list import BrickConfigurationList from .instructions import BrickInstructions from .instructions_list import BrickInstructionsList -from .login import LoginManager from .set import BrickSet +from .socket_decorator import authenticated_socket, rebrickable_socket from .sql import close as sql_close logger = logging.getLogger(__name__) @@ -87,12 +86,8 @@ class BrickSocket(object): self.disconnected() @self.socket.on(MESSAGES['DOWNLOAD_INSTRUCTIONS'], namespace=self.namespace) # noqa: E501 + @authenticated_socket(self) def download_instructions(data: dict[str, Any], /) -> None: - # Needs to be authenticated - if LoginManager.is_not_authenticated(): - self.fail(message='You need to be authenticated') - return - instructions = BrickInstructions( '{name}.pdf'.format(name=data.get('alt', '')), socket=self @@ -107,71 +102,18 @@ class BrickSocket(object): except Exception: pass - # Start it in a thread if requested - if self.threaded: - @copy_current_request_context - def do_download() -> None: - instructions.download(path) + instructions.download(path) - BrickInstructionsList(force=True) - - self.socket.start_background_task(do_download) - else: - instructions.download(path) - - BrickInstructionsList(force=True) + BrickInstructionsList(force=True) @self.socket.on(MESSAGES['IMPORT_SET'], namespace=self.namespace) + @rebrickable_socket(self) def import_set(data: dict[str, Any], /) -> None: - # Needs to be authenticated - if LoginManager.is_not_authenticated(): - self.fail(message='You need to be authenticated') - return - - # Needs the Rebrickable API key - try: - BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501 - except Exception as e: - self.fail(message=str(e)) - return - - brickset = BrickSet(socket=self) - - # Start it in a thread if requested - if self.threaded: - @copy_current_request_context - def do_download() -> None: - brickset.download(data) - - self.socket.start_background_task(do_download) - else: - brickset.download(data) + BrickSet(socket=self).download(data) @self.socket.on(MESSAGES['LOAD_SET'], namespace=self.namespace) def load_set(data: dict[str, Any], /) -> None: - # Needs to be authenticated - if LoginManager.is_not_authenticated(): - self.fail(message='You need to be authenticated') - return - - # Needs the Rebrickable API key - try: - BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501 - except Exception as e: - self.fail(message=str(e)) - return - - brickset = BrickSet(socket=self) - - # Start it in a thread if requested - if self.threaded: - @copy_current_request_context - def do_load() -> None: - brickset.load(data) - - self.socket.start_background_task(do_load) - else: - brickset.load(data) + BrickSet(socket=self).load(data) # Update the progress auto-incrementing def auto_progress( diff --git a/bricktracker/socket_decorator.py b/bricktracker/socket_decorator.py new file mode 100644 index 0000000..331b457 --- /dev/null +++ b/bricktracker/socket_decorator.py @@ -0,0 +1,93 @@ +from functools import wraps +from threading import Thread +from typing import Callable, ParamSpec, TYPE_CHECKING, Union + +from flask import copy_current_request_context + +from .configuration_list import BrickConfigurationList +from .login import LoginManager +if TYPE_CHECKING: + from .socket import BrickSocket + +# What a threaded function can return (None or Thread) +SocketReturn = Union[None, Thread] + +# Threaded signature (*arg, **kwargs -> (None or Thread) +P = ParamSpec('P') +SocketCallable = Callable[P, SocketReturn] + + +# Fail if not authenticated +def authenticated_socket( + self: 'BrickSocket', + /, + *, + threaded: bool = True, +) -> Callable[[SocketCallable], SocketCallable]: + def outer(function: SocketCallable, /) -> SocketCallable: + @wraps(function) + def wrapper(*args, **kwargs) -> SocketReturn: + # Needs to be authenticated + if LoginManager.is_not_authenticated(): + self.fail(message='You need to be authenticated') + return + + # Apply threading + if threaded: + return threaded_socket(self)(function)(*args, **kwargs) + else: + return function(*args, **kwargs) + + return wrapper + return outer + + +# Fail if not ready for Rebrickable (authenticated, API key) +# Automatically makes it threaded +def rebrickable_socket( + self: 'BrickSocket', + /, + *, + threaded: bool = True, +) -> Callable[[SocketCallable], SocketCallable]: + def outer(function: SocketCallable, /) -> SocketCallable: + @wraps(function) + # Automatically authenticated + @authenticated_socket(self, threaded=False) + def wrapper(*args, **kwargs) -> SocketReturn: + # Needs the Rebrickable API key + try: + BrickConfigurationList.error_unless_is_set('REBRICKABLE_API_KEY') # noqa: E501 + except Exception as e: + self.fail(message=str(e)) + return + + # Apply threading + if threaded: + return threaded_socket(self)(function)(*args, **kwargs) + else: + return function(*args, **kwargs) + + return wrapper + return outer + + +# Start the function in a thread if the socket is threaded +def threaded_socket( + self: 'BrickSocket', + / +) -> Callable[[SocketCallable], SocketCallable]: + def outer(function: SocketCallable, /) -> SocketCallable: + @wraps(function) + def wrapper(*args, **kwargs) -> SocketReturn: + # Start it in a thread if requested + if self.threaded: + @copy_current_request_context + def do_function() -> None: + function(*args, **kwargs) + + return self.socket.start_background_task(do_function) + else: + return function(*args, **kwargs) + return wrapper + return outer