You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
433 lines
13 KiB
433 lines
13 KiB
import atexit
|
|
import logging
|
|
import time
|
|
from enum import Enum
|
|
from multiprocessing import Manager
|
|
from multiprocessing import Process
|
|
from multiprocessing import Queue
|
|
from multiprocessing.managers import SyncManager
|
|
from queue import Empty
|
|
from struct import error
|
|
from struct import unpack
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
|
|
from serial import Serial
|
|
|
|
from . import commands
|
|
from .commands import Command
|
|
from .commands import CommandId
|
|
from .commands import CommandTarget
|
|
from .commands import Request
|
|
from .commands import Response
|
|
from .commands import get_response_class
|
|
from .container import get_initialize_container
|
|
from .util import log_function_call
|
|
|
|
_logger = logging.getLogger(__file__)
|
|
|
|
_command_queue: Dict[str, Queue] = dict()
|
|
"""role name: command queue"""
|
|
|
|
_response_queue: Dict[str, Queue] = dict()
|
|
"""role name: response queue"""
|
|
|
|
_manager: Optional[SyncManager] = None
|
|
_awaiting_response_identifier_list: Dict[str, List] = dict()
|
|
"""role name: request queue"""
|
|
|
|
|
|
class State(Enum):
|
|
heart_beat = 0x0
|
|
executing_command = 0x1
|
|
executing_command_waiting_for_response = 0x2
|
|
receiving_command = 0x10
|
|
|
|
|
|
def worker_process(
|
|
role: str,
|
|
commnad_queue: Queue,
|
|
response_queue: Queue,
|
|
awaiting_response_identifier_list: List,
|
|
):
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="[%(asctime)s] [%(name)-20s] [%(levelname)-8s] --- %(message)s",
|
|
)
|
|
root_logger = logging.getLogger(role)
|
|
root_logger.setLevel(logging.DEBUG)
|
|
logger = root_logger.getChild("worker_process")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
container = get_initialize_container()
|
|
container.config.from_dict(
|
|
{
|
|
"role": role,
|
|
},
|
|
)
|
|
heartbeat_interval = container.config.heartbeat_interval()
|
|
serial_reconnection_wait_timeout = (
|
|
container.config.serial_reconnection_wait_timeout()
|
|
)
|
|
connected = False
|
|
|
|
logger.info("entering command loop...")
|
|
while True:
|
|
try:
|
|
with container.serial() as serial:
|
|
logger.info("connected with serial device")
|
|
connected = True
|
|
enter_fsm(
|
|
target=CommandTarget[role],
|
|
root_logger=root_logger,
|
|
serial=serial,
|
|
command_queue=commnad_queue,
|
|
response_queue=response_queue,
|
|
awaiting_response_identifier_list=awaiting_response_identifier_list,
|
|
heartbeat_interval=heartbeat_interval,
|
|
)
|
|
except OSError:
|
|
if connected:
|
|
logger.warning("connection to serial device lost")
|
|
connected = False
|
|
time.sleep(serial_reconnection_wait_timeout)
|
|
logger.warning("reconnecting...")
|
|
|
|
|
|
def enter_fsm(
|
|
target: CommandTarget,
|
|
root_logger: logging.Logger,
|
|
serial: Serial,
|
|
command_queue: Queue,
|
|
response_queue: Queue,
|
|
awaiting_response_identifier_list: List,
|
|
heartbeat_interval: float,
|
|
):
|
|
logger = root_logger.getChild("FSM")
|
|
|
|
state = State.executing_command
|
|
current_command: Optional[Command] = None
|
|
responses_received: List[Response] = list()
|
|
time_at_beginning_waiting_for_response: float = 0.0
|
|
last_heart_beat_time: float = 0.0
|
|
serial_receiver = SerialReceiver(
|
|
root_logger=root_logger,
|
|
target=target,
|
|
)
|
|
|
|
while True:
|
|
if state == State.heart_beat:
|
|
if time.time() - heartbeat_interval > last_heart_beat_time:
|
|
command_queue.put(
|
|
commands.HeartbeatRequest(
|
|
root_logger=root_logger,
|
|
target=target,
|
|
),
|
|
)
|
|
|
|
# heartbeat: client -> ble -> server -> ble -> client
|
|
# if target == CommandTarget.client:
|
|
# command_queue.put(
|
|
# commands.HeartbeatRequest(
|
|
# root_logger=root_logger,
|
|
# target=CommandTarget.server,
|
|
# ),
|
|
# )
|
|
|
|
last_heart_beat_time = time.time()
|
|
|
|
state = State.executing_command
|
|
continue
|
|
|
|
elif state == State.executing_command:
|
|
current_command = dequeue_command(queue=command_queue)
|
|
if current_command is None:
|
|
state = State.receiving_command
|
|
continue
|
|
|
|
current_command.execute(serial=serial)
|
|
|
|
if isinstance(current_command, Request):
|
|
time_at_beginning_waiting_for_response = time.time()
|
|
state = State.executing_command_waiting_for_response
|
|
continue
|
|
|
|
elif state == State.executing_command_waiting_for_response:
|
|
if not isinstance(current_command, Request):
|
|
raise RuntimeError(
|
|
"entered state 'executing_command_waiting_for_response' but "
|
|
"current command does not expect a response.",
|
|
)
|
|
else:
|
|
request: Request = current_command
|
|
commands_, responses = serial_receiver.receive(
|
|
serial=serial,
|
|
)
|
|
responses_received.extend(responses)
|
|
for command in commands_:
|
|
command_queue.put(command)
|
|
|
|
while responses_received:
|
|
received_response: Response = responses_received.pop(0)
|
|
if request.response_identifier == received_response.identifier:
|
|
request.process_response(
|
|
response=received_response,
|
|
)
|
|
if request.response_identifier in awaiting_response_identifier_list:
|
|
response_queue.put(received_response)
|
|
awaiting_response_identifier_list.remove(
|
|
request.response_identifier,
|
|
)
|
|
|
|
state = State.executing_command
|
|
break
|
|
else:
|
|
logger.warning(
|
|
f"received response with ID {received_response.identifier} "
|
|
"but expected response with ID "
|
|
f"{request.response_identifier}",
|
|
)
|
|
else:
|
|
if (
|
|
time.time() - request.timeout
|
|
> time_at_beginning_waiting_for_response
|
|
):
|
|
logger.error(
|
|
"Timeout while waiting for response with ID "
|
|
f"{request.response_identifier}",
|
|
)
|
|
current_command = None
|
|
state = State.executing_command
|
|
continue
|
|
|
|
elif state == State.receiving_command:
|
|
commands_, responses = serial_receiver.receive(
|
|
serial=serial,
|
|
)
|
|
responses_received.extend(responses)
|
|
for command in commands_:
|
|
command_queue.put(command)
|
|
|
|
state = State.heart_beat
|
|
continue
|
|
|
|
else:
|
|
raise RuntimeError(f"Invalid state: {state}")
|
|
|
|
|
|
def dequeue_command(
|
|
queue: Queue,
|
|
) -> Optional[Command]:
|
|
while not queue.empty():
|
|
return queue.get()
|
|
|
|
return None
|
|
|
|
|
|
def execute_command(
|
|
role: str,
|
|
command: Command,
|
|
timeout: int = 2,
|
|
) -> Optional[Response]:
|
|
"""Add a command to the command queue
|
|
|
|
:param role: The role name
|
|
:param command: The command to enqueue
|
|
"""
|
|
if isinstance(command, Request):
|
|
_awaiting_response_identifier_list[role].append(command.response_identifier)
|
|
|
|
_command_queue[role].put(command)
|
|
|
|
try:
|
|
return _response_queue[role].get(timeout=timeout)
|
|
except Empty:
|
|
return None
|
|
|
|
|
|
class CommandInterpretationError(Exception):
|
|
"""Raised in case the command could not be interpreted"""
|
|
|
|
|
|
class CommandBytesReadInsufficient(CommandInterpretationError):
|
|
"""Raised in case the command could not be interpreted"""
|
|
|
|
|
|
class InvalidStopByteError(CommandInterpretationError):
|
|
"""Raised in case of an invalid stop byte."""
|
|
|
|
|
|
class CommandInterpreter:
|
|
header_size = 4
|
|
|
|
def __init__(
|
|
self,
|
|
root_logger: logging.Logger,
|
|
) -> None:
|
|
self._logger = root_logger.getChild(self.__class__.__name__)
|
|
self._logger.setLevel(logging.INFO)
|
|
self.command_id_int = 0
|
|
self.data_length = 0
|
|
self.payload = bytes()
|
|
|
|
def interpret(
|
|
self,
|
|
bytes_read: bytes,
|
|
) -> bytes:
|
|
"""Interpret the first command in a byte stream.
|
|
|
|
:param bytes_read: The bytes which are not yet parsed.
|
|
:returns: The byte which are not yet parsed.
|
|
:raises CommandInterpretationError: If the command could not be parsed.
|
|
:raises CommandBytesReadInsufficient: Not enough bytes to fully parse the
|
|
command.
|
|
"""
|
|
self._logger.debug(f"bytes: {bytes_read.hex()}")
|
|
try:
|
|
self.command_id_int, self.data_length, _ = unpack(
|
|
">BHB",
|
|
bytes_read[: self.header_size],
|
|
)
|
|
except error:
|
|
self._logger.error("error while interpreting command header")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
try:
|
|
self.payload = bytes(
|
|
bytes_read[self.header_size : self.header_size + self.data_length],
|
|
)
|
|
except IndexError:
|
|
self._logger.error(
|
|
"There are less bytes than expected: "
|
|
f"Expected={self.header_size + self.data_length -1}, "
|
|
f"received={len(bytes_read)}",
|
|
)
|
|
self._logger.debug(f"bytes: {bytes_read.hex()}")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
try:
|
|
stop_byte = bytes_read[self.header_size + self.data_length]
|
|
except IndexError:
|
|
self._logger.debug("could not get stop byte")
|
|
raise CommandBytesReadInsufficient()
|
|
|
|
if stop_byte != 0xFF:
|
|
self._logger.error("Invalid stop byte")
|
|
raise InvalidStopByteError()
|
|
|
|
try:
|
|
return bytes_read[self.header_size + self.data_length + 1 :]
|
|
except IndexError:
|
|
return bytes()
|
|
|
|
|
|
class SerialReceiver:
|
|
def __init__(
|
|
self,
|
|
root_logger: logging.Logger,
|
|
target: CommandTarget,
|
|
) -> None:
|
|
self.root_logger = root_logger
|
|
self.target = target
|
|
|
|
self._logger = root_logger.getChild(self.__class__.__name__)
|
|
self._logger.setLevel(logging.INFO)
|
|
|
|
self._bytes_unread = bytearray()
|
|
|
|
def receive(
|
|
self,
|
|
serial: Serial,
|
|
) -> Tuple[Sequence[Command], Sequence[Response]]:
|
|
commands_received: List[Command] = list()
|
|
responses_received: List[Response] = list()
|
|
|
|
self._bytes_unread.extend(serial.read(serial.in_waiting))
|
|
|
|
while self._bytes_unread:
|
|
|
|
try:
|
|
command_interpreter = CommandInterpreter(
|
|
root_logger=self.root_logger,
|
|
)
|
|
self._bytes_unread = bytearray(
|
|
command_interpreter.interpret(
|
|
bytes_read=self._bytes_unread,
|
|
),
|
|
)
|
|
except InvalidStopByteError:
|
|
# drop bytes until after the next stop byte or buffer is empty
|
|
while self._bytes_unread:
|
|
if self._bytes_unread.pop(0) == 0xFF:
|
|
continue
|
|
except CommandInterpretationError:
|
|
return commands_received, responses_received
|
|
|
|
try:
|
|
command_id = CommandId(command_interpreter.command_id_int)
|
|
except ValueError:
|
|
self._logger.error(
|
|
f"invalid command {command_interpreter.command_id_int} with "
|
|
f"payload {str(command_interpreter.payload)}",
|
|
)
|
|
|
|
if command_id == CommandId.command_log:
|
|
command = commands.LogCommand(
|
|
root_logger=self.root_logger,
|
|
target=self.target,
|
|
data=command_interpreter.payload,
|
|
)
|
|
commands_received.append(command)
|
|
continue
|
|
|
|
try:
|
|
response_class = get_response_class(command_id=command_id)
|
|
except KeyError:
|
|
raise RuntimeError
|
|
|
|
responses_received.append(
|
|
response_class(command_interpreter.payload),
|
|
)
|
|
|
|
return commands_received, responses_received
|
|
|
|
|
|
_process: Dict[str, Process] = dict()
|
|
"""role name: process"""
|
|
|
|
|
|
def _end_running_process():
|
|
process: Process
|
|
for process in _process.values():
|
|
process.kill()
|
|
|
|
|
|
@log_function_call
|
|
def start_backgroup_process():
|
|
global _process
|
|
global _command_queue
|
|
global _manager
|
|
container = get_initialize_container()
|
|
|
|
_manager = Manager()
|
|
role_name: str
|
|
for role_name in container.config.roles():
|
|
_command_queue[role_name] = Queue()
|
|
_response_queue[role_name] = Queue(maxsize=32)
|
|
_awaiting_response_identifier_list[role_name] = _manager.list()
|
|
|
|
_process[role_name] = Process(
|
|
target=worker_process,
|
|
args=(
|
|
role_name,
|
|
_command_queue[role_name],
|
|
_response_queue[role_name],
|
|
_awaiting_response_identifier_list[role_name],
|
|
),
|
|
)
|
|
_process[role_name].start()
|
|
|
|
atexit.register(_end_running_process)
|
|
|