You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

436 lines
14 KiB

import atexit
import logging
import time
from enum import Enum
from multiprocessing import Manager
from multiprocessing import Process
from multiprocessing import Queue
from multiprocessing.managers import SyncManager
from queue import Empty
from struct import error
from struct import unpack
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from serial import Serial
from . import commands
from .commands import Command
from .commands import CommandId
from .commands import CommandTarget
from .commands import Request
from .commands import Response
from .commands import get_response_class
from .container import get_initialize_container
from .util import log_function_call
_logger = logging.getLogger(__file__)
_command_queue: Dict[str, Queue] = dict()
"""role name: command queue"""
_response_queue: Dict[str, Queue] = dict()
"""role name: response queue"""
_manager: Optional[SyncManager] = None
_awaiting_response_identifier_list: Dict[str, List] = dict()
"""role name: request queue"""
class State(Enum):
heart_beat = 0x0
executing_command = 0x1
executing_command_waiting_for_response = 0x2
receiving_command = 0x10
def worker_process(
role: str,
commnad_queue: Queue,
response_queue: Queue,
awaiting_response_identifier_list: List,
):
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s] [%(name)-20s] [%(levelname)-8s] --- %(message)s",
)
root_logger = logging.getLogger(role)
root_logger.setLevel(logging.DEBUG)
logger = root_logger.getChild("worker_process")
logger.setLevel(logging.INFO)
container = get_initialize_container()
container.config.from_dict(
{
"role": role,
},
)
heartbeat_interval = container.config.heartbeat_interval()
serial_reconnection_wait_timeout = (
container.config.serial_reconnection_wait_timeout()
)
connected = False
logger.info("entering command loop...")
while True:
try:
with container.serial() as serial:
logger.info("connected with serial device")
connected = True
enter_fsm(
target=CommandTarget[role],
root_logger=root_logger,
serial=serial,
command_queue=commnad_queue,
response_queue=response_queue,
awaiting_response_identifier_list=awaiting_response_identifier_list,
heartbeat_interval=heartbeat_interval,
)
except OSError:
if connected:
logger.warning("connection to serial device lost")
connected = False
time.sleep(serial_reconnection_wait_timeout)
logger.warning("reconnecting...")
except Exception as e:
logger.exception(e)
logger.error("Unexpected exception happened, recovering...")
def enter_fsm(
target: CommandTarget,
root_logger: logging.Logger,
serial: Serial,
command_queue: Queue,
response_queue: Queue,
awaiting_response_identifier_list: List,
heartbeat_interval: float,
):
logger = root_logger.getChild("FSM")
state = State.executing_command
current_command: Optional[Command] = None
responses_received: List[Response] = list()
time_at_beginning_waiting_for_response: float = 0.0
last_heart_beat_time: float = 0.0
serial_receiver = SerialReceiver(
root_logger=root_logger,
target=target,
)
while True:
if state == State.heart_beat:
if time.time() - heartbeat_interval > last_heart_beat_time:
command_queue.put(
commands.HeartbeatRequest(
root_logger=root_logger,
target=target,
),
)
# heartbeat: client -> ble -> server -> ble -> client
# if target == CommandTarget.client:
# command_queue.put(
# commands.HeartbeatRequest(
# root_logger=root_logger,
# target=CommandTarget.server,
# ),
# )
last_heart_beat_time = time.time()
state = State.executing_command
continue
elif state == State.executing_command:
current_command = dequeue_command(queue=command_queue)
if current_command is None:
state = State.receiving_command
continue
current_command.execute(serial=serial)
if isinstance(current_command, Request):
time_at_beginning_waiting_for_response = time.time()
state = State.executing_command_waiting_for_response
continue
elif state == State.executing_command_waiting_for_response:
if not isinstance(current_command, Request):
raise RuntimeError(
"entered state 'executing_command_waiting_for_response' but "
"current command does not expect a response.",
)
else:
request: Request = current_command
commands_, responses = serial_receiver.receive(
serial=serial,
)
responses_received.extend(responses)
for command in commands_:
command_queue.put(command)
while responses_received:
received_response: Response = responses_received.pop(0)
if request.response_identifier == received_response.identifier:
request.process_response(
response=received_response,
)
if request.response_identifier in awaiting_response_identifier_list:
response_queue.put(received_response)
awaiting_response_identifier_list.remove(
request.response_identifier,
)
state = State.executing_command
break
else:
logger.warning(
f"received response with ID {received_response.identifier} "
"but expected response with ID "
f"{request.response_identifier}",
)
else:
if (
time.time() - request.timeout
> time_at_beginning_waiting_for_response
):
logger.error(
"Timeout while waiting for response with ID "
f"{request.response_identifier}",
)
current_command = None
state = State.executing_command
continue
elif state == State.receiving_command:
commands_, responses = serial_receiver.receive(
serial=serial,
)
responses_received.extend(responses)
for command in commands_:
command_queue.put(command)
state = State.heart_beat
continue
else:
raise RuntimeError(f"Invalid state: {state}")
def dequeue_command(
queue: Queue,
) -> Optional[Command]:
while not queue.empty():
return queue.get()
return None
def execute_command(
role: str,
command: Command,
timeout: int = 2,
) -> Optional[Response]:
"""Add a command to the command queue
:param role: The role name
:param command: The command to enqueue
"""
if isinstance(command, Request):
_awaiting_response_identifier_list[role].append(command.response_identifier)
_command_queue[role].put(command)
try:
return _response_queue[role].get(timeout=timeout)
except Empty:
return None
class CommandInterpretationError(Exception):
"""Raised in case the command could not be interpreted"""
class CommandBytesReadInsufficient(CommandInterpretationError):
"""Raised in case the command could not be interpreted"""
class InvalidStopByteError(CommandInterpretationError):
"""Raised in case of an invalid stop byte."""
class CommandInterpreter:
header_size = 4
def __init__(
self,
root_logger: logging.Logger,
) -> None:
self._logger = root_logger.getChild(self.__class__.__name__)
self._logger.setLevel(logging.INFO)
self.command_id_int = 0
self.data_length = 0
self.payload = bytes()
def interpret(
self,
bytes_read: bytes,
) -> bytes:
"""Interpret the first command in a byte stream.
:param bytes_read: The bytes which are not yet parsed.
:returns: The byte which are not yet parsed.
:raises CommandInterpretationError: If the command could not be parsed.
:raises CommandBytesReadInsufficient: Not enough bytes to fully parse the
command.
"""
self._logger.debug(f"bytes: {bytes_read.hex()}")
try:
self.command_id_int, self.data_length, _ = unpack(
">BHB",
bytes_read[: self.header_size],
)
except error:
self._logger.error("error while interpreting command header")
raise CommandBytesReadInsufficient()
try:
self.payload = bytes(
bytes_read[self.header_size : self.header_size + self.data_length],
)
except IndexError:
self._logger.error(
"There are less bytes than expected: "
f"Expected={self.header_size + self.data_length -1}, "
f"received={len(bytes_read)}",
)
self._logger.debug(f"bytes: {bytes_read.hex()}")
raise CommandBytesReadInsufficient()
try:
stop_byte = bytes_read[self.header_size + self.data_length]
except IndexError:
self._logger.debug("could not get stop byte")
raise CommandBytesReadInsufficient()
if stop_byte != 0xFF:
self._logger.error("Invalid stop byte")
raise InvalidStopByteError()
try:
return bytes_read[self.header_size + self.data_length + 1 :]
except IndexError:
return bytes()
class SerialReceiver:
def __init__(
self,
root_logger: logging.Logger,
target: CommandTarget,
) -> None:
self.root_logger = root_logger
self.target = target
self._logger = root_logger.getChild(self.__class__.__name__)
self._logger.setLevel(logging.INFO)
self._bytes_unread = bytearray()
def receive(
self,
serial: Serial,
) -> Tuple[Sequence[Command], Sequence[Response]]:
commands_received: List[Command] = list()
responses_received: List[Response] = list()
self._bytes_unread.extend(serial.read(serial.in_waiting))
while self._bytes_unread:
try:
command_interpreter = CommandInterpreter(
root_logger=self.root_logger,
)
self._bytes_unread = bytearray(
command_interpreter.interpret(
bytes_read=self._bytes_unread,
),
)
except InvalidStopByteError:
# drop bytes until after the next stop byte or buffer is empty
while self._bytes_unread:
if self._bytes_unread.pop(0) == 0xFF:
continue
except CommandInterpretationError:
return commands_received, responses_received
try:
command_id = CommandId(command_interpreter.command_id_int)
except ValueError:
self._logger.error(
f"invalid command {command_interpreter.command_id_int} with "
f"payload {str(command_interpreter.payload)}",
)
if command_id == CommandId.command_log:
command = commands.LogCommand(
root_logger=self.root_logger,
target=self.target,
data=command_interpreter.payload,
)
commands_received.append(command)
continue
try:
response_class = get_response_class(command_id=command_id)
except KeyError:
raise RuntimeError
responses_received.append(
response_class(command_interpreter.payload),
)
return commands_received, responses_received
_process: Dict[str, Process] = dict()
"""role name: process"""
def _end_running_process():
process: Process
for process in _process.values():
process.kill()
@log_function_call
def start_backgroup_process():
global _process
global _command_queue
global _manager
container = get_initialize_container()
_manager = Manager()
role_name: str
for role_name in container.config.roles():
_command_queue[role_name] = Queue()
_response_queue[role_name] = Queue(maxsize=32)
_awaiting_response_identifier_list[role_name] = _manager.list()
_process[role_name] = Process(
target=worker_process,
args=(
role_name,
_command_queue[role_name],
_response_queue[role_name],
_awaiting_response_identifier_list[role_name],
),
)
_process[role_name].start()
atexit.register(_end_running_process)