diff --git a/embedded-service/src/hid/command.rs b/embedded-service/src/hid/command.rs index 0798995d6..340aed334 100644 --- a/embedded-service/src/hid/command.rs +++ b/embedded-service/src/hid/command.rs @@ -225,7 +225,14 @@ impl Opcode { #[allow(missing_docs)] pub enum Command<'a> { Reset, - GetReport(ReportType, ReportId), + GetReport { + /// Report type (Input or Feature) + report_type: ReportType, + /// Report ID + report_id: ReportId, + /// Expected payload size in bytes, used to constrain the response read length + expected_payload_size: Option, + }, SetReport(ReportType, ReportId, SharedRef<'a, u8>), GetIdle(ReportId), SetIdle(ReportId, ReportFreq), @@ -239,7 +246,7 @@ impl From> for Opcode { fn from(value: Command<'_>) -> Self { match value { Command::Reset => Opcode::Reset, - Command::GetReport(_, _) => Opcode::GetReport, + Command::GetReport { .. } => Opcode::GetReport, Command::SetReport(_, _, _) => Opcode::SetReport, Command::GetIdle(_) => Opcode::GetIdle, Command::SetIdle(_, _) => Opcode::SetIdle, @@ -293,6 +300,7 @@ impl<'a> Command<'a> { report_type: Option, report_id: Option, data: Option>, + expected_response_size: Option, ) -> Result { if opcode.requires_report_id() && report_id.is_none() { return Err(Error::RequiresReportId); @@ -310,7 +318,11 @@ impl<'a> Command<'a> { Opcode::Reset => Command::Reset, Opcode::GetReport => { if report_type? == ReportType::Input || report_type? == ReportType::Feature { - Command::GetReport(report_type?, report_id.ok_or(Error::RequiresReportId)?) + Command::GetReport { + report_type: report_type?, + report_id: report_id.ok_or(Error::RequiresReportId)?, + expected_payload_size: expected_response_size, + } } else { return Err(Error::InvalidReportType); } @@ -476,7 +488,11 @@ impl<'a> Command<'a> { let (command_len, _) = Self::encode_basic_op(buf, Opcode::Reset)?; len += command_len; } - Command::GetReport(report_type, report_id) => { + Command::GetReport { + report_type, + report_id, + expected_payload_size: _, + } => { let (command_len, buf) = Self::encode_common(buf, Opcode::GetReport, Some(*report_type), *report_id)?; len += command_len; @@ -594,37 +610,57 @@ mod test { let mut test_buffer = [0u8; 7]; // Test input report - let len = Command::GetReport(ReportType::Input, REPORT_ID) - .encode_into_slice(&mut test_buffer, None, None) - .unwrap(); + let len = Command::GetReport { + report_type: ReportType::Input, + report_id: REPORT_ID, + expected_payload_size: None, + } + .encode_into_slice(&mut test_buffer, None, None) + .unwrap(); assert_eq!(&test_buffer[0..len], [0x18, 0x02]); // Test feature report test_buffer.fill(0); - let len = Command::GetReport(ReportType::Feature, REPORT_ID) - .encode_into_slice(&mut test_buffer, None, None) - .unwrap(); + let len = Command::GetReport { + report_type: ReportType::Feature, + report_id: REPORT_ID, + expected_payload_size: None, + } + .encode_into_slice(&mut test_buffer, None, None) + .unwrap(); assert_eq!(&test_buffer[0..len], [0x38, 0x02]); // Test extended report test_buffer.fill(0); - let len = Command::GetReport(ReportType::Input, EXT_REPORT_ID) - .encode_into_slice(&mut test_buffer, None, None) - .unwrap(); + let len = Command::GetReport { + report_type: ReportType::Input, + report_id: EXT_REPORT_ID, + expected_payload_size: None, + } + .encode_into_slice(&mut test_buffer, None, None) + .unwrap(); assert_eq!(&test_buffer[0..len], [0x1f, 0x02, EXTENDED_REPORT_ID]); // Test standard report id with registers test_buffer.fill(0); - let len = Command::GetReport(ReportType::Feature, REPORT_ID) - .encode_into_slice(&mut test_buffer, Some(CMD_REG), Some(DATA_REG)) - .unwrap(); + let len = Command::GetReport { + report_type: ReportType::Feature, + report_id: REPORT_ID, + expected_payload_size: Some(64), + } + .encode_into_slice(&mut test_buffer, Some(CMD_REG), Some(DATA_REG)) + .unwrap(); assert_eq!(&test_buffer[0..len], [0x05, 0x00, 0x38, 0x02, 0x06, 0x00]); // Test extended report id with registers test_buffer.fill(0); - let len = Command::GetReport(ReportType::Input, EXT_REPORT_ID) - .encode_into_slice(&mut test_buffer, Some(CMD_REG), Some(DATA_REG)) - .unwrap(); + let len = Command::GetReport { + report_type: ReportType::Input, + report_id: EXT_REPORT_ID, + expected_payload_size: None, + } + .encode_into_slice(&mut test_buffer, Some(CMD_REG), Some(DATA_REG)) + .unwrap(); assert_eq!( &test_buffer[0..len], [0x05, 0x00, 0x1f, 0x02, EXTENDED_REPORT_ID, 0x06, 0x00] diff --git a/hid-service/src/i2c/device.rs b/hid-service/src/i2c/device.rs index 85c7596f0..4263a36c5 100644 --- a/hid-service/src/i2c/device.rs +++ b/hid-service/src/i2c/device.rs @@ -9,6 +9,8 @@ use embedded_services::{error, hid, info, trace}; use crate::Error; +const LENGTH_PREFIX_SIZE: usize = 2; + /// Timeout configuration for I2C HID device operations. pub struct Config { /// Timeout for descriptor reads and commands. @@ -209,6 +211,20 @@ impl> Device { Error::Hid(hid::Error::Serialize) })?; + let (response_size, constrained) = match cmd { + hid::Command::GetReport { + expected_payload_size: Some(expected_payload_size), + .. + } => (*expected_payload_size as usize + LENGTH_PREFIX_SIZE, true), + _ => (buffer_len, false), + }; + let read_buf = + buf.get_mut(0..response_size) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: response_size, + actual: buffer_len, + })))?; + let mut bus = self.bus.lock().await; with_timeout( @@ -221,7 +237,7 @@ impl> Device { expected: len, actual: temp_w_buf.len(), })))?, - buf, + read_buf, ), ) .await @@ -234,7 +250,32 @@ impl> Device { Error::Bus(e) })?; - Ok(Some(Response::FeatureReport(self.buffer.reference()))) + let returned_len = if constrained { + let actual_frame_len = read_buf + .first_chunk::() + .map(|b| u16::from_le_bytes(*b) as usize) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: LENGTH_PREFIX_SIZE, + actual: read_buf.len(), + })))?; + if actual_frame_len < LENGTH_PREFIX_SIZE || actual_frame_len > response_size { + error!( + "Length mismatch: declared={} expected<={} min={}", + actual_frame_len, response_size, LENGTH_PREFIX_SIZE + ); + return Err(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: response_size, + actual: actual_frame_len, + }))); + } + actual_frame_len + } else { + response_size + }; + + Ok(Some(Response::FeatureReport( + self.buffer.reference().slice(0..returned_len).map_err(Error::Buffer)?, + ))) } else { let len = cmd .encode_into_slice( diff --git a/hid-service/src/i2c/host.rs b/hid-service/src/i2c/host.rs index 973800437..f65b9ff1a 100644 --- a/hid-service/src/i2c/host.rs +++ b/hid-service/src/i2c/host.rs @@ -227,7 +227,7 @@ impl Host { // Create command let report_type = hid::ReportType::try_from(cmd).ok(); - let command = hid::Command::new(cmd, opcode, report_type, report_id, buffer); + let command = hid::Command::new(cmd, opcode, report_type, report_id, buffer, None); match command { Ok(command) => Ok(command), Err(e) => { diff --git a/keyboard-service/src/hid_kb.rs b/keyboard-service/src/hid_kb.rs index f72d9a471..7fb8f0b28 100644 --- a/keyboard-service/src/hid_kb.rs +++ b/keyboard-service/src/hid_kb.rs @@ -168,7 +168,11 @@ pub async fn handle_keyboard(mut hid_kb: T) -> Result { + hid::Command::GetReport { + report_type, + report_id, + expected_payload_size: _, + } => { { let report = hid_kb.get_report(report_type, report_id); let report = HidI2cReport::from_report_slice(report, max_input_len).to_bytes();