initial commit
This commit is contained in:
345
tests/test_protocol.py
Normal file
345
tests/test_protocol.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Tests for communication protocol module.
|
||||
|
||||
Testing command encoding/decoding, CRC calculations,
|
||||
and protocol message structure.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, call
|
||||
import struct
|
||||
from laser_control.protocol import (
|
||||
Protocol,
|
||||
CommandCode,
|
||||
TaskType,
|
||||
Message,
|
||||
Response
|
||||
)
|
||||
from laser_control.exceptions import (
|
||||
CommunicationError,
|
||||
CRCError,
|
||||
ProtocolError
|
||||
)
|
||||
|
||||
|
||||
class TestCRCCalculation:
|
||||
"""Test CRC calculation and verification."""
|
||||
|
||||
def test_crc_calculation(self):
|
||||
"""Test CRC calculation for known data (at least 2 words needed)."""
|
||||
# calculate_crc skips word 0 and XORs words 1..N
|
||||
# So we need at least 4 bytes (2 words)
|
||||
data = b'\x00\x01\x02\x03\x04\x05\x06\x07'
|
||||
crc = Protocol.calculate_crc(data)
|
||||
assert isinstance(crc, int)
|
||||
assert 0 <= crc <= 0xFFFF
|
||||
|
||||
def test_crc_consistency(self):
|
||||
"""Test CRC calculation consistency."""
|
||||
data = b'\x11\x11' + b'\x00' * 26 + b'\xFF\xFF' # 30 bytes
|
||||
crc1 = Protocol.calculate_crc(data)
|
||||
crc2 = Protocol.calculate_crc(data)
|
||||
assert crc1 == crc2
|
||||
|
||||
def test_crc_different_data(self):
|
||||
"""Test CRC differs for different data."""
|
||||
data1 = b'\x00\x00\x01\x02\x03\x04'
|
||||
data2 = b'\x00\x00\x05\x06\x07\x08'
|
||||
crc1 = Protocol.calculate_crc(data1)
|
||||
crc2 = Protocol.calculate_crc(data2)
|
||||
assert crc1 != crc2
|
||||
|
||||
|
||||
class TestMessageEncoding:
|
||||
"""Test message encoding for device commands."""
|
||||
|
||||
def test_encode_decode_enable_command(self):
|
||||
"""Test encoding DECODE_ENABLE command."""
|
||||
message = Protocol.encode_decode_enable(
|
||||
temp1=25.5,
|
||||
temp2=30.0,
|
||||
current1=40.0,
|
||||
current2=35.0,
|
||||
pi_coeff1_p=1,
|
||||
pi_coeff1_i=1,
|
||||
pi_coeff2_p=1,
|
||||
pi_coeff2_i=1,
|
||||
message_id=12345
|
||||
)
|
||||
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) == 30 # Expected message length
|
||||
|
||||
# Check command code (0x1111 stored little-endian via flipfour → 0x11 0x11)
|
||||
assert message[0] == 0x11
|
||||
assert message[1] == 0x11
|
||||
|
||||
def test_encode_task_enable_command(self):
|
||||
"""Test encoding TASK_ENABLE command."""
|
||||
message = Protocol.encode_task_enable(
|
||||
task_type=TaskType.CHANGE_CURRENT_LD1,
|
||||
static_temp1=25.0,
|
||||
static_temp2=30.0,
|
||||
static_current1=40.0,
|
||||
static_current2=35.0,
|
||||
min_value=20.0,
|
||||
max_value=50.0,
|
||||
step=0.5,
|
||||
time_step=50,
|
||||
delay_time=5,
|
||||
message_id=54321
|
||||
)
|
||||
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) > 0
|
||||
|
||||
# Check command code
|
||||
command = struct.unpack('<H', message[0:2])[0]
|
||||
assert command == CommandCode.TASK_ENABLE
|
||||
|
||||
def test_encode_trans_enable_command(self):
|
||||
"""Test encoding TRANS_ENABLE command."""
|
||||
message = Protocol.encode_trans_enable(message_id=11111)
|
||||
|
||||
# encode_trans_enable returns bytearray; ensure it's bytes-like
|
||||
assert len(message) == 2
|
||||
# 0x4444 flipped → bytes 0x44 0x44
|
||||
assert message[0] == 0x44
|
||||
assert message[1] == 0x44
|
||||
|
||||
def test_encode_state_command(self):
|
||||
"""Test encoding STATE command."""
|
||||
message = Protocol.encode_state(message_id=22222)
|
||||
|
||||
assert len(message) == 2
|
||||
# 0x6666 → 0x66 0x66
|
||||
assert message[0] == 0x66
|
||||
assert message[1] == 0x66
|
||||
|
||||
def test_encode_default_enable_command(self):
|
||||
"""Test encoding DEFAULT_ENABLE (reset) command."""
|
||||
message = Protocol.encode_default_enable(message_id=33333)
|
||||
|
||||
assert len(message) == 2
|
||||
# 0x2222 → 0x22 0x22
|
||||
assert message[0] == 0x22
|
||||
assert message[1] == 0x22
|
||||
|
||||
|
||||
class TestResponseDecoding:
|
||||
"""Test response message decoding."""
|
||||
|
||||
def test_decode_valid_response(self):
|
||||
"""Test decoding valid device response using conftest helper."""
|
||||
from tests.conftest import make_valid_response
|
||||
data = make_valid_response(message_id=12345)
|
||||
|
||||
response = Protocol.decode_response(data)
|
||||
|
||||
assert isinstance(response.current1, float)
|
||||
assert isinstance(response.temp1, float)
|
||||
assert isinstance(response.voltage_3v3, float)
|
||||
assert response.message_id == 12345
|
||||
|
||||
def test_decode_response_invalid_crc(self):
|
||||
"""Test decoding response with invalid CRC."""
|
||||
response_data = bytearray(30)
|
||||
struct.pack_into('<H', response_data, 28, 0xFFFF) # Invalid CRC
|
||||
|
||||
with pytest.raises(CRCError):
|
||||
Protocol.decode_response(bytes(response_data))
|
||||
|
||||
def test_decode_response_invalid_length(self):
|
||||
"""Test decoding response with invalid length."""
|
||||
response_data = bytes(20) # Too short (expected 30)
|
||||
|
||||
with pytest.raises(ProtocolError) as exc_info:
|
||||
Protocol.decode_response(response_data)
|
||||
# ProtocolError message includes "bytes"
|
||||
assert "bytes" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_state_response(self):
|
||||
"""Test decoding IDLE state response (2 bytes, flipfour encoded)."""
|
||||
from laser_control.protocol import _flipfour, _int_to_hex4
|
||||
# STATE IDLE = 0x0000; after flipfour it remains 0x0000
|
||||
state_bytes = bytes.fromhex(_flipfour(_int_to_hex4(0x0000)))
|
||||
state = Protocol.decode_state(state_bytes)
|
||||
assert state == 0x0000 # IDLE
|
||||
|
||||
def test_decode_state_error_conditions(self):
|
||||
"""Test decoding various error state codes."""
|
||||
from laser_control.protocol import _flipfour, _int_to_hex4
|
||||
error_codes = [0x0001, 0x0002, 0x0004, 0x0008, 0x0010]
|
||||
|
||||
for code in error_codes:
|
||||
state_bytes = bytes.fromhex(_flipfour(_int_to_hex4(code)))
|
||||
state = Protocol.decode_state(state_bytes)
|
||||
assert state == code
|
||||
|
||||
|
||||
class TestProtocolHelpers:
|
||||
"""Test protocol helper functions."""
|
||||
|
||||
def test_flipfour_byte_order(self):
|
||||
"""Test byte order flipping for little-endian conversion.
|
||||
|
||||
Protocol.flipfour() operates on 16-bit integers (byte-swap within a word).
|
||||
The underlying _flipfour() operates on 4-char hex strings (word-swap).
|
||||
"""
|
||||
from laser_control.protocol import _flipfour
|
||||
# _flipfour swaps two byte pairs: 'aabb' → 'bbaa'
|
||||
assert _flipfour('1234') == '3412'
|
||||
assert _flipfour('abcd') == 'cdab'
|
||||
assert _flipfour('0000') == '0000'
|
||||
assert _flipfour('1111') == '1111'
|
||||
|
||||
# Protocol.flipfour() byte-swaps a 16-bit int
|
||||
assert Protocol.flipfour(0x1234) == 0x3412
|
||||
assert Protocol.flipfour(0x0000) == 0x0000
|
||||
|
||||
def test_pack_float_conversion(self):
|
||||
"""Test float to bytes conversion."""
|
||||
value = 25.5
|
||||
packed = Protocol.pack_float(value)
|
||||
assert len(packed) == 4
|
||||
|
||||
# Unpack and verify
|
||||
unpacked = struct.unpack('<f', packed)[0]
|
||||
assert abs(unpacked - value) < 0.001
|
||||
|
||||
def test_pack_uint16_conversion(self):
|
||||
"""Test uint16 to bytes conversion."""
|
||||
value = 12345
|
||||
packed = Protocol.pack_uint16(value)
|
||||
assert len(packed) == 2
|
||||
|
||||
unpacked = struct.unpack('<H', packed)[0]
|
||||
assert unpacked == value
|
||||
|
||||
|
||||
class TestSerialCommunication:
|
||||
"""Test serial port communication."""
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_send_command(self, mock_serial_class):
|
||||
"""Test sending command over serial."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
|
||||
# Send a command
|
||||
message = b'test_message'
|
||||
protocol.send_raw(message)
|
||||
|
||||
mock_serial.write.assert_called_once_with(message)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_receive_response(self, mock_serial_class):
|
||||
"""Test receiving response from serial."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
# Mock response data
|
||||
response_data = bytes(30)
|
||||
mock_serial.read.return_value = response_data
|
||||
mock_serial.in_waiting = 30
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
|
||||
data = protocol.receive_raw(30)
|
||||
assert data == response_data
|
||||
mock_serial.read.assert_called_once_with(30)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_connection_failure(self, mock_serial_class):
|
||||
"""Test handling connection failure."""
|
||||
mock_serial_class.side_effect = Exception("Port not found")
|
||||
|
||||
protocol = Protocol(port='/dev/invalid')
|
||||
with pytest.raises(CommunicationError) as exc_info:
|
||||
protocol.connect()
|
||||
assert "connect" in str(exc_info.value).lower()
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_auto_port_detection(self, mock_serial_class):
|
||||
"""Test automatic port detection."""
|
||||
with patch('serial.tools.list_ports.comports') as mock_comports:
|
||||
# Mock available ports
|
||||
mock_port = MagicMock()
|
||||
mock_port.device = '/dev/ttyUSB0'
|
||||
mock_comports.return_value = [mock_port]
|
||||
|
||||
protocol = Protocol() # No port specified
|
||||
protocol.connect()
|
||||
|
||||
mock_serial_class.assert_called_with(
|
||||
port='/dev/ttyUSB0',
|
||||
baudrate=115200,
|
||||
timeout=1
|
||||
)
|
||||
|
||||
@patch('serial.Serial')
|
||||
def test_disconnect(self, mock_serial_class):
|
||||
"""Test proper disconnection."""
|
||||
mock_serial = MagicMock()
|
||||
mock_serial_class.return_value = mock_serial
|
||||
|
||||
protocol = Protocol(port='/dev/ttyUSB0')
|
||||
protocol.connect()
|
||||
protocol.disconnect()
|
||||
|
||||
mock_serial.close.assert_called_once()
|
||||
|
||||
|
||||
class TestMessageValidation:
|
||||
"""Test message validation and error handling."""
|
||||
|
||||
def test_invalid_task_type(self):
|
||||
"""Test handling of invalid task type."""
|
||||
with pytest.raises(ValueError):
|
||||
Protocol.encode_task_enable(
|
||||
task_type=99, # Invalid type
|
||||
static_temp1=25.0,
|
||||
static_temp2=30.0,
|
||||
static_current1=40.0,
|
||||
static_current2=35.0,
|
||||
min_value=20.0,
|
||||
max_value=50.0,
|
||||
step=0.5,
|
||||
time_step=50,
|
||||
delay_time=5,
|
||||
message_id=12345
|
||||
)
|
||||
|
||||
def test_message_id_overflow(self):
|
||||
"""encode_decode_enable wraps message_id to 16-bit boundary."""
|
||||
# Message ID > 0xFFFF should wrap (& 0xFFFF in controller)
|
||||
large_id = 0x10000 + 123
|
||||
wrapped = large_id & 0xFFFF
|
||||
message = Protocol.encode_decode_enable(
|
||||
temp1=25.0, temp2=30.0,
|
||||
current1=40.0, current2=35.0,
|
||||
pi_coeff1_p=1, pi_coeff1_i=1,
|
||||
pi_coeff2_p=1, pi_coeff2_i=1,
|
||||
message_id=wrapped,
|
||||
)
|
||||
assert isinstance(message, bytes)
|
||||
assert len(message) == 30
|
||||
|
||||
def test_negative_values_handling(self):
|
||||
"""Test handling of negative values where not allowed."""
|
||||
with pytest.raises(ValueError):
|
||||
Protocol.encode_decode_enable(
|
||||
temp1=25.0,
|
||||
temp2=30.0,
|
||||
current1=-10.0, # Negative current
|
||||
current2=35.0,
|
||||
pi_coeff1_p=1.0,
|
||||
pi_coeff1_i=0.5,
|
||||
pi_coeff2_p=1.0,
|
||||
pi_coeff2_i=0.5,
|
||||
message_id=12345
|
||||
)
|
||||
Reference in New Issue
Block a user