import atexit import logging import time from enum import Enum from multiprocessing import Process from multiprocessing import Queue from struct import error from struct import unpack from typing import Dict from typing import List from typing import Optional from typing import Sequence from typing import Tuple from serial import Serial from backend.monsun_backend.util import log_function_call from . import commands from .commands import Command from .commands import CommandId from .commands import Request from .commands import Response from .container import get_initialize_container _logger = logging.getLogger(__file__) _command_queue: Dict[str, Queue] = dict() """role name: command queue""" class State(Enum): heart_beat = 0x0 executing_command = 0x1 executing_command_waiting_for_response = 0x2 receiving_command = 0x10 def worker_process( role: str, queue: Queue, ): logging.basicConfig( level=logging.DEBUG, format="[%(asctime)s] [%(name)-20s] [%(levelname)-8s] --- %(message)s", ) root_logger = logging.getLogger(role) root_logger.setLevel(logging.DEBUG) logger = root_logger.getChild("worker_process") logger.setLevel(logging.INFO) container = get_initialize_container() container.config.from_dict( { "role": role, }, ) heartbeat_interval = container.config.heartbeat_interval() serial_reconnection_wait_timeout = ( container.config.serial_reconnection_wait_timeout() ) connected = False logger.info("entering command loop...") while True: try: with container.serial() as serial: logger.info("connected with serial device") connected = True enter_fsm( root_logger=root_logger, serial=serial, command_queue=queue, heartbeat_interval=heartbeat_interval, ) except OSError: if connected: logger.warning("connection to serial device lost") connected = False time.sleep(serial_reconnection_wait_timeout) logger.warning("reconnecting...") def enter_fsm( root_logger: logging.Logger, serial: Serial, command_queue: Queue, heartbeat_interval: float, ): logger = root_logger.getChild("FSM") state = State.executing_command current_command: Optional[Command] = None responses_received: List[Response] = list() time_at_beginning_waiting_for_response: float = 0.0 last_heart_beat_time: float = 0.0 serial_receiver = SerialReceiver( root_logger=root_logger, ) while True: if state == State.heart_beat: if time.time() - heartbeat_interval > last_heart_beat_time: command_queue.put( commands.HeartbeatRequest( root_logger=root_logger, ), ) last_heart_beat_time = time.time() state = State.executing_command continue elif state == State.executing_command: current_command = dequeue_command(queue=command_queue) if current_command is None: state = State.receiving_command continue current_command.execute(serial=serial) if isinstance(current_command, Request): time_at_beginning_waiting_for_response = time.time() state = State.executing_command_waiting_for_response continue elif state == State.executing_command_waiting_for_response: if not isinstance(current_command, Request): raise RuntimeError( "entered state 'executing_command_waiting_for_response' but " "current command does not expect a response.", ) else: request: Request = current_command commands_, responses = serial_receiver.receive( serial=serial, ) responses_received.extend(responses) for command in commands_: command_queue.put(command) while responses_received: received_response: Response = responses_received.pop(0) if request.response_identifier == received_response.identifier: request.process_response( response=received_response, ) state = State.executing_command break else: logger.warning( f"received response with ID {received_response.identifier} " "but expected response with ID " f"{request.response_identifier}", ) else: if ( time.time() - request.timeout > time_at_beginning_waiting_for_response ): logger.error( "Timeout while waiting for response with ID " f"{request.response_identifier}", ) current_command = None state = State.executing_command continue elif state == State.receiving_command: commands_, responses = serial_receiver.receive( serial=serial, ) responses_received.extend(responses) for command in commands_: command_queue.put(command) state = State.heart_beat continue else: raise RuntimeError(f"Invalid state: {state}") def dequeue_command( queue: Queue, ) -> Optional[Command]: while not queue.empty(): return queue.get() return None def enqueue_command( role: str, command: Command, ): """Add a command to the command queue :param role: The role name :param command: The command to enqueue """ _command_queue[role].put(command) class CommandInterpretationError(Exception): """Raised in case the command could not be interpreted""" class CommandBytesReadInsufficient(CommandInterpretationError): """Raised in case the command could not be interpreted""" class CommandInterpreter: header_size = 4 def __init__( self, root_logger: logging.Logger, ) -> None: self._logger = root_logger.getChild(self.__class__.__name__) self._logger.setLevel(logging.INFO) self.command_id_int = 0 self.data_length = 0 self.payload = bytes() def interpret( self, bytes_read: bytes, ) -> bytes: """Interpret the first command in a byte stream. :param bytes_read: The bytes which are not yet parsed. :returns: The byte which are not yet parsed. :raises CommandInterpretationError: If the command could not be parsed. :raises CommandBytesReadInsufficient: Not enough bytes to fully parse the command. """ self._logger.debug(f"bytes: {bytes_read.hex()}") try: self.command_id_int, self.data_length, _ = unpack( ">BHB", bytes_read[: self.header_size], ) except error: self._logger.error("error while interpreting command header") raise CommandBytesReadInsufficient() try: self.payload = bytes( bytes_read[self.header_size : self.header_size + self.data_length], ) except IndexError: self._logger.error( "There are less bytes than expected: " f"Expected={self.header_size + self.data_length -1}, " f"received={len(bytes_read)}", ) self._logger.debug(f"bytes: {bytes_read.hex()}") raise CommandBytesReadInsufficient() try: stop_byte = bytes_read[self.header_size + self.data_length] except IndexError: self._logger.error("could not get stop byte") raise CommandBytesReadInsufficient() if stop_byte != 0xFF: self._logger.error("Invalid stop byte") raise CommandInterpretationError() try: return bytes_read[self.header_size + self.data_length + 1 :] except IndexError: return bytes() class SerialReceiver: def __init__( self, root_logger: logging.Logger, ) -> None: self.root_logger = root_logger self._logger = root_logger.getChild(self.__class__.__name__) self._logger.setLevel(logging.INFO) self._bytes_unread = bytearray() def receive( self, serial: Serial, ) -> Tuple[Sequence[Command], Sequence[Response]]: commands_received: List[Command] = list() responses_received: List[Response] = list() self._bytes_unread.extend(serial.read(serial.in_waiting)) while self._bytes_unread: try: command_interpreter = CommandInterpreter( root_logger=self.root_logger, ) self._bytes_unread = bytearray( command_interpreter.interpret( bytes_read=self._bytes_unread, ), ) # except CommandBytesReadInsufficient: # return commands_received, responses_received except CommandInterpretationError: return commands_received, responses_received try: command_id = CommandId(command_interpreter.command_id_int) except ValueError: self._logger.error( f"invalid command {command_interpreter.command_id_int} with " f"payload {str(command_interpreter.payload)}", ) if command_id == CommandId.command_log: command = commands.LogCommand( root_logger=self.root_logger, data=command_interpreter.payload, ) commands_received.append(command) elif command_id == CommandId.command_heartbeat_response: responses_received.append( commands.HeartbeatResponse(command_interpreter.payload), ) elif command_id == CommandId.command_led_response: responses_received.append( commands.LEDResponse(command_interpreter.payload), ) elif command_id == CommandId.command_gp_response: responses_received.append( commands.GPResponse(command_interpreter.payload), ) else: raise RuntimeError return commands_received, responses_received _process: Dict[str, Process] = dict() """role name: process""" def _end_running_process(): process: Process for process in _process.values(): process.kill() @log_function_call def start_backgroup_process(): global _process global _command_queue container = get_initialize_container() role_name: str for role_name in container.config.roles(): _command_queue[role_name] = Queue() _process[role_name] = Process( target=worker_process, args=( role_name, _command_queue[role_name], ), ) _process[role_name].start() atexit.register(_end_running_process)