You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
393 lines
12 KiB
393 lines
12 KiB
import atexit
|
|
import logging
|
|
import time
|
|
from enum import Enum
|
|
from multiprocessing import Process
|
|
from multiprocessing import Queue
|
|
from struct import error
|
|
from struct import unpack
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
|
|
from serial import Serial
|
|
|
|
from backend.monsun_backend.util import log_function_call
|
|
|
|
from . import commands
|
|
from .commands import Command
|
|
from .commands import CommandId
|
|
from .commands import CommandTarget
|
|
from .commands import Request
|
|
from .commands import Response
|
|
from .container import get_initialize_container
|
|
|
|
_logger = logging.getLogger(__file__)
|
|
|
|
_command_queue: Dict[str, Queue] = dict()
|
|
"""role name: command queue"""
|
|
|
|
|
|
class State(Enum):
|
|
heart_beat = 0x0
|
|
executing_command = 0x1
|
|
executing_command_waiting_for_response = 0x2
|
|
receiving_command = 0x10
|
|
|
|
|
|
def worker_process(
|
|
role: str,
|
|
queue: Queue,
|
|
):
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="[%(asctime)s] [%(name)-20s] [%(levelname)-8s] --- %(message)s",
|
|
)
|
|
root_logger = logging.getLogger(role)
|
|
root_logger.setLevel(logging.DEBUG)
|
|
logger = root_logger.getChild("worker_process")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
container = get_initialize_container()
|
|
container.config.from_dict(
|
|
{
|
|
"role": role,
|
|
},
|
|
)
|
|
heartbeat_interval = container.config.heartbeat_interval()
|
|
serial_reconnection_wait_timeout = (
|
|
container.config.serial_reconnection_wait_timeout()
|
|
)
|
|
connected = False
|
|
|
|
logger.info("entering command loop...")
|
|
while True:
|
|
try:
|
|
with container.serial() as serial:
|
|
logger.info("connected with serial device")
|
|
connected = True
|
|
enter_fsm(
|
|
target=CommandTarget[role],
|
|
root_logger=root_logger,
|
|
serial=serial,
|
|
command_queue=queue,
|
|
heartbeat_interval=heartbeat_interval,
|
|
)
|
|
except OSError:
|
|
if connected:
|
|
logger.warning("connection to serial device lost")
|
|
connected = False
|
|
time.sleep(serial_reconnection_wait_timeout)
|
|
logger.warning("reconnecting...")
|
|
|
|
|
|
def enter_fsm(
|
|
target: CommandTarget,
|
|
root_logger: logging.Logger,
|
|
serial: Serial,
|
|
command_queue: Queue,
|
|
heartbeat_interval: float,
|
|
):
|
|
logger = root_logger.getChild("FSM")
|
|
|
|
state = State.executing_command
|
|
current_command: Optional[Command] = None
|
|
responses_received: List[Response] = list()
|
|
time_at_beginning_waiting_for_response: float = 0.0
|
|
last_heart_beat_time: float = 0.0
|
|
serial_receiver = SerialReceiver(
|
|
root_logger=root_logger,
|
|
target=target,
|
|
)
|
|
|
|
while True:
|
|
if state == State.heart_beat:
|
|
if time.time() - heartbeat_interval > last_heart_beat_time:
|
|
command_queue.put(
|
|
commands.HeartbeatRequest(
|
|
root_logger=root_logger,
|
|
target=target,
|
|
),
|
|
)
|
|
|
|
# heartbeat: client -> ble -> server -> ble -> client
|
|
# if target == CommandTarget.client:
|
|
# command_queue.put(
|
|
# commands.HeartbeatRequest(
|
|
# root_logger=root_logger,
|
|
# target=CommandTarget.server,
|
|
# ),
|
|
# )
|
|
|
|
last_heart_beat_time = time.time()
|
|
|
|
state = State.executing_command
|
|
continue
|
|
|
|
elif state == State.executing_command:
|
|
current_command = dequeue_command(queue=command_queue)
|
|
if current_command is None:
|
|
state = State.receiving_command
|
|
continue
|
|
|
|
current_command.execute(serial=serial)
|
|
|
|
if isinstance(current_command, Request):
|
|
time_at_beginning_waiting_for_response = time.time()
|
|
state = State.executing_command_waiting_for_response
|
|
continue
|
|
|
|
elif state == State.executing_command_waiting_for_response:
|
|
if not isinstance(current_command, Request):
|
|
raise RuntimeError(
|
|
"entered state 'executing_command_waiting_for_response' but "
|
|
"current command does not expect a response.",
|
|
)
|
|
else:
|
|
request: Request = current_command
|
|
commands_, responses = serial_receiver.receive(
|
|
serial=serial,
|
|
)
|
|
responses_received.extend(responses)
|
|
for command in commands_:
|
|
command_queue.put(command)
|
|
|
|
while responses_received:
|
|
received_response: Response = responses_received.pop(0)
|
|
if request.response_identifier == received_response.identifier:
|
|
request.process_response(
|
|
response=received_response,
|
|
)
|
|
state = State.executing_command
|
|
break
|
|
else:
|
|
logger.warning(
|
|
f"received response with ID {received_response.identifier} "
|
|
"but expected response with ID "
|
|
f"{request.response_identifier}",
|
|
)
|
|
else:
|
|
if (
|
|
time.time() - request.timeout
|
|
> time_at_beginning_waiting_for_response
|
|
):
|
|
logger.error(
|
|
"Timeout while waiting for response with ID "
|
|
f"{request.response_identifier}",
|
|
)
|
|
current_command = None
|
|
state = State.executing_command
|
|
continue
|
|
|
|
elif state == State.receiving_command:
|
|
commands_, responses = serial_receiver.receive(
|
|
serial=serial,
|
|
)
|
|
responses_received.extend(responses)
|
|
for command in commands_:
|
|
command_queue.put(command)
|
|
|
|
state = State.heart_beat
|
|
continue
|
|
|
|
else:
|
|
raise RuntimeError(f"Invalid state: {state}")
|
|
|
|
|
|
def dequeue_command(
|
|
queue: Queue,
|
|
) -> Optional[Command]:
|
|
while not queue.empty():
|
|
return queue.get()
|
|
|
|
return None
|
|
|
|
|
|
def enqueue_command(
|
|
role: str,
|
|
command: Command,
|
|
):
|
|
"""Add a command to the command queue
|
|
|
|
:param role: The role name
|
|
:param command: The command to enqueue
|
|
"""
|
|
_command_queue[role].put(command)
|
|
|
|
|
|
class CommandInterpretationError(Exception):
|
|
"""Raised in case the command could not be interpreted"""
|
|
|
|
|
|
class CommandBytesReadInsufficient(CommandInterpretationError):
|
|
"""Raised in case the command could not be interpreted"""
|
|
|
|
|
|
class CommandInterpreter:
|
|
header_size = 4
|
|
|
|
def __init__(
|
|
self,
|
|
root_logger: logging.Logger,
|
|
) -> None:
|
|
self._logger = root_logger.getChild(self.__class__.__name__)
|
|
self._logger.setLevel(logging.INFO)
|
|
self.command_id_int = 0
|
|
self.data_length = 0
|
|
self.payload = bytes()
|
|
|
|
def interpret(
|
|
self,
|
|
bytes_read: bytes,
|
|
) -> bytes:
|
|
"""Interpret the first command in a byte stream.
|
|
|
|
:param bytes_read: The bytes which are not yet parsed.
|
|
:returns: The byte which are not yet parsed.
|
|
:raises CommandInterpretationError: If the command could not be parsed.
|
|
:raises CommandBytesReadInsufficient: Not enough bytes to fully parse the
|
|
command.
|
|
"""
|
|
self._logger.debug(f"bytes: {bytes_read.hex()}")
|
|
try:
|
|
self.command_id_int, self.data_length, _ = unpack(
|
|
">BHB",
|
|
bytes_read[: self.header_size],
|
|
)
|
|
except error:
|
|
self._logger.error("error while interpreting command header")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
try:
|
|
self.payload = bytes(
|
|
bytes_read[self.header_size : self.header_size + self.data_length],
|
|
)
|
|
except IndexError:
|
|
self._logger.error(
|
|
"There are less bytes than expected: "
|
|
f"Expected={self.header_size + self.data_length -1}, "
|
|
f"received={len(bytes_read)}",
|
|
)
|
|
self._logger.debug(f"bytes: {bytes_read.hex()}")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
try:
|
|
stop_byte = bytes_read[self.header_size + self.data_length]
|
|
except IndexError:
|
|
self._logger.debug("could not get stop byte")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
if stop_byte != 0xFF:
|
|
self._logger.error("Invalid stop byte")
|
|
raise CommandInterpretationError()
|
|
|
|
try:
|
|
return bytes_read[self.header_size + self.data_length + 1 :]
|
|
except IndexError:
|
|
return bytes()
|
|
|
|
|
|
class SerialReceiver:
|
|
def __init__(
|
|
self,
|
|
root_logger: logging.Logger,
|
|
target: CommandTarget,
|
|
) -> None:
|
|
self.root_logger = root_logger
|
|
self.target = target
|
|
|
|
self._logger = root_logger.getChild(self.__class__.__name__)
|
|
self._logger.setLevel(logging.INFO)
|
|
|
|
self._bytes_unread = bytearray()
|
|
|
|
def receive(
|
|
self,
|
|
serial: Serial,
|
|
) -> Tuple[Sequence[Command], Sequence[Response]]:
|
|
commands_received: List[Command] = list()
|
|
responses_received: List[Response] = list()
|
|
|
|
self._bytes_unread.extend(serial.read(serial.in_waiting))
|
|
|
|
while self._bytes_unread:
|
|
|
|
try:
|
|
command_interpreter = CommandInterpreter(
|
|
root_logger=self.root_logger,
|
|
)
|
|
self._bytes_unread = bytearray(
|
|
command_interpreter.interpret(
|
|
bytes_read=self._bytes_unread,
|
|
),
|
|
)
|
|
# except CommandBytesReadInsufficient:
|
|
# return commands_received, responses_received
|
|
except CommandInterpretationError:
|
|
return commands_received, responses_received
|
|
|
|
try:
|
|
command_id = CommandId(command_interpreter.command_id_int)
|
|
except ValueError:
|
|
self._logger.error(
|
|
f"invalid command {command_interpreter.command_id_int} with "
|
|
f"payload {str(command_interpreter.payload)}",
|
|
)
|
|
|
|
if command_id == CommandId.command_log:
|
|
command = commands.LogCommand(
|
|
root_logger=self.root_logger,
|
|
target=self.target,
|
|
data=command_interpreter.payload,
|
|
)
|
|
commands_received.append(command)
|
|
elif command_id == CommandId.command_heartbeat_response:
|
|
responses_received.append(
|
|
commands.HeartbeatResponse(command_interpreter.payload),
|
|
)
|
|
elif command_id == CommandId.command_led_response:
|
|
responses_received.append(
|
|
commands.LEDResponse(command_interpreter.payload),
|
|
)
|
|
elif command_id == CommandId.command_gp_response:
|
|
responses_received.append(
|
|
commands.GPResponse(command_interpreter.payload),
|
|
)
|
|
else:
|
|
raise RuntimeError
|
|
|
|
return commands_received, responses_received
|
|
|
|
|
|
_process: Dict[str, Process] = dict()
|
|
"""role name: process"""
|
|
|
|
|
|
def _end_running_process():
|
|
process: Process
|
|
for process in _process.values():
|
|
process.kill()
|
|
|
|
|
|
@log_function_call
|
|
def start_backgroup_process():
|
|
global _process
|
|
global _command_queue
|
|
container = get_initialize_container()
|
|
|
|
role_name: str
|
|
for role_name in container.config.roles():
|
|
_command_queue[role_name] = Queue()
|
|
|
|
_process[role_name] = Process(
|
|
target=worker_process,
|
|
args=(
|
|
role_name,
|
|
_command_queue[role_name],
|
|
),
|
|
)
|
|
_process[role_name].start()
|
|
|
|
atexit.register(_end_running_process)
|
|
|