initial commit
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
109
tests/conftest.py
Normal file
109
tests/conftest.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
Shared fixtures for laser_control tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import struct
|
||||
from unittest.mock import MagicMock, patch
|
||||
from laser_control.protocol import Protocol, _build_crc, _flipfour, _int_to_hex4
|
||||
from laser_control.controller import LaserController
|
||||
from laser_control.conversions import (
|
||||
current_n_to_ma, temp_n_to_c, temp_ext_n_to_c,
|
||||
voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v,
|
||||
)
|
||||
|
||||
|
||||
def make_valid_response(
|
||||
current1_n: int = 10000,
|
||||
current2_n: int = 12000,
|
||||
temp1_n: int = 30000,
|
||||
temp2_n: int = 32000,
|
||||
temp_ext1_n: int = 2048,
|
||||
temp_ext2_n: int = 2048,
|
||||
mon_3v3_n: int = 2703, # ~3.3V
|
||||
mon_5v1_n: int = 2731, # ~5.0V
|
||||
mon_5v2_n: int = 2731,
|
||||
mon_7v0_n: int = 1042, # ~7.0V
|
||||
message_id: int = 12345,
|
||||
) -> bytes:
|
||||
"""
|
||||
Build a syntactically valid 30-byte DATA response.
|
||||
|
||||
Words (each 2 bytes, little-endian via flipfour):
|
||||
0 header
|
||||
1 I1
|
||||
2 I2
|
||||
3 TO6_LSB
|
||||
4 TO6_MSB
|
||||
5 Temp_1
|
||||
6 Temp_2
|
||||
7 Temp_Ext_1
|
||||
8 Temp_Ext_2
|
||||
9 MON_3V3
|
||||
10 MON_5V1
|
||||
11 MON_5V2
|
||||
12 MON_7V0
|
||||
13 Message_ID
|
||||
14 CRC
|
||||
"""
|
||||
words_raw = [
|
||||
0xABCD, # Word 0 header
|
||||
current1_n, # Word 1
|
||||
current2_n, # Word 2
|
||||
0, # Word 3 TO6_LSB
|
||||
0, # Word 4 TO6_MSB
|
||||
temp1_n, # Word 5
|
||||
temp2_n, # Word 6
|
||||
temp_ext1_n, # Word 7
|
||||
temp_ext2_n, # Word 8
|
||||
mon_3v3_n, # Word 9
|
||||
mon_5v1_n, # Word 10
|
||||
mon_5v2_n, # Word 11
|
||||
mon_7v0_n, # Word 12
|
||||
message_id, # Word 13
|
||||
0, # Word 14 CRC placeholder
|
||||
]
|
||||
|
||||
# Build hex string with flipfour applied
|
||||
hex_str = ""
|
||||
for w in words_raw:
|
||||
hex_str += _flipfour(_int_to_hex4(w))
|
||||
|
||||
# Compute CRC over words 1..13 (indices 4..55 in hex, i.e. skip word 0)
|
||||
words_hex = [hex_str[i:i+4] for i in range(0, len(hex_str), 4)]
|
||||
crc_words = words_hex[1:14] # words 1..13
|
||||
crc_val = int(crc_words[0], 16)
|
||||
for w in crc_words[1:]:
|
||||
crc_val ^= int(w, 16)
|
||||
|
||||
# Replace CRC word
|
||||
hex_str = hex_str[:56] + _flipfour(_int_to_hex4(crc_val))
|
||||
return bytes.fromhex(hex_str)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_response_bytes():
|
||||
"""Pre-built valid 30-byte device response."""
|
||||
return make_valid_response()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serial():
|
||||
"""Mock serial.Serial object."""
|
||||
with patch('serial.Serial') as mock_cls:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.is_open = True
|
||||
mock_cls.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_controller(mock_serial):
|
||||
"""LaserController with mocked serial connection."""
|
||||
mock_serial.read.return_value = make_valid_response()
|
||||
|
||||
ctrl = LaserController(port='/dev/ttyUSB0')
|
||||
with patch('serial.Serial', return_value=mock_serial):
|
||||
ctrl._protocol._serial = mock_serial
|
||||
mock_serial.is_open = True
|
||||
return ctrl
|
||||
294
tests/test_integration.py
Normal file
294
tests/test_integration.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""
|
||||
Integration tests for the laser control module.
|
||||
|
||||
Tests the full call chain: LaserController → Protocol → Serial,
|
||||
using mock serial ports. No real hardware required.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from laser_control.controller import LaserController
|
||||
from laser_control.models import VariationType, DeviceState
|
||||
from laser_control.exceptions import (
|
||||
ValidationError,
|
||||
CommunicationError,
|
||||
TemperatureOutOfRangeError,
|
||||
CurrentOutOfRangeError,
|
||||
)
|
||||
from laser_control.protocol import Protocol, CommandCode
|
||||
from .conftest import make_valid_response
|
||||
|
||||
|
||||
class TestManualModeIntegration:
|
||||
"""Integration tests for manual mode operation."""
|
||||
|
||||
def test_full_manual_mode_flow(self, connected_controller, mock_serial):
|
||||
"""Test complete manual mode command flow."""
|
||||
connected_controller.set_manual_mode(
|
||||
temp1=25.0, temp2=30.0,
|
||||
current1=40.0, current2=35.0
|
||||
)
|
||||
|
||||
# Verify command was sent
|
||||
assert mock_serial.write.called
|
||||
sent_data = mock_serial.write.call_args[0][0]
|
||||
assert len(sent_data) == 30 # SEND_PARAMS_TOTAL_LENGTH
|
||||
|
||||
# Verify command code (bytes 0-1, little-endian 0x1111 → 0x11 0x11)
|
||||
assert sent_data[0] == 0x11
|
||||
assert sent_data[1] == 0x11
|
||||
|
||||
def test_manual_mode_validation_rejects_invalid_temp(self, connected_controller):
|
||||
"""Test that manual mode rejects out-of-range temperature."""
|
||||
with pytest.raises(TemperatureOutOfRangeError) as exc_info:
|
||||
connected_controller.set_manual_mode(
|
||||
temp1=50.0, # Too high
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=35.0
|
||||
)
|
||||
assert "temp1" in str(exc_info.value)
|
||||
assert "50.0" in str(exc_info.value)
|
||||
|
||||
def test_manual_mode_validation_rejects_invalid_current(self, connected_controller):
|
||||
"""Test that manual mode rejects out-of-range current."""
|
||||
with pytest.raises(CurrentOutOfRangeError) as exc_info:
|
||||
connected_controller.set_manual_mode(
|
||||
temp1=25.0,
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=70.0 # Too high
|
||||
)
|
||||
assert "current2" in str(exc_info.value)
|
||||
|
||||
def test_manual_mode_no_serial_call_on_validation_failure(
|
||||
self, connected_controller, mock_serial
|
||||
):
|
||||
"""Serial write must not be called when validation fails."""
|
||||
mock_serial.write.reset_mock()
|
||||
with pytest.raises(ValidationError):
|
||||
connected_controller.set_manual_mode(
|
||||
temp1=5.0, # Invalid
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=35.0
|
||||
)
|
||||
mock_serial.write.assert_not_called()
|
||||
|
||||
def test_message_id_increments(self, connected_controller, mock_serial):
|
||||
"""Message ID should increment with each command."""
|
||||
initial_id = connected_controller._message_id
|
||||
connected_controller.set_manual_mode(25.0, 30.0, 40.0, 35.0)
|
||||
assert connected_controller._message_id == (initial_id + 1) & 0xFFFF
|
||||
|
||||
connected_controller.set_manual_mode(26.0, 31.0, 41.0, 36.0)
|
||||
assert connected_controller._message_id == (initial_id + 2) & 0xFFFF
|
||||
|
||||
|
||||
class TestVariationModeIntegration:
|
||||
"""Integration tests for variation mode operation."""
|
||||
|
||||
def test_current_ld1_variation_flow(self, connected_controller, mock_serial):
|
||||
"""Test complete current variation for laser 1."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50,
|
||||
'delay_time': 5,
|
||||
'static_temp1': 25.0,
|
||||
'static_temp2': 30.0,
|
||||
'static_current1': 35.0,
|
||||
'static_current2': 35.0,
|
||||
}
|
||||
connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params)
|
||||
|
||||
assert mock_serial.write.called
|
||||
sent_data = mock_serial.write.call_args[0][0]
|
||||
assert len(sent_data) == 32 # TASK_ENABLE_COMMAND_LENGTH
|
||||
|
||||
# Verify command code (0x7777)
|
||||
assert sent_data[0] == 0x77
|
||||
assert sent_data[1] == 0x77
|
||||
|
||||
def test_current_ld2_variation_flow(self, connected_controller, mock_serial):
|
||||
"""Test complete current variation for laser 2."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50,
|
||||
'delay_time': 5,
|
||||
'static_temp1': 25.0,
|
||||
'static_temp2': 30.0,
|
||||
'static_current1': 35.0,
|
||||
'static_current2': 35.0,
|
||||
}
|
||||
connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD2, params)
|
||||
assert mock_serial.write.called
|
||||
|
||||
def test_variation_rejects_invalid_step(self, connected_controller, mock_serial):
|
||||
"""Variation must reject step below minimum."""
|
||||
mock_serial.write.reset_mock()
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.001, # Too small
|
||||
'time_step': 50,
|
||||
'delay_time': 5,
|
||||
'static_temp1': 25.0,
|
||||
'static_temp2': 30.0,
|
||||
'static_current1': 35.0,
|
||||
'static_current2': 35.0,
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params)
|
||||
mock_serial.write.assert_not_called()
|
||||
|
||||
def test_variation_rejects_inverted_range(self, connected_controller):
|
||||
"""Variation must reject min > max."""
|
||||
params = {
|
||||
'min_value': 50.0, # min > max
|
||||
'max_value': 20.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50,
|
||||
'delay_time': 5,
|
||||
'static_temp1': 25.0,
|
||||
'static_temp2': 30.0,
|
||||
'static_current1': 35.0,
|
||||
'static_current2': 35.0,
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
connected_controller.start_variation(VariationType.CHANGE_CURRENT_LD1, params)
|
||||
assert "min" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestMeasurementsIntegration:
|
||||
"""Integration tests for measurement retrieval."""
|
||||
|
||||
def test_get_measurements_returns_data(self, connected_controller, mock_serial):
|
||||
"""get_measurements should decode and return device data."""
|
||||
mock_serial.read.return_value = make_valid_response()
|
||||
measurements = connected_controller.get_measurements()
|
||||
|
||||
assert measurements is not None
|
||||
assert isinstance(measurements.current1, float)
|
||||
assert isinstance(measurements.current2, float)
|
||||
assert isinstance(measurements.temp1, float)
|
||||
assert isinstance(measurements.temp2, float)
|
||||
assert isinstance(measurements.voltage_3v3, float)
|
||||
|
||||
def test_get_measurements_calls_callback(self, mock_serial):
|
||||
"""on_data callback should be triggered on new measurements."""
|
||||
received = []
|
||||
mock_serial.read.return_value = make_valid_response()
|
||||
mock_serial.is_open = True
|
||||
|
||||
ctrl = LaserController(
|
||||
port='/dev/ttyUSB0',
|
||||
on_data=lambda m: received.append(m)
|
||||
)
|
||||
ctrl._protocol._serial = mock_serial
|
||||
|
||||
ctrl.get_measurements()
|
||||
assert len(received) == 1
|
||||
assert received[0].voltage_3v3 > 0
|
||||
|
||||
def test_get_measurements_no_data(self, connected_controller, mock_serial):
|
||||
"""get_measurements returns None when no data received."""
|
||||
mock_serial.read.return_value = b''
|
||||
result = connected_controller.get_measurements()
|
||||
assert result is None
|
||||
|
||||
def test_voltage_rail_check(self, connected_controller, mock_serial):
|
||||
"""Test power rail health check on measurements."""
|
||||
mock_serial.read.return_value = make_valid_response(
|
||||
mon_3v3_n=2703, # ~3.3V
|
||||
mon_5v1_n=2731, # ~5.0V
|
||||
mon_5v2_n=2731,
|
||||
mon_7v0_n=1042, # ~7.0V
|
||||
)
|
||||
measurements = connected_controller.get_measurements()
|
||||
if measurements:
|
||||
rails = measurements.check_power_rails()
|
||||
assert isinstance(rails, dict)
|
||||
assert '3v3' in rails
|
||||
assert '5v1' in rails
|
||||
assert '5v2' in rails
|
||||
assert '7v0' in rails
|
||||
|
||||
|
||||
class TestConnectionManagement:
|
||||
"""Integration tests for connection handling."""
|
||||
|
||||
def test_context_manager(self, mock_serial):
|
||||
"""Test using LaserController as context manager."""
|
||||
mock_serial.is_open = True
|
||||
with patch('serial.Serial', return_value=mock_serial):
|
||||
with LaserController(port='/dev/ttyUSB0') as ctrl:
|
||||
assert ctrl.is_connected
|
||||
mock_serial.close.assert_called()
|
||||
|
||||
def test_send_without_connection_raises(self):
|
||||
"""Sending command without connection raises CommunicationError."""
|
||||
ctrl = LaserController(port='/dev/ttyUSB0')
|
||||
# Don't call connect()
|
||||
with pytest.raises(CommunicationError) as exc_info:
|
||||
ctrl.set_manual_mode(25.0, 30.0, 40.0, 35.0)
|
||||
assert "connect" in str(exc_info.value).lower()
|
||||
|
||||
def test_stop_task_sends_default_enable(self, connected_controller, mock_serial):
|
||||
"""stop_task should send DEFAULT_ENABLE (0x2222)."""
|
||||
mock_serial.write.reset_mock()
|
||||
connected_controller.stop_task()
|
||||
|
||||
assert mock_serial.write.called
|
||||
sent_data = mock_serial.write.call_args[0][0]
|
||||
# DEFAULT_ENABLE: 0x2222 → flipped to bytes 0x22 0x22
|
||||
assert sent_data[0] == 0x22
|
||||
assert sent_data[1] == 0x22
|
||||
|
||||
def test_reset_sends_default_enable(self, connected_controller, mock_serial):
|
||||
"""reset() should also send DEFAULT_ENABLE."""
|
||||
mock_serial.write.reset_mock()
|
||||
connected_controller.reset()
|
||||
assert mock_serial.write.called
|
||||
|
||||
|
||||
class TestConversionsRoundtrip:
|
||||
"""Test that physical unit conversions are self-consistent."""
|
||||
|
||||
def test_temperature_roundtrip(self):
|
||||
"""temp_c_to_n and temp_n_to_c should be inverse of each other."""
|
||||
from laser_control.conversions import temp_c_to_n, temp_n_to_c
|
||||
for temp in [15.0, 20.0, 25.0, 30.0, 35.0, 40.0]:
|
||||
n = temp_c_to_n(temp)
|
||||
recovered = temp_n_to_c(n)
|
||||
assert abs(recovered - temp) < 0.05, \
|
||||
f"Temperature roundtrip failed for {temp}°C: got {recovered}°C"
|
||||
|
||||
def test_current_roundtrip(self):
|
||||
"""current_ma_to_n and current_n_to_ma should be approximately inverse."""
|
||||
from laser_control.conversions import current_ma_to_n, current_n_to_ma
|
||||
# Note: current_n_to_ma is for photodiode readback, not exact inverse
|
||||
# of current_ma_to_n (different calibration paths).
|
||||
# We just test that values are in plausible range.
|
||||
for current in [15.0, 30.0, 45.0, 60.0]:
|
||||
n = current_ma_to_n(current)
|
||||
assert 0 <= n <= 65535
|
||||
|
||||
def test_voltage_conversions_nominal(self):
|
||||
"""Test voltage conversions at nominal counts."""
|
||||
from laser_control.conversions import (
|
||||
voltage_3v3_n_to_v, voltage_5v_n_to_v, voltage_7v_n_to_v
|
||||
)
|
||||
# Approximate nominal ADC counts for each rail
|
||||
v33 = voltage_3v3_n_to_v(2703)
|
||||
assert 3.1 <= v33 <= 3.5, f"3.3V rail: {v33}"
|
||||
|
||||
v5 = voltage_5v_n_to_v(2731)
|
||||
assert 4.8 <= v5 <= 5.3, f"5V rail: {v5}"
|
||||
|
||||
v7 = voltage_7v_n_to_v(1042)
|
||||
assert 6.5 <= v7 <= 7.5, f"7V rail: {v7}"
|
||||
345
tests/test_protocol.py
Normal file
345
tests/test_protocol.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Tests for communication protocol module.
|
||||
|
||||
Testing command encoding/decoding, CRC calculations,
|
||||
and protocol message structure.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, call
|
||||
import struct
|
||||
from laser_control.protocol import (
|
||||
Protocol,
|
||||
CommandCode,
|
||||
TaskType,
|
||||
Message,
|
||||
Response
|
||||
)
|
||||
from laser_control.exceptions import (
|
||||
CommunicationError,
|
||||
CRCError,
|
||||
ProtocolError
|
||||
)
|
||||
|
||||
|
||||
class TestCRCCalculation:
|
||||
"""Test CRC calculation and verification."""
|
||||
|
||||
def test_crc_calculation(self):
|
||||
"""Test CRC calculation for known data (at least 2 words needed)."""
|
||||
# calculate_crc skips word 0 and XORs words 1..N
|
||||
# So we need at least 4 bytes (2 words)
|
||||
data = b'\x00\x01\x02\x03\x04\x05\x06\x07'
|
||||
crc = Protocol.calculate_crc(data)
|
||||
assert isinstance(crc, int)
|
||||
assert 0 <= crc <= 0xFFFF
|
||||
|
||||
def test_crc_consistency(self):
|
||||
"""Test CRC calculation consistency."""
|
||||
data = b'\x11\x11' + b'\x00' * 26 + b'\xFF\xFF' # 30 bytes
|
||||
crc1 = Protocol.calculate_crc(data)
|
||||
crc2 = Protocol.calculate_crc(data)
|
||||
assert crc1 == crc2
|
||||
|
||||
def test_crc_different_data(self):
|
||||
"""Test CRC differs for different data."""
|
||||
data1 = b'\x00\x00\x01\x02\x03\x04'
|
||||
data2 = b'\x00\x00\x05\x06\x07\x08'
|
||||
crc1 = Protocol.calculate_crc(data1)
|
||||
crc2 = Protocol.calculate_crc(data2)
|
||||
assert crc1 != crc2
|
||||
|
||||
|
||||
class TestMessageEncoding:
|
||||
"""Test message encoding for device commands."""
|
||||
|
||||
def test_encode_decode_enable_command(self):
|
||||
"""Test encoding DECODE_ENABLE command."""
|
||||
message = Protocol.encode_decode_enable(
|
||||
temp1=25.5,
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=35.0,
|
||||
pi_coeff1_p=1,
|
||||
pi_coeff1_i=1,
|
||||
pi_coeff2_p=1,
|
||||
pi_coeff2_i=1,
|
||||
message_id=12345
|
||||
)
|
||||
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) == 30 # Expected message length
|
||||
|
||||
# Check command code (0x1111 stored little-endian via flipfour → 0x11 0x11)
|
||||
assert message[0] == 0x11
|
||||
assert message[1] == 0x11
|
||||
|
||||
def test_encode_task_enable_command(self):
|
||||
"""Test encoding TASK_ENABLE command."""
|
||||
message = Protocol.encode_task_enable(
|
||||
task_type=TaskType.CHANGE_CURRENT_LD1,
|
||||
static_temp1=25.0,
|
||||
static_temp2=30.0,
|
||||
static_current1=40.0,
|
||||
static_current2=35.0,
|
||||
min_value=20.0,
|
||||
max_value=50.0,
|
||||
step=0.5,
|
||||
time_step=50,
|
||||
delay_time=5,
|
||||
message_id=54321
|
||||
)
|
||||
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) > 0
|
||||
|
||||
# Check command code
|
||||
command = struct.unpack('<H', message[0:2])[0]
|
||||
assert command == CommandCode.TASK_ENABLE
|
||||
|
||||
def test_encode_trans_enable_command(self):
|
||||
"""Test encoding TRANS_ENABLE command."""
|
||||
message = Protocol.encode_trans_enable(message_id=11111)
|
||||
|
||||
# encode_trans_enable returns bytearray; ensure it's bytes-like
|
||||
assert len(message) == 2
|
||||
# 0x4444 flipped → bytes 0x44 0x44
|
||||
assert message[0] == 0x44
|
||||
assert message[1] == 0x44
|
||||
|
||||
def test_encode_state_command(self):
|
||||
"""Test encoding STATE command."""
|
||||
message = Protocol.encode_state(message_id=22222)
|
||||
|
||||
assert len(message) == 2
|
||||
# 0x6666 → 0x66 0x66
|
||||
assert message[0] == 0x66
|
||||
assert message[1] == 0x66
|
||||
|
||||
def test_encode_default_enable_command(self):
|
||||
"""Test encoding DEFAULT_ENABLE (reset) command."""
|
||||
message = Protocol.encode_default_enable(message_id=33333)
|
||||
|
||||
assert len(message) == 2
|
||||
# 0x2222 → 0x22 0x22
|
||||
assert message[0] == 0x22
|
||||
assert message[1] == 0x22
|
||||
|
||||
|
||||
class TestResponseDecoding:
|
||||
"""Test response message decoding."""
|
||||
|
||||
def test_decode_valid_response(self):
|
||||
"""Test decoding valid device response using conftest helper."""
|
||||
from tests.conftest import make_valid_response
|
||||
data = make_valid_response(message_id=12345)
|
||||
|
||||
response = Protocol.decode_response(data)
|
||||
|
||||
assert isinstance(response.current1, float)
|
||||
assert isinstance(response.temp1, float)
|
||||
assert isinstance(response.voltage_3v3, float)
|
||||
assert response.message_id == 12345
|
||||
|
||||
def test_decode_response_invalid_crc(self):
|
||||
"""Test decoding response with invalid CRC."""
|
||||
response_data = bytearray(30)
|
||||
struct.pack_into('<H', response_data, 28, 0xFFFF) # Invalid CRC
|
||||
|
||||
with pytest.raises(CRCError):
|
||||
Protocol.decode_response(bytes(response_data))
|
||||
|
||||
def test_decode_response_invalid_length(self):
|
||||
"""Test decoding response with invalid length."""
|
||||
response_data = bytes(20) # Too short (expected 30)
|
||||
|
||||
with pytest.raises(ProtocolError) as exc_info:
|
||||
Protocol.decode_response(response_data)
|
||||
# ProtocolError message includes "bytes"
|
||||
assert "bytes" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_state_response(self):
|
||||
"""Test decoding IDLE state response (2 bytes, flipfour encoded)."""
|
||||
from laser_control.protocol import _flipfour, _int_to_hex4
|
||||
# STATE IDLE = 0x0000; after flipfour it remains 0x0000
|
||||
state_bytes = bytes.fromhex(_flipfour(_int_to_hex4(0x0000)))
|
||||
state = Protocol.decode_state(state_bytes)
|
||||
assert state == 0x0000 # IDLE
|
||||
|
||||
def test_decode_state_error_conditions(self):
|
||||
"""Test decoding various error state codes."""
|
||||
from laser_control.protocol import _flipfour, _int_to_hex4
|
||||
error_codes = [0x0001, 0x0002, 0x0004, 0x0008, 0x0010]
|
||||
|
||||
for code in error_codes:
|
||||
state_bytes = bytes.fromhex(_flipfour(_int_to_hex4(code)))
|
||||
state = Protocol.decode_state(state_bytes)
|
||||
assert state == code
|
||||
|
||||
|
||||
class TestProtocolHelpers:
|
||||
"""Test protocol helper functions."""
|
||||
|
||||
def test_flipfour_byte_order(self):
|
||||
"""Test byte order flipping for little-endian conversion.
|
||||
|
||||
Protocol.flipfour() operates on 16-bit integers (byte-swap within a word).
|
||||
The underlying _flipfour() operates on 4-char hex strings (word-swap).
|
||||
"""
|
||||
from laser_control.protocol import _flipfour
|
||||
# _flipfour swaps two byte pairs: 'aabb' → 'bbaa'
|
||||
assert _flipfour('1234') == '3412'
|
||||
assert _flipfour('abcd') == 'cdab'
|
||||
assert _flipfour('0000') == '0000'
|
||||
assert _flipfour('1111') == '1111'
|
||||
|
||||
# Protocol.flipfour() byte-swaps a 16-bit int
|
||||
assert Protocol.flipfour(0x1234) == 0x3412
|
||||
assert Protocol.flipfour(0x0000) == 0x0000
|
||||
|
||||
def test_pack_float_conversion(self):
|
||||
"""Test float to bytes conversion."""
|
||||
value = 25.5
|
||||
packed = Protocol.pack_float(value)
|
||||
assert len(packed) == 4
|
||||
|
||||
# Unpack and verify
|
||||
unpacked = struct.unpack('<f', packed)[0]
|
||||
assert abs(unpacked - value) < 0.001
|
||||
|
||||
def test_pack_uint16_conversion(self):
|
||||
"""Test uint16 to bytes conversion."""
|
||||
value = 12345
|
||||
packed = Protocol.pack_uint16(value)
|
||||
assert len(packed) == 2
|
||||
|
||||
unpacked = struct.unpack('<H', packed)[0]
|
||||
assert unpacked == value
|
||||
|
||||
|
||||
class TestSerialCommunication:
|
||||
"""Test serial port communication."""
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_send_command(self, mock_serial_class):
|
||||
"""Test sending command over serial."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
|
||||
# Send a command
|
||||
message = b'test_message'
|
||||
protocol.send_raw(message)
|
||||
|
||||
mock_serial.write.assert_called_once_with(message)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_receive_response(self, mock_serial_class):
|
||||
"""Test receiving response from serial."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
# Mock response data
|
||||
response_data = bytes(30)
|
||||
mock_serial.read.return_value = response_data
|
||||
mock_serial.in_waiting = 30
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
|
||||
data = protocol.receive_raw(30)
|
||||
assert data == response_data
|
||||
mock_serial.read.assert_called_once_with(30)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_connection_failure(self, mock_serial_class):
|
||||
"""Test handling connection failure."""
|
||||
mock_serial_class.side_effect = Exception("Port not found")
|
||||
|
||||
protocol = Protocol(port='/dev/invalid')
|
||||
with pytest.raises(CommunicationError) as exc_info:
|
||||
protocol.connect()
|
||||
assert "connect" in str(exc_info.value).lower()
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_auto_port_detection(self, mock_serial_class):
|
||||
"""Test automatic port detection."""
|
||||
with patch('serial.tools.list_ports.comports') as mock_comports:
|
||||
# Mock available ports
|
||||
mock_port = MagicMock()
|
||||
mock_port.device = '/dev/ttyUSB0'
|
||||
mock_comports.return_value = [mock_port]
|
||||
|
||||
protocol = Protocol() # No port specified
|
||||
protocol.connect()
|
||||
|
||||
mock_serial_class.assert_called_with(
|
||||
port='/dev/ttyUSB0',
|
||||
baudrate=115200,
|
||||
timeout=1
|
||||
)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_disconnect(self, mock_serial_class):
|
||||
"""Test proper disconnection."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
protocol.disconnect()
|
||||
|
||||
mock_serial.close.assert_called_once()
|
||||
|
||||
|
||||
class TestMessageValidation:
|
||||
"""Test message validation and error handling."""
|
||||
|
||||
def test_invalid_task_type(self):
|
||||
"""Test handling of invalid task type."""
|
||||
with pytest.raises(ValueError):
|
||||
Protocol.encode_task_enable(
|
||||
task_type=99, # Invalid type
|
||||
static_temp1=25.0,
|
||||
static_temp2=30.0,
|
||||
static_current1=40.0,
|
||||
static_current2=35.0,
|
||||
min_value=20.0,
|
||||
max_value=50.0,
|
||||
step=0.5,
|
||||
time_step=50,
|
||||
delay_time=5,
|
||||
message_id=12345
|
||||
)
|
||||
|
||||
def test_message_id_overflow(self):
|
||||
"""encode_decode_enable wraps message_id to 16-bit boundary."""
|
||||
# Message ID > 0xFFFF should wrap (& 0xFFFF in controller)
|
||||
large_id = 0x10000 + 123
|
||||
wrapped = large_id & 0xFFFF
|
||||
message = Protocol.encode_decode_enable(
|
||||
temp1=25.0, temp2=30.0,
|
||||
current1=40.0, current2=35.0,
|
||||
pi_coeff1_p=1, pi_coeff1_i=1,
|
||||
pi_coeff2_p=1, pi_coeff2_i=1,
|
||||
message_id=wrapped,
|
||||
)
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) == 30
|
||||
|
||||
def test_negative_values_handling(self):
|
||||
"""Test handling of negative values where not allowed."""
|
||||
with pytest.raises(ValueError):
|
||||
Protocol.encode_decode_enable(
|
||||
temp1=25.0,
|
||||
temp2=30.0,
|
||||
current1=-10.0, # Negative current
|
||||
current2=35.0,
|
||||
pi_coeff1_p=1.0,
|
||||
pi_coeff1_i=0.5,
|
||||
pi_coeff2_p=1.0,
|
||||
pi_coeff2_i=0.5,
|
||||
message_id=12345
|
||||
)
|
||||
383
tests/test_validation.py
Normal file
383
tests/test_validation.py
Normal file
@ -0,0 +1,383 @@
|
||||
"""
|
||||
Tests for parameter validation module.
|
||||
|
||||
Testing validation of all input parameters with boundary conditions,
|
||||
invalid types, and edge cases.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import math
|
||||
from laser_control.validators import ParameterValidator
|
||||
from laser_control.exceptions import (
|
||||
ValidationError,
|
||||
TemperatureOutOfRangeError,
|
||||
CurrentOutOfRangeError,
|
||||
InvalidParameterError
|
||||
)
|
||||
from laser_control.models import VariationType
|
||||
|
||||
|
||||
class TestTemperatureValidation:
|
||||
"""Test temperature parameter validation."""
|
||||
|
||||
def test_valid_temperature_range(self):
|
||||
"""Test temperatures within valid range."""
|
||||
# Valid temperatures should pass
|
||||
assert ParameterValidator.validate_temperature(15.0, "temp1") == 15.0
|
||||
assert ParameterValidator.validate_temperature(25.5, "temp2") == 25.5
|
||||
assert ParameterValidator.validate_temperature(40.0, "temp1") == 40.0
|
||||
|
||||
def test_temperature_below_minimum(self):
|
||||
"""Test temperature below minimum threshold."""
|
||||
with pytest.raises(TemperatureOutOfRangeError) as exc_info:
|
||||
ParameterValidator.validate_temperature(10.0, "temp1")
|
||||
assert "temp1" in str(exc_info.value)
|
||||
assert "15.0" in str(exc_info.value) # min value
|
||||
|
||||
def test_temperature_above_maximum(self):
|
||||
"""Test temperature above maximum threshold."""
|
||||
with pytest.raises(TemperatureOutOfRangeError) as exc_info:
|
||||
ParameterValidator.validate_temperature(45.0, "temp2")
|
||||
assert "temp2" in str(exc_info.value)
|
||||
assert "40.0" in str(exc_info.value) # max value
|
||||
|
||||
def test_temperature_invalid_type(self):
|
||||
"""Test invalid temperature type."""
|
||||
with pytest.raises(InvalidParameterError) as exc_info:
|
||||
ParameterValidator.validate_temperature("invalid", "temp1")
|
||||
assert "temp1" in str(exc_info.value)
|
||||
assert "number" in str(exc_info.value).lower()
|
||||
|
||||
def test_temperature_nan_value(self):
|
||||
"""Test NaN temperature value."""
|
||||
with pytest.raises(InvalidParameterError) as exc_info:
|
||||
ParameterValidator.validate_temperature(float('nan'), "temp1")
|
||||
assert "NaN" in str(exc_info.value)
|
||||
|
||||
def test_temperature_inf_value(self):
|
||||
"""Test infinite temperature value."""
|
||||
with pytest.raises(InvalidParameterError) as exc_info:
|
||||
ParameterValidator.validate_temperature(float('inf'), "temp2")
|
||||
assert "infinite" in str(exc_info.value).lower()
|
||||
|
||||
def test_temperature_none_value(self):
|
||||
"""Test None temperature value."""
|
||||
with pytest.raises(InvalidParameterError) as exc_info:
|
||||
ParameterValidator.validate_temperature(None, "temp1")
|
||||
assert "temp1" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestCurrentValidation:
|
||||
"""Test current parameter validation."""
|
||||
|
||||
def test_valid_current_range(self):
|
||||
"""Test currents within valid range."""
|
||||
assert ParameterValidator.validate_current(15.0, "current1") == 15.0
|
||||
assert ParameterValidator.validate_current(37.5, "current2") == 37.5
|
||||
assert ParameterValidator.validate_current(60.0, "current1") == 60.0
|
||||
|
||||
def test_current_below_minimum(self):
|
||||
"""Test current below minimum threshold."""
|
||||
with pytest.raises(CurrentOutOfRangeError) as exc_info:
|
||||
ParameterValidator.validate_current(10.0, "current1")
|
||||
assert "current1" in str(exc_info.value)
|
||||
assert "15.0" in str(exc_info.value) # min value
|
||||
|
||||
def test_current_above_maximum(self):
|
||||
"""Test current above maximum threshold."""
|
||||
with pytest.raises(CurrentOutOfRangeError) as exc_info:
|
||||
ParameterValidator.validate_current(65.0, "current2")
|
||||
assert "current2" in str(exc_info.value)
|
||||
assert "60.0" in str(exc_info.value) # max value
|
||||
|
||||
def test_current_invalid_type(self):
|
||||
"""Test invalid current type."""
|
||||
with pytest.raises(InvalidParameterError) as exc_info:
|
||||
ParameterValidator.validate_current([15, 20], "current1")
|
||||
assert "current1" in str(exc_info.value)
|
||||
|
||||
def test_current_negative_value(self):
|
||||
"""Test negative current value."""
|
||||
with pytest.raises(CurrentOutOfRangeError) as exc_info:
|
||||
ParameterValidator.validate_current(-5.0, "current1")
|
||||
assert "current1" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestVariationParameterValidation:
|
||||
"""Test variation mode parameter validation."""
|
||||
|
||||
def test_valid_current_variation_params(self):
|
||||
"""Test valid parameters for current variation."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50, # microseconds
|
||||
'delay_time': 5 # milliseconds
|
||||
}
|
||||
validated = ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert validated['min_value'] == 20.0
|
||||
assert validated['max_value'] == 50.0
|
||||
assert validated['step'] == 0.5
|
||||
|
||||
def test_variation_min_greater_than_max(self):
|
||||
"""Test min value greater than max value."""
|
||||
params = {
|
||||
'min_value': 50.0,
|
||||
'max_value': 20.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert "min" in str(exc_info.value).lower()
|
||||
assert "max" in str(exc_info.value).lower()
|
||||
|
||||
def test_variation_invalid_step(self):
|
||||
"""Test invalid step values."""
|
||||
# Zero step
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0,
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert "step" in str(exc_info.value).lower()
|
||||
|
||||
# Negative step
|
||||
params['step'] = -0.5
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert "step" in str(exc_info.value).lower()
|
||||
|
||||
def test_variation_step_too_small(self):
|
||||
"""Test step value too small for current."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.001, # Too small for current (min 0.002)
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD2
|
||||
)
|
||||
assert "step" in str(exc_info.value).lower()
|
||||
assert "0.002" in str(exc_info.value)
|
||||
|
||||
def test_variation_step_too_large(self):
|
||||
"""Test step value too large."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 10.0, # Too large for current (max 0.5)
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert "step" in str(exc_info.value).lower()
|
||||
assert "0.5" in str(exc_info.value)
|
||||
|
||||
def test_valid_temperature_variation_params(self):
|
||||
"""Test valid parameters for temperature variation."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 35.0,
|
||||
'step': 0.1,
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
validated = ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_TEMPERATURE_LD1
|
||||
)
|
||||
assert validated['min_value'] == 20.0
|
||||
assert validated['max_value'] == 35.0
|
||||
assert validated['step'] == 0.1
|
||||
|
||||
def test_temperature_variation_step_bounds(self):
|
||||
"""Test temperature variation step boundaries."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 35.0,
|
||||
'step': 0.02, # Too small (min 0.05)
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_TEMPERATURE_LD2
|
||||
)
|
||||
assert "0.05" in str(exc_info.value)
|
||||
|
||||
params['step'] = 2.0 # Too large (max 1.0)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_TEMPERATURE_LD1
|
||||
)
|
||||
assert "1.0" in str(exc_info.value)
|
||||
|
||||
def test_missing_required_params(self):
|
||||
"""Test missing required parameters."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0
|
||||
# Missing step, time_step, delay_time
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
VariationType.CHANGE_CURRENT_LD1
|
||||
)
|
||||
assert "required" in str(exc_info.value).lower()
|
||||
|
||||
def test_invalid_variation_type(self):
|
||||
"""Test invalid variation type."""
|
||||
params = {
|
||||
'min_value': 20.0,
|
||||
'max_value': 50.0,
|
||||
'step': 0.5,
|
||||
'time_step': 50,
|
||||
'delay_time': 5
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_variation_params(
|
||||
params,
|
||||
"INVALID_TYPE"
|
||||
)
|
||||
assert "variation type" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestTimeParameterValidation:
|
||||
"""Test time parameter validation."""
|
||||
|
||||
def test_valid_time_params(self):
|
||||
"""Test valid time parameters."""
|
||||
step_time, delay_time = ParameterValidator.validate_time_params(50, 5)
|
||||
assert step_time == 50
|
||||
assert delay_time == 5
|
||||
|
||||
step_time, delay_time = ParameterValidator.validate_time_params(20, 3)
|
||||
assert step_time == 20
|
||||
assert delay_time == 3
|
||||
|
||||
step_time, delay_time = ParameterValidator.validate_time_params(100, 10)
|
||||
assert step_time == 100
|
||||
assert delay_time == 10
|
||||
|
||||
def test_time_step_below_minimum(self):
|
||||
"""Test time step below minimum."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_time_params(10, 5) # Min is 20
|
||||
assert "time step" in str(exc_info.value).lower()
|
||||
assert "20" in str(exc_info.value)
|
||||
|
||||
def test_time_step_above_maximum(self):
|
||||
"""Test time step above maximum."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_time_params(150, 5) # Max is 100
|
||||
assert "time step" in str(exc_info.value).lower()
|
||||
assert "100" in str(exc_info.value)
|
||||
|
||||
def test_delay_time_below_minimum(self):
|
||||
"""Test delay time below minimum."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_time_params(50, 1) # Min is 3
|
||||
assert "delay" in str(exc_info.value).lower()
|
||||
assert "3" in str(exc_info.value)
|
||||
|
||||
def test_delay_time_above_maximum(self):
|
||||
"""Test delay time above maximum."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ParameterValidator.validate_time_params(50, 15) # Max is 10
|
||||
assert "delay" in str(exc_info.value).lower()
|
||||
assert "10" in str(exc_info.value)
|
||||
|
||||
def test_time_params_invalid_type(self):
|
||||
"""Test invalid type for time parameters."""
|
||||
with pytest.raises(InvalidParameterError):
|
||||
ParameterValidator.validate_time_params("50", 5)
|
||||
|
||||
with pytest.raises(InvalidParameterError):
|
||||
ParameterValidator.validate_time_params(50, [5])
|
||||
|
||||
def test_time_params_float_conversion(self):
|
||||
"""Test float to int conversion for time parameters."""
|
||||
step_time, delay_time = ParameterValidator.validate_time_params(50.7, 5.2)
|
||||
assert step_time == 50 # Should be truncated to int
|
||||
assert delay_time == 5
|
||||
|
||||
|
||||
class TestManualModeValidation:
|
||||
"""Test manual mode parameter validation."""
|
||||
|
||||
def test_validate_all_manual_params(self):
|
||||
"""Test validation of all manual mode parameters at once."""
|
||||
result = ParameterValidator.validate_manual_mode_params(
|
||||
temp1=25.0,
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=35.0
|
||||
)
|
||||
assert result['temp1'] == 25.0
|
||||
assert result['temp2'] == 30.0
|
||||
assert result['current1'] == 40.0
|
||||
assert result['current2'] == 35.0
|
||||
|
||||
def test_manual_mode_invalid_combination(self):
|
||||
"""Test invalid parameter combinations in manual mode."""
|
||||
# One invalid parameter should fail all validation
|
||||
with pytest.raises(ValidationError):
|
||||
ParameterValidator.validate_manual_mode_params(
|
||||
temp1=25.0,
|
||||
temp2=30.0,
|
||||
current1=70.0, # Too high
|
||||
current2=35.0
|
||||
)
|
||||
|
||||
def test_manual_mode_boundary_values(self):
|
||||
"""Test boundary values for manual mode."""
|
||||
# All minimum values
|
||||
result = ParameterValidator.validate_manual_mode_params(
|
||||
temp1=15.0,
|
||||
temp2=15.0,
|
||||
current1=15.0,
|
||||
current2=15.0
|
||||
)
|
||||
assert all(v in [15.0] for v in result.values())
|
||||
|
||||
# All maximum values
|
||||
result = ParameterValidator.validate_manual_mode_params(
|
||||
temp1=40.0,
|
||||
temp2=40.0,
|
||||
current1=60.0,
|
||||
current2=60.0
|
||||
)
|
||||
assert result['temp1'] == 40.0
|
||||
assert result['temp2'] == 40.0
|
||||
assert result['current1'] == 60.0
|
||||
assert result['current2'] == 60.0
|
||||
Reference in New Issue
Block a user