"""Codec for the UART protocol implemented by the current firmware.""" from __future__ import annotations from datetime import datetime import struct from .constants import ( AD9102_FLAG_ENABLE, AD9102_FLAG_SRAM, AD9102_FLAG_SRAM_FORMAT_ALT, AD9102_FLAG_TRIANGLE, AD9102_WAVE_MAX_CHUNK_SAMPLES, AD9102_WAVE_OPCODE_BEGIN, AD9102_WAVE_OPCODE_CANCEL, AD9102_WAVE_OPCODE_COMMIT, AD9102_WAVE_SAMPLE_MAX, AD9102_WAVE_SAMPLE_MIN, AD9833_FLAG_ENABLE, AD9833_FLAG_TRIANGLE, CMD_DECODE_ENABLE, CMD_DEFAULT_ENABLE, CMD_PROFILE_SAVE_CONTROL, CMD_PROFILE_SAVE_DATA, CMD_AD9102_CONTROL, CMD_AD9102_WAVE_CONTROL, CMD_AD9102_WAVE_DATA, CMD_AD9833_CONTROL, CMD_DS1809_CONTROL, CMD_STATE, CMD_STM32_DAC_CONTROL, CMD_TRANS_ENABLE, DEFAULT_SETUP_WORD, DS1809_FLAG_DECREMENT, DS1809_FLAG_INCREMENT, GET_DATA_TOTAL_LENGTH, PROFILE_NAME_MAX_LENGTH, PROFILE_SAVE_CONTROL_TOTAL_LENGTH, PROFILE_SAVE_DATA_CHUNK_BYTES, PROFILE_SAVE_DATA_TOTAL_LENGTH, PROFILE_SAVE_OPCODE_BEGIN, PROFILE_SAVE_OPCODE_CANCEL, PROFILE_SAVE_OPCODE_COMMIT, PROFILE_SAVE_SECTION_PROFILE_TEXT, PROFILE_SAVE_SECTION_WAVEFORM_TEXT, SEND_PARAMS_TOTAL_LENGTH, SHORT_CONTROL_TOTAL_LENGTH, STM32_DAC_FLAG_ENABLE, STATUS_DESCRIPTIONS, STATUS_RESPONSE_LENGTH, WAVE_DATA_TOTAL_LENGTH, ) from .conversions import ( current_ma_to_n, current_n_to_ma, temp_c_to_n, temp_ext_n_to_c, temp_n_to_c, voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v, ) from .exceptions import CRCError, ProtocolError from .models import DeviceState, Measurements def _int_to_hex4(value: int) -> str: """Return a zero-padded four-digit lowercase hex string.""" if value < 0 or value > 0xFFFF: raise ValueError(f"Value {value} out of uint16 range") return f"{value:04x}" def _flipfour(value: str) -> str: """Swap byte pairs in a four-character hex word.""" if len(value) != 4: raise ValueError(f"Expected 4 hex chars, got {value!r}") return value[2:4] + value[0:2] def _build_crc(data_hex: str) -> str: """Return the checksum word for a wire-order hex packet without CRC.""" if len(data_hex) % 4 != 0: raise ValueError("Packet hex string must contain complete 16-bit words") words = [data_hex[index:index + 4] for index in range(0, len(data_hex), 4)] checksum = 0 for word in words[1:]: checksum ^= int(word, 16) return _int_to_hex4(checksum) def _pack_words(words: list[int]) -> bytes: return struct.pack("<" + "H" * len(words), *words) def _unpack_words(data: bytes) -> tuple[int, ...]: if len(data) % 2 != 0: raise ProtocolError(f"Packet length must be even, got {len(data)} bytes") return struct.unpack("<" + "H" * (len(data) // 2), data) def _payload_checksum(words: list[int]) -> int: checksum = 0 for word in words: checksum ^= word return checksum & 0xFFFF def _ensure_uint(value: int, name: str, minimum: int, maximum: int) -> int: if not isinstance(value, int): raise ValueError(f"{name} must be an integer") if not minimum <= value <= maximum: raise ValueError(f"{name} must be in range [{minimum}, {maximum}]") return value def _encode_ascii_name_words(profile_name: str) -> tuple[list[int], int]: if not isinstance(profile_name, str): raise ValueError("profile_name must be a string") try: encoded = profile_name.encode("ascii") except UnicodeEncodeError as exc: raise ValueError("profile_name must contain ASCII characters only") from exc if not 1 <= len(encoded) <= PROFILE_NAME_MAX_LENGTH: raise ValueError( f"profile_name length must be in range [1, {PROFILE_NAME_MAX_LENGTH}]" ) padded = encoded + (b"\x00" * (PROFILE_NAME_MAX_LENGTH - len(encoded))) words = [ padded[index] | (padded[index + 1] << 8) for index in range(0, PROFILE_NAME_MAX_LENGTH, 2) ] return words, len(encoded) class Protocol: """Static helpers for encoding commands and decoding responses.""" @staticmethod def calculate_crc(data: bytes) -> int: """Calculate XOR checksum over all words except the first header word.""" words = _unpack_words(data) if len(words) <= 1: return 0 return _payload_checksum(list(words[1:])) @staticmethod def encode_decode_enable( temp1: float, temp2: float, current1: float, current2: float, pi_coeff1_p: int, pi_coeff1_i: int, pi_coeff2_p: int, pi_coeff2_i: int, message_id: int, ) -> bytes: """Build the 30-byte DECODE_ENABLE command.""" words = [ CMD_DECODE_ENABLE, DEFAULT_SETUP_WORD, temp_c_to_n(temp1), temp_c_to_n(temp2), 0, 0, 0, pi_coeff1_p & 0xFFFF, pi_coeff1_i & 0xFFFF, pi_coeff2_p & 0xFFFF, pi_coeff2_i & 0xFFFF, message_id & 0xFFFF, current_ma_to_n(current1), current_ma_to_n(current2), ] words.append(_payload_checksum(words[1:])) packet = _pack_words(words) if len(packet) != SEND_PARAMS_TOTAL_LENGTH: raise ProtocolError( f"DECODE_ENABLE length mismatch: {len(packet)} bytes" ) return packet @staticmethod def encode_trans_enable() -> bytes: """Build the short TRANS_ENABLE command.""" return _pack_words([CMD_TRANS_ENABLE]) @staticmethod def encode_state() -> bytes: """Build the short STATE command.""" return _pack_words([CMD_STATE]) @staticmethod def encode_default_enable() -> bytes: """Build the short DEFAULT_ENABLE command.""" return _pack_words([CMD_DEFAULT_ENABLE]) @staticmethod def encode_ad9102_control( *, enabled: bool, triangle: bool, sram_mode: bool, param0: int, param1: int, alt_format: bool = False, ) -> bytes: """Build an AD9102 control packet.""" flags = 0 if enabled: flags |= AD9102_FLAG_ENABLE if triangle: flags |= AD9102_FLAG_TRIANGLE if sram_mode: flags |= AD9102_FLAG_SRAM if alt_format: flags |= AD9102_FLAG_SRAM_FORMAT_ALT return Protocol._encode_short_control( CMD_AD9102_CONTROL, flags, _ensure_uint(param0, "param0", 0, 0xFFFF), _ensure_uint(param1, "param1", 0, 0xFFFF), ) @staticmethod def encode_ad9833_control(*, enabled: bool, triangle: bool, frequency_word: int) -> bytes: """Build an AD9833 control packet.""" flags = 0 if enabled: flags |= AD9833_FLAG_ENABLE if triangle: flags |= AD9833_FLAG_TRIANGLE frequency_word = _ensure_uint(frequency_word, "frequency_word", 0, 0x0FFFFFFF) return Protocol._encode_short_control( CMD_AD9833_CONTROL, flags, frequency_word & 0x3FFF, (frequency_word >> 14) & 0x3FFF, ) @staticmethod def encode_ds1809_control(*, increment: bool, decrement: bool, count: int, pulse_ms: int) -> bytes: """Build a DS1809 control packet.""" if increment and decrement: raise ValueError("increment and decrement cannot both be true") flags = 0 if increment: flags |= DS1809_FLAG_INCREMENT if decrement: flags |= DS1809_FLAG_DECREMENT return Protocol._encode_short_control( CMD_DS1809_CONTROL, flags, _ensure_uint(count, "count", 0, 0xFFFF), _ensure_uint(pulse_ms, "pulse_ms", 0, 0xFFFF), ) @staticmethod def encode_stm32_dac_control(*, enabled: bool, dac_code: int) -> bytes: """Build an STM32 DAC control packet.""" flags = STM32_DAC_FLAG_ENABLE if enabled else 0 return Protocol._encode_short_control( CMD_STM32_DAC_CONTROL, flags, _ensure_uint(dac_code, "dac_code", 0, 0x0FFF), 0, ) @staticmethod def encode_ad9102_wave_begin(sample_count: int) -> bytes: """Build an AD9102 custom-wave upload BEGIN packet.""" return Protocol._encode_short_control( CMD_AD9102_WAVE_CONTROL, AD9102_WAVE_OPCODE_BEGIN, _ensure_uint(sample_count, "sample_count", 0, 0xFFFF), 0, ) @staticmethod def encode_ad9102_wave_commit() -> bytes: """Build an AD9102 custom-wave upload COMMIT packet.""" return Protocol._encode_short_control( CMD_AD9102_WAVE_CONTROL, AD9102_WAVE_OPCODE_COMMIT, 0, 0, ) @staticmethod def encode_ad9102_wave_cancel() -> bytes: """Build an AD9102 custom-wave upload CANCEL packet.""" return Protocol._encode_short_control( CMD_AD9102_WAVE_CONTROL, AD9102_WAVE_OPCODE_CANCEL, 0, 0, ) @staticmethod def encode_ad9102_wave_data(samples: list[int]) -> bytes: """Build one fixed-size AD9102 custom-wave data chunk packet.""" if not samples: raise ValueError("samples must not be empty") if len(samples) > AD9102_WAVE_MAX_CHUNK_SAMPLES: raise ValueError( f"samples length must be <= {AD9102_WAVE_MAX_CHUNK_SAMPLES}" ) encoded_samples = [] for index, sample in enumerate(samples): if not isinstance(sample, int): raise ValueError(f"sample[{index}] must be an integer") if not AD9102_WAVE_SAMPLE_MIN <= sample <= AD9102_WAVE_SAMPLE_MAX: raise ValueError( f"sample[{index}] must be in range " f"[{AD9102_WAVE_SAMPLE_MIN}, {AD9102_WAVE_SAMPLE_MAX}]" ) encoded_samples.append(sample & 0xFFFF) padded_samples = encoded_samples + [0] * (AD9102_WAVE_MAX_CHUNK_SAMPLES - len(samples)) words = [CMD_AD9102_WAVE_DATA, len(samples), *padded_samples] words.append(_payload_checksum(words[1:])) packet = _pack_words(words) if len(packet) != WAVE_DATA_TOTAL_LENGTH: raise ProtocolError(f"AD9102_WAVE_DATA length mismatch: {len(packet)} bytes") return packet @staticmethod def encode_profile_save_begin( *, profile_name: str, profile_text_bytes: int, waveform_text_bytes: int, ) -> bytes: """Build the fixed-size BEGIN packet for a streamed SD profile save.""" name_words, name_length = _encode_ascii_name_words(profile_name) payload_words = [ PROFILE_SAVE_OPCODE_BEGIN, _ensure_uint(profile_text_bytes, "profile_text_bytes", 1, 0xFFFF), _ensure_uint(waveform_text_bytes, "waveform_text_bytes", 0, 0xFFFF), name_length, *name_words, 0, ] payload_words.append(_payload_checksum(payload_words)) packet = _pack_words([CMD_PROFILE_SAVE_CONTROL, *payload_words]) if len(packet) != PROFILE_SAVE_CONTROL_TOTAL_LENGTH: raise ProtocolError( f"PROFILE_SAVE_BEGIN length mismatch: {len(packet)} bytes" ) return packet @staticmethod def encode_profile_save_commit() -> bytes: """Build the fixed-size COMMIT packet for a streamed SD profile save.""" payload_words = [PROFILE_SAVE_OPCODE_COMMIT] + ([0] * 12) payload_words.append(_payload_checksum(payload_words)) packet = _pack_words([CMD_PROFILE_SAVE_CONTROL, *payload_words]) if len(packet) != PROFILE_SAVE_CONTROL_TOTAL_LENGTH: raise ProtocolError( f"PROFILE_SAVE_COMMIT length mismatch: {len(packet)} bytes" ) return packet @staticmethod def encode_profile_save_cancel() -> bytes: """Build the fixed-size CANCEL packet for a streamed SD profile save.""" payload_words = [PROFILE_SAVE_OPCODE_CANCEL] + ([0] * 12) payload_words.append(_payload_checksum(payload_words)) packet = _pack_words([CMD_PROFILE_SAVE_CONTROL, *payload_words]) if len(packet) != PROFILE_SAVE_CONTROL_TOTAL_LENGTH: raise ProtocolError( f"PROFILE_SAVE_CANCEL length mismatch: {len(packet)} bytes" ) return packet @staticmethod def encode_profile_save_data(*, section_id: int, chunk: bytes) -> bytes: """Build one fixed-size data packet carrying profile or waveform text.""" if not isinstance(chunk, (bytes, bytearray)): raise ValueError("chunk must be bytes") if not chunk: raise ValueError("chunk must not be empty") if len(chunk) > PROFILE_SAVE_DATA_CHUNK_BYTES: raise ValueError( f"chunk length must be <= {PROFILE_SAVE_DATA_CHUNK_BYTES}" ) if section_id not in ( PROFILE_SAVE_SECTION_PROFILE_TEXT, PROFILE_SAVE_SECTION_WAVEFORM_TEXT, ): raise ValueError("section_id is invalid") padded = bytes(chunk) + (b"\x00" * (PROFILE_SAVE_DATA_CHUNK_BYTES - len(chunk))) data_words = [ padded[index] | (padded[index + 1] << 8) for index in range(0, PROFILE_SAVE_DATA_CHUNK_BYTES, 2) ] payload_words = [section_id, len(chunk), *data_words] payload_words.append(_payload_checksum(payload_words)) packet = _pack_words([CMD_PROFILE_SAVE_DATA, *payload_words]) if len(packet) != PROFILE_SAVE_DATA_TOTAL_LENGTH: raise ProtocolError( f"PROFILE_SAVE_DATA length mismatch: {len(packet)} bytes" ) return packet @staticmethod def _encode_short_control(header: int, word0: int, word1: int, word2: int) -> bytes: words = [header, word0 & 0xFFFF, word1 & 0xFFFF, word2 & 0xFFFF] words.append(_payload_checksum(words[1:])) packet = _pack_words(words) if len(packet) != SHORT_CONTROL_TOTAL_LENGTH: raise ProtocolError(f"Short control length mismatch: {len(packet)} bytes") return packet @staticmethod def decode_response(data: bytes) -> Measurements: """Decode a 30-byte telemetry frame into a Measurements object.""" if len(data) != GET_DATA_TOTAL_LENGTH: raise ProtocolError( f"Expected {GET_DATA_TOTAL_LENGTH} bytes, got {len(data)} bytes" ) words = _unpack_words(data) expected_crc = _payload_checksum(list(words[1:14])) if words[14] != expected_crc: raise CRCError(expected=expected_crc, received=words[14]) return Measurements( current1=current_n_to_ma(words[1]), current2=current_n_to_ma(words[2]), temp1=temp_n_to_c(words[5]), temp2=temp_n_to_c(words[6]), temp_ext1=temp_ext_n_to_c(words[7]), temp_ext2=temp_ext_n_to_c(words[8]), voltage_3v3=voltage_3v3_n_to_v(words[9]), voltage_5v1=voltage_5v_n_to_v(words[10]), voltage_5v2=voltage_5v_n_to_v(words[11]), voltage_7v0=voltage_7v_n_to_v(words[12]), message_id=words[13], to6_counter_lsb=words[3], to6_counter_msb=words[4], timestamp=datetime.now(), ) @staticmethod def decode_status(data: bytes) -> tuple[DeviceState, int]: """Decode the two-byte firmware status response into flags and detail.""" if len(data) != STATUS_RESPONSE_LENGTH: raise ProtocolError( f"Expected {STATUS_RESPONSE_LENGTH} status bytes, got {len(data)}" ) raw_word = _unpack_words(data)[0] flags = DeviceState(raw_word & 0x00FF) detail = (raw_word >> 8) & 0x00FF return flags, detail @staticmethod def decode_state(data: bytes) -> int: """Compatibility helper returning only the low-byte status mask.""" flags, _detail = Protocol.decode_status(data) return int(flags) @staticmethod def state_to_description(state: DeviceState | int) -> str: """Return a readable description for a status mask.""" state = DeviceState(int(state)) if state == DeviceState.OK: return "All ok." parts = [ text for mask, text in STATUS_DESCRIPTIONS.items() if (state & DeviceState(mask)) == DeviceState(mask) ] if parts: return "; ".join(parts) return f"Unknown status mask: 0x{int(state):02X}" __all__ = ["Protocol", "_build_crc", "_flipfour", "_int_to_hex4"]