""" Communication protocol for laser control module. Encodes commands to bytes and decodes device responses. Faithful re-implementation of the logic in device_commands.py, refactored into a clean, testable class-based API. """ import struct from typing import Optional from enum import IntEnum from datetime import datetime import serial import serial.tools.list_ports from .constants import ( BAUDRATE, SERIAL_TIMEOUT_SEC, GET_DATA_TOTAL_LENGTH, SEND_PARAMS_TOTAL_LENGTH, TASK_ENABLE_COMMAND_LENGTH, CMD_DECODE_ENABLE, CMD_DEFAULT_ENABLE, CMD_TRANS_ENABLE, CMD_REMOVE_FILE, CMD_STATE, CMD_TASK_ENABLE, STATE_DESCRIPTIONS, STATE_OK, ) from .conversions import ( temp_c_to_n, temp_n_to_c, temp_ext_n_to_c, current_ma_to_n, current_n_to_ma, voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v, ) from .models import Measurements, VariationType from .exceptions import ( CommunicationError, PortNotFoundError, CRCError, ProtocolError, ) # Re-export enums so tests can import from protocol module class CommandCode(IntEnum): DECODE_ENABLE = CMD_DECODE_ENABLE DEFAULT_ENABLE = CMD_DEFAULT_ENABLE TRANS_ENABLE = CMD_TRANS_ENABLE REMOVE_FILE = CMD_REMOVE_FILE STATE = CMD_STATE TASK_ENABLE = CMD_TASK_ENABLE class TaskType(IntEnum): MANUAL = 0x00 CHANGE_CURRENT_LD1 = 0x01 CHANGE_CURRENT_LD2 = 0x02 CHANGE_TEMPERATURE_LD1 = 0x03 CHANGE_TEMPERATURE_LD2 = 0x04 class DeviceState(IntEnum): IDLE = 0x0000 RUNNING = 0x0001 BUSY = 0x0002 ERROR = 0x00FF ERROR_OVERHEAT = 0x0100 ERROR_POWER = 0x0200 ERROR_COMMUNICATION = 0x0400 ERROR_INVALID_COMMAND = 0x0800 # ---- Low-level helpers -------------------------------------------------- def _int_to_hex4(value: int) -> str: """Return 4-character lowercase hex string (0–65535).""" if value < 0 or value > 65535: raise ValueError(f"Value {value} out of uint16 range [0, 65535]") return f"{value:04x}" def _flipfour(s: str) -> str: """Swap two byte-pairs: 'aabb' → 'bbaa' (little-endian word).""" if len(s) != 4: raise ValueError(f"Expected 4-char hex string, got '{s}'") return s[2:4] + s[0:2] def _xor_crc(words: list) -> str: """XOR all 16-bit hex words and return 4-char hex CRC.""" result = int(words[0], 16) for w in words[1:]: result ^= int(w, 16) return _int_to_hex4(result) def _build_crc(data_hex: str) -> str: """Calculate XOR CRC over words 1..N of a hex string (skip word 0).""" words = [data_hex[i:i+4] for i in range(0, len(data_hex), 4)] return _xor_crc(words[1:]) def _encode_setup() -> str: """Build the 16-bit setup word (all subsystems enabled, SD save off).""" bits = ['0'] * 16 bits[15] = '1' # enable work bits[14] = '1' # enable 5v1 bits[13] = '1' # enable 5v2 bits[12] = '1' # enable LD1 bits[11] = '1' # enable LD2 bits[10] = '1' # enable REF1 bits[9] = '1' # enable REF2 bits[8] = '1' # enable TEC1 bits[7] = '1' # enable TEC2 bits[6] = '1' # enable temp stab 1 bits[5] = '1' # enable temp stab 2 bits[4] = '0' # enable sd save (disabled) bits[3] = '1' # enable PI1 coef read bits[2] = '1' # enable PI2 coef read bits[1] = '0' # reserved bits[0] = '0' # reserved return f"{int(''.join(bits), 2):04x}" # ---- Response dataclass -------------------------------------------------- class Response: """Decoded device DATA response.""" __slots__ = [ 'current1', 'current2', 'temp1', 'temp2', 'temp_ext1', 'temp_ext2', 'voltage_3v3', 'voltage_5v1', 'voltage_5v2', 'voltage_7v0', 'to6_lsb', 'to6_msb', 'message_id', 'header', ] def to_measurements(self) -> Measurements: return Measurements( current1=self.current1, current2=self.current2, temp1=self.temp1, temp2=self.temp2, temp_ext1=self.temp_ext1, temp_ext2=self.temp_ext2, voltage_3v3=self.voltage_3v3, voltage_5v1=self.voltage_5v1, voltage_5v2=self.voltage_5v2, voltage_7v0=self.voltage_7v0, timestamp=datetime.now(), message_id=self.message_id, to6_counter_lsb=self.to6_lsb, to6_counter_msb=self.to6_msb, ) # ---- Message builder -------------------------------------------------- class Message: """Named container for an encoded command byte array.""" def __init__(self, data: bytearray): self._data = data def to_bytes(self) -> bytes: return bytes(self._data) def __len__(self): return len(self._data) # ---- Protocol class -------------------------------------------------- class Protocol: """ Encodes commands and decodes responses for the laser control board. Can also manage a serial port connection when port is provided. """ def __init__(self, port: Optional[str] = None): self._port_name = port self._serial: Optional[serial.Serial] = None # ---- Connection management def connect(self) -> None: """Open the serial port. Auto-detects if port is None.""" port = self._port_name or self._detect_port() try: self._serial = serial.Serial( port=port, baudrate=BAUDRATE, timeout=SERIAL_TIMEOUT_SEC, ) except Exception as exc: raise CommunicationError( f"Cannot connect to port '{port}': {exc}" ) from exc def disconnect(self) -> None: """Close the serial port if open.""" if self._serial and self._serial.is_open: self._serial.close() @property def is_connected(self) -> bool: return self._serial is not None and self._serial.is_open def _detect_port(self) -> str: """Return first available serial port device path.""" ports = list(serial.tools.list_ports.comports()) if not ports: raise PortNotFoundError() return ports[0].device # ---- Raw I/O def send_raw(self, data: bytes) -> None: self._serial.write(data) def receive_raw(self, length: int) -> bytes: return self._serial.read(length) # ---- Static encoding helpers (no connection required) --------------- @staticmethod def flipfour(value: int) -> int: """Byte-swap a 16-bit integer (little-endian word swap).""" return ((value & 0xFF) << 8) | ((value >> 8) & 0xFF) @staticmethod def pack_float(value: float) -> bytes: return struct.pack(' bytes: return struct.pack(' int: """ XOR CRC over all 16-bit words except the last two bytes (CRC field). Mirrors the original CalculateCRC logic. """ hex_str = data.hex() words = [hex_str[i:i+4] for i in range(0, len(hex_str), 4)] # Skip word 0 (command code) per original firmware expectation crc_words = words[1:] result = int(crc_words[0], 16) for w in crc_words[1:]: result ^= int(w, 16) return result # ---- Command encoders ----------------------------------------------- @staticmethod def encode_decode_enable( temp1: float, temp2: float, current1: float, current2: float, pi_coeff1_p: int, pi_coeff1_i: int, pi_coeff2_p: int, pi_coeff2_i: int, message_id: int, ) -> bytes: """ Build DECODE_ENABLE command (0x1111). Sets temperature and current setpoints for both lasers. Returns 30-byte bytearray. """ if current1 < 0 or current2 < 0: raise ValueError("Current values must not be negative") data = _flipfour(_int_to_hex4(CMD_DECODE_ENABLE)) # Word 0 data += _flipfour(_encode_setup()) # Word 1 data += _flipfour(_int_to_hex4(temp_c_to_n(temp1))) # Word 2 data += _flipfour(_int_to_hex4(temp_c_to_n(temp2))) # Word 3 data += _flipfour('0000') * 3 # Words 4-6 data += _flipfour(_int_to_hex4(pi_coeff1_p)) # Word 7 data += _flipfour(_int_to_hex4(pi_coeff1_i)) # Word 8 data += _flipfour(_int_to_hex4(pi_coeff2_p)) # Word 9 data += _flipfour(_int_to_hex4(pi_coeff2_i)) # Word 10 data += _flipfour(_int_to_hex4(message_id & 0xFFFF)) # Word 11 data += _flipfour(_int_to_hex4(current_ma_to_n(current1))) # Word 12 data += _flipfour(_int_to_hex4(current_ma_to_n(current2))) # Word 13 data += _build_crc(data) # Word 14 result = bytearray.fromhex(data) assert len(result) == SEND_PARAMS_TOTAL_LENGTH, \ f"DECODE_ENABLE length mismatch: {len(result)}" return bytes(result) @staticmethod def encode_task_enable( task_type: TaskType, static_temp1: float, static_temp2: float, static_current1: float, static_current2: float, min_value: float, max_value: float, step: float, time_step: int, delay_time: int, message_id: int, pi_coeff1_p: int = 1, pi_coeff1_i: int = 1, pi_coeff2_p: int = 1, pi_coeff2_i: int = 1, ) -> bytes: """ Build TASK_ENABLE command (0x7777). Starts a measurement task (current or temperature variation). Returns 32-byte bytearray. """ if not isinstance(task_type, TaskType): try: task_type = TaskType(task_type) except ValueError: raise ValueError(f"Invalid task_type: {task_type}") data = _flipfour(_int_to_hex4(CMD_TASK_ENABLE)) # Word 0 data += _flipfour(_encode_setup()) # Word 1 data += _flipfour(_int_to_hex4(task_type.value)) # Word 2 match task_type: case TaskType.CHANGE_CURRENT_LD1: data += _flipfour(_int_to_hex4(current_ma_to_n(min_value))) # Word 3 data += _flipfour(_int_to_hex4(current_ma_to_n(max_value))) # Word 4 data += _flipfour(_int_to_hex4(current_ma_to_n(step))) # Word 5 data += _flipfour(_int_to_hex4(int(delay_time * 100))) # Word 6 data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp1))) # Word 7 data += _flipfour(_int_to_hex4(current_ma_to_n(static_current2)))# Word 8 data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp2))) # Word 9 case TaskType.CHANGE_CURRENT_LD2: data += _flipfour(_int_to_hex4(current_ma_to_n(min_value))) # Word 3 data += _flipfour(_int_to_hex4(current_ma_to_n(max_value))) # Word 4 data += _flipfour(_int_to_hex4(int(step * 100))) # Word 5 data += _flipfour(_int_to_hex4(int(delay_time * 100))) # Word 6 data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp2))) # Word 7 data += _flipfour(_int_to_hex4(current_ma_to_n(static_current1)))# Word 8 data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp1))) # Word 9 case TaskType.CHANGE_TEMPERATURE_LD1 | TaskType.CHANGE_TEMPERATURE_LD2: raise NotImplementedError("Temperature variation is not yet implemented in firmware") case _: raise ValueError(f"Unsupported task type: {task_type}") data += _flipfour(_int_to_hex4(time_step)) # Word 10 data += _flipfour(_int_to_hex4(pi_coeff1_p)) # Word 11 data += _flipfour(_int_to_hex4(pi_coeff1_i)) # Word 12 data += _flipfour(_int_to_hex4(pi_coeff2_p)) # Word 13 data += _flipfour(_int_to_hex4(pi_coeff2_i)) # Word 14 data += _build_crc(data) # Word 15 result = bytearray.fromhex(data) assert len(result) == TASK_ENABLE_COMMAND_LENGTH, \ f"TASK_ENABLE length mismatch: {len(result)}" return bytes(result) @staticmethod def encode_trans_enable(message_id: int = 0) -> bytes: """Build TRANS_ENABLE command (0x4444) — request last data.""" return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_TRANS_ENABLE))) @staticmethod def encode_state(message_id: int = 0) -> bytes: """Build STATE command (0x6666) — request device state.""" return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_STATE))) @staticmethod def encode_default_enable(message_id: int = 0) -> bytes: """Build DEFAULT_ENABLE command (0x2222) — reset device.""" return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_DEFAULT_ENABLE))) @staticmethod def encode_remove_file() -> bytes: """Build REMOVE_FILE command (0x5555) — delete saved data.""" return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_REMOVE_FILE))) # ---- Response decoders ----------------------------------------------- @staticmethod def decode_response(data: bytes) -> Response: """ Decode a 30-byte DATA response from the device. Raises: ProtocolError: If data length is wrong. CRCError: If CRC check fails. """ if len(data) != GET_DATA_TOTAL_LENGTH: raise ProtocolError( f"Expected {GET_DATA_TOTAL_LENGTH} bytes, got {len(data)} bytes" ) hex_str = data.hex() def get_word(num: int) -> str: return _flipfour(hex_str[num*4: num*4+4]) def get_int_word(num: int) -> int: return int(get_word(num), 16) # CRC check: XOR over words 1..13 (wire order), compare with word 14 (wire order) crc_words = [hex_str[i:i+4] for i in range(4, len(hex_str)-4, 4)] computed = int(crc_words[0], 16) for w in crc_words[1:]: computed ^= int(w, 16) stored = int(hex_str[56:60], 16) if computed != stored: raise CRCError(expected=computed, received=stored) resp = Response() resp.header = get_word(0) resp.current1 = current_n_to_ma(get_int_word(1)) resp.current2 = current_n_to_ma(get_int_word(2)) resp.to6_lsb = get_int_word(3) resp.to6_msb = get_int_word(4) resp.temp1 = temp_n_to_c(get_int_word(5)) resp.temp2 = temp_n_to_c(get_int_word(6)) resp.temp_ext1 = temp_ext_n_to_c(get_int_word(7)) resp.temp_ext2 = temp_ext_n_to_c(get_int_word(8)) resp.voltage_3v3 = voltage_3v3_n_to_v(get_int_word(9)) resp.voltage_5v1 = voltage_5v_n_to_v(get_int_word(10)) resp.voltage_5v2 = voltage_5v_n_to_v(get_int_word(11)) resp.voltage_7v0 = voltage_7v_n_to_v(get_int_word(12)) resp.message_id = get_int_word(13) return resp @staticmethod def decode_state(data: bytes) -> int: """ Decode a 2-byte STATE response from the device. Returns: Integer state code (compare with DeviceState enum). """ if len(data) < 2: raise ProtocolError(f"STATE response too short: {len(data)} bytes") hex_str = data.hex() state_hex = _flipfour(hex_str[0:4]) return int(state_hex, 16) @staticmethod def state_to_description(state_hex_str: str) -> str: """Return human-readable description for a state hex string.""" return STATE_DESCRIPTIONS.get(state_hex_str, "Unknown or reserved error.")