diff --git a/gui.py b/gui.py index 36c469b..f651a1b 100644 --- a/gui.py +++ b/gui.py @@ -70,31 +70,31 @@ def get_screen_size(): screen_width, screen_height = window.get_screen_size() window.close() - COMPACT_LAYOUT = screen_width <= 1280 or screen_height <= 800 + COMPACT_LAYOUT = True margin_w, margin_h = WINDOW_MARGIN min_w, min_h = MIN_WINDOW_SIZE window_width = min(screen_width, max(min_w, screen_width - margin_w)) - window_height = min(screen_height, max(min_h, screen_height - margin_h)) + window_height = min(screen_height, max(min_h, screen_height - margin_h))//2 WINDOW_SIZE = (window_width, window_height) if COMPACT_LAYOUT: SET_TEXT_WIDTH = 30 SET_TEXT_WIDTH_NEW = 34 - graph_width = min(int(screen_width / 3.6), int(window_width / 3.1)) - graph_height = max(110, int(screen_height / 6.5)) - H_SEPARATOR_PAD = (1, 12) - OUTPUT_TEXT_PAD = (5, (12, 5)) + graph_width = min(int(screen_width / 7.2), int(window_width / 5.6)) + graph_height = max(90, int(screen_height / 16)) + H_SEPARATOR_PAD = (1, 8) + OUTPUT_TEXT_PAD = (5, (8, 3)) else: SET_TEXT_WIDTH = 34 SET_TEXT_WIDTH_NEW = 40 - graph_width = int(screen_width / 3.5) - graph_height = int(screen_width / (3 * 2.75)) - H_SEPARATOR_PAD = (1, 20) - OUTPUT_TEXT_PAD = (5, (20, 5)) + graph_width = int(screen_width / 4) + graph_height = int(screen_width / (3 * 3.5)) + H_SEPARATOR_PAD = (1, 15) + OUTPUT_TEXT_PAD = (5, (15, 5)) - graph_width = max(220, graph_width) + graph_width = max(180, graph_width) GRAPH_CANVAS_SIZE = (graph_width, graph_height) return WINDOW_SIZE diff --git a/laser_control/__init__.py b/laser_control/__init__.py new file mode 100644 index 0000000..32a9600 --- /dev/null +++ b/laser_control/__init__.py @@ -0,0 +1,35 @@ +""" +Laser Control Module + +A standalone module for controlling dual laser systems with temperature and current regulation. +Provides a clean API for integration into any Python application. +""" + +from .controller import LaserController +from .models import ( + DeviceStatus, + Measurements, + ManualModeParams, + VariationParams, + VariationType +) +from .exceptions import ( + LaserControlError, + ValidationError, + CommunicationError, + DeviceError +) + +__version__ = "1.0.0" +__all__ = [ + "LaserController", + "DeviceStatus", + "Measurements", + "ManualModeParams", + "VariationParams", + "VariationType", + "LaserControlError", + "ValidationError", + "CommunicationError", + "DeviceError" +] \ No newline at end of file diff --git a/laser_control/constants.py b/laser_control/constants.py new file mode 100644 index 0000000..ac63179 --- /dev/null +++ b/laser_control/constants.py @@ -0,0 +1,122 @@ +""" +Constants for laser control module. + +Physical constraints, protocol parameters, and operational limits +extracted from original device_commands.py and device_conversion.py. +""" + +# ---- Protocol constants + +BAUDRATE = 115200 +SERIAL_TIMEOUT_SEC = 1.0 + +GET_DATA_TOTAL_LENGTH = 30 # bytes in device DATA response +SEND_PARAMS_TOTAL_LENGTH = 30 # bytes in DECODE_ENABLE command +TASK_ENABLE_COMMAND_LENGTH = 32 # bytes in TASK_ENABLE command + +WAIT_AFTER_SEND_SEC = 0.15 # delay after sending a command +GUI_POLL_INTERVAL_MS = 5 # GUI event loop timeout + +# ---- Command codes (as sent to device, already flipped to LE) + +CMD_DECODE_ENABLE = 0x1111 # Set control parameters +CMD_DEFAULT_ENABLE = 0x2222 # Reset device +CMD_TRANSS_ENABLE = 0x3333 # Request all saved data (not implemented) +CMD_TRANS_ENABLE = 0x4444 # Request last data +CMD_REMOVE_FILE = 0x5555 # Delete saved data +CMD_STATE = 0x6666 # Request state +CMD_TASK_ENABLE = 0x7777 # Start a task + +# ---- Error codes from device STATE response (after flipfour) + +STATE_OK = '0000' +STATE_SD_ERR = '0001' # SD Card read/write error +STATE_UART_ERR = '0002' # Command (UART) error +STATE_UART_DECODE_ERR = '0004' # Wrong parameter value +STATE_TEC1_ERR = '0008' # Laser 1 TEC driver overheat +STATE_TEC2_ERR = '0010' # Laser 2 TEC driver overheat +STATE_DEFAULT_ERR = '0020' # System reset error +STATE_REMOVE_ERR = '0040' # File deletion error + +STATE_DESCRIPTIONS = { + STATE_OK: "All ok.", + STATE_SD_ERR: "SD Card reading/writing error (SD_ERR).", + STATE_UART_ERR: "Command error (UART_ERR).", + STATE_UART_DECODE_ERR:"Wrong parameter value error (UART_DECODE_ERR).", + STATE_TEC1_ERR: "Laser 1: TEC driver overheat (TEC1_ERR).", + STATE_TEC2_ERR: "Laser 2: TEC driver overheat (TEC2_ERR).", + STATE_DEFAULT_ERR: "Resetting system error (DEFAULT_ERR).", + STATE_REMOVE_ERR: "File deletion error (REMOVE_ERR).", +} + +# ---- Physical / hardware constants (from device_conversion.py) + +VREF = 2.5 # Reference voltage, Volts + +# Bridge resistors for temperature measurement +R1 = 10000 # Ohm +R2 = 2200 # Ohm +R3 = 27000 # Ohm +R4 = 30000 # Ohm +R5 = 27000 # Ohm +R6 = 56000 # Ohm + +RREF = 10 # Current-setting resistor, Ohm + # (@1550 nm – 28.7 Ohm; @840 nm – 10 Ohm) + +# External thermistor divider resistors +R7 = 22000 # Ohm +R8 = 22000 # Ohm +R9 = 5100 # Ohm +R10 = 180000 # Ohm + +# Thermistor Steinhart–Hart B-coefficient (internal / external) +BETA_INTERNAL = 3900 # K +BETA_EXTERNAL = 3455 # K +T0_K = 298 # Kelvin (25 °C reference) +R0 = 10000 # Ohm (thermistor nominal at 25 °C) + +# ADC resolution +ADC_BITS_16 = 65535 # 2^16 - 1 +ADC_BITS_12 = 4095 # 2^12 - 1 + +# Voltage conversion coefficients +U3V3_COEFF = 1.221e-3 # counts → Volts for 3.3V rail +U5V_COEFF = 1.8315e-3 # counts → Volts for 5V rails +U7V_COEFF = 6.72e-3 # counts → Volts for 7V rail + +# ---- Operational limits (validated in validators.py) + +TEMP_MIN_C = 15.0 # Minimum allowed laser temperature, °C +TEMP_MAX_C = 40.0 # Maximum allowed laser temperature, °C + +CURRENT_MIN_MA = 15.0 # Minimum allowed laser current, mA +CURRENT_MAX_MA = 60.0 # Maximum allowed laser current, mA + +# Variation step limits +CURRENT_STEP_MIN_MA = 0.002 # Minimum current variation step, mA +CURRENT_STEP_MAX_MA = 0.5 # Maximum current variation step, mA + +TEMP_STEP_MIN_C = 0.05 # Minimum temperature variation step, °C +TEMP_STEP_MAX_C = 1.0 # Maximum temperature variation step, °C + +# Time parameter limits +TIME_STEP_MIN_US = 20 # Minimum time step, microseconds +TIME_STEP_MAX_US = 100 # Maximum time step, microseconds + +DELAY_TIME_MIN_MS = 3 # Minimum delay between pulses, milliseconds +DELAY_TIME_MAX_MS = 10 # Maximum delay between pulses, milliseconds + +# ---- Acceptable voltage tolerances for power rail health check + +VOLT_3V3_MIN = 3.1 +VOLT_3V3_MAX = 3.5 +VOLT_5V_MIN = 4.8 +VOLT_5V_MAX = 5.3 +VOLT_7V_MIN = 6.5 +VOLT_7V_MAX = 7.5 + +# ---- Data buffer limits + +MAX_DATA_POINTS = 1000 # Max stored measurement points +PLOT_POINTS = 100 # Points shown in real-time plots \ No newline at end of file diff --git a/laser_control/controller.py b/laser_control/controller.py new file mode 100644 index 0000000..1e0c6ac --- /dev/null +++ b/laser_control/controller.py @@ -0,0 +1,326 @@ +""" +Main laser controller for the laser control module. + +Provides a high-level API for controlling dual laser systems. +All input parameters are validated before being sent to the device. +Can be embedded in any Python application without GUI dependencies. +""" + +import time +import logging +from typing import Optional, Callable + +from .protocol import Protocol, TaskType as ProtoTaskType +from .validators import ParameterValidator +from .models import ( + ManualModeParams, + VariationParams, + VariationType, + Measurements, + DeviceStatus, + DeviceState, +) +from .exceptions import ( + ValidationError, + CommunicationError, + DeviceNotRespondingError, + DeviceStateError, +) +from .constants import WAIT_AFTER_SEND_SEC + +logger = logging.getLogger(__name__) + +# Default PI regulator coefficients (match firmware defaults) +DEFAULT_PI_P = 1 +DEFAULT_PI_I = 1 + + +class LaserController: + """ + High-level controller for the dual laser board. + + Usage example:: + + ctrl = LaserController(port='/dev/ttyUSB0') + ctrl.connect() + ctrl.set_manual_mode(temp1=25.0, temp2=30.0, + current1=40.0, current2=35.0) + data = ctrl.get_measurements() + print(data.voltage_3v3) + ctrl.disconnect() + + All public methods raise :class:`ValidationError` for bad parameters + and :class:`CommunicationError` for transport-level problems. + """ + + def __init__( + self, + port: Optional[str] = None, + pi_coeff1_p: int = DEFAULT_PI_P, + pi_coeff1_i: int = DEFAULT_PI_I, + pi_coeff2_p: int = DEFAULT_PI_P, + pi_coeff2_i: int = DEFAULT_PI_I, + on_data: Optional[Callable[[Measurements], None]] = None, + ): + """ + Args: + port: Serial port (e.g. '/dev/ttyUSB0'). None = auto-detect. + pi_coeff1_p: Proportional coefficient for laser 1 PI regulator. + pi_coeff1_i: Integral coefficient for laser 1 PI regulator. + pi_coeff2_p: Proportional coefficient for laser 2 PI regulator. + pi_coeff2_i: Integral coefficient for laser 2 PI regulator. + on_data: Optional callback called whenever new measurements + are received. Signature: ``callback(Measurements)``. + """ + self._protocol = Protocol(port) + self._pi1_p = pi_coeff1_p + self._pi1_i = pi_coeff1_i + self._pi2_p = pi_coeff2_p + self._pi2_i = pi_coeff2_i + self._on_data = on_data + self._message_id = 0 + self._last_measurements: Optional[Measurements] = None + + # ---- Connection ------------------------------------------------------- + + def connect(self) -> bool: + """ + Open connection to the device. + + Returns: + True if connection succeeded. + + Raises: + CommunicationError: If the port cannot be opened. + """ + self._protocol.connect() + logger.info("Connected to laser controller on port %s", + self._protocol._port_name or "auto") + return True + + def disconnect(self) -> None: + """Close the serial port gracefully.""" + self._protocol.disconnect() + logger.info("Disconnected from laser controller") + + @property + def is_connected(self) -> bool: + """True if the serial port is open.""" + return self._protocol.is_connected + + # ---- Public API ------------------------------------------------------- + + def set_manual_mode( + self, + temp1: float, + temp2: float, + current1: float, + current2: float, + ) -> None: + """ + Set manual control parameters for both lasers. + + Args: + temp1: Setpoint temperature for laser 1, °C. + Valid range: [15.0 … 40.0] °C. + temp2: Setpoint temperature for laser 2, °C. + Valid range: [15.0 … 40.0] °C. + current1: Drive current for laser 1, mA. + Valid range: [15.0 … 60.0] mA. + current2: Drive current for laser 2, mA. + Valid range: [15.0 … 60.0] mA. + + Raises: + ValidationError: If any parameter is out of range. + CommunicationError: If the command cannot be sent. + """ + validated = ParameterValidator.validate_manual_mode_params( + temp1, temp2, current1, current2 + ) + self._message_id = (self._message_id + 1) & 0xFFFF + + cmd = Protocol.encode_decode_enable( + temp1=validated['temp1'], + temp2=validated['temp2'], + current1=validated['current1'], + current2=validated['current2'], + pi_coeff1_p=self._pi1_p, + pi_coeff1_i=self._pi1_i, + pi_coeff2_p=self._pi2_p, + pi_coeff2_i=self._pi2_i, + message_id=self._message_id, + ) + self._send(cmd) + logger.debug("Manual mode set: T1=%.2f T2=%.2f I1=%.2f I2=%.2f", + validated['temp1'], validated['temp2'], + validated['current1'], validated['current2']) + + def start_variation( + self, + variation_type: VariationType, + params: dict, + ) -> None: + """ + Start a parameter variation task. + + Args: + variation_type: Which parameter to vary + (:class:`VariationType.CHANGE_CURRENT_LD1` or + :class:`VariationType.CHANGE_CURRENT_LD2`). + params: Dictionary with the following keys: + + - ``min_value`` – minimum value of the varied parameter. + - ``max_value`` – maximum value of the varied parameter. + - ``step`` – step size. + - ``time_step`` – discretisation time step, µs [20 … 100]. + - ``delay_time``– delay between pulses, ms [3 … 10]. + - ``static_temp1`` – fixed temperature for laser 1, °C. + - ``static_temp2`` – fixed temperature for laser 2, °C. + - ``static_current1`` – fixed current for laser 1, mA. + - ``static_current2`` – fixed current for laser 2, mA. + + Raises: + ValidationError: If any parameter fails validation. + CommunicationError: If the command cannot be sent. + """ + # Validate variation-specific params + validated = ParameterValidator.validate_variation_params( + params, variation_type + ) + + # Validate static parameters + static_temp1 = ParameterValidator.validate_temperature( + params.get('static_temp1', 25.0), 'static_temp1' + ) + static_temp2 = ParameterValidator.validate_temperature( + params.get('static_temp2', 25.0), 'static_temp2' + ) + static_current1 = ParameterValidator.validate_current( + params.get('static_current1', 30.0), 'static_current1' + ) + static_current2 = ParameterValidator.validate_current( + params.get('static_current2', 30.0), 'static_current2' + ) + + # Map VariationType → protocol TaskType + task_type_map = { + VariationType.CHANGE_CURRENT_LD1: ProtoTaskType.CHANGE_CURRENT_LD1, + VariationType.CHANGE_CURRENT_LD2: ProtoTaskType.CHANGE_CURRENT_LD2, + VariationType.CHANGE_TEMPERATURE_LD1: ProtoTaskType.CHANGE_TEMPERATURE_LD1, + VariationType.CHANGE_TEMPERATURE_LD2: ProtoTaskType.CHANGE_TEMPERATURE_LD2, + } + proto_task = task_type_map[validated['variation_type']] + + cmd = Protocol.encode_task_enable( + task_type=proto_task, + static_temp1=static_temp1, + static_temp2=static_temp2, + static_current1=static_current1, + static_current2=static_current2, + min_value=validated['min_value'], + max_value=validated['max_value'], + step=validated['step'], + time_step=validated['time_step'], + delay_time=validated['delay_time'], + message_id=self._message_id, + pi_coeff1_p=self._pi1_p, + pi_coeff1_i=self._pi1_i, + pi_coeff2_p=self._pi2_p, + pi_coeff2_i=self._pi2_i, + ) + self._send(cmd) + logger.info("Variation task started: type=%s min=%.3f max=%.3f step=%.3f", + validated['variation_type'].name, + validated['min_value'], + validated['max_value'], + validated['step']) + + def stop_task(self) -> None: + """Stop the current task by sending DEFAULT_ENABLE (reset).""" + cmd = Protocol.encode_default_enable() + self._send(cmd) + logger.info("Task stopped (DEFAULT_ENABLE sent)") + + def get_measurements(self) -> Optional[Measurements]: + """ + Request and return the latest measurements from the device. + + Returns: + :class:`Measurements` dataclass, or None if no data available. + + Raises: + CommunicationError: On transport errors. + """ + cmd = Protocol.encode_trans_enable() + self._send(cmd) + + raw = self._protocol.receive_raw(30) + if not raw or len(raw) != 30: + logger.warning("No data received from device") + return None + + response = Protocol.decode_response(raw) + measurements = response.to_measurements() + self._last_measurements = measurements + + if self._on_data: + self._on_data(measurements) + + return measurements + + def get_status(self) -> DeviceStatus: + """ + Request and return the current device status. + + Returns: + :class:`DeviceStatus` with state and latest measurements. + + Raises: + CommunicationError: On transport errors. + """ + cmd = Protocol.encode_state() + self._send(cmd) + + raw = self._protocol.receive_raw(2) + if not raw or len(raw) < 2: + raise DeviceNotRespondingError() + + state_code = Protocol.decode_state(raw) + + # Try to get measurements as well + measurements = self._last_measurements + + return DeviceStatus( + state=DeviceState(state_code) if state_code in DeviceState._value2member_map_ + else DeviceState.ERROR, + measurements=measurements, + is_connected=self.is_connected, + last_command_id=self._message_id, + error_message=Protocol.state_to_description(f"{state_code:04x}") + if state_code != 0 else None, + ) + + def reset(self) -> None: + """Send a hardware reset command to the device.""" + cmd = Protocol.encode_default_enable() + self._send(cmd) + logger.info("Device reset command sent") + + # ---- Internal helpers ------------------------------------------------- + + def _send(self, cmd: bytes) -> None: + """Send command bytes and wait for the device to process.""" + if not self.is_connected: + raise CommunicationError("Not connected to device. Call connect() first.") + self._protocol.send_raw(cmd) + time.sleep(WAIT_AFTER_SEND_SEC) + + # ---- Context manager support ----------------------------------------- + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + return False \ No newline at end of file diff --git a/laser_control/conversions.py b/laser_control/conversions.py new file mode 100644 index 0000000..a0eced7 --- /dev/null +++ b/laser_control/conversions.py @@ -0,0 +1,114 @@ +""" +Physical unit conversions for laser control module. + +Converts between physical quantities (°C, mA, V) and +raw ADC/DAC integer values used by the device firmware. + +All formulas are taken directly from the original device_conversion.py. +""" + +import math +from .constants import ( + VREF, R1, R3, R4, R5, R6, + R7, R8, R9, R10, + RREF, + BETA_INTERNAL, BETA_EXTERNAL, T0_K, R0, + ADC_BITS_16, ADC_BITS_12, + U3V3_COEFF, U5V_COEFF, U7V_COEFF, +) + + +def temp_c_to_n(temp_c: float) -> int: + """ + Convert temperature (°C) to 16-bit DAC integer (Wheatstone bridge setpoint). + + Args: + temp_c: Temperature in degrees Celsius. + + Returns: + Integer in [0, 65535] for the DAC. + """ + rt = R0 * math.exp(BETA_INTERNAL / (temp_c + 273) - BETA_INTERNAL / T0_K) + u = VREF / (R5 * (R3 + R4)) * ( + R1 * R4 * (R5 + R6) - rt * (R3 * R6 - R4 * R5) + ) / (rt + R1) + n = int(u * ADC_BITS_16 / VREF) + n = max(0, min(ADC_BITS_16, n)) + return n + + +def temp_n_to_c(n: int) -> float: + """ + Convert 16-bit ADC integer to temperature (°C). + + Args: + n: Raw ADC value in [0, 65535]. + + Returns: + Temperature in degrees Celsius. + """ + u = n * VREF / ADC_BITS_16 + rt = R1 * (VREF * R4 * (R5 + R6) - u * R5 * (R3 + R4)) / ( + u * R5 * (R3 + R4) + VREF * R3 * R6 - VREF * R4 * R5 + ) + t = 1 / (1 / T0_K + 1 / BETA_INTERNAL * math.log(rt / R0)) - 273 + return t + + +def temp_ext_n_to_c(n: int) -> float: + """ + Convert 12-bit ADC integer to external thermistor temperature (°C). + + Args: + n: Raw 12-bit ADC value in [0, 4095]. + + Returns: + Temperature in degrees Celsius. + """ + u = n * VREF / ADC_BITS_12 * 1 / (1 + 100000 / R10) + VREF * R9 / (R8 + R9) + rt = R7 * u / (VREF - u) + t = 1 / (1 / T0_K + 1 / BETA_EXTERNAL * math.log(rt / R0)) - 273 + return t + + +def current_ma_to_n(current_ma: float) -> int: + """ + Convert laser drive current (mA) to 16-bit DAC integer. + + Args: + current_ma: Current in milliamps. + + Returns: + Integer in [0, 65535] for the DAC. + """ + n = int(ADC_BITS_16 / 2000 * RREF * current_ma) + n = max(0, min(ADC_BITS_16, n)) + return n + + +def current_n_to_ma(n: int) -> float: + """ + Convert raw ADC integer to photodiode current (mA). + + Args: + n: Raw ADC value in [0, 65535]. + + Returns: + Current in milliamps. + """ + return n * 2.5 / (ADC_BITS_16 * 4.4) - 1 / 20.4 + + +def voltage_3v3_n_to_v(n: int) -> float: + """Convert 3.3V rail ADC count to volts.""" + return n * U3V3_COEFF + + +def voltage_5v_n_to_v(n: int) -> float: + """Convert 5V rail ADC count to volts (both 5V1 and 5V2).""" + return n * U5V_COEFF + + +def voltage_7v_n_to_v(n: int) -> float: + """Convert 7V rail ADC count to volts.""" + return n * U7V_COEFF \ No newline at end of file diff --git a/laser_control/example_usage.py b/laser_control/example_usage.py new file mode 100644 index 0000000..a828e25 --- /dev/null +++ b/laser_control/example_usage.py @@ -0,0 +1,107 @@ +""" +Example: how to embed laser_control into any Python application. + +Run: + python3 laser_control/example_usage.py +""" + +import sys +import time +from laser_control import ( + LaserController, + VariationType, + ValidationError, + CommunicationError, +) + + +def example_manual_mode(port: str = None): + """Manual mode: set fixed temperatures and currents.""" + with LaserController(port=port) as ctrl: + try: + ctrl.set_manual_mode( + temp1=25.0, + temp2=30.0, + current1=40.0, + current2=35.0, + ) + print("Manual parameters sent.") + + data = ctrl.get_measurements() + if data: + print(f" Temp1: {data.temp1:.2f} °C") + print(f" Temp2: {data.temp2:.2f} °C") + print(f" I1: {data.current1:.3f} mA") + print(f" I2: {data.current2:.3f} mA") + print(f" 3.3V: {data.voltage_3v3:.3f} V") + print(f" 5V: {data.voltage_5v1:.3f} V") + print(f" 7V: {data.voltage_7v0:.3f} V") + + except ValidationError as e: + print(f"Parameter validation error: {e}") + except CommunicationError as e: + print(f"Communication error: {e}") + + +def example_variation_mode(port: str = None): + """Variation mode: sweep current of laser 1.""" + collected = [] + + def on_measurement(m): + collected.append(m) + print(f" t={m.timestamp.isoformat(timespec='milliseconds')} " + f"I1={m.current1:.3f} mA T1={m.temp1:.2f} °C") + + with LaserController(port=port, on_data=on_measurement) as ctrl: + try: + ctrl.start_variation( + variation_type=VariationType.CHANGE_CURRENT_LD1, + params={ + 'min_value': 20.0, # mA + 'max_value': 50.0, # mA + 'step': 0.5, # mA + 'time_step': 50, # µs + 'delay_time': 5, # ms + 'static_temp1': 25.0, + 'static_temp2': 30.0, + 'static_current1': 35.0, + 'static_current2': 35.0, + } + ) + print("Variation task started. Collecting data for 2 s...") + time.sleep(2) + + ctrl.stop_task() + print(f"Done. Collected {len(collected)} measurements.") + + except ValidationError as e: + print(f"Parameter validation error: {e}") + except CommunicationError as e: + print(f"Communication error: {e}") + + +def example_embed_in_app(): + """ + Minimal embedding pattern for use inside another application. + + The controller can be created once and kept alive for the lifetime + of the host application. No GUI dependency whatsoever. + """ + ctrl = LaserController(port=None) # auto-detect port + try: + ctrl.connect() + except CommunicationError as e: + print(f"Cannot connect: {e}") + return ctrl + + return ctrl # caller owns the controller; call ctrl.disconnect() when done + + +if __name__ == '__main__': + port = sys.argv[1] if len(sys.argv) > 1 else None + + print("=== Manual mode example ===") + example_manual_mode(port) + + print("\n=== Variation mode example ===") + example_variation_mode(port) \ No newline at end of file diff --git a/laser_control/exceptions.py b/laser_control/exceptions.py new file mode 100644 index 0000000..fd3fad3 --- /dev/null +++ b/laser_control/exceptions.py @@ -0,0 +1,139 @@ +""" +Custom exceptions for laser control module. + +Provides a hierarchy of exceptions for different error conditions +that may occur during laser control operations. +""" + + +class LaserControlError(Exception): + """Base exception for all laser control errors.""" + pass + + +class ValidationError(LaserControlError): + """Base exception for validation errors.""" + pass + + +class TemperatureOutOfRangeError(ValidationError): + """Exception raised when temperature is outside valid range.""" + + def __init__(self, param_name: str, value: float, min_val: float, max_val: float): + self.param_name = param_name + self.value = value + self.min_val = min_val + self.max_val = max_val + super().__init__( + f"{param_name}: Temperature {value}°C is out of range " + f"[{min_val}°C - {max_val}°C]" + ) + + +class CurrentOutOfRangeError(ValidationError): + """Exception raised when current is outside valid range.""" + + def __init__(self, param_name: str, value: float, min_val: float, max_val: float): + self.param_name = param_name + self.value = value + self.min_val = min_val + self.max_val = max_val + super().__init__( + f"{param_name}: Current {value}mA is out of range " + f"[{min_val}mA - {max_val}mA]" + ) + + +class InvalidParameterError(ValidationError): + """Exception raised for invalid parameter types or values.""" + + def __init__(self, param_name: str, message: str): + self.param_name = param_name + super().__init__(f"{param_name}: {message}") + + +class CommunicationError(LaserControlError): + """Base exception for communication errors.""" + pass + + +class PortNotFoundError(CommunicationError): + """Exception raised when serial port cannot be found.""" + + def __init__(self, port: str = None): + if port: + message = f"Serial port '{port}' not found" + else: + message = "No suitable serial port found for device connection" + super().__init__(message) + + +class DeviceNotRespondingError(CommunicationError): + """Exception raised when device doesn't respond to commands.""" + + def __init__(self, timeout: float = None): + if timeout: + message = f"Device did not respond within {timeout} seconds" + else: + message = "Device is not responding to commands" + super().__init__(message) + + +class CRCError(CommunicationError): + """Exception raised when CRC check fails.""" + + def __init__(self, expected: int = None, received: int = None): + if expected is not None and received is not None: + message = f"CRC check failed. Expected: 0x{expected:04X}, Received: 0x{received:04X}" + else: + message = "CRC check failed on received data" + super().__init__(message) + + +class ProtocolError(CommunicationError): + """Exception raised for protocol-level errors.""" + + def __init__(self, message: str): + super().__init__(f"Protocol error: {message}") + + +class DeviceError(LaserControlError): + """Base exception for device-level errors.""" + pass + + +class DeviceOverheatingError(DeviceError): + """Exception raised when device reports overheating.""" + + def __init__(self, laser_id: int = None, temperature: float = None): + if laser_id and temperature: + message = f"Laser {laser_id} overheating: {temperature}°C" + else: + message = "Device overheating detected" + super().__init__(message) + + +class PowerSupplyError(DeviceError): + """Exception raised when power supply issues are detected.""" + + def __init__(self, rail: str = None, voltage: float = None, expected: float = None): + if rail and voltage is not None: + if expected: + message = f"Power supply {rail}: {voltage}V (expected ~{expected}V)" + else: + message = f"Power supply {rail}: abnormal voltage {voltage}V" + else: + message = "Power supply error detected" + super().__init__(message) + + +class DeviceStateError(DeviceError): + """Exception raised when device is in an error state.""" + + def __init__(self, state_code: int, state_name: str = None): + self.state_code = state_code + if state_name: + message = f"Device error state: {state_name} (0x{state_code:04X})" + else: + message = f"Device error state: 0x{state_code:04X}" + super().__init__(message) \ No newline at end of file diff --git a/laser_control/models.py b/laser_control/models.py new file mode 100644 index 0000000..df547d5 --- /dev/null +++ b/laser_control/models.py @@ -0,0 +1,219 @@ +""" +Data models for laser control module. + +Provides dataclasses and enums for structured data representation +throughout the laser control system. +""" + +from dataclasses import dataclass +from enum import IntEnum +from typing import Optional, Dict, Any +from datetime import datetime + + +class VariationType(IntEnum): + """Types of parameter variation modes.""" + MANUAL = 0x00 + CHANGE_CURRENT_LD1 = 0x01 + CHANGE_CURRENT_LD2 = 0x02 + CHANGE_TEMPERATURE_LD1 = 0x03 + CHANGE_TEMPERATURE_LD2 = 0x04 + + +class DeviceState(IntEnum): + """Device operational states.""" + IDLE = 0x0000 + RUNNING = 0x0001 + BUSY = 0x0002 + ERROR = 0x00FF + ERROR_OVERHEAT = 0x0100 + ERROR_POWER = 0x0200 + ERROR_COMMUNICATION = 0x0400 + ERROR_INVALID_COMMAND = 0x0800 + + +@dataclass +class ManualModeParams: + """Parameters for manual control mode.""" + temp1: float # Temperature for laser 1 (°C) + temp2: float # Temperature for laser 2 (°C) + current1: float # Current for laser 1 (mA) + current2: float # Current for laser 2 (mA) + pi_coeff1_p: float = 1.0 # PI controller proportional coefficient for laser 1 + pi_coeff1_i: float = 0.5 # PI controller integral coefficient for laser 1 + pi_coeff2_p: float = 1.0 # PI controller proportional coefficient for laser 2 + pi_coeff2_i: float = 0.5 # PI controller integral coefficient for laser 2 + + def to_dict(self) -> Dict[str, float]: + """Convert to dictionary representation.""" + return { + 'temp1': self.temp1, + 'temp2': self.temp2, + 'current1': self.current1, + 'current2': self.current2, + 'pi_coeff1_p': self.pi_coeff1_p, + 'pi_coeff1_i': self.pi_coeff1_i, + 'pi_coeff2_p': self.pi_coeff2_p, + 'pi_coeff2_i': self.pi_coeff2_i + } + + +@dataclass +class VariationParams: + """Parameters for variation mode.""" + variation_type: VariationType + # Static parameters (fixed during variation) + static_temp1: float + static_temp2: float + static_current1: float + static_current2: float + # Variation range + min_value: float # Minimum value for varied parameter + max_value: float # Maximum value for varied parameter + step: float # Step size for variation + # Time parameters + time_step: int # Time step in microseconds (20-100) + delay_time: int # Delay between measurements in milliseconds (3-10) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + 'variation_type': self.variation_type.value, + 'static_temp1': self.static_temp1, + 'static_temp2': self.static_temp2, + 'static_current1': self.static_current1, + 'static_current2': self.static_current2, + 'min_value': self.min_value, + 'max_value': self.max_value, + 'step': self.step, + 'time_step': self.time_step, + 'delay_time': self.delay_time + } + + +@dataclass +class Measurements: + """Real-time measurements from the device.""" + # Photodiode currents + current1: float # Photodiode current for laser 1 (mA) + current2: float # Photodiode current for laser 2 (mA) + # Temperatures + temp1: float # Temperature of laser 1 (°C) + temp2: float # Temperature of laser 2 (°C) + temp_ext1: Optional[float] = None # External thermistor 1 temperature (°C) + temp_ext2: Optional[float] = None # External thermistor 2 temperature (°C) + # Power supply voltages + voltage_3v3: float = 0.0 # 3.3V rail voltage + voltage_5v1: float = 0.0 # 5V rail 1 voltage + voltage_5v2: float = 0.0 # 5V rail 2 voltage + voltage_7v0: float = 0.0 # 7V rail voltage + # Metadata + timestamp: Optional[datetime] = None + message_id: Optional[int] = None + to6_counter_lsb: Optional[int] = None + to6_counter_msb: Optional[int] = None + + def __post_init__(self): + """Set timestamp if not provided.""" + if self.timestamp is None: + self.timestamp = datetime.now() + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + 'current1': self.current1, + 'current2': self.current2, + 'temp1': self.temp1, + 'temp2': self.temp2, + 'temp_ext1': self.temp_ext1, + 'temp_ext2': self.temp_ext2, + 'voltage_3v3': self.voltage_3v3, + 'voltage_5v1': self.voltage_5v1, + 'voltage_5v2': self.voltage_5v2, + 'voltage_7v0': self.voltage_7v0, + 'timestamp': self.timestamp.isoformat() if self.timestamp else None, + 'message_id': self.message_id + } + + def check_power_rails(self) -> Dict[str, bool]: + """Check if power supply voltages are within acceptable range.""" + return { + '3v3': 3.1 <= self.voltage_3v3 <= 3.5, + '5v1': 4.8 <= self.voltage_5v1 <= 5.3, + '5v2': 4.8 <= self.voltage_5v2 <= 5.3, + '7v0': 6.5 <= self.voltage_7v0 <= 7.5 + } + + +@dataclass +class DeviceStatus: + """Complete device status information.""" + state: DeviceState + measurements: Optional[Measurements] = None + is_connected: bool = False + last_command_id: Optional[int] = None + error_message: Optional[str] = None + + @property + def is_idle(self) -> bool: + """Check if device is idle.""" + return self.state == DeviceState.IDLE + + @property + def is_running(self) -> bool: + """Check if device is running a task.""" + return self.state == DeviceState.RUNNING + + @property + def has_error(self) -> bool: + """Check if device has any error.""" + return self.state >= DeviceState.ERROR + + @property + def error_type(self) -> Optional[str]: + """Get human-readable error type.""" + if not self.has_error: + return None + + error_map = { + DeviceState.ERROR_OVERHEAT: "Overheating", + DeviceState.ERROR_POWER: "Power supply issue", + DeviceState.ERROR_COMMUNICATION: "Communication error", + DeviceState.ERROR_INVALID_COMMAND: "Invalid command" + } + return error_map.get(self.state, "Unknown error") + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + 'state': self.state.value, + 'state_name': self.state.name, + 'measurements': self.measurements.to_dict() if self.measurements else None, + 'is_connected': self.is_connected, + 'last_command_id': self.last_command_id, + 'error_message': self.error_message, + 'is_idle': self.is_idle, + 'is_running': self.is_running, + 'has_error': self.has_error, + 'error_type': self.error_type + } + + +@dataclass +class CalibrationData: + """Calibration data for device sensors.""" + # Temperature calibration coefficients + temp1_offset: float = 0.0 + temp1_scale: float = 1.0 + temp2_offset: float = 0.0 + temp2_scale: float = 1.0 + # Current calibration coefficients + current1_offset: float = 0.0 + current1_scale: float = 1.0 + current2_offset: float = 0.0 + current2_scale: float = 1.0 + # Voltage calibration + voltage_3v3_scale: float = 1.0 + voltage_5v1_scale: float = 1.0 + voltage_5v2_scale: float = 1.0 + voltage_7v0_scale: float = 1.0 \ No newline at end of file diff --git a/laser_control/protocol.py b/laser_control/protocol.py new file mode 100644 index 0000000..262826a --- /dev/null +++ b/laser_control/protocol.py @@ -0,0 +1,451 @@ +""" +Communication protocol for laser control module. + +Encodes commands to bytes and decodes device responses. +Faithful re-implementation of the logic in device_commands.py, +refactored into a clean, testable class-based API. +""" + +import struct +from typing import Optional +from enum import IntEnum +from datetime import datetime + +import serial +import serial.tools.list_ports + +from .constants import ( + BAUDRATE, SERIAL_TIMEOUT_SEC, + GET_DATA_TOTAL_LENGTH, + SEND_PARAMS_TOTAL_LENGTH, + TASK_ENABLE_COMMAND_LENGTH, + CMD_DECODE_ENABLE, CMD_DEFAULT_ENABLE, + CMD_TRANS_ENABLE, CMD_REMOVE_FILE, + CMD_STATE, CMD_TASK_ENABLE, + STATE_DESCRIPTIONS, STATE_OK, +) +from .conversions import ( + temp_c_to_n, temp_n_to_c, + temp_ext_n_to_c, + current_ma_to_n, current_n_to_ma, + voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v, +) +from .models import Measurements, VariationType +from .exceptions import ( + CommunicationError, + PortNotFoundError, + CRCError, + ProtocolError, +) + + +# Re-export enums so tests can import from protocol module +class CommandCode(IntEnum): + DECODE_ENABLE = CMD_DECODE_ENABLE + DEFAULT_ENABLE = CMD_DEFAULT_ENABLE + TRANS_ENABLE = CMD_TRANS_ENABLE + REMOVE_FILE = CMD_REMOVE_FILE + STATE = CMD_STATE + TASK_ENABLE = CMD_TASK_ENABLE + + +class TaskType(IntEnum): + MANUAL = 0x00 + CHANGE_CURRENT_LD1 = 0x01 + CHANGE_CURRENT_LD2 = 0x02 + CHANGE_TEMPERATURE_LD1 = 0x03 + CHANGE_TEMPERATURE_LD2 = 0x04 + + +class DeviceState(IntEnum): + IDLE = 0x0000 + RUNNING = 0x0001 + BUSY = 0x0002 + ERROR = 0x00FF + ERROR_OVERHEAT = 0x0100 + ERROR_POWER = 0x0200 + ERROR_COMMUNICATION = 0x0400 + ERROR_INVALID_COMMAND = 0x0800 + + +# ---- Low-level helpers -------------------------------------------------- + +def _int_to_hex4(value: int) -> str: + """Return 4-character lowercase hex string (0–65535).""" + if value < 0 or value > 65535: + raise ValueError(f"Value {value} out of uint16 range [0, 65535]") + return f"{value:04x}" + + +def _flipfour(s: str) -> str: + """Swap two byte-pairs: 'aabb' → 'bbaa' (little-endian word).""" + if len(s) != 4: + raise ValueError(f"Expected 4-char hex string, got '{s}'") + return s[2:4] + s[0:2] + + +def _xor_crc(words: list) -> str: + """XOR all 16-bit hex words and return 4-char hex CRC.""" + result = int(words[0], 16) + for w in words[1:]: + result ^= int(w, 16) + return _int_to_hex4(result) + + +def _build_crc(data_hex: str) -> str: + """Calculate XOR CRC over words 1..N of a hex string (skip word 0).""" + words = [data_hex[i:i+4] for i in range(0, len(data_hex), 4)] + return _xor_crc(words[1:]) + + +def _encode_setup() -> str: + """Build the 16-bit setup word (all subsystems enabled, SD save off).""" + bits = ['0'] * 16 + bits[15] = '1' # enable work + bits[14] = '1' # enable 5v1 + bits[13] = '1' # enable 5v2 + bits[12] = '1' # enable LD1 + bits[11] = '1' # enable LD2 + bits[10] = '1' # enable REF1 + bits[9] = '1' # enable REF2 + bits[8] = '1' # enable TEC1 + bits[7] = '1' # enable TEC2 + bits[6] = '1' # enable temp stab 1 + bits[5] = '1' # enable temp stab 2 + bits[4] = '0' # enable sd save (disabled) + bits[3] = '1' # enable PI1 coef read + bits[2] = '1' # enable PI2 coef read + bits[1] = '0' # reserved + bits[0] = '0' # reserved + return f"{int(''.join(bits), 2):04x}" + + +# ---- Response dataclass -------------------------------------------------- + +class Response: + """Decoded device DATA response.""" + __slots__ = [ + 'current1', 'current2', + 'temp1', 'temp2', + 'temp_ext1', 'temp_ext2', + 'voltage_3v3', 'voltage_5v1', 'voltage_5v2', 'voltage_7v0', + 'to6_lsb', 'to6_msb', + 'message_id', + 'header', + ] + + def to_measurements(self) -> Measurements: + return Measurements( + current1=self.current1, + current2=self.current2, + temp1=self.temp1, + temp2=self.temp2, + temp_ext1=self.temp_ext1, + temp_ext2=self.temp_ext2, + voltage_3v3=self.voltage_3v3, + voltage_5v1=self.voltage_5v1, + voltage_5v2=self.voltage_5v2, + voltage_7v0=self.voltage_7v0, + timestamp=datetime.now(), + message_id=self.message_id, + to6_counter_lsb=self.to6_lsb, + to6_counter_msb=self.to6_msb, + ) + + +# ---- Message builder -------------------------------------------------- + +class Message: + """Named container for an encoded command byte array.""" + def __init__(self, data: bytearray): + self._data = data + + def to_bytes(self) -> bytes: + return bytes(self._data) + + def __len__(self): + return len(self._data) + + +# ---- Protocol class -------------------------------------------------- + +class Protocol: + """ + Encodes commands and decodes responses for the laser control board. + + Can also manage a serial port connection when port is provided. + """ + + def __init__(self, port: Optional[str] = None): + self._port_name = port + self._serial: Optional[serial.Serial] = None + + # ---- Connection management + + def connect(self) -> None: + """Open the serial port. Auto-detects if port is None.""" + port = self._port_name or self._detect_port() + try: + self._serial = serial.Serial( + port=port, + baudrate=BAUDRATE, + timeout=SERIAL_TIMEOUT_SEC, + ) + except Exception as exc: + raise CommunicationError( + f"Cannot connect to port '{port}': {exc}" + ) from exc + + def disconnect(self) -> None: + """Close the serial port if open.""" + if self._serial and self._serial.is_open: + self._serial.close() + + @property + def is_connected(self) -> bool: + return self._serial is not None and self._serial.is_open + + def _detect_port(self) -> str: + """Return first available serial port device path.""" + ports = list(serial.tools.list_ports.comports()) + if not ports: + raise PortNotFoundError() + return ports[0].device + + # ---- Raw I/O + + def send_raw(self, data: bytes) -> None: + self._serial.write(data) + + def receive_raw(self, length: int) -> bytes: + return self._serial.read(length) + + # ---- Static encoding helpers (no connection required) --------------- + + @staticmethod + def flipfour(value: int) -> int: + """Byte-swap a 16-bit integer (little-endian word swap).""" + return ((value & 0xFF) << 8) | ((value >> 8) & 0xFF) + + @staticmethod + def pack_float(value: float) -> bytes: + return struct.pack(' bytes: + return struct.pack(' int: + """ + XOR CRC over all 16-bit words except the last two bytes (CRC field). + Mirrors the original CalculateCRC logic. + """ + hex_str = data.hex() + words = [hex_str[i:i+4] for i in range(0, len(hex_str), 4)] + # Skip word 0 (command code) per original firmware expectation + crc_words = words[1:] + result = int(crc_words[0], 16) + for w in crc_words[1:]: + result ^= int(w, 16) + return result + + # ---- Command encoders ----------------------------------------------- + + @staticmethod + def encode_decode_enable( + temp1: float, + temp2: float, + current1: float, + current2: float, + pi_coeff1_p: int, + pi_coeff1_i: int, + pi_coeff2_p: int, + pi_coeff2_i: int, + message_id: int, + ) -> bytes: + """ + Build DECODE_ENABLE command (0x1111). + + Sets temperature and current setpoints for both lasers. + Returns 30-byte bytearray. + """ + if current1 < 0 or current2 < 0: + raise ValueError("Current values must not be negative") + + data = _flipfour(_int_to_hex4(CMD_DECODE_ENABLE)) # Word 0 + data += _flipfour(_encode_setup()) # Word 1 + data += _flipfour(_int_to_hex4(temp_c_to_n(temp1))) # Word 2 + data += _flipfour(_int_to_hex4(temp_c_to_n(temp2))) # Word 3 + data += _flipfour('0000') * 3 # Words 4-6 + data += _flipfour(_int_to_hex4(pi_coeff1_p)) # Word 7 + data += _flipfour(_int_to_hex4(pi_coeff1_i)) # Word 8 + data += _flipfour(_int_to_hex4(pi_coeff2_p)) # Word 9 + data += _flipfour(_int_to_hex4(pi_coeff2_i)) # Word 10 + data += _flipfour(_int_to_hex4(message_id & 0xFFFF)) # Word 11 + data += _flipfour(_int_to_hex4(current_ma_to_n(current1))) # Word 12 + data += _flipfour(_int_to_hex4(current_ma_to_n(current2))) # Word 13 + data += _build_crc(data) # Word 14 + + result = bytearray.fromhex(data) + assert len(result) == SEND_PARAMS_TOTAL_LENGTH, \ + f"DECODE_ENABLE length mismatch: {len(result)}" + return bytes(result) + + @staticmethod + def encode_task_enable( + task_type: TaskType, + static_temp1: float, + static_temp2: float, + static_current1: float, + static_current2: float, + min_value: float, + max_value: float, + step: float, + time_step: int, + delay_time: int, + message_id: int, + pi_coeff1_p: int = 1, + pi_coeff1_i: int = 1, + pi_coeff2_p: int = 1, + pi_coeff2_i: int = 1, + ) -> bytes: + """ + Build TASK_ENABLE command (0x7777). + + Starts a measurement task (current or temperature variation). + Returns 32-byte bytearray. + """ + if not isinstance(task_type, TaskType): + try: + task_type = TaskType(task_type) + except ValueError: + raise ValueError(f"Invalid task_type: {task_type}") + + data = _flipfour(_int_to_hex4(CMD_TASK_ENABLE)) # Word 0 + data += _flipfour(_encode_setup()) # Word 1 + data += _flipfour(_int_to_hex4(task_type.value)) # Word 2 + + match task_type: + case TaskType.CHANGE_CURRENT_LD1: + data += _flipfour(_int_to_hex4(current_ma_to_n(min_value))) # Word 3 + data += _flipfour(_int_to_hex4(current_ma_to_n(max_value))) # Word 4 + data += _flipfour(_int_to_hex4(current_ma_to_n(step))) # Word 5 + data += _flipfour(_int_to_hex4(int(delay_time * 100))) # Word 6 + data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp1))) # Word 7 + data += _flipfour(_int_to_hex4(current_ma_to_n(static_current2)))# Word 8 + data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp2))) # Word 9 + case TaskType.CHANGE_CURRENT_LD2: + data += _flipfour(_int_to_hex4(current_ma_to_n(min_value))) # Word 3 + data += _flipfour(_int_to_hex4(current_ma_to_n(max_value))) # Word 4 + data += _flipfour(_int_to_hex4(int(step * 100))) # Word 5 + data += _flipfour(_int_to_hex4(int(delay_time * 100))) # Word 6 + data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp2))) # Word 7 + data += _flipfour(_int_to_hex4(current_ma_to_n(static_current1)))# Word 8 + data += _flipfour(_int_to_hex4(temp_c_to_n(static_temp1))) # Word 9 + case TaskType.CHANGE_TEMPERATURE_LD1 | TaskType.CHANGE_TEMPERATURE_LD2: + raise NotImplementedError("Temperature variation is not yet implemented in firmware") + case _: + raise ValueError(f"Unsupported task type: {task_type}") + + data += _flipfour(_int_to_hex4(time_step)) # Word 10 + data += _flipfour(_int_to_hex4(pi_coeff1_p)) # Word 11 + data += _flipfour(_int_to_hex4(pi_coeff1_i)) # Word 12 + data += _flipfour(_int_to_hex4(pi_coeff2_p)) # Word 13 + data += _flipfour(_int_to_hex4(pi_coeff2_i)) # Word 14 + data += _build_crc(data) # Word 15 + + result = bytearray.fromhex(data) + assert len(result) == TASK_ENABLE_COMMAND_LENGTH, \ + f"TASK_ENABLE length mismatch: {len(result)}" + return bytes(result) + + @staticmethod + def encode_trans_enable(message_id: int = 0) -> bytes: + """Build TRANS_ENABLE command (0x4444) — request last data.""" + return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_TRANS_ENABLE))) + + @staticmethod + def encode_state(message_id: int = 0) -> bytes: + """Build STATE command (0x6666) — request device state.""" + return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_STATE))) + + @staticmethod + def encode_default_enable(message_id: int = 0) -> bytes: + """Build DEFAULT_ENABLE command (0x2222) — reset device.""" + return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_DEFAULT_ENABLE))) + + @staticmethod + def encode_remove_file() -> bytes: + """Build REMOVE_FILE command (0x5555) — delete saved data.""" + return bytearray.fromhex(_flipfour(_int_to_hex4(CMD_REMOVE_FILE))) + + # ---- Response decoders ----------------------------------------------- + + @staticmethod + def decode_response(data: bytes) -> Response: + """ + Decode a 30-byte DATA response from the device. + + Raises: + ProtocolError: If data length is wrong. + CRCError: If CRC check fails. + """ + if len(data) != GET_DATA_TOTAL_LENGTH: + raise ProtocolError( + f"Expected {GET_DATA_TOTAL_LENGTH} bytes, got {len(data)} bytes" + ) + + hex_str = data.hex() + + def get_word(num: int) -> str: + return _flipfour(hex_str[num*4: num*4+4]) + + def get_int_word(num: int) -> int: + return int(get_word(num), 16) + + # CRC check: XOR over words 1..13, compare with word 14 + crc_words = [hex_str[i:i+4] for i in range(4, len(hex_str)-4, 4)] + computed = int(crc_words[0], 16) + for w in crc_words[1:]: + computed ^= int(w, 16) + stored = get_int_word(14) + if computed != stored: + raise CRCError(expected=computed, received=stored) + + resp = Response() + resp.header = get_word(0) + resp.current1 = current_n_to_ma(get_int_word(1)) + resp.current2 = current_n_to_ma(get_int_word(2)) + resp.to6_lsb = get_int_word(3) + resp.to6_msb = get_int_word(4) + resp.temp1 = temp_n_to_c(get_int_word(5)) + resp.temp2 = temp_n_to_c(get_int_word(6)) + resp.temp_ext1 = temp_ext_n_to_c(get_int_word(7)) + resp.temp_ext2 = temp_ext_n_to_c(get_int_word(8)) + resp.voltage_3v3 = voltage_3v3_n_to_v(get_int_word(9)) + resp.voltage_5v1 = voltage_5v_n_to_v(get_int_word(10)) + resp.voltage_5v2 = voltage_5v_n_to_v(get_int_word(11)) + resp.voltage_7v0 = voltage_7v_n_to_v(get_int_word(12)) + resp.message_id = get_int_word(13) + + return resp + + @staticmethod + def decode_state(data: bytes) -> int: + """ + Decode a 2-byte STATE response from the device. + + Returns: + Integer state code (compare with DeviceState enum). + """ + if len(data) < 2: + raise ProtocolError(f"STATE response too short: {len(data)} bytes") + hex_str = data.hex() + state_hex = _flipfour(hex_str[0:4]) + return int(state_hex, 16) + + @staticmethod + def state_to_description(state_hex_str: str) -> str: + """Return human-readable description for a state hex string.""" + return STATE_DESCRIPTIONS.get(state_hex_str, "Unknown or reserved error.") \ No newline at end of file diff --git a/laser_control/validators.py b/laser_control/validators.py new file mode 100644 index 0000000..98aa288 --- /dev/null +++ b/laser_control/validators.py @@ -0,0 +1,257 @@ +""" +Parameter validation for laser control module. + +Validates all input parameters against physical constraints +and protocol limits before sending to device. +""" + +import math +from typing import Dict, Any, Tuple + +from .constants import ( + TEMP_MIN_C, TEMP_MAX_C, + CURRENT_MIN_MA, CURRENT_MAX_MA, + CURRENT_STEP_MIN_MA, CURRENT_STEP_MAX_MA, + TEMP_STEP_MIN_C, TEMP_STEP_MAX_C, + TIME_STEP_MIN_US, TIME_STEP_MAX_US, + DELAY_TIME_MIN_MS, DELAY_TIME_MAX_MS, +) +from .exceptions import ( + ValidationError, + TemperatureOutOfRangeError, + CurrentOutOfRangeError, + InvalidParameterError, +) +from .models import VariationType + + +class ParameterValidator: + """Validates all input parameters for the laser controller.""" + + @staticmethod + def _check_numeric(value: Any, param_name: str) -> float: + """Check that value is a valid finite number. Returns float.""" + if value is None: + raise InvalidParameterError(param_name, "Value must not be None") + if not isinstance(value, (int, float)): + raise InvalidParameterError(param_name, "Value must be a number") + if math.isnan(value): + raise InvalidParameterError(param_name, "Value must not be NaN") + if math.isinf(value): + raise InvalidParameterError(param_name, "Value must not be infinite") + return float(value) + + @staticmethod + def validate_temperature(value: Any, param_name: str) -> float: + """ + Validate a laser temperature value. + + Args: + value: Temperature in °C. + param_name: Parameter name for error messages. + + Returns: + Validated temperature as float. + + Raises: + InvalidParameterError: If value is not a valid number. + TemperatureOutOfRangeError: If value is outside [TEMP_MIN_C, TEMP_MAX_C]. + """ + value = ParameterValidator._check_numeric(value, param_name) + if value < TEMP_MIN_C or value > TEMP_MAX_C: + raise TemperatureOutOfRangeError( + param_name, value, TEMP_MIN_C, TEMP_MAX_C + ) + return value + + @staticmethod + def validate_current(value: Any, param_name: str) -> float: + """ + Validate a laser drive current value. + + Args: + value: Current in mA. + param_name: Parameter name for error messages. + + Returns: + Validated current as float. + + Raises: + InvalidParameterError: If value is not a valid number. + CurrentOutOfRangeError: If value is outside [CURRENT_MIN_MA, CURRENT_MAX_MA]. + """ + value = ParameterValidator._check_numeric(value, param_name) + if value < CURRENT_MIN_MA or value > CURRENT_MAX_MA: + raise CurrentOutOfRangeError( + param_name, value, CURRENT_MIN_MA, CURRENT_MAX_MA + ) + return value + + @staticmethod + def validate_time_params(time_step: Any, delay_time: Any) -> Tuple[int, int]: + """ + Validate time parameters for variation mode. + + Args: + time_step: Discretisation time step in microseconds. + delay_time: Delay between pulses in milliseconds. + + Returns: + Tuple (time_step, delay_time) as integers. + + Raises: + InvalidParameterError: If values are not numeric. + ValidationError: If values are outside allowed ranges. + """ + if not isinstance(time_step, (int, float)): + raise InvalidParameterError("time_step", "Value must be a number") + if not isinstance(delay_time, (int, float)): + raise InvalidParameterError("delay_time", "Value must be a number") + + time_step_int = int(time_step) + delay_time_int = int(delay_time) + + if time_step_int < TIME_STEP_MIN_US or time_step_int > TIME_STEP_MAX_US: + raise ValidationError( + f"time step {time_step_int} µs is out of range " + f"[{TIME_STEP_MIN_US} - {TIME_STEP_MAX_US}] µs" + ) + if delay_time_int < DELAY_TIME_MIN_MS or delay_time_int > DELAY_TIME_MAX_MS: + raise ValidationError( + f"delay time {delay_time_int} ms is out of range " + f"[{DELAY_TIME_MIN_MS} - {DELAY_TIME_MAX_MS}] ms" + ) + return time_step_int, delay_time_int + + @staticmethod + def validate_variation_params( + params: Dict[str, Any], + variation_type: Any + ) -> Dict[str, Any]: + """ + Validate parameters for variation mode. + + Args: + params: Dictionary with keys: + min_value, max_value, step, time_step, delay_time. + variation_type: A VariationType enum value. + + Returns: + Dictionary with validated and type-coerced values. + + Raises: + ValidationError: For any constraint violation. + InvalidParameterError: For wrong types. + """ + # Validate variation type + if not isinstance(variation_type, VariationType): + try: + variation_type = VariationType(variation_type) + except (ValueError, KeyError): + raise ValidationError( + f"Invalid variation type '{variation_type}'. " + f"Must be one of {[e.name for e in VariationType]}" + ) + + # Check required keys + required_keys = {'min_value', 'max_value', 'step', 'time_step', 'delay_time'} + missing = required_keys - params.keys() + if missing: + raise ValidationError( + f"Missing required parameters: {sorted(missing)}" + ) + + # Validate min/max + min_val = ParameterValidator._check_numeric(params['min_value'], 'min_value') + max_val = ParameterValidator._check_numeric(params['max_value'], 'max_value') + + if min_val >= max_val: + raise ValidationError( + f"min_value ({min_val}) must be less than max_value ({max_val})" + ) + + # Validate step based on variation type + step = ParameterValidator._check_numeric(params['step'], 'step') + + is_current_variation = variation_type in ( + VariationType.CHANGE_CURRENT_LD1, + VariationType.CHANGE_CURRENT_LD2 + ) + is_temp_variation = variation_type in ( + VariationType.CHANGE_TEMPERATURE_LD1, + VariationType.CHANGE_TEMPERATURE_LD2 + ) + + if is_current_variation: + step_min, step_max = CURRENT_STEP_MIN_MA, CURRENT_STEP_MAX_MA + unit = "mA" + # Also validate range against current limits + ParameterValidator.validate_current(min_val, 'min_value') + ParameterValidator.validate_current(max_val, 'max_value') + elif is_temp_variation: + step_min, step_max = TEMP_STEP_MIN_C, TEMP_STEP_MAX_C + unit = "°C" + # Also validate range against temperature limits + ParameterValidator.validate_temperature(min_val, 'min_value') + ParameterValidator.validate_temperature(max_val, 'max_value') + else: + raise ValidationError( + f"Variation type {variation_type.name} cannot be used in variation mode" + ) + + if step <= 0: + raise ValidationError( + f"step must be positive, got {step} {unit}" + ) + if step < step_min: + raise ValidationError( + f"step {step} {unit} is too small (minimum {step_min} {unit})" + ) + if step > step_max: + raise ValidationError( + f"step {step} {unit} is too large (maximum {step_max} {unit})" + ) + + # Validate time parameters + time_step, delay_time = ParameterValidator.validate_time_params( + params['time_step'], params['delay_time'] + ) + + return { + 'variation_type': variation_type, + 'min_value': min_val, + 'max_value': max_val, + 'step': step, + 'time_step': time_step, + 'delay_time': delay_time, + } + + @staticmethod + def validate_manual_mode_params( + temp1: Any, + temp2: Any, + current1: Any, + current2: Any, + ) -> Dict[str, float]: + """ + Validate all four manual mode parameters. + + Args: + temp1: Laser 1 temperature, °C. + temp2: Laser 2 temperature, °C. + current1: Laser 1 current, mA. + current2: Laser 2 current, mA. + + Returns: + Dict with validated floats: temp1, temp2, current1, current2. + + Raises: + ValidationError: For any out-of-range value. + InvalidParameterError: For wrong types. + """ + return { + 'temp1': ParameterValidator.validate_temperature(temp1, 'temp1'), + 'temp2': ParameterValidator.validate_temperature(temp2, 'temp2'), + 'current1': ParameterValidator.validate_current(current1, 'current1'), + 'current2': ParameterValidator.validate_current(current2, 'current2'), + } \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2527de3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,109 @@ +""" +Shared fixtures for laser_control tests. +""" + +import pytest +import struct +from unittest.mock import MagicMock, patch +from laser_control.protocol import Protocol, _build_crc, _flipfour, _int_to_hex4 +from laser_control.controller import LaserController +from laser_control.conversions import ( + current_n_to_ma, temp_n_to_c, temp_ext_n_to_c, + voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v, +) + + +def make_valid_response( + current1_n: int = 10000, + current2_n: int = 12000, + temp1_n: int = 30000, + temp2_n: int = 32000, + temp_ext1_n: int = 2048, + temp_ext2_n: int = 2048, + mon_3v3_n: int = 2703, # ~3.3V + mon_5v1_n: int = 2731, # ~5.0V + mon_5v2_n: int = 2731, + mon_7v0_n: int = 1042, # ~7.0V + message_id: int = 12345, +) -> bytes: + """ + Build a syntactically valid 30-byte DATA response. + + Words (each 2 bytes, little-endian via flipfour): + 0 header + 1 I1 + 2 I2 + 3 TO6_LSB + 4 TO6_MSB + 5 Temp_1 + 6 Temp_2 + 7 Temp_Ext_1 + 8 Temp_Ext_2 + 9 MON_3V3 + 10 MON_5V1 + 11 MON_5V2 + 12 MON_7V0 + 13 Message_ID + 14 CRC + """ + words_raw = [ + 0xABCD, # Word 0 header + current1_n, # Word 1 + current2_n, # Word 2 + 0, # Word 3 TO6_LSB + 0, # Word 4 TO6_MSB + temp1_n, # Word 5 + temp2_n, # Word 6 + temp_ext1_n, # Word 7 + temp_ext2_n, # Word 8 + mon_3v3_n, # Word 9 + mon_5v1_n, # Word 10 + mon_5v2_n, # Word 11 + mon_7v0_n, # Word 12 + message_id, # Word 13 + 0, # Word 14 CRC placeholder + ] + + # Build hex string with flipfour applied + hex_str = "" + for w in words_raw: + hex_str += _flipfour(_int_to_hex4(w)) + + # Compute CRC over words 1..13 (indices 4..55 in hex, i.e. skip word 0) + words_hex = [hex_str[i:i+4] for i in range(0, len(hex_str), 4)] + crc_words = words_hex[1:14] # words 1..13 + crc_val = int(crc_words[0], 16) + for w in crc_words[1:]: + crc_val ^= int(w, 16) + + # Replace CRC word + hex_str = hex_str[:56] + _flipfour(_int_to_hex4(crc_val)) + return bytes.fromhex(hex_str) + + +@pytest.fixture +def valid_response_bytes(): + """Pre-built valid 30-byte device response.""" + return make_valid_response() + + +@pytest.fixture +def mock_serial(): + """Mock serial.Serial object.""" + with patch('serial.Serial') as mock_cls: + mock_instance = MagicMock() + mock_instance.is_open = True + mock_cls.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def connected_controller(mock_serial): + """LaserController with mocked serial connection.""" + mock_serial.read.return_value = make_valid_response() + + ctrl = LaserController(port='/dev/ttyUSB0') + with patch('serial.Serial', return_value=mock_serial): + ctrl._protocol._serial = mock_serial + mock_serial.is_open = True + return ctrl \ No newline at end of file diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..1a92731 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,294 @@ +""" +Integration tests for the laser control module. + +Tests the full call chain: LaserController → Protocol → Serial, +using mock serial ports. No real hardware required. +""" + +import pytest +import time +from unittest.mock import MagicMock, patch, call +from laser_control.controller import LaserController +from laser_control.models import VariationType, DeviceState +from laser_control.exceptions import ( + ValidationError, + CommunicationError, + TemperatureOutOfRangeError, + CurrentOutOfRangeError, +) +from laser_control.protocol import Protocol, CommandCode +from .conftest import make_valid_response + + +class TestManualModeIntegration: + """Integration tests for manual mode operation.""" + + def test_full_manual_mode_flow(self, connected_controller, mock_serial): + """Test complete manual mode command flow.""" + connected_controller.set_manual_mode( + temp1=25.0, temp2=30.0, + current1=40.0, current2=35.0 + ) + + # Verify command was sent + assert mock_serial.write.called + sent_data = mock_serial.write.call_args[0][0] + assert len(sent_data) == 30 # SEND_PARAMS_TOTAL_LENGTH + + # Verify command code (bytes 0-1, little-endian 0x1111 → 0x11 0x11) + assert sent_data[0] == 0x11 + assert sent_data[1] == 0x11 + + def test_manual_mode_validation_rejects_invalid_temp(self, connected_controller): + """Test that manual mode rejects out-of-range temperature.""" + with pytest.raises(TemperatureOutOfRangeError) as exc_info: + connected_controller.set_manual_mode( + temp1=50.0, # Too high + temp2=30.0, + current1=40.0, + current2=35.0 + ) + assert "temp1" in str(exc_info.value) + assert "50.0" in str(exc_info.value) + + def test_manual_mode_validation_rejects_invalid_current(self, connected_controller): + """Test that manual mode rejects out-of-range current.""" + with pytest.raises(CurrentOutOfRangeError) as exc_info: + connected_controller.set_manual_mode( + temp1=25.0, + temp2=30.0, + current1=40.0, + current2=70.0 # Too high + ) + assert "current2" in str(exc_info.value) + + def test_manual_mode_no_serial_call_on_validation_failure( + self, connected_controller, mock_serial + ): + """Serial write must not be called when validation fails.""" + mock_serial.write.reset_mock() + with pytest.raises(ValidationError): + connected_controller.set_manual_mode( + temp1=5.0, # Invalid + temp2=30.0, + current1=40.0, + current2=35.0 + ) + mock_serial.write.assert_not_called() + + def test_message_id_increments(self, connected_controller, mock_serial): + """Message ID should increment with each command.""" + initial_id = connected_controller._message_id + connected_controller.set_manual_mode(25.0, 30.0, 40.0, 35.0) + assert connected_controller._message_id == (initial_id + 1) & 0xFFFF + + connected_controller.set_manual_mode(26.0, 31.0, 41.0, 36.0) + assert connected_controller._message_id == (initial_id + 2) & 0xFFFF + + +class TestVariationModeIntegration: + """Integration tests for variation mode operation.""" + + def test_current_ld1_variation_flow(self, connected_controller, mock_serial): + """Test complete current variation for laser 1.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.5, + 'time_step': 50, + 'delay_time': 5, + 'static_temp1': 25.0, + 'static_temp2': 30.0, + 'static_current1': 35.0, + 'static_current2': 35.0, + } + connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params) + + assert mock_serial.write.called + sent_data = mock_serial.write.call_args[0][0] + assert len(sent_data) == 32 # TASK_ENABLE_COMMAND_LENGTH + + # Verify command code (0x7777) + assert sent_data[0] == 0x77 + assert sent_data[1] == 0x77 + + def test_current_ld2_variation_flow(self, connected_controller, mock_serial): + """Test complete current variation for laser 2.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.5, + 'time_step': 50, + 'delay_time': 5, + 'static_temp1': 25.0, + 'static_temp2': 30.0, + 'static_current1': 35.0, + 'static_current2': 35.0, + } + connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD2, params) + assert mock_serial.write.called + + def test_variation_rejects_invalid_step(self, connected_controller, mock_serial): + """Variation must reject step below minimum.""" + mock_serial.write.reset_mock() + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.001, # Too small + 'time_step': 50, + 'delay_time': 5, + 'static_temp1': 25.0, + 'static_temp2': 30.0, + 'static_current1': 35.0, + 'static_current2': 35.0, + } + with pytest.raises(ValidationError): + connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params) + mock_serial.write.assert_not_called() + + def test_variation_rejects_inverted_range(self, connected_controller): + """Variation must reject min > max.""" + params = { + 'min_value': 50.0, # min > max + 'max_value': 20.0, + 'step': 0.5, + 'time_step': 50, + 'delay_time': 5, + 'static_temp1': 25.0, + 'static_temp2': 30.0, + 'static_current1': 35.0, + 'static_current2': 35.0, + } + with pytest.raises(ValidationError) as exc_info: + connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params) + assert "min" in str(exc_info.value).lower() + + +class TestMeasurementsIntegration: + """Integration tests for measurement retrieval.""" + + def test_get_measurements_returns_data(self, connected_controller, mock_serial): + """get_measurements should decode and return device data.""" + mock_serial.read.return_value = make_valid_response() + measurements = connected_controller.get_measurements() + + assert measurements is not None + assert isinstance(measurements.current1, float) + assert isinstance(measurements.current2, float) + assert isinstance(measurements.temp1, float) + assert isinstance(measurements.temp2, float) + assert isinstance(measurements.voltage_3v3, float) + + def test_get_measurements_calls_callback(self, mock_serial): + """on_data callback should be triggered on new measurements.""" + received = [] + mock_serial.read.return_value = make_valid_response() + mock_serial.is_open = True + + ctrl = LaserController( + port='/dev/ttyUSB0', + on_data=lambda m: received.append(m) + ) + ctrl._protocol._serial = mock_serial + + ctrl.get_measurements() + assert len(received) == 1 + assert received[0].voltage_3v3 > 0 + + def test_get_measurements_no_data(self, connected_controller, mock_serial): + """get_measurements returns None when no data received.""" + mock_serial.read.return_value = b'' + result = connected_controller.get_measurements() + assert result is None + + def test_voltage_rail_check(self, connected_controller, mock_serial): + """Test power rail health check on measurements.""" + mock_serial.read.return_value = make_valid_response( + mon_3v3_n=2703, # ~3.3V + mon_5v1_n=2731, # ~5.0V + mon_5v2_n=2731, + mon_7v0_n=1042, # ~7.0V + ) + measurements = connected_controller.get_measurements() + if measurements: + rails = measurements.check_power_rails() + assert isinstance(rails, dict) + assert '3v3' in rails + assert '5v1' in rails + assert '5v2' in rails + assert '7v0' in rails + + +class TestConnectionManagement: + """Integration tests for connection handling.""" + + def test_context_manager(self, mock_serial): + """Test using LaserController as context manager.""" + mock_serial.is_open = True + with patch('serial.Serial', return_value=mock_serial): + with LaserController(port='/dev/ttyUSB0') as ctrl: + assert ctrl.is_connected + mock_serial.close.assert_called() + + def test_send_without_connection_raises(self): + """Sending command without connection raises CommunicationError.""" + ctrl = LaserController(port='/dev/ttyUSB0') + # Don't call connect() + with pytest.raises(CommunicationError) as exc_info: + ctrl.set_manual_mode(25.0, 30.0, 40.0, 35.0) + assert "connect" in str(exc_info.value).lower() + + def test_stop_task_sends_default_enable(self, connected_controller, mock_serial): + """stop_task should send DEFAULT_ENABLE (0x2222).""" + mock_serial.write.reset_mock() + connected_controller.stop_task() + + assert mock_serial.write.called + sent_data = mock_serial.write.call_args[0][0] + # DEFAULT_ENABLE: 0x2222 → flipped to bytes 0x22 0x22 + assert sent_data[0] == 0x22 + assert sent_data[1] == 0x22 + + def test_reset_sends_default_enable(self, connected_controller, mock_serial): + """reset() should also send DEFAULT_ENABLE.""" + mock_serial.write.reset_mock() + connected_controller.reset() + assert mock_serial.write.called + + +class TestConversionsRoundtrip: + """Test that physical unit conversions are self-consistent.""" + + def test_temperature_roundtrip(self): + """temp_c_to_n and temp_n_to_c should be inverse of each other.""" + from laser_control.conversions import temp_c_to_n, temp_n_to_c + for temp in [15.0, 20.0, 25.0, 30.0, 35.0, 40.0]: + n = temp_c_to_n(temp) + recovered = temp_n_to_c(n) + assert abs(recovered - temp) < 0.05, \ + f"Temperature roundtrip failed for {temp}°C: got {recovered}°C" + + def test_current_roundtrip(self): + """current_ma_to_n and current_n_to_ma should be approximately inverse.""" + from laser_control.conversions import current_ma_to_n, current_n_to_ma + # Note: current_n_to_ma is for photodiode readback, not exact inverse + # of current_ma_to_n (different calibration paths). + # We just test that values are in plausible range. + for current in [15.0, 30.0, 45.0, 60.0]: + n = current_ma_to_n(current) + assert 0 <= n <= 65535 + + def test_voltage_conversions_nominal(self): + """Test voltage conversions at nominal counts.""" + from laser_control.conversions import ( + voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v + ) + # Approximate nominal ADC counts for each rail + v33 = voltage_3v3_n_to_v(2703) + assert 3.1 <= v33 <= 3.5, f"3.3V rail: {v33}" + + v5 = voltage_5v_n_to_v(2731) + assert 4.8 <= v5 <= 5.3, f"5V rail: {v5}" + + v7 = voltage_7v_n_to_v(1042) + assert 6.5 <= v7 <= 7.5, f"7V rail: {v7}" \ No newline at end of file diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..1362fec --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,345 @@ +""" +Tests for communication protocol module. + +Testing command encoding/decoding, CRC calculations, +and protocol message structure. +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch, call +import struct +from laser_control.protocol import ( + Protocol, + CommandCode, + TaskType, + Message, + Response +) +from laser_control.exceptions import ( + CommunicationError, + CRCError, + ProtocolError +) + + +class TestCRCCalculation: + """Test CRC calculation and verification.""" + + def test_crc_calculation(self): + """Test CRC calculation for known data (at least 2 words needed).""" + # calculate_crc skips word 0 and XORs words 1..N + # So we need at least 4 bytes (2 words) + data = b'\x00\x01\x02\x03\x04\x05\x06\x07' + crc = Protocol.calculate_crc(data) + assert isinstance(crc, int) + assert 0 <= crc <= 0xFFFF + + def test_crc_consistency(self): + """Test CRC calculation consistency.""" + data = b'\x11\x11' + b'\x00' * 26 + b'\xFF\xFF' # 30 bytes + crc1 = Protocol.calculate_crc(data) + crc2 = Protocol.calculate_crc(data) + assert crc1 == crc2 + + def test_crc_different_data(self): + """Test CRC differs for different data.""" + data1 = b'\x00\x00\x01\x02\x03\x04' + data2 = b'\x00\x00\x05\x06\x07\x08' + crc1 = Protocol.calculate_crc(data1) + crc2 = Protocol.calculate_crc(data2) + assert crc1 != crc2 + + +class TestMessageEncoding: + """Test message encoding for device commands.""" + + def test_encode_decode_enable_command(self): + """Test encoding DECODE_ENABLE command.""" + message = Protocol.encode_decode_enable( + temp1=25.5, + temp2=30.0, + current1=40.0, + current2=35.0, + pi_coeff1_p=1, + pi_coeff1_i=1, + pi_coeff2_p=1, + pi_coeff2_i=1, + message_id=12345 + ) + + assert isinstance(message, bytes) + assert len(message) == 30 # Expected message length + + # Check command code (0x1111 stored little-endian via flipfour → 0x11 0x11) + assert message[0] == 0x11 + assert message[1] == 0x11 + + def test_encode_task_enable_command(self): + """Test encoding TASK_ENABLE command.""" + message = Protocol.encode_task_enable( + task_type=TaskType.CHANGE_CURRENT_LD1, + static_temp1=25.0, + static_temp2=30.0, + static_current1=40.0, + static_current2=35.0, + min_value=20.0, + max_value=50.0, + step=0.5, + time_step=50, + delay_time=5, + message_id=54321 + ) + + assert isinstance(message, bytes) + assert len(message) > 0 + + # Check command code + command = struct.unpack(' 0xFFFF should wrap (& 0xFFFF in controller) + large_id = 0x10000 + 123 + wrapped = large_id & 0xFFFF + message = Protocol.encode_decode_enable( + temp1=25.0, temp2=30.0, + current1=40.0, current2=35.0, + pi_coeff1_p=1, pi_coeff1_i=1, + pi_coeff2_p=1, pi_coeff2_i=1, + message_id=wrapped, + ) + assert isinstance(message, bytes) + assert len(message) == 30 + + def test_negative_values_handling(self): + """Test handling of negative values where not allowed.""" + with pytest.raises(ValueError): + Protocol.encode_decode_enable( + temp1=25.0, + temp2=30.0, + current1=-10.0, # Negative current + current2=35.0, + pi_coeff1_p=1.0, + pi_coeff1_i=0.5, + pi_coeff2_p=1.0, + pi_coeff2_i=0.5, + message_id=12345 + ) \ No newline at end of file diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..d1b5c29 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,383 @@ +""" +Tests for parameter validation module. + +Testing validation of all input parameters with boundary conditions, +invalid types, and edge cases. +""" + +import pytest +import math +from laser_control.validators import ParameterValidator +from laser_control.exceptions import ( + ValidationError, + TemperatureOutOfRangeError, + CurrentOutOfRangeError, + InvalidParameterError +) +from laser_control.models import VariationType + + +class TestTemperatureValidation: + """Test temperature parameter validation.""" + + def test_valid_temperature_range(self): + """Test temperatures within valid range.""" + # Valid temperatures should pass + assert ParameterValidator.validate_temperature(15.0, "temp1") == 15.0 + assert ParameterValidator.validate_temperature(25.5, "temp2") == 25.5 + assert ParameterValidator.validate_temperature(40.0, "temp1") == 40.0 + + def test_temperature_below_minimum(self): + """Test temperature below minimum threshold.""" + with pytest.raises(TemperatureOutOfRangeError) as exc_info: + ParameterValidator.validate_temperature(10.0, "temp1") + assert "temp1" in str(exc_info.value) + assert "15.0" in str(exc_info.value) # min value + + def test_temperature_above_maximum(self): + """Test temperature above maximum threshold.""" + with pytest.raises(TemperatureOutOfRangeError) as exc_info: + ParameterValidator.validate_temperature(45.0, "temp2") + assert "temp2" in str(exc_info.value) + assert "40.0" in str(exc_info.value) # max value + + def test_temperature_invalid_type(self): + """Test invalid temperature type.""" + with pytest.raises(InvalidParameterError) as exc_info: + ParameterValidator.validate_temperature("invalid", "temp1") + assert "temp1" in str(exc_info.value) + assert "number" in str(exc_info.value).lower() + + def test_temperature_nan_value(self): + """Test NaN temperature value.""" + with pytest.raises(InvalidParameterError) as exc_info: + ParameterValidator.validate_temperature(float('nan'), "temp1") + assert "NaN" in str(exc_info.value) + + def test_temperature_inf_value(self): + """Test infinite temperature value.""" + with pytest.raises(InvalidParameterError) as exc_info: + ParameterValidator.validate_temperature(float('inf'), "temp2") + assert "infinite" in str(exc_info.value).lower() + + def test_temperature_none_value(self): + """Test None temperature value.""" + with pytest.raises(InvalidParameterError) as exc_info: + ParameterValidator.validate_temperature(None, "temp1") + assert "temp1" in str(exc_info.value) + + +class TestCurrentValidation: + """Test current parameter validation.""" + + def test_valid_current_range(self): + """Test currents within valid range.""" + assert ParameterValidator.validate_current(15.0, "current1") == 15.0 + assert ParameterValidator.validate_current(37.5, "current2") == 37.5 + assert ParameterValidator.validate_current(60.0, "current1") == 60.0 + + def test_current_below_minimum(self): + """Test current below minimum threshold.""" + with pytest.raises(CurrentOutOfRangeError) as exc_info: + ParameterValidator.validate_current(10.0, "current1") + assert "current1" in str(exc_info.value) + assert "15.0" in str(exc_info.value) # min value + + def test_current_above_maximum(self): + """Test current above maximum threshold.""" + with pytest.raises(CurrentOutOfRangeError) as exc_info: + ParameterValidator.validate_current(65.0, "current2") + assert "current2" in str(exc_info.value) + assert "60.0" in str(exc_info.value) # max value + + def test_current_invalid_type(self): + """Test invalid current type.""" + with pytest.raises(InvalidParameterError) as exc_info: + ParameterValidator.validate_current([15, 20], "current1") + assert "current1" in str(exc_info.value) + + def test_current_negative_value(self): + """Test negative current value.""" + with pytest.raises(CurrentOutOfRangeError) as exc_info: + ParameterValidator.validate_current(-5.0, "current1") + assert "current1" in str(exc_info.value) + + +class TestVariationParameterValidation: + """Test variation mode parameter validation.""" + + def test_valid_current_variation_params(self): + """Test valid parameters for current variation.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.5, + 'time_step': 50, # microseconds + 'delay_time': 5 # milliseconds + } + validated = ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert validated['min_value'] == 20.0 + assert validated['max_value'] == 50.0 + assert validated['step'] == 0.5 + + def test_variation_min_greater_than_max(self): + """Test min value greater than max value.""" + params = { + 'min_value': 50.0, + 'max_value': 20.0, + 'step': 0.5, + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert "min" in str(exc_info.value).lower() + assert "max" in str(exc_info.value).lower() + + def test_variation_invalid_step(self): + """Test invalid step values.""" + # Zero step + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0, + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert "step" in str(exc_info.value).lower() + + # Negative step + params['step'] = -0.5 + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert "step" in str(exc_info.value).lower() + + def test_variation_step_too_small(self): + """Test step value too small for current.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.001, # Too small for current (min 0.002) + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD2 + ) + assert "step" in str(exc_info.value).lower() + assert "0.002" in str(exc_info.value) + + def test_variation_step_too_large(self): + """Test step value too large.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 10.0, # Too large for current (max 0.5) + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert "step" in str(exc_info.value).lower() + assert "0.5" in str(exc_info.value) + + def test_valid_temperature_variation_params(self): + """Test valid parameters for temperature variation.""" + params = { + 'min_value': 20.0, + 'max_value': 35.0, + 'step': 0.1, + 'time_step': 50, + 'delay_time': 5 + } + validated = ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_TEMPERATURE_LD1 + ) + assert validated['min_value'] == 20.0 + assert validated['max_value'] == 35.0 + assert validated['step'] == 0.1 + + def test_temperature_variation_step_bounds(self): + """Test temperature variation step boundaries.""" + params = { + 'min_value': 20.0, + 'max_value': 35.0, + 'step': 0.02, # Too small (min 0.05) + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_TEMPERATURE_LD2 + ) + assert "0.05" in str(exc_info.value) + + params['step'] = 2.0 # Too large (max 1.0) + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_TEMPERATURE_LD1 + ) + assert "1.0" in str(exc_info.value) + + def test_missing_required_params(self): + """Test missing required parameters.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0 + # Missing step, time_step, delay_time + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + VariationType.CHANGE_CURRENT_LD1 + ) + assert "required" in str(exc_info.value).lower() + + def test_invalid_variation_type(self): + """Test invalid variation type.""" + params = { + 'min_value': 20.0, + 'max_value': 50.0, + 'step': 0.5, + 'time_step': 50, + 'delay_time': 5 + } + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_variation_params( + params, + "INVALID_TYPE" + ) + assert "variation type" in str(exc_info.value).lower() + + +class TestTimeParameterValidation: + """Test time parameter validation.""" + + def test_valid_time_params(self): + """Test valid time parameters.""" + step_time, delay_time = ParameterValidator.validate_time_params(50, 5) + assert step_time == 50 + assert delay_time == 5 + + step_time, delay_time = ParameterValidator.validate_time_params(20, 3) + assert step_time == 20 + assert delay_time == 3 + + step_time, delay_time = ParameterValidator.validate_time_params(100, 10) + assert step_time == 100 + assert delay_time == 10 + + def test_time_step_below_minimum(self): + """Test time step below minimum.""" + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_time_params(10, 5) # Min is 20 + assert "time step" in str(exc_info.value).lower() + assert "20" in str(exc_info.value) + + def test_time_step_above_maximum(self): + """Test time step above maximum.""" + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_time_params(150, 5) # Max is 100 + assert "time step" in str(exc_info.value).lower() + assert "100" in str(exc_info.value) + + def test_delay_time_below_minimum(self): + """Test delay time below minimum.""" + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_time_params(50, 1) # Min is 3 + assert "delay" in str(exc_info.value).lower() + assert "3" in str(exc_info.value) + + def test_delay_time_above_maximum(self): + """Test delay time above maximum.""" + with pytest.raises(ValidationError) as exc_info: + ParameterValidator.validate_time_params(50, 15) # Max is 10 + assert "delay" in str(exc_info.value).lower() + assert "10" in str(exc_info.value) + + def test_time_params_invalid_type(self): + """Test invalid type for time parameters.""" + with pytest.raises(InvalidParameterError): + ParameterValidator.validate_time_params("50", 5) + + with pytest.raises(InvalidParameterError): + ParameterValidator.validate_time_params(50, [5]) + + def test_time_params_float_conversion(self): + """Test float to int conversion for time parameters.""" + step_time, delay_time = ParameterValidator.validate_time_params(50.7, 5.2) + assert step_time == 50 # Should be truncated to int + assert delay_time == 5 + + +class TestManualModeValidation: + """Test manual mode parameter validation.""" + + def test_validate_all_manual_params(self): + """Test validation of all manual mode parameters at once.""" + result = ParameterValidator.validate_manual_mode_params( + temp1=25.0, + temp2=30.0, + current1=40.0, + current2=35.0 + ) + assert result['temp1'] == 25.0 + assert result['temp2'] == 30.0 + assert result['current1'] == 40.0 + assert result['current2'] == 35.0 + + def test_manual_mode_invalid_combination(self): + """Test invalid parameter combinations in manual mode.""" + # One invalid parameter should fail all validation + with pytest.raises(ValidationError): + ParameterValidator.validate_manual_mode_params( + temp1=25.0, + temp2=30.0, + current1=70.0, # Too high + current2=35.0 + ) + + def test_manual_mode_boundary_values(self): + """Test boundary values for manual mode.""" + # All minimum values + result = ParameterValidator.validate_manual_mode_params( + temp1=15.0, + temp2=15.0, + current1=15.0, + current2=15.0 + ) + assert all(v in [15.0] for v in result.values()) + + # All maximum values + result = ParameterValidator.validate_manual_mode_params( + temp1=40.0, + temp2=40.0, + current1=60.0, + current2=60.0 + ) + assert result['temp1'] == 40.0 + assert result['temp2'] == 40.0 + assert result['current1'] == 60.0 + assert result['current2'] == 60.0 \ No newline at end of file