From 4c8fecc81704d7b8815edd3e5d1000930327b867 Mon Sep 17 00:00:00 2001 From: John Schock Date: Mon, 4 May 2026 13:36:34 -0700 Subject: [PATCH 1/3] Replace placeholder with uefi_hid component Remove the placeholder crate and add the uefi_hid component which consumes the HidIo protocol and produces standard UEFI input protocols (SimpleTextInput, SimpleTextInputEx, AbsolutePointer) for keyboard and pointer HID devices. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- Cargo.toml | 11 +- placeholder/Cargo.toml | 12 - placeholder/src/lib.rs | 8 - uefi_hid/Cargo.toml | 28 + uefi_hid/README.md | 98 + uefi_hid/src/hid.rs | 207 ++ uefi_hid/src/hid_io/mod.rs | 547 +++++ uefi_hid/src/hid_io/protocol.rs | 93 + uefi_hid/src/keyboard/key_queue.rs | 1436 +++++++++++++ uefi_hid/src/keyboard/layout.rs | 726 +++++++ uefi_hid/src/keyboard/mod.rs | 2190 ++++++++++++++++++++ uefi_hid/src/keyboard/simple_text_in.rs | 576 +++++ uefi_hid/src/keyboard/simple_text_in_ex.rs | 826 ++++++++ uefi_hid/src/lib.rs | 180 ++ uefi_hid/src/pointer/absolute_pointer.rs | 682 ++++++ uefi_hid/src/pointer/mod.rs | 794 +++++++ uefi_hid/src/test_stubs.rs | 71 + 17 files changed, 8463 insertions(+), 22 deletions(-) delete mode 100644 placeholder/Cargo.toml delete mode 100644 placeholder/src/lib.rs create mode 100644 uefi_hid/Cargo.toml create mode 100644 uefi_hid/README.md create mode 100644 uefi_hid/src/hid.rs create mode 100644 uefi_hid/src/hid_io/mod.rs create mode 100644 uefi_hid/src/hid_io/protocol.rs create mode 100644 uefi_hid/src/keyboard/key_queue.rs create mode 100644 uefi_hid/src/keyboard/layout.rs create mode 100644 uefi_hid/src/keyboard/mod.rs create mode 100644 uefi_hid/src/keyboard/simple_text_in.rs create mode 100644 uefi_hid/src/keyboard/simple_text_in_ex.rs create mode 100644 uefi_hid/src/lib.rs create mode 100644 uefi_hid/src/pointer/absolute_pointer.rs create mode 100644 uefi_hid/src/pointer/mod.rs create mode 100644 uefi_hid/src/test_stubs.rs diff --git a/Cargo.toml b/Cargo.toml index 66d96d2..8918ae2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "3" -members = ["placeholder"] +members = ["uefi_hid"] [workspace.package] version = "0.0.1" @@ -11,7 +11,14 @@ description = "Patina components" repository = "https://github.com/OpenDevicePartnership/patina-components" [workspace.dependencies] -patina = { version = "21"} +hidparser = { version = "1" } +log = { version = "0.4", default-features = false } +mockall = { version = "0.13.0" } +num_enum = { version = "0.7", default-features = false } +# git path is a temporary placeholder until the appropriate protocol branches are checked in upstream. +patina = { git = "https://github.com/joschock/patina", branch = "usb_hid_protocol_defs" } +r-efi = { version = "5.0.0", default-features = false } +scroll = { version = "0.13", default-features = false, features = ["derive"] } [workspace.lints.clippy] undocumented_unsafe_blocks = "warn" diff --git a/placeholder/Cargo.toml b/placeholder/Cargo.toml deleted file mode 100644 index d932bfd..0000000 --- a/placeholder/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "placeholder" -description = "Placeholder crate to be deleted once components are added" -version.workspace = true -license.workspace = true -edition.workspace = true -rust-version.workspace = true -repository.workspace = true -publish = false - -[features] -doc = [] diff --git a/placeholder/src/lib.rs b/placeholder/src/lib.rs deleted file mode 100644 index 3ccd8ce..0000000 --- a/placeholder/src/lib.rs +++ /dev/null @@ -1,8 +0,0 @@ -// Placeholder crate to be deleted once components are added -#![no_std] - -#[cfg(test)] -mod tests { - #[test] - fn placeholder() {} -} diff --git a/uefi_hid/Cargo.toml b/uefi_hid/Cargo.toml new file mode 100644 index 0000000..55219e9 --- /dev/null +++ b/uefi_hid/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "uefi_hid" +description = "UEFI HID (Human Interface Device) support as a Patina component." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +repository.workspace = true +publish = false + +[lints] +workspace = true + +[features] +default = ["ctrl-alt-del"] +ctrl-alt-del = [] + +[dependencies] +hidparser = { workspace = true } +log = { workspace = true } +num_enum = { workspace = true } +patina = { workspace = true } +r-efi = { workspace = true } +scroll = { workspace = true } + +[dev-dependencies] +mockall = { workspace = true } +patina = { workspace = true, features = ["mockall"] } diff --git a/uefi_hid/README.md b/uefi_hid/README.md new file mode 100644 index 0000000..8e5aa07 --- /dev/null +++ b/uefi_hid/README.md @@ -0,0 +1,98 @@ + +# UEFI HID + +## Overview + +This Patina component provides Human Interface Device (HID) support for UEFI. It consumes the +[HidIo](https://github.com/microsoft/mu_plus/blob/release/202502/HidPkg/Include/Protocol/HidIo.h) protocol and +produces standard UEFI input protocols for keyboard and pointer HID devices: + +- **SimpleTextInput** (`EFI_SIMPLE_TEXT_INPUT_PROTOCOL`) +- **SimpleTextInputEx** (`EFI_SIMPLE_TEXT_INPUT_EX_PROTOCOL`) +- **AbsolutePointer** (`EFI_ABSOLUTE_POINTER_PROTOCOL`) + +## Architecture + +The component installs a UEFI Driver Binding that manages HID device instances. When the driver is +started on a controller that exposes the HidIo protocol, it: + +1. Opens the HidIo protocol on the controller. +2. Parses the HID report descriptor to identify keyboard and pointer usages. +3. Creates the appropriate input protocol handlers (keyboard and/or pointer). +4. Installs the corresponding UEFI input protocols on the controller handle: + - **SimpleTextInput** and **SimpleTextInputEx** for keyboard devices. + - **AbsolutePointer** for pointer and touch devices. +5. Registers a report callback to receive asynchronous HID input reports. + +### Report Processing + +Incoming HID reports are buffered through a `ReportQueue` rather than being processed inline from the +HidIo producer's callback. This ensures all report processing occurs at a consistent `TPL_CALLBACK` +regardless of the producer's calling TPL: + +1. **Report callback** (any TPL): pushes raw HID report bytes onto a queue and signals a `TPL_CALLBACK` event. +2. **Event handler** (`TPL_CALLBACK`): dequeues all pending reports and dispatches them to receivers. + +## Modules + +| Module | Description | +|---|---| +| `hid` | Driver binding implementation that manages HID instances on controllers. | +| `hid_io` | HidIo protocol FFI bindings, report queue, and receiver traits (`HidIo`, `HidReportReceiver`). | +| `keyboard` | Keyboard HID handler — translates HID key reports into UEFI keystrokes using HII keyboard layouts, and produces SimpleTextInput / SimpleTextInputEx protocol interfaces. | +| `pointer` | Pointer HID handler — translates HID pointer/touch reports into absolute pointer state and produces the AbsolutePointer protocol interface. | + +## Features + +| Feature | Default | Description | +|---|---|---| +| `ctrl-alt-del` | ✅ | Enables Ctrl+Alt+Delete to trigger a system reset via UEFI Runtime Services. | + +## Dependencies + +Key crate dependencies (see `Cargo.toml` for the full list): + +- [`hidparser`](https://crates.io/crates/hidparser) — HID report descriptor parsing. +- [`patina`](https://crates.io/crates/patina) — Patina component SDK (boot services, driver binding, protocol interfaces). +- [`r-efi`](https://crates.io/crates/r-efi) — Rust UEFI type definitions. + +## Platform Integration + +To include `uefi_hid` in a Patina binary, add the crate as a dependency and register the component +in the platform's `ComponentInfo` implementation. + +1. Add the dependency to the binary crate's `Cargo.toml`: + + ```toml + [dependencies] + uefi_hid = { version = "20" } + ``` + +2. Register the component in the `components` function: + + ```rust + impl ComponentInfo for MyPlatform { + fn components(mut add: Add) { + // ...other components... + add.component(uefi_hid::UefiHidComponent); + } + } + ``` + +The driver binding will automatically attach to any controller that exposes the HidIo protocol. A +HidIo producer (e.g. a USB HID driver) must be present in the platform firmware for this component +to be functional. + +The `ctrl-alt-del` feature is enabled by default. To disable it: + +```toml +uefi_hid = { path = "../components/uefi_hid", default-features = false } +``` + +## Testing + +Unit tests use `mockall` and `patina`'s mock boot services: + +```sh +cargo test -p uefi_hid +``` diff --git a/uefi_hid/src/hid.rs b/uefi_hid/src/hid.rs new file mode 100644 index 0000000..16bd7b8 --- /dev/null +++ b/uefi_hid/src/hid.rs @@ -0,0 +1,207 @@ +//! HID driver binding implementation. +//! +//! The [`HidDriver`] implements [`patina::driver_binding::DriverBinding`] to +//! manage HID instances on controllers that support the HidIo protocol. +//! +//! When started on a controller, it creates a [`crate::hid_io::UefiHidIo`] +//! instance, instantiates keyboard and pointer receivers, initializes them +//! via `&dyn HidIo`, and starts report reception through the device. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::{boxed::Box, vec::Vec}; +use core::{ffi::c_void, ptr::NonNull}; + +use r_efi::{efi, protocols::device_path::Protocol as EfiDevicePathProtocol}; + +use patina::{ + BinaryGuid, boot_services::BootServices, driver_binding::DriverBinding, uefi_protocol::ProtocolInterface, +}; + +use crate::{ + hid_io::{HidReportReceiver, ReceiverFactory, UefiHidIo}, + keyboard::KeyboardHidHandler, + pointer::PointerHidHandler, +}; + +/// Per-controller context installed as a private protocol to track the HID instance. +struct HidInstance { + _device: UefiHidIo, +} + +// SAFETY: HidInstance GUID uniquely identifies this private protocol. +unsafe impl ProtocolInterface for HidInstance { + const PROTOCOL_GUID: BinaryGuid = BinaryGuid::from_string("0a87cfdb-c482-48e4-ade7-d9f99620e169"); +} + +impl HidInstance { + // Creates a new HidInstance wrapping the given device. + fn new(device: UefiHidIo) -> Self { + Self { _device: device } + } +} + +/// HID driver that implements [`DriverBinding`]. +/// +/// Creates and manages HID instances on controllers that expose the HidIo +/// protocol. Directly constructs [`UefiHidIo`] devices and creates keyboard +/// and pointer handlers as report receivers. +pub struct HidDriver { + boot_services: &'static T, + agent: efi::Handle, +} + +impl HidDriver { + /// Creates a new HID driver bound to the given agent handle. + /// + /// `agent` is the image handle for this driver, used for protocol operations. + pub fn new(boot_services: &'static T, agent: efi::Handle) -> Self { + Self { boot_services, agent } + } + + // Creates factory functions for HID report receivers. + fn new_receiver_factories(&self) -> Vec { + let bs = self.boot_services; + alloc::vec![ + Box::new(move |controller, hid_io| { + Ok(PointerHidHandler::new(bs, controller, hid_io)? as Box) + }), + Box::new(move |controller, hid_io| { + Ok(KeyboardHidHandler::new(bs, controller, hid_io)? as Box) + }), + ] + } +} + +// controller is an efi::Handle (raw pointer) from the DriverBinding trait. efi::Handle is defined as *mut c_void, but +// essentially an opaque type that happens to be a pointer. The unsafe deref warning will be resolved once latest +// r_efi with unsafe API is integrated. +#[allow(clippy::not_unsafe_ptr_arg_deref)] +// This is a wrapper trait to abstract driver binding for FFI; core logic is all tested elsewhere. +#[coverage(off)] +impl DriverBinding for HidDriver { + /// Tests if the given controller supports the HidIo protocol. + fn driver_binding_supported( + &self, + _boot_services: &'static U, + controller: efi::Handle, + _remaining_device_path: Option>, + ) -> Result { + Ok(UefiHidIo::supports(self.boot_services, self.agent, controller)) + } + + /// Starts HID support for the given controller. + /// + /// Creates a UefiHidIo device with keyboard and pointer receivers, and + /// installs a private protocol to track the instance context. + fn driver_binding_start( + &mut self, + boot_services: &'static U, + controller: efi::Handle, + _remaining_device_path: Option>, + ) -> Result<(), efi::Status> { + log::trace!("driver_binding_start: starting HID on controller {:?}", controller); + let device = UefiHidIo::new(self.boot_services, self.agent, controller, self.new_receiver_factories())?; + + let hid_instance = Box::new(HidInstance::new(device)); + boot_services.install_protocol_interface(Some(controller), hid_instance)?; + + Ok(()) + } + + /// Stops HID support for the given controller. + /// + /// Retrieves and drops the HID instance context, reclaiming all resources. + fn driver_binding_stop( + &mut self, + boot_services: &'static U, + controller: efi::Handle, + _number_of_children: usize, + _child_handle_buffer: Option>, + ) -> Result<(), efi::Status> { + log::trace!("driver_binding_stop: stopping HID on controller {:?}", controller); + // SAFETY: The private protocol was installed on this controller by start. + let hid_instance = unsafe { + boot_services.open_protocol_unchecked( + controller, + &HidInstance::::PROTOCOL_GUID, + self.agent, + controller, + efi::OPEN_PROTOCOL_GET_PROTOCOL, + ) + }? as *mut HidInstance; + + // SAFETY: Uninstalling our private protocol interface. + if let Err(status) = unsafe { + boot_services.uninstall_protocol_interface_unchecked( + controller, + &HidInstance::::PROTOCOL_GUID, + hid_instance as *mut c_void, + ) + } { + log::error!("hid::driver_binding_stop: failed to uninstall protocol: {status:x?}"); + return Err(status); + } + + // SAFETY: hid_instance was created via Box::into_raw (through install_protocol_interface) in start. + drop(unsafe { Box::from_raw(hid_instance) }); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::HidDriver; + use crate::hid_io::protocol::HidIoProtocol; + use patina::{boot_services::MockBootServices, driver_binding::DriverBinding}; + use r_efi::efi; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| patina::boot_services::tpl::Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked mock for test use with 'static lifetime requirement. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + #[test] + fn supported_returns_true_when_hid_io_present() { + let boot_services = mock_boot_services(); + boot_services + .expect_open_protocol::() + .withf_st(|controller, _, _, _| *controller == 0x3 as efi::Handle) + .returning(|_, _, _, _| { + // SAFETY: supports() never dereferences the protocol; zeroed is fine. + Ok(crate::test_stubs::hid_io_stub()) + }); + boot_services.expect_open_protocol::().returning(|_, _, _, _| Err(efi::Status::NOT_FOUND)); + + let hid_driver = HidDriver::new(boot_services, 0x1 as efi::Handle); + + assert_eq!(hid_driver.driver_binding_supported(boot_services, 0x2 as efi::Handle, None), Ok(false)); + assert_eq!(hid_driver.driver_binding_supported(boot_services, 0x3 as efi::Handle, None), Ok(true)); + } + + #[test] + fn start_returns_unsupported_with_no_receivers() { + let boot_services = mock_boot_services(); + // open_protocol succeeds (protocol exists) but receiver factories fail → UNSUPPORTED. + boot_services + .expect_open_protocol::() + .returning(|_, _, _, _| Ok(crate::test_stubs::hid_io_stub())); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + + let mut hid_driver = HidDriver::new(boot_services, 0x1 as efi::Handle); + + // All receiver factories fail (stub protocol has no real descriptor) → UNSUPPORTED. + assert_eq!( + hid_driver.driver_binding_start(boot_services, 0x2 as efi::Handle, None), + Err(efi::Status::UNSUPPORTED) + ); + } +} diff --git a/uefi_hid/src/hid_io/mod.rs b/uefi_hid/src/hid_io/mod.rs new file mode 100644 index 0000000..a5f84f8 --- /dev/null +++ b/uefi_hid/src/hid_io/mod.rs @@ -0,0 +1,547 @@ +//! HidIo Support. +//! +//! Abstractions for interacting with HID devices via the UEFI HidIo protocol. +//! +//! ## Architecture +//! +//! Incoming reports are buffered through a `ReportQueue` rather than being +//! processed inline from the HidIo producer's callback. This ensures all report +//! processing occurs at a consistent TPL_CALLBACK regardless of the producer's +//! calling TPL: +//! +//! 1. **Report callback** (any TPL): locks [`TplMutex`] at TPL_NOTIFY, pushes +//! raw bytes onto a [`VecDeque`], signals a TPL_CALLBACK event. +//! 2. **Event handler** (TPL_CALLBACK): locks the same mutex, dequeues all +//! pending reports, then dispatches them to all receivers. +//! +//! ## Traits +//! +//! - [`HidIo`] — narrow, receiver-facing interface for reading descriptors and +//! sending output reports. +//! - [`HidReportReceiver`] — interface for logic that receives HID reports. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +pub mod protocol; + +use alloc::{boxed::Box, collections::VecDeque, vec::Vec}; +use core::{cell::Cell, ffi::c_void, mem::ManuallyDrop, slice::from_raw_parts}; + +#[cfg(test)] +use mockall::automock; +use r_efi::efi; + +use self::protocol::{HID_IO_PROTOCOL_GUID, HidIoProtocol}; +use hidparser::ReportDescriptor; +use patina::{ + boot_services::{BootServices, event::EventType, tpl::Tpl}, + tpl_mutex::TplMutex, +}; + +/// Interface for logic that receives HID reports from a device. +#[cfg_attr(test, automock)] +pub trait HidReportReceiver { + /// Passes an incoming report to the receiver for processing. + fn receive_report(&mut self, report: &[u8], hid_io: &dyn HidIo); +} + +/// Factory function that creates a fully initialized HidReportReceiver for a controller. +/// +/// Returns the receiver on success, or an error status if the device is not supported +/// by this receiver type (e.g. wrong report descriptor). +pub type ReceiverFactory = Box Result, efi::Status>>; + +/// Receiver-facing interface for interacting with a HID device. +/// +/// Provides read-only access to the report descriptor and the ability to send +/// output reports (e.g. keyboard LED state). Does not manage device lifecycle. +#[cfg_attr(test, automock)] +pub trait HidIo { + /// Returns the parsed report descriptor for the device. + fn get_report_descriptor(&self) -> Result; + /// Sends an output report to the device. + fn set_output_report(&self, id: Option, report: &[u8]) -> Result<(), efi::Status>; +} + +// -- ReportQueue ---------------------------------------------------------- + +/// State for active report reception. +/// +/// Heap-allocated via [`Box`] for address stability. Shared between the HidIo +/// report callback (pushes reports) and the TPL_CALLBACK event handler +/// (processes them and fans out to all receivers). +struct ReportQueue { + /// Pending report bytes, protected at TPL_NOTIFY. + queue: TplMutex>, T>, + /// Boot services reference for signaling the drain event. + boot_services: &'static T, + /// TPL_CALLBACK event that triggers [`UefiHidIo::process_queued_reports`]. + process_queue_event: efi::Event, + /// The receivers that process drained reports. + receivers: Vec>, + /// Raw pointer to the HidIo protocol for passing `&dyn HidIo` to receivers. + hid_io: *const HidIoProtocol, + /// Set to `true` during drop if unregister_report_callback fails. Checked by + /// `report_callback` to bail early, preventing access to freed resources. + poisoned: Cell, +} + +// -- UefiHidIo -------------------------------------------------------------- + +/// HID device using UEFI boot services to interact with HidIo controllers. +/// +/// Reports are buffered in a [`TplMutex`]-protected queue and processed at +/// TPL_CALLBACK via a UEFI event, rather than inline from the producer's +/// callback. The device is fully active from construction; teardown happens on +/// drop. +pub struct UefiHidIo { + hid_io: *const HidIoProtocol, + boot_services: &'static T, + controller: efi::Handle, + agent: efi::Handle, + report_queue: ManuallyDrop>>, +} + +impl UefiHidIo { + /// Returns true if the given controller supports the HidIo protocol. + #[allow(clippy::not_unsafe_ptr_arg_deref)] // efi::Handle is an opaque *mut c_void, not dereferenced + pub fn supports(boot_services: &'static T, agent: efi::Handle, controller: efi::Handle) -> bool { + // SAFETY: HidIoProtocol layout matches the HidIo GUID; ProtocolInterface is correctly implemented. + // We only care that the protocol exists, we do not use the resulting reference. + unsafe { + boot_services + .open_protocol::(controller, agent, controller, efi::OPEN_PROTOCOL_GET_PROTOCOL) + .is_ok() + } + } + + /// Creates a new UefiHidIo bound to the given controller. + /// + /// Opens the device `by_driver`, runs each factory to create receivers (keeping + /// only those that succeed), creates the report queue, and registers the + /// protocol callback. Returns `UNSUPPORTED` if no receivers initialize + /// successfully. The device is released on drop. + #[allow(clippy::not_unsafe_ptr_arg_deref)] // efi::Handle is an opaque *mut c_void, not dereferenced + pub fn new( + boot_services: &'static T, + agent: efi::Handle, + controller: efi::Handle, + receiver_factories: Vec, + ) -> Result { + // SAFETY: HidIoProtocol layout matches the HidIo GUID; ProtocolInterface is correctly implemented. + // Open BY_DRIVER to ensure exclusive access to the underlying protocol. + let hid_io_raw = unsafe { + boot_services.open_protocol::(controller, agent, controller, efi::OPEN_PROTOCOL_BY_DRIVER) + }?; + + let hid_io = hid_io_raw as *const HidIoProtocol; + + // SAFETY: hid_io points to a protocol opened BY_DRIVER above; valid for our lifetime. + let hid_io_ref = unsafe { &*hid_io }; + + // Create receivers from factories, keeping only those that succeed. + let receivers: Vec<_> = + receiver_factories.into_iter().filter_map(|factory| factory(controller, hid_io_ref).ok()).collect(); + + if receivers.is_empty() { + // No receivers initialized — close the protocol and report unsupported. + log::trace!("UefiHidIo::new: no receivers initialized, returning UNSUPPORTED"); + if let Err(status) = + boot_services.close_protocol(controller, HID_IO_PROTOCOL_GUID.as_efi_guid(), agent, controller) + { + log::error!("Unexpected error closing HidIo protocol: {status:x?}"); + } + return Err(efi::Status::UNSUPPORTED); + } + + // Build the report queue. + let mut report_queue = Box::new(ReportQueue { + queue: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, VecDeque::new()), + boot_services, + process_queue_event: core::ptr::null_mut(), + receivers, + hid_io, + poisoned: Cell::new(false), + }); + + let queue_ptr: *mut ReportQueue = &mut *report_queue; + + // SAFETY: queue_ptr is heap-allocated via Box and remains valid for the event's lifetime + // (the event is closed before the Box is dropped). + let process_queue_event = unsafe { + boot_services.create_event_unchecked::>( + EventType::NOTIFY_SIGNAL, + Tpl::CALLBACK, + Some(Self::process_queued_reports), + queue_ptr, + ) + } + .inspect_err(|_| { + let _ = boot_services.close_protocol(controller, HID_IO_PROTOCOL_GUID.as_efi_guid(), agent, controller); + })?; + + report_queue.process_queue_event = process_queue_event; + + // Register the HidIo callback. This must happen after process_queue_event is stored so the + // report_callback can safely signal it. This also configures HidIo to start sending reports. + // SAFETY: hid_io points to a protocol opened BY_DRIVER above; valid for our lifetime. + let hid_io_protocol = unsafe { &*hid_io }; + // SAFETY: hid_io points to a valid protocol; callback and context are valid. + match unsafe { + (hid_io_protocol.register_report_callback)(hid_io, Self::report_callback, queue_ptr as *mut c_void) + } { + efi::Status::SUCCESS => (), + err => { + let _ = boot_services.close_event(process_queue_event); + let _ = boot_services.close_protocol(controller, HID_IO_PROTOCOL_GUID.as_efi_guid(), agent, controller); + return Err(err); + } + } + + log::trace!( + "UefiHidIo::new: initialized with {:?} receivers on controller {:?}", + report_queue.receivers.len(), + controller + ); + + Ok(Self { hid_io, boot_services, controller, agent, report_queue: ManuallyDrop::new(report_queue) }) + } + + /// HidIo protocol report callback. Enqueues the report and signals the drain event. This runs at + /// whatever TPL the HidIo instance generating the report runs at, which may vary by controller and report type. + /// + /// # Safety + /// + /// `context` must be a valid pointer to the [`ReportQueue`] that was passed during + /// `register_report_callback`. `report_buffer` must be valid for `report_buffer_size` bytes. + unsafe extern "efiapi" fn report_callback( + report_buffer_size: u16, + report_buffer: *mut c_void, + context: *mut c_void, + ) { + // SAFETY: context is a valid *mut ReportQueue set during register_report_callback in new(). + // The HidIo protocol callback signature requires *mut c_void, but this function only takes a + // shared reference; no mutable aliasing occurs because mutation goes through the TplMutex on the queue. + let report_queue = + unsafe { (context as *const ReportQueue).as_ref() }.expect("null report_callback context"); + if report_queue.poisoned.get() { + return; + } + // SAFETY: report_buffer is valid for report_buffer_size bytes per HidIo protocol contract. + let report = unsafe { from_raw_parts(report_buffer as *const u8, report_buffer_size as usize) }; + log::trace!("report_callback: received report, size: {:?}", report_buffer_size); + { + let mut queue = report_queue.queue.lock(); + queue.push_back(report.to_vec()); + } + let _ = report_queue.boot_services.signal_event(report_queue.process_queue_event); + } + + /// Event handler that dequeues pending reports and dispatches them to all receivers. + /// This runs at TPL_CALLBACK, ensuring that processing for events happens at the lowest possible TPL regardless of + /// the TPL the report was received at. + extern "efiapi" fn process_queued_reports(_event: efi::Event, context: *mut ReportQueue) { + // SAFETY: context is a valid *mut ReportQueue set during event creation. + let report_queue = unsafe { context.as_mut() }.expect("null process_queued_reports context"); + let reports: Vec> = { + let mut queue = report_queue.queue.lock(); + queue.drain(..).collect() + }; + + log::trace!("process_queued_reports: draining {:?} queued reports", reports.len()); + + // Any additional events that are queued while processing will also trigger a signal on this event, + // which will re-queue this function. So we don't need to worry about missing reports. + + // SAFETY: hid_io points to a protocol opened BY_DRIVER; valid for device lifetime. + let hid_io = unsafe { &*report_queue.hid_io }; + for report in &reports { + for receiver in &mut report_queue.receivers { + receiver.receive_report(report, hid_io); + } + } + } +} + +impl Drop for UefiHidIo { + fn drop(&mut self) { + // SAFETY: hid_io points to a protocol opened BY_DRIVER in new; valid until this drop. + let hid_io_protocol = unsafe { &*self.hid_io }; + // SAFETY: hid_io points to a valid protocol opened BY_DRIVER. + let unregister_status = + unsafe { (hid_io_protocol.unregister_report_callback)(self.hid_io, Self::report_callback) }; + + if unregister_status != efi::Status::SUCCESS { + // Callback may still fire — poison the queue so report_callback becomes a no-op, + // and leak the Box so the memory it dereferences remains valid. + self.report_queue.poisoned.set(true); + log::error!( + "Failed to unregister report callback: {unregister_status:x?}. \ + Leaking ReportQueue to prevent use-after-free." + ); + } + + let _ = self.boot_services.close_event(self.report_queue.process_queue_event); + if let Err(status) = self.boot_services.close_protocol( + self.controller, + HID_IO_PROTOCOL_GUID.as_efi_guid(), + self.agent, + self.controller, + ) { + log::error!("Unexpected error closing HidIo protocol: {status:x?}"); + } + + if unregister_status == efi::Status::SUCCESS { + // SAFETY: No more callbacks can fire; safe to free the report queue. + unsafe { ManuallyDrop::drop(&mut self.report_queue) }; + } + } +} + +#[cfg(test)] +mod test { + use alloc::{boxed::Box, vec}; + use core::{ + ffi::c_void, + sync::atomic::{AtomicPtr, Ordering}, + }; + + use r_efi::efi; + + use super::protocol::{HidIoReportCallback, HidReportType}; + + use patina::boot_services::MockBootServices; + + use super::*; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked mock for test use with 'static lifetime requirement. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + fn mock_hid_io_protocol() -> &'static mut HidIoProtocol { + crate::test_stubs::hid_io_stub() + } + + /// A factory that always succeeds, returning a no-op receiver. + fn ok_factory() -> ReceiverFactory { + Box::new(|_, _| Ok(Box::new(MockHidReportReceiver::new()) as Box)) + } + + /// A factory that always fails. + fn err_factory() -> ReceiverFactory { + Box::new(|_, _| Err(efi::Status::UNSUPPORTED)) + } + + /// Creates a UefiHidIo with a mock protocol and a single no-op receiver. + /// `setup` is called on the receiver mock before it's boxed. + fn make_hid_device( + setup: impl FnOnce(&mut MockBootServices, &mut MockHidReportReceiver), + ) -> UefiHidIo { + let boot_services = mock_boot_services(); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + boot_services.expect_close_event().returning(|_| Ok(())); + boot_services + .expect_create_event_unchecked::>() + .returning(|_, _, _, _| Ok(0xE0E as efi::Event)); + + let mut receiver = MockHidReportReceiver::new(); + setup(boot_services, &mut receiver); + + let hid_io_protocol = mock_hid_io_protocol(); + let hid_io = hid_io_protocol as *const HidIoProtocol; + + let receivers: Vec> = vec![Box::new(receiver)]; + let mut report_queue = Box::new(ReportQueue { + poisoned: Cell::new(false), + queue: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, VecDeque::new()), + boot_services, + process_queue_event: core::ptr::null_mut(), + receivers, + hid_io, + }); + report_queue.process_queue_event = 0xE0E as efi::Event; + + UefiHidIo { + hid_io, + boot_services, + controller: core::ptr::null_mut(), + agent: core::ptr::null_mut(), + report_queue: ManuallyDrop::new(report_queue), + } + } + + /// Returns a mutable pointer to the mock protocol for setting up test stubs. + /// + /// # Safety + /// + /// Only safe in tests where the protocol was leaked from `mock_hid_io_protocol()`. + /// The caller must ensure no aliasing references exist. + unsafe fn mock_protocol(hid_io: *const HidIoProtocol) -> &'static mut HidIoProtocol { + // SAFETY: hid_io was leaked from mock_hid_io_protocol and is valid for the test lifetime. + unsafe { &mut *(hid_io as *mut HidIoProtocol) } + } + + #[test] + fn new_returns_error_when_open_protocol_fails() { + let boot_services = mock_boot_services(); + boot_services.expect_open_protocol::().returning(|_, _, _, _| Err(efi::Status::NOT_FOUND)); + + let result = UefiHidIo::new(boot_services, core::ptr::null_mut(), core::ptr::null_mut(), vec![ok_factory()]); + assert_eq!(result.err(), Some(efi::Status::NOT_FOUND)); + } + + #[test] + fn new_returns_unsupported_when_no_receivers_initialize() { + let boot_services = mock_boot_services(); + boot_services.expect_open_protocol::().returning(|_, _, _, _| Ok(mock_hid_io_protocol())); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + + let result = UefiHidIo::new(boot_services, core::ptr::null_mut(), core::ptr::null_mut(), vec![err_factory()]); + assert_eq!(result.err(), Some(efi::Status::UNSUPPORTED)); + } + + #[test] + fn new_returns_unsupported_when_receivers_vec_is_empty() { + let boot_services = mock_boot_services(); + boot_services.expect_open_protocol::().returning(|_, _, _, _| Ok(mock_hid_io_protocol())); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + + let result = UefiHidIo::new(boot_services, core::ptr::null_mut(), core::ptr::null_mut(), vec![]); + assert_eq!(result.err(), Some(efi::Status::UNSUPPORTED)); + } + + #[test] + fn hid_io_set_output_report_calls_protocol() { + let device = make_hid_device(|_, _| {}); + + extern "efiapi" fn mock_set_report( + _this: *const HidIoProtocol, + report_id: u8, + report_type: HidReportType, + report_buffer_size: usize, + report_buffer: *mut c_void, + ) -> efi::Status { + assert_eq!(report_id, 5); + assert_eq!(report_type, HidReportType::OutputReport); + assert_eq!(report_buffer_size, 4); + // SAFETY: report_buffer is valid for report_buffer_size bytes, as guaranteed by the HID I/O protocol contract. + let report = unsafe { core::slice::from_raw_parts(report_buffer as *const u8, report_buffer_size) }; + assert_eq!(report, [0x00, 0x01, 0x02, 0x03]); + efi::Status::SUCCESS + } + // SAFETY: device.hid_io was leaked from mock_hid_io_protocol; no aliasing references exist. + unsafe { mock_protocol(device.hid_io) }.set_report = mock_set_report; + + // SAFETY: device.hid_io is a valid pointer leaked from mock_hid_io_protocol. + assert_eq!(unsafe { &*device.hid_io }.set_output_report(Some(5), &[0x00, 0x01, 0x02, 0x03]), Ok(())); + } + + #[test] + fn report_callback_enqueues_and_process_queued_reports_delivers_to_all() { + static CALLBACK_CONTEXT: AtomicPtr = AtomicPtr::new(core::ptr::null_mut()); + + let boot_services = mock_boot_services(); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + boot_services.expect_close_event().returning(|_| Ok(())); + boot_services + .expect_create_event_unchecked::>() + .returning(|_, _, _, _| Ok(0xE0E as efi::Event)); + boot_services.expect_signal_event().returning(|_| Ok(())); + + let hid_io_protocol = mock_hid_io_protocol(); + + extern "efiapi" fn register_cb( + _this: *const HidIoProtocol, + _callback: HidIoReportCallback, + context: *mut c_void, + ) -> efi::Status { + CALLBACK_CONTEXT.store(context, Ordering::Relaxed); + efi::Status::SUCCESS + } + hid_io_protocol.register_report_callback = register_cb; + + let hid_io = hid_io_protocol as *const HidIoProtocol; + + let mut r1 = MockHidReportReceiver::new(); + r1.expect_receive_report().withf(|report, _| report == [0x10u8, 0x20, 0x30]).times(1).return_const(()); + let mut r2 = MockHidReportReceiver::new(); + r2.expect_receive_report().withf(|report, _| report == [0x10u8, 0x20, 0x30]).times(1).return_const(()); + + let receivers: Vec> = vec![Box::new(r1), Box::new(r2)]; + let mut report_queue = Box::new(ReportQueue { + poisoned: Cell::new(false), + queue: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, VecDeque::new()), + boot_services, + process_queue_event: core::ptr::null_mut(), + receivers, + hid_io, + }); + report_queue.process_queue_event = 0xE0E as efi::Event; + + let mut device = UefiHidIo { + hid_io, + boot_services, + controller: core::ptr::null_mut(), + agent: core::ptr::null_mut(), + report_queue: ManuallyDrop::new(report_queue), + }; + + // Simulate the HidIo producer calling report_callback. + let report_data = [0x10u8, 0x20, 0x30]; + let queue_ptr: *mut ReportQueue = &mut **device.report_queue as *mut _; + // SAFETY: test-only call with valid report data and context. + unsafe { + UefiHidIo::::report_callback( + report_data.len() as u16, + report_data.as_ptr() as *mut c_void, + queue_ptr as *mut c_void, + ); + } + + // Manually invoke the process_queued_reports handler (normally triggered by the UEFI event system). + UefiHidIo::::process_queued_reports(core::ptr::null_mut(), queue_ptr); + // MockHidReportReceivers will verify expectations on drop. + } + + #[test] + fn drop_unregisters_callback_and_closes_protocol() { + let device = make_hid_device(|_, _| {}); + drop(device); + // MockBootServices expectations for close_protocol and close_event + // are verified on drop. + } + + #[test] + fn new_cleans_up_on_register_callback_failure() { + extern "efiapi" fn failing_register( + _this: *const HidIoProtocol, + _callback: HidIoReportCallback, + _context: *mut c_void, + ) -> efi::Status { + efi::Status::DEVICE_ERROR + } + + let boot_services = mock_boot_services(); + boot_services.expect_open_protocol::().returning(|_, _, _, _| { + let protocol = mock_hid_io_protocol(); + protocol.register_report_callback = failing_register; + Ok(protocol) + }); + boot_services.expect_close_protocol().returning(|_, _, _, _| Ok(())); + boot_services.expect_close_event().returning(|_| Ok(())); + boot_services + .expect_create_event_unchecked::>() + .returning(|_, _, _, _| Ok(0xE0E as efi::Event)); + + let result = UefiHidIo::new(boot_services, core::ptr::null_mut(), core::ptr::null_mut(), vec![ok_factory()]); + assert_eq!(result.err(), Some(efi::Status::DEVICE_ERROR)); + } +} diff --git a/uefi_hid/src/hid_io/protocol.rs b/uefi_hid/src/hid_io/protocol.rs new file mode 100644 index 0000000..b31cf55 --- /dev/null +++ b/uefi_hid/src/hid_io/protocol.rs @@ -0,0 +1,93 @@ +//! HidIo protocol re-exports and consumer-side helpers. +//! +//! The FFI types are defined in the [`patina::vendor_protocols::hid_io`] module. This module +//! re-exports them for internal use and provides helper functions that depend +//! on `hidparser` for parsing report descriptors. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +pub use patina::vendor_protocols::hid_io::{HID_IO_PROTOCOL_GUID, HidIoProtocol, HidIoReportCallback, HidReportType}; + +use alloc::vec; +use core::ffi::c_void; + +use r_efi::efi; + +use super::HidIo; +use hidparser::ReportDescriptor; + +/// Initial buffer size for `get_report_descriptor`. Covers virtually all real +/// devices in a single transfer; larger descriptors fall back to a second call. +const INITIAL_REPORT_DESCRIPTOR_SIZE: usize = 4096; + +/// Retrieves and parses the report descriptor via the HidIo protocol. +/// +/// Tries an initial 4KB buffer first. If the device returns `BUFFER_TOO_SMALL`, +/// re-allocates to the exact required size and retries. +fn get_report_descriptor_impl(hid_io: &HidIoProtocol) -> Result { + let mut report_descriptor_size = INITIAL_REPORT_DESCRIPTOR_SIZE; + let mut buffer = vec![0u8; report_descriptor_size]; + + // SAFETY: hid_io points to a valid HidIoProtocol; buffer and size are valid. + match unsafe { + (hid_io.get_report_descriptor)( + hid_io as *const HidIoProtocol, + &mut report_descriptor_size, + buffer.as_mut_ptr() as *mut c_void, + ) + } { + efi::Status::SUCCESS => { + buffer.truncate(report_descriptor_size); + } + efi::Status::BUFFER_TOO_SMALL => { + buffer.resize(report_descriptor_size, 0); + // SAFETY: hid_io points to a valid HidIoProtocol; buffer and size are valid. + match unsafe { + (hid_io.get_report_descriptor)( + hid_io as *const HidIoProtocol, + &mut report_descriptor_size, + buffer.as_mut_ptr() as *mut c_void, + ) + } { + efi::Status::SUCCESS => { + buffer.truncate(report_descriptor_size); + } + err => return Err(err), + } + } + err => return Err(err), + } + + hidparser::parse_report_descriptor(&buffer).map_err(|_| efi::Status::DEVICE_ERROR) +} + +/// Sends an output report through the HidIo protocol. +fn set_output_report_impl(hid_io: &HidIoProtocol, id: Option, report: &[u8]) -> Result<(), efi::Status> { + // SAFETY: hid_io points to a valid HidIoProtocol; report buffer and size are valid. + match unsafe { + (hid_io.set_report)( + hid_io as *const HidIoProtocol, + id.unwrap_or(0), + HidReportType::OutputReport, + report.len(), + report.as_ptr() as *mut c_void, + ) + } { + efi::Status::SUCCESS => Ok(()), + err => Err(err), + } +} + +impl HidIo for HidIoProtocol { + fn get_report_descriptor(&self) -> Result { + get_report_descriptor_impl(self) + } + + fn set_output_report(&self, id: Option, report: &[u8]) -> Result<(), efi::Status> { + set_output_report_impl(self, id, report) + } +} diff --git a/uefi_hid/src/keyboard/key_queue.rs b/uefi_hid/src/keyboard/key_queue.rs new file mode 100644 index 0000000..db6818b --- /dev/null +++ b/uefi_hid/src/keyboard/key_queue.rs @@ -0,0 +1,1436 @@ +//! Key queue support for HID driver. +//! +//! Manages pending keystrokes, keyboard state, and translates between HID +//! usages and EFI keyboard primitives using the active HII keyboard layout. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::{ + collections::{BTreeSet, VecDeque}, + vec::Vec, +}; +use core::ops::Deref; +use hidparser::report_data_types::Usage; + +use crate::keyboard::layout::{EfiKey, HiiKey, HiiKeyboardLayout, HiiNsKeyDescriptor}; + +use r_efi::protocols::{self, hii_database::*, simple_text_input::InputKey, simple_text_input_ex::*}; + +// The set of HID usages that represent modifier keys this driver is interested in. +#[rustfmt::skip] +const KEYBOARD_MODIFIERS: &[u16] = &[ + LEFT_CONTROL_MODIFIER, RIGHT_CONTROL_MODIFIER, LEFT_SHIFT_MODIFIER, RIGHT_SHIFT_MODIFIER, LEFT_ALT_MODIFIER, + RIGHT_ALT_MODIFIER, LEFT_LOGO_MODIFIER, RIGHT_LOGO_MODIFIER, MENU_MODIFIER, PRINT_MODIFIER, SYS_REQUEST_MODIFIER, + ALT_GR_MODIFIER, +]; + +// The set of HID usages that represent modifier keys that toggle state (as opposed to remain active while pressed). +const TOGGLE_MODIFIERS: &[u16] = &[NUM_LOCK_MODIFIER, CAPS_LOCK_MODIFIER, SCROLL_LOCK_MODIFIER]; + +// Shift modifiers. +const SHIFT_MODIFIERS: &[u16] = &[LEFT_SHIFT_MODIFIER, RIGHT_SHIFT_MODIFIER]; + +// Mapping from HID modifier to the corresponding key_shift_state flag. +#[rustfmt::skip] +const SHIFT_STATE_MAP: &[(u16, u32)] = &[ + (LEFT_CONTROL_MODIFIER, LEFT_CONTROL_PRESSED), + (RIGHT_CONTROL_MODIFIER, RIGHT_CONTROL_PRESSED), + (LEFT_ALT_MODIFIER, LEFT_ALT_PRESSED), + (RIGHT_ALT_MODIFIER, RIGHT_ALT_PRESSED), + (LEFT_SHIFT_MODIFIER, LEFT_SHIFT_PRESSED), + (RIGHT_SHIFT_MODIFIER, RIGHT_SHIFT_PRESSED), + (LEFT_LOGO_MODIFIER, LEFT_LOGO_PRESSED), + (RIGHT_LOGO_MODIFIER, RIGHT_LOGO_PRESSED), + (MENU_MODIFIER, MENU_KEY_PRESSED), + (SYS_REQUEST_MODIFIER, SYS_REQ_PRESSED), + (PRINT_MODIFIER, SYS_REQ_PRESSED), +]; + +// Mapping from key_toggle_state flag to the corresponding HID modifier. +#[rustfmt::skip] +const TOGGLE_STATE_MAP: &[(KeyToggleState, u16)] = &[ + (SCROLL_LOCK_ACTIVE, SCROLL_LOCK_MODIFIER), + (NUM_LOCK_ACTIVE, NUM_LOCK_MODIFIER), + (CAPS_LOCK_ACTIVE, CAPS_LOCK_MODIFIER), +]; + +/// Defines whether a key stroke represents a key being pressed (KeyDown) or released (KeyUp) +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum KeyAction { + /// Key is being pressed + KeyDown, + /// Key is being released + KeyUp, +} + +/// A wrapper for the KeyData type that allows definition of the Ord trait and additional registration matching logic. +#[derive(Debug, Clone)] +pub(crate) struct OrdKeyData(pub protocols::simple_text_input_ex::KeyData); + +impl Deref for OrdKeyData { + type Target = protocols::simple_text_input_ex::KeyData; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Ord for OrdKeyData { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + let e = self.key.unicode_char.cmp(&other.key.unicode_char); + if !e.is_eq() { + return e; + } + let e = self.key.scan_code.cmp(&other.key.scan_code); + if !e.is_eq() { + return e; + } + let e = self.key_state.key_shift_state.cmp(&other.key_state.key_shift_state); + if !e.is_eq() { + return e; + } + self.key_state.key_toggle_state.cmp(&other.key_state.key_toggle_state) + } +} + +impl PartialOrd for OrdKeyData { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for OrdKeyData { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for OrdKeyData {} + +impl OrdKeyData { + // Returns whether this key matches the given registration. Note that this is not a straight compare - UEFI spec + // allows for some degree of wildcard matching. Refer to UEFI spec 2.10 section 12.2.5. + pub(crate) fn matches_registered_key(&self, registration: &Self) -> bool { + // char and scan must match (per the reference implementation in the EDK2 C code). + self.key.unicode_char == registration.key.unicode_char + && self.key.scan_code == registration.key.scan_code + // shift state must be zero (wildcard) or must match. + && (registration.key_state.key_shift_state == 0 + || registration.key_state.key_shift_state == self.key_state.key_shift_state) + // toggle state must be zero (wildcard) or must match. + && (registration.key_state.key_toggle_state == 0 + || registration.key_state.key_toggle_state == self.key_state.key_toggle_state) + } +} + +/// Manages the queue of pending keystrokes. +#[derive(Debug, Default)] +pub(crate) struct KeyQueue { + layout: Option, + active_modifiers: BTreeSet, + active_ns_key: Option, + partial_key_support_active: bool, + key_queue: VecDeque, + registered_keys: BTreeSet, + notified_key_queue: VecDeque, +} + +impl KeyQueue { + /// Resets the key queue to its initial state. + pub(crate) fn reset(&mut self, extended_reset: bool) { + if extended_reset { + self.active_modifiers.clear(); + } else { + self.active_modifiers.retain(|x| modifier_to_led_usage(*x).is_some()); + } + self.active_ns_key = None; + self.partial_key_support_active = false; + self.key_queue.clear(); + } + + /// Processes a keystroke and updates the key queue. + pub(crate) fn keystroke(&mut self, key: Usage, action: KeyAction) { + log::trace!("keystroke: usage 0x{:08X}, action {:?}", u32::from(key), action); + let Some(ref active_layout) = self.layout else { + //nothing to do if no layout. This is unexpected: layout should be initialized with default if not present. + log::warn!("key_queue::keystroke: Received keystroke without layout."); + return; + }; + + let Some(efi_key) = usage_to_efi_key(key) else { + //unsupported key usage, nothing to do. + return; + }; + + // Check if it is a dependent key of a currently active "non-spacing" (ns) key. + // Non-spacing key handling is described in UEFI spec 2.10 section 33.2.4.3. + let mut current_descriptor = None; + if let Some(ref ns_key) = self.active_ns_key { + for descriptor in &ns_key.dependent_keys { + if descriptor.key == efi_key { + // found a dependent key for a previously active ns key. + // de-activate the ns key and process the dependent descriptor. + current_descriptor = Some(*descriptor); + self.active_ns_key = None; + break; + } + } + } + + // If it is not a dependent key of a currently active ns key, then check if it is a regular or ns key. + if current_descriptor.is_none() { + for key in &active_layout.keys { + match key { + HiiKey::Key(descriptor) if descriptor.key == efi_key => { + current_descriptor = Some(*descriptor); + break; + } + HiiKey::NsKey(ns_descriptor) if ns_descriptor.descriptor.key == efi_key => { + // if it is an ns_key, set it as the active ns key, and no further processing is needed. + self.active_ns_key = Some(ns_descriptor.clone()); + return; + } + _ => continue, + } + } + } + + let Some(current_descriptor) = current_descriptor else { + return; //could not find descriptor, nothing to do. + }; + + //handle modifiers that are active as long as they are pressed + if KEYBOARD_MODIFIERS.contains(¤t_descriptor.modifier) { + match action { + KeyAction::KeyDown => { + self.active_modifiers.insert(current_descriptor.modifier); + log::trace!("keystroke: modifier 0x{:04X} pressed", current_descriptor.modifier); + } + KeyAction::KeyUp => { + self.active_modifiers.remove(¤t_descriptor.modifier); + log::trace!("keystroke: modifier 0x{:04X} released", current_descriptor.modifier); + } + } + } + + //handle modifiers that toggle each time the key is pressed. + if TOGGLE_MODIFIERS.contains(¤t_descriptor.modifier) && action == KeyAction::KeyDown { + if self.active_modifiers.contains(¤t_descriptor.modifier) { + self.active_modifiers.remove(¤t_descriptor.modifier); + log::trace!("keystroke: toggle modifier 0x{:04X} deactivated", current_descriptor.modifier); + } else { + self.active_modifiers.insert(current_descriptor.modifier); + log::trace!("keystroke: toggle modifier 0x{:04X} activated", current_descriptor.modifier); + } + } + + if action == KeyAction::KeyUp { + //nothing else to do. + return; + } + + // process the keystroke to construct a KeyData item to add to the queue. + let mut key_data = protocols::simple_text_input_ex::KeyData { + key: InputKey { + unicode_char: current_descriptor.unicode, + scan_code: modifier_to_scan(current_descriptor.modifier), + }, + ..Default::default() + }; + + // retrieve relevant modifier state that may need to be applied to the key data. + let shift_active = SHIFT_MODIFIERS.iter().any(|x| self.active_modifiers.contains(x)); + let alt_gr_active = self.active_modifiers.contains(&ALT_GR_MODIFIER); + let caps_lock_active = self.active_modifiers.contains(&CAPS_LOCK_MODIFIER); + let num_lock_active = self.active_modifiers.contains(&NUM_LOCK_MODIFIER); + + // Apply the shift modifier if needed. shift_applied tracks whether shift was consumed so it can be removed from + // the key state later (see UEFI spec 2.10 section 12.2.3). + let affected_by_shift = (current_descriptor.affected_attribute & AFFECTED_BY_STANDARD_SHIFT) != 0; + let shift_applied = affected_by_shift && shift_active; + + if affected_by_shift { + match (shift_active, alt_gr_active) { + (true, true) => key_data.key.unicode_char = current_descriptor.shifted_alt_gr_unicode, + (true, false) => key_data.key.unicode_char = current_descriptor.shifted_unicode, + (false, true) => key_data.key.unicode_char = current_descriptor.alt_gr_unicode, + (false, false) => {} // unicode_char already set to default + } + } + + // if capslock is active, then invert the shift state of the key. + if (current_descriptor.affected_attribute & AFFECTED_BY_CAPS_LOCK) != 0 && caps_lock_active { + //Note: reference EDK2 implementation does not apply capslock to alt_gr. + if key_data.key.unicode_char == current_descriptor.unicode { + key_data.key.unicode_char = current_descriptor.shifted_unicode; + } else if key_data.key.unicode_char == current_descriptor.shifted_unicode { + key_data.key.unicode_char = current_descriptor.unicode; + } + } + + // for the num pad, numlock (and shift state) controls whether a number key or a control key (e.g. arrow) is queued. + if (current_descriptor.affected_attribute & AFFECTED_BY_NUM_LOCK) != 0 { + if num_lock_active && !shift_active { + key_data.key.scan_code = SCAN_NULL; + } else { + key_data.key.unicode_char = 0x0000; + } + } + + //special handling for unicode ESC (0x1B). + const ESC_UNICODE: u16 = 0x1B; + if key_data.key.unicode_char == ESC_UNICODE && key_data.key.scan_code == SCAN_NULL { + key_data.key.scan_code = SCAN_ESC; + key_data.key.unicode_char = 0x0000; + } + + if !self.partial_key_support_active && key_data.key.unicode_char == 0 && key_data.key.scan_code == SCAN_NULL { + return; // no further processing required if there is no key or scancode and partial support is not active. + } + + //initialize key state from active modifiers + key_data.key_state = self.init_key_state(); + + // if shift was applied above, then remove shift from key state. See UEFI spec 2.10 section 12.2.3. + if shift_applied { + key_data.key_state.key_shift_state &= !(LEFT_SHIFT_PRESSED | RIGHT_SHIFT_PRESSED); + } + + // if a callback has been registered matching this key, enqueue it in the callback queue. + if self.is_registered_key(key_data) { + self.notified_key_queue.push_back(key_data); + } + + // enqueue the key data. + log::trace!( + "keystroke: enqueuing key unicode=0x{:04X} scan=0x{:04X}", + key_data.key.unicode_char, + key_data.key.scan_code, + ); + self.key_queue.push_back(key_data); + } + + // Returns true if the key matches any registered notification key. + fn is_registered_key(&self, current_key: KeyData) -> bool { + for registered_key in &self.registered_keys { + if OrdKeyData(current_key).matches_registered_key(registered_key) { + return true; + } + } + false + } + + /// Returns a new KeyState reflecting the current modifier state. + pub(crate) fn init_key_state(&self) -> KeyState { + let key_shift_state = SHIFT_STATE_MAP + .iter() + .filter(|(modifier, _)| self.active_modifiers.contains(modifier)) + .fold(SHIFT_STATE_VALID, |state, (_, pressed)| state | pressed); + + let key_toggle_state = TOGGLE_STATE_MAP + .iter() + .filter(|(_, modifier)| self.active_modifiers.contains(modifier)) + .fold(TOGGLE_STATE_VALID, |state, (flag, _)| state | flag) + | if self.partial_key_support_active { KEY_STATE_EXPOSED } else { 0 }; + KeyState { key_shift_state, key_toggle_state } + } + + /// Removes and returns the next pending keystroke. + pub(crate) fn pop_key(&mut self) -> Option { + self.key_queue.pop_front() + } + + /// Returns the next pending keystroke without removing it. + pub(crate) fn peek_key(&self) -> Option { + self.key_queue.front().cloned() + } + + /// Removes and returns the next pending notify keystroke. + pub(crate) fn pop_notify_key(&mut self) -> Option { + self.notified_key_queue.pop_front() + } + + /// Returns the next pending notify keystroke without removing it. + pub(crate) fn peek_notify_key(&self) -> Option { + self.notified_key_queue.front().cloned() + } + + /// Sets the toggle state for scroll/caps/num lock and partial key exposure. + pub(crate) fn set_key_toggle_state(&mut self, toggle_state: KeyToggleState) { + log::trace!("set_key_toggle_state: 0x{:02X}", toggle_state); + for &(flag, modifier) in TOGGLE_STATE_MAP { + if (toggle_state & flag) != 0 { + self.active_modifiers.insert(modifier); + } else { + self.active_modifiers.remove(&modifier); + } + } + + self.partial_key_support_active = (toggle_state & KEY_STATE_EXPOSED) != 0; + } + + /// Returns the HID LED usages corresponding to active toggle modifiers. + pub(crate) fn active_leds(&self) -> Vec { + self.active_modifiers.iter().copied().filter_map(modifier_to_led_usage).collect() + } + + /// Returns the current keyboard layout. + pub(crate) fn layout(&self) -> Option { + self.layout.clone() + } + + /// Sets the keyboard layout used for keystroke translation. + pub(crate) fn set_layout(&mut self, new_layout: Option) { + self.layout = new_layout; + } + + /// Registers a key for notification matching. + pub(crate) fn add_notify_key(&mut self, key_data: OrdKeyData) { + self.registered_keys.insert(key_data); + } + + /// Returns whether the given usage represents a key that should support key repeat. + /// Modifier keys (Shift, Ctrl, Alt, etc.), toggle keys (CapsLock, NumLock, ScrollLock), and + /// non-spacing (dead) keys are excluded. + pub(crate) fn is_repeatable_key(&self, usage: Usage) -> bool { + let Some(ref active_layout) = self.layout else { + return false; + }; + + let Some(efi_key) = usage_to_efi_key(usage) else { + return false; + }; + + for key in &active_layout.keys { + match key { + HiiKey::Key(descriptor) if descriptor.key == efi_key => { + return !KEYBOARD_MODIFIERS.contains(&descriptor.modifier) + && !TOGGLE_MODIFIERS.contains(&descriptor.modifier); + } + HiiKey::NsKey(ns_descriptor) if ns_descriptor.descriptor.key == efi_key => { + return false; + } + _ => continue, + } + } + false + } + + /// Unregisters a notification key. + pub(crate) fn remove_notify_key(&mut self, key_data: &OrdKeyData) { + self.registered_keys.remove(key_data); + } +} + +// Helper routine that converts a HID Usage to the corresponding EfiKey. +fn usage_to_efi_key(usage: Usage) -> Option { + //Refer to UEFI spec version 2.10 figure 34.3 + match usage.into() { + 0x00070001..=0x00070003 => None, //Keyboard error codes. + 0x00070004 => Some(EfiKey::C1), + 0x00070005 => Some(EfiKey::B5), + 0x00070006 => Some(EfiKey::B3), + 0x00070007 => Some(EfiKey::C3), + 0x00070008 => Some(EfiKey::D3), + 0x00070009 => Some(EfiKey::C4), + 0x0007000A => Some(EfiKey::C5), + 0x0007000B => Some(EfiKey::C6), + 0x0007000C => Some(EfiKey::D8), + 0x0007000D => Some(EfiKey::C7), + 0x0007000E => Some(EfiKey::C8), + 0x0007000F => Some(EfiKey::C9), + 0x00070010 => Some(EfiKey::B7), + 0x00070011 => Some(EfiKey::B6), + 0x00070012 => Some(EfiKey::D9), + 0x00070013 => Some(EfiKey::D10), + 0x00070014 => Some(EfiKey::D1), + 0x00070015 => Some(EfiKey::D4), + 0x00070016 => Some(EfiKey::C2), + 0x00070017 => Some(EfiKey::D5), + 0x00070018 => Some(EfiKey::D7), + 0x00070019 => Some(EfiKey::B4), + 0x0007001A => Some(EfiKey::D2), + 0x0007001B => Some(EfiKey::B2), + 0x0007001C => Some(EfiKey::D6), + 0x0007001D => Some(EfiKey::B1), + 0x0007001E => Some(EfiKey::E1), + 0x0007001F => Some(EfiKey::E2), + 0x00070020 => Some(EfiKey::E3), + 0x00070021 => Some(EfiKey::E4), + 0x00070022 => Some(EfiKey::E5), + 0x00070023 => Some(EfiKey::E6), + 0x00070024 => Some(EfiKey::E7), + 0x00070025 => Some(EfiKey::E8), + 0x00070026 => Some(EfiKey::E9), + 0x00070027 => Some(EfiKey::E10), + 0x00070028 => Some(EfiKey::Enter), + 0x00070029 => Some(EfiKey::Esc), + 0x0007002A => Some(EfiKey::BackSpace), + 0x0007002B => Some(EfiKey::Tab), + 0x0007002C => Some(EfiKey::SpaceBar), + 0x0007002D => Some(EfiKey::E11), + 0x0007002E => Some(EfiKey::E12), + 0x0007002F => Some(EfiKey::D11), + 0x00070030 => Some(EfiKey::D12), + 0x00070031 => Some(EfiKey::D13), + 0x00070032 => Some(EfiKey::C12), + 0x00070033 => Some(EfiKey::C10), + 0x00070034 => Some(EfiKey::C11), + 0x00070035 => Some(EfiKey::E0), + 0x00070036 => Some(EfiKey::B8), + 0x00070037 => Some(EfiKey::B9), + 0x00070038 => Some(EfiKey::B10), + 0x00070039 => Some(EfiKey::CapsLock), + 0x0007003A => Some(EfiKey::F1), + 0x0007003B => Some(EfiKey::F2), + 0x0007003C => Some(EfiKey::F3), + 0x0007003D => Some(EfiKey::F4), + 0x0007003E => Some(EfiKey::F5), + 0x0007003F => Some(EfiKey::F6), + 0x00070040 => Some(EfiKey::F7), + 0x00070041 => Some(EfiKey::F8), + 0x00070042 => Some(EfiKey::F9), + 0x00070043 => Some(EfiKey::F10), + 0x00070044 => Some(EfiKey::F11), + 0x00070045 => Some(EfiKey::F12), + 0x00070046 => Some(EfiKey::Print), + 0x00070047 => Some(EfiKey::SLck), + 0x00070048 => Some(EfiKey::Pause), + 0x00070049 => Some(EfiKey::Ins), + 0x0007004A => Some(EfiKey::Home), + 0x0007004B => Some(EfiKey::PgUp), + 0x0007004C => Some(EfiKey::Del), + 0x0007004D => Some(EfiKey::End), + 0x0007004E => Some(EfiKey::PgDn), + 0x0007004F => Some(EfiKey::RightArrow), + 0x00070050 => Some(EfiKey::LeftArrow), + 0x00070051 => Some(EfiKey::DownArrow), + 0x00070052 => Some(EfiKey::UpArrow), + 0x00070053 => Some(EfiKey::NLck), + 0x00070054 => Some(EfiKey::Slash), + 0x00070055 => Some(EfiKey::Asterisk), + 0x00070056 => Some(EfiKey::Minus), + 0x00070057 => Some(EfiKey::Plus), + 0x00070058 => Some(EfiKey::Enter), + 0x00070059 => Some(EfiKey::One), + 0x0007005A => Some(EfiKey::Two), + 0x0007005B => Some(EfiKey::Three), + 0x0007005C => Some(EfiKey::Four), + 0x0007005D => Some(EfiKey::Five), + 0x0007005E => Some(EfiKey::Six), + 0x0007005F => Some(EfiKey::Seven), + 0x00070060 => Some(EfiKey::Eight), + 0x00070061 => Some(EfiKey::Nine), + 0x00070062 => Some(EfiKey::Zero), + 0x00070063 => Some(EfiKey::Period), + 0x00070064 => Some(EfiKey::B0), + 0x00070065 => Some(EfiKey::A4), + 0x00070066..=0x000700DF => None, // not used by EFI keyboard layout. + 0x000700E0 => Some(EfiKey::LCtrl), + 0x000700E1 => Some(EfiKey::LShift), + 0x000700E2 => Some(EfiKey::LAlt), + 0x000700E3 => Some(EfiKey::A0), + 0x000700E4 => Some(EfiKey::RCtrl), + 0x000700E5 => Some(EfiKey::RShift), + 0x000700E6 => Some(EfiKey::A2), + 0x000700E7 => Some(EfiKey::A3), + _ => None, // all other usages not used by EFI keyboard layout. + } +} + +//These should be defined in r_efi::protocols::simple_text_input + +/// UEFI scan code: no key pressed. +pub const SCAN_NULL: u16 = 0x0000; +/// UEFI scan code: Up arrow. +pub const SCAN_UP: u16 = 0x0001; +/// UEFI scan code: Down arrow. +pub const SCAN_DOWN: u16 = 0x0002; +/// UEFI scan code: Right arrow. +pub const SCAN_RIGHT: u16 = 0x0003; +/// UEFI scan code: Left arrow. +pub const SCAN_LEFT: u16 = 0x0004; +/// UEFI scan code: Home. +pub const SCAN_HOME: u16 = 0x0005; +/// UEFI scan code: End. +pub const SCAN_END: u16 = 0x0006; +/// UEFI scan code: Insert. +pub const SCAN_INSERT: u16 = 0x0007; +/// UEFI scan code: Delete. +pub const SCAN_DELETE: u16 = 0x0008; +/// UEFI scan code: Page Up. +pub const SCAN_PAGE_UP: u16 = 0x0009; +/// UEFI scan code: Page Down. +pub const SCAN_PAGE_DOWN: u16 = 0x000A; +/// UEFI scan code: F1. +pub const SCAN_F1: u16 = 0x000B; +/// UEFI scan code: F2. +pub const SCAN_F2: u16 = 0x000C; +/// UEFI scan code: F3. +pub const SCAN_F3: u16 = 0x000D; +/// UEFI scan code: F4. +pub const SCAN_F4: u16 = 0x000E; +/// UEFI scan code: F5. +pub const SCAN_F5: u16 = 0x000F; +/// UEFI scan code: F6. +pub const SCAN_F6: u16 = 0x0010; +/// UEFI scan code: F7. +pub const SCAN_F7: u16 = 0x0011; +/// UEFI scan code: F8. +pub const SCAN_F8: u16 = 0x0012; +/// UEFI scan code: F9. +pub const SCAN_F9: u16 = 0x0013; +/// UEFI scan code: F10. +pub const SCAN_F10: u16 = 0x0014; +/// UEFI scan code: F11. +pub const SCAN_F11: u16 = 0x0015; +/// UEFI scan code: F12. +pub const SCAN_F12: u16 = 0x0016; +/// UEFI scan code: Escape. +pub const SCAN_ESC: u16 = 0x0017; +/// UEFI scan code: Pause. +pub const SCAN_PAUSE: u16 = 0x0048; + +// helper routine that converts the given modifier to the corresponding SCAN code +fn modifier_to_scan(modifier: u16) -> u16 { + match modifier { + INSERT_MODIFIER => SCAN_INSERT, + DELETE_MODIFIER => SCAN_DELETE, + PAGE_DOWN_MODIFIER => SCAN_PAGE_DOWN, + PAGE_UP_MODIFIER => SCAN_PAGE_UP, + HOME_MODIFIER => SCAN_HOME, + END_MODIFIER => SCAN_END, + LEFT_ARROW_MODIFIER => SCAN_LEFT, + RIGHT_ARROW_MODIFIER => SCAN_RIGHT, + DOWN_ARROW_MODIFIER => SCAN_DOWN, + UP_ARROW_MODIFIER => SCAN_UP, + FUNCTION_KEY_ONE_MODIFIER => SCAN_F1, + FUNCTION_KEY_TWO_MODIFIER => SCAN_F2, + FUNCTION_KEY_THREE_MODIFIER => SCAN_F3, + FUNCTION_KEY_FOUR_MODIFIER => SCAN_F4, + FUNCTION_KEY_FIVE_MODIFIER => SCAN_F5, + FUNCTION_KEY_SIX_MODIFIER => SCAN_F6, + FUNCTION_KEY_SEVEN_MODIFIER => SCAN_F7, + FUNCTION_KEY_EIGHT_MODIFIER => SCAN_F8, + FUNCTION_KEY_NINE_MODIFIER => SCAN_F9, + FUNCTION_KEY_TEN_MODIFIER => SCAN_F10, + FUNCTION_KEY_ELEVEN_MODIFIER => SCAN_F11, + FUNCTION_KEY_TWELVE_MODIFIER => SCAN_F12, + PAUSE_MODIFIER => SCAN_PAUSE, + _ => SCAN_NULL, + } +} + +// helper routine that converts the given modifier to the corresponding HID Usage. +fn modifier_to_led_usage(modifier: u16) -> Option { + match modifier { + NUM_LOCK_MODIFIER => Some(Usage::from(0x00080001)), + CAPS_LOCK_MODIFIER => Some(Usage::from(0x00080002)), + SCROLL_LOCK_MODIFIER => Some(Usage::from(0x00080003)), + _ => None, + } +} + +#[cfg(test)] +mod test { + + use hidparser::report_data_types::Usage; + use r_efi::protocols::{ + self, + hii_database::{ + AFFECTED_BY_CAPS_LOCK, AFFECTED_BY_STANDARD_SHIFT, NS_KEY_DEPENDENCY_MODIFIER, NS_KEY_MODIFIER, + }, + }; + + use crate::keyboard::layout::{EfiKey, HiiKey, HiiKeyDescriptor, HiiNsKeyDescriptor}; + + use crate::keyboard::key_queue::{OrdKeyData, SCAN_DOWN, SCAN_END, SCAN_ESC, SCAN_NULL}; + + use super::KeyQueue; + + // HID usages for keys used in tests. + fn usage_a() -> Usage { + Usage::from(0x00070004u32) + } // EfiKey::C1 = 'a'/'A' + fn usage_1() -> Usage { + Usage::from(0x0007001Eu32) + } // EfiKey::E1 = '1'/'!' + fn usage_esc() -> Usage { + Usage::from(0x00070029u32) + } // EfiKey::Esc + fn usage_lshift() -> Usage { + Usage::from(0x000700E1u32) + } // EfiKey::LShift + fn usage_lctrl() -> Usage { + Usage::from(0x000700E0u32) + } // EfiKey::LCtrl + fn usage_capslock() -> Usage { + Usage::from(0x00070039u32) + } // EfiKey::CapsLock + fn usage_numlock() -> Usage { + Usage::from(0x00070053u32) + } // EfiKey::NLck + fn usage_scrolllock() -> Usage { + Usage::from(0x00070047u32) + } // EfiKey::SLck + fn usage_numpad1() -> Usage { + Usage::from(0x00070059u32) + } // EfiKey::One (numpad) + + fn key_queue_with_default_layout() -> KeyQueue { + let mut kq = KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + kq + } + + fn press_key(kq: &mut KeyQueue, usage: Usage) { + kq.keystroke(usage, super::KeyAction::KeyDown); + } + + fn release_key(kq: &mut KeyQueue, usage: Usage) { + kq.keystroke(usage, super::KeyAction::KeyUp); + } + + fn tap_key(kq: &mut KeyQueue, usage: Usage) { + press_key(kq, usage); + release_key(kq, usage); + } + + // convenience macro for defining HiiKeyDescriptor structures. + // note: for unicode characters, these are encoded as u16 for compliance with UEFI spec. UEFI only supports UCS-2 + // encoding - so unicode characters that require more than two bytes under UTF-16 are not supported (and will panic). + macro_rules! key_descriptor { + ($key:expr, $unicode:literal, $shifted:literal, $alt_gr:literal, $shifted_alt_gr:literal, $modifier:expr, $affected:expr ) => { + HiiKeyDescriptor { + key: $key, + unicode: $unicode.encode_utf16(&mut [0u16; 1])[0], + shifted_unicode: $shifted.encode_utf16(&mut [0u16; 1])[0], + alt_gr_unicode: $alt_gr.encode_utf16(&mut [0u16; 1])[0], + shifted_alt_gr_unicode: $shifted_alt_gr.encode_utf16(&mut [0u16; 1])[0], + modifier: $modifier, + affected_attribute: $affected, + } + }; + } + + #[test] + fn test_ord_key_comparisons() { + let mut key_data1: protocols::simple_text_input_ex::KeyData = Default::default(); + let mut key_data2: protocols::simple_text_input_ex::KeyData = Default::default(); + + assert_eq!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + key_data1.key.unicode_char = 'a' as u16; + assert_ne!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + key_data2.key.unicode_char = 'a' as u16; + assert_eq!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + + key_data1.key.scan_code = SCAN_DOWN; + assert_ne!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + key_data2.key.scan_code = SCAN_DOWN; + assert_eq!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + + key_data1.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED; + assert_ne!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + key_data2.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED; + assert_eq!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + + key_data1.key_state.key_toggle_state = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE; + assert_ne!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + key_data2.key_state.key_toggle_state = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE; + assert_eq!(OrdKeyData(key_data1), OrdKeyData(key_data2)); + + assert_eq!(OrdKeyData(key_data1).partial_cmp(&OrdKeyData(key_data2)), Some(core::cmp::Ordering::Equal)); + } + + #[test] + fn test_ns_keystroke() { + let mut key_queue = KeyQueue::default(); + + let mut ns_key_layout = crate::keyboard::layout::get_default_keyboard_layout(); + + let keys = &mut ns_key_layout.keys; + + let (index, _) = keys + .iter() + .enumerate() + .find(|(_, element)| if let HiiKey::Key(key) = element { key.key == EfiKey::E0 } else { false }) + .unwrap(); + + #[rustfmt::skip] + let ns_key = HiiKey::NsKey(HiiNsKeyDescriptor { + descriptor: + key_descriptor!(EfiKey::E0, '\0', '\0', '\0', '\0', NS_KEY_MODIFIER, 0), + dependent_keys: vec![ + key_descriptor!(EfiKey::C1, '\u{00E2}', '\u{00C2}', '\0', '\0', NS_KEY_DEPENDENCY_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key_descriptor!(EfiKey::D3, '\u{00EA}', '\u{00CA}', '\0', '\0', NS_KEY_DEPENDENCY_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key_descriptor!(EfiKey::D8, '\u{00EC}', '\u{00CC}', '\0', '\0', NS_KEY_DEPENDENCY_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key_descriptor!(EfiKey::D9, '\u{00F4}', '\u{00D4}', '\0', '\0', NS_KEY_DEPENDENCY_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key_descriptor!(EfiKey::D7, '\u{00FB}', '\u{00CB}', '\0', '\0', NS_KEY_DEPENDENCY_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK) + ]}); + + keys[index] = ns_key.clone(); + + key_queue.set_layout(Some(ns_key_layout)); + + let key = Usage::from(0x00070035); //E0 + + key_queue.keystroke(key, super::KeyAction::KeyDown); + key_queue.keystroke(key, super::KeyAction::KeyUp); + + let HiiKey::NsKey(expected_key) = ns_key else { panic!() }; + assert_eq!(key_queue.active_ns_key, Some(expected_key)); + + assert!(key_queue.peek_key().is_none()); + + let key = Usage::from(0x00070004); //C1 + key_queue.keystroke(key, super::KeyAction::KeyDown); + key_queue.keystroke(key, super::KeyAction::KeyUp); + + let stroke = key_queue.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, '\u{00E2}' as u16); + } + + // --- keystroke queuing tests --- + + #[test] + fn basic_key_press_produces_correct_unicode() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_a()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 'a' as u16); + assert_eq!(stroke.key.scan_code, SCAN_NULL); + } + + #[test] + fn key_up_does_not_enqueue() { + let mut kq = key_queue_with_default_layout(); + release_key(&mut kq, usage_a()); + assert!(kq.pop_key().is_none()); + } + + #[test] + fn shifted_key_produces_shifted_unicode() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, usage_a()); + // skip the shift key entry if any, find the 'A' key + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 'A' as u16); + } + + #[test] + fn shift_removed_from_key_state_when_applied() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, usage_a()); + let stroke = kq.pop_key().unwrap(); + // shift was consumed to produce 'A', so shift should not appear in key_state + assert_eq!( + stroke.key_state.key_shift_state + & (protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED + | protocols::simple_text_input_ex::RIGHT_SHIFT_PRESSED), + 0 + ); + } + + #[test] + fn shift_not_removed_when_not_applied() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, usage_esc()); // ESC is not AFFECTED_BY_STANDARD_SHIFT + let stroke = kq.pop_key().unwrap(); + // shift was NOT consumed, so it should remain in key_state + assert_ne!(stroke.key_state.key_shift_state & protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED, 0); + } + + #[test] + fn caps_lock_inverts_to_shifted() { + let mut kq = key_queue_with_default_layout(); + // toggle caps lock on + tap_key(&mut kq, usage_capslock()); + press_key(&mut kq, usage_a()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 'A' as u16); + } + + #[test] + fn caps_lock_plus_shift_inverts_to_unshifted() { + let mut kq = key_queue_with_default_layout(); + // toggle caps lock on + tap_key(&mut kq, usage_capslock()); + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, usage_a()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 'a' as u16); + } + + #[test] + fn caps_lock_does_not_affect_number_keys() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_capslock()); + press_key(&mut kq, usage_1()); + let stroke = kq.pop_key().unwrap(); + // '1' is AFFECTED_BY_STANDARD_SHIFT but not AFFECTED_BY_CAPS_LOCK + assert_eq!(stroke.key.unicode_char, '1' as u16); + } + + #[test] + fn num_lock_on_numpad_produces_number() { + let mut kq = key_queue_with_default_layout(); + // toggle num lock on + tap_key(&mut kq, usage_numlock()); + press_key(&mut kq, usage_numpad1()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, '1' as u16); + assert_eq!(stroke.key.scan_code, SCAN_NULL); + } + + #[test] + fn num_lock_off_numpad_produces_scan_code() { + let mut kq = key_queue_with_default_layout(); + // num lock is off by default + press_key(&mut kq, usage_numpad1()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 0); + assert_eq!(stroke.key.scan_code, SCAN_END); + } + + #[test] + fn esc_produces_scan_code() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_esc()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.scan_code, SCAN_ESC); + assert_eq!(stroke.key.unicode_char, 0); + } + + #[test] + fn keystroke_without_layout_does_not_enqueue() { + let mut kq = KeyQueue::default(); // no layout set + press_key(&mut kq, usage_a()); + assert!(kq.pop_key().is_none()); + } + + // --- init_key_state tests --- + + #[test] + fn init_key_state_reflects_active_shift_modifiers() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_lctrl()); + let state = kq.init_key_state(); + assert_ne!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED, 0); + } + + #[test] + fn init_key_state_reflects_toggle_modifiers() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_capslock()); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + } + + // --- set_key_toggle_state tests --- + + #[test] + fn set_key_toggle_state_scroll_lock() { + let mut kq = KeyQueue::default(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::SCROLL_LOCK_ACTIVE, + ); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::SCROLL_LOCK_ACTIVE, 0); + assert_eq!(state.key_toggle_state & protocols::simple_text_input_ex::NUM_LOCK_ACTIVE, 0); + } + + #[test] + fn set_key_toggle_state_num_lock() { + let mut kq = KeyQueue::default(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::NUM_LOCK_ACTIVE, + ); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::NUM_LOCK_ACTIVE, 0); + } + + #[test] + fn set_key_toggle_state_key_state_exposed() { + let mut kq = KeyQueue::default(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::KEY_STATE_EXPOSED, 0); + } + + #[test] + fn set_key_toggle_state_clears_previously_set_toggles() { + let mut kq = KeyQueue::default(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID + | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE + | protocols::simple_text_input_ex::NUM_LOCK_ACTIVE, + ); + // now clear num lock + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, + ); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + assert_eq!(state.key_toggle_state & protocols::simple_text_input_ex::NUM_LOCK_ACTIVE, 0); + } + + // --- reset tests --- + + #[test] + fn reset_non_extended_retains_toggle_modifiers() { + let mut kq = key_queue_with_default_layout(); + // activate caps lock (toggle) and left shift (non-toggle) + tap_key(&mut kq, usage_capslock()); + press_key(&mut kq, usage_lshift()); + // drain any queued keys + while kq.pop_key().is_some() {} + + kq.reset(false); + let state = kq.init_key_state(); + // caps lock (toggle/LED modifier) should be retained + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + // left shift (non-toggle) should be cleared + assert_eq!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED, 0); + } + + #[test] + fn reset_extended_clears_all_modifiers() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_capslock()); + press_key(&mut kq, usage_lshift()); + while kq.pop_key().is_some() {} + + kq.reset(true); + let state = kq.init_key_state(); + assert_eq!(state.key_shift_state, protocols::simple_text_input_ex::SHIFT_STATE_VALID); + assert_eq!(state.key_toggle_state, protocols::simple_text_input_ex::TOGGLE_STATE_VALID); + } + + // --- matches_registered_key tests --- + + #[test] + fn matches_registered_key_with_wildcard_shift_state() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + key_data.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED; + + // registration with zero shift state should match any shift state + let mut registration: protocols::simple_text_input_ex::KeyData = Default::default(); + registration.key.unicode_char = 'a' as u16; + registration.key_state.key_shift_state = 0; + + assert!(OrdKeyData(key_data).matches_registered_key(&OrdKeyData(registration))); + } + + #[test] + fn matches_registered_key_with_mismatched_shift_state() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + key_data.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED; + + let mut registration: protocols::simple_text_input_ex::KeyData = Default::default(); + registration.key.unicode_char = 'a' as u16; + registration.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::RIGHT_CONTROL_PRESSED; + + assert!(!OrdKeyData(key_data).matches_registered_key(&OrdKeyData(registration))); + } + + #[test] + fn matches_registered_key_with_wildcard_toggle_state() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + key_data.key_state.key_toggle_state = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE; + + let mut registration: protocols::simple_text_input_ex::KeyData = Default::default(); + registration.key.unicode_char = 'a' as u16; + registration.key_state.key_toggle_state = 0; + + assert!(OrdKeyData(key_data).matches_registered_key(&OrdKeyData(registration))); + } + + // --- active_leds tests --- + + #[test] + fn active_leds_reflects_toggle_modifiers() { + let mut kq = key_queue_with_default_layout(); + assert!(kq.active_leds().is_empty()); + + tap_key(&mut kq, usage_capslock()); + let leds = kq.active_leds(); + assert_eq!(leds.len(), 1); + assert_eq!(leds[0], Usage::from(0x00080002)); // caps lock LED usage + } + + #[test] + fn active_leds_multiple_toggles() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_capslock()); + tap_key(&mut kq, usage_scrolllock()); + let leds = kq.active_leds(); + assert_eq!(leds.len(), 2); + } + + // --- registered key notification queuing --- + + #[test] + fn registered_key_enqueues_to_notify_queue() { + let mut kq = key_queue_with_default_layout(); + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key.unicode_char = 'a' as u16; + kq.add_notify_key(OrdKeyData(reg_key)); + + press_key(&mut kq, usage_a()); + assert!(kq.peek_notify_key().is_some()); + let notify = kq.pop_notify_key().unwrap(); + assert_eq!(notify.key.unicode_char, 'a' as u16); + } + + #[test] + fn unregistered_key_does_not_enqueue_to_notify_queue() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_a()); + assert!(kq.peek_notify_key().is_none()); + } + + #[test] + fn num_lock_plus_shift_numpad_produces_scan_code() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_numlock()); + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, usage_numpad1()); + let stroke = kq.pop_key().unwrap(); + assert_eq!(stroke.key.unicode_char, 0); + assert_eq!(stroke.key.scan_code, SCAN_END); + } + + #[test] + fn toggle_modifier_toggles_off_on_second_press() { + let mut kq = key_queue_with_default_layout(); + tap_key(&mut kq, usage_capslock()); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + + tap_key(&mut kq, usage_capslock()); + let state = kq.init_key_state(); + assert_eq!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + } + + #[test] + fn partial_key_support_enqueues_empty_key() { + let mut kq = key_queue_with_default_layout(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + // left ctrl produces no unicode or scan code, but should still be enqueued with partial support + press_key(&mut kq, usage_lctrl()); + assert!(kq.pop_key().is_some()); + } + + #[test] + fn keys_dequeue_in_fifo_order() { + let mut kq = key_queue_with_default_layout(); + press_key(&mut kq, usage_a()); + press_key(&mut kq, usage_1()); + let first = kq.pop_key().unwrap(); + let second = kq.pop_key().unwrap(); + assert_eq!(first.key.unicode_char, 'a' as u16); + assert_eq!(second.key.unicode_char, '1' as u16); + assert!(kq.pop_key().is_none()); + } + + // --- usage_to_efi_key coverage --- + + #[test] + fn usage_to_efi_key_covers_all_letter_keys() { + // Test letter keys 'a'..'z' (usages 0x04..0x1D) + for usage_val in 0x00070004u32..=0x0007001Du32 { + assert!( + super::usage_to_efi_key(Usage::from(usage_val)).is_some(), + "usage_to_efi_key should return Some for usage 0x{:08X}", + usage_val, + ); + } + } + + #[test] + fn usage_to_efi_key_covers_digit_and_symbol_keys() { + // 0x1E..0x38 covers digits '1'..'0', enter, esc, backspace, tab, space, and symbols + for usage_val in 0x0007001Eu32..=0x00070038u32 { + assert!( + super::usage_to_efi_key(Usage::from(usage_val)).is_some(), + "usage_to_efi_key should return Some for usage 0x{:08X}", + usage_val, + ); + } + } + + #[test] + fn usage_to_efi_key_covers_f_keys_and_navigation() { + // 0x39..0x65 covers capslock, F1..F12, print, scroll lock, pause, ins, home, pgup, + // del, end, pgdn, arrows, numlock, numpad keys, B0, A4 + for usage_val in 0x00070039u32..=0x00070065u32 { + assert!( + super::usage_to_efi_key(Usage::from(usage_val)).is_some(), + "usage_to_efi_key should return Some for usage 0x{:08X}", + usage_val, + ); + } + } + + #[test] + fn usage_to_efi_key_covers_modifier_keys() { + // 0xE0..0xE7 covers LCtrl, LShift, LAlt, A0, RCtrl, RShift, A2, A3 + for usage_val in 0x000700E0u32..=0x000700E7u32 { + assert!( + super::usage_to_efi_key(Usage::from(usage_val)).is_some(), + "usage_to_efi_key should return Some for usage 0x{:08X}", + usage_val, + ); + } + } + + #[test] + fn usage_to_efi_key_returns_none_for_error_codes() { + for usage_val in 0x00070001u32..=0x00070003u32 { + assert!(super::usage_to_efi_key(Usage::from(usage_val)).is_none()); + } + } + + #[test] + fn usage_to_efi_key_returns_none_for_unused_range() { + assert!(super::usage_to_efi_key(Usage::from(0x00070066u32)).is_none()); + assert!(super::usage_to_efi_key(Usage::from(0x000700DFu32)).is_none()); + } + + #[test] + fn usage_to_efi_key_returns_none_for_out_of_range() { + assert!(super::usage_to_efi_key(Usage::from(0x000700F0u32)).is_none()); + assert!(super::usage_to_efi_key(Usage::from(0x00000000u32)).is_none()); + } + + // --- modifier_to_scan coverage --- + + #[test] + fn modifier_to_scan_covers_all_function_keys() { + use r_efi::protocols::hii_database::*; + let cases = [ + (INSERT_MODIFIER, super::SCAN_INSERT), + (DELETE_MODIFIER, super::SCAN_DELETE), + (PAGE_DOWN_MODIFIER, super::SCAN_PAGE_DOWN), + (PAGE_UP_MODIFIER, super::SCAN_PAGE_UP), + (HOME_MODIFIER, super::SCAN_HOME), + (END_MODIFIER, super::SCAN_END), + (LEFT_ARROW_MODIFIER, super::SCAN_LEFT), + (RIGHT_ARROW_MODIFIER, super::SCAN_RIGHT), + (DOWN_ARROW_MODIFIER, super::SCAN_DOWN), + (UP_ARROW_MODIFIER, super::SCAN_UP), + (FUNCTION_KEY_ONE_MODIFIER, super::SCAN_F1), + (FUNCTION_KEY_TWO_MODIFIER, super::SCAN_F2), + (FUNCTION_KEY_THREE_MODIFIER, super::SCAN_F3), + (FUNCTION_KEY_FOUR_MODIFIER, super::SCAN_F4), + (FUNCTION_KEY_FIVE_MODIFIER, super::SCAN_F5), + (FUNCTION_KEY_SIX_MODIFIER, super::SCAN_F6), + (FUNCTION_KEY_SEVEN_MODIFIER, super::SCAN_F7), + (FUNCTION_KEY_EIGHT_MODIFIER, super::SCAN_F8), + (FUNCTION_KEY_NINE_MODIFIER, super::SCAN_F9), + (FUNCTION_KEY_TEN_MODIFIER, super::SCAN_F10), + (FUNCTION_KEY_ELEVEN_MODIFIER, super::SCAN_F11), + (FUNCTION_KEY_TWELVE_MODIFIER, super::SCAN_F12), + (PAUSE_MODIFIER, super::SCAN_PAUSE), + ]; + for (modifier, expected_scan) in cases { + assert_eq!( + super::modifier_to_scan(modifier), + expected_scan, + "modifier_to_scan(0x{:04X}) should be 0x{:04X}", + modifier, + expected_scan + ); + } + } + + #[test] + fn modifier_to_scan_returns_null_for_unknown() { + assert_eq!(super::modifier_to_scan(0xFFFF), SCAN_NULL); + } + + // --- alt_gr key coverage --- + + #[test] + fn alt_gr_without_shift_produces_alt_gr_unicode() { + use r_efi::protocols::hii_database::ALT_GR_MODIFIER; + let mut kq = KeyQueue::default(); + // Layout with both the alt_gr modifier key and a key that has alt_gr mapping + let layout = crate::keyboard::layout::HiiKeyboardLayout { + keys: alloc::vec![ + HiiKey::Key(key_descriptor!(EfiKey::A2, '\0', '\0', '\0', '\0', ALT_GR_MODIFIER, 0)), + HiiKey::Key(key_descriptor!(EfiKey::C1, 'a', 'A', 'ä', 'Ä', 0, AFFECTED_BY_STANDARD_SHIFT)), + ], + ..crate::keyboard::layout::get_default_keyboard_layout() + }; + kq.set_layout(Some(layout)); + + // Press right alt (alt_gr), then 'a' + press_key(&mut kq, Usage::from(0x000700E6u32)); // A2 = right alt + press_key(&mut kq, usage_a()); + + let key_data = kq.pop_key().unwrap(); + assert_eq!(key_data.key.unicode_char, 'ä' as u16); + } + + #[test] + fn shift_plus_alt_gr_produces_shifted_alt_gr_unicode() { + use r_efi::protocols::hii_database::ALT_GR_MODIFIER; + let mut kq = KeyQueue::default(); + let layout = crate::keyboard::layout::HiiKeyboardLayout { + keys: alloc::vec![ + HiiKey::Key(key_descriptor!(EfiKey::A2, '\0', '\0', '\0', '\0', ALT_GR_MODIFIER, 0)), + HiiKey::Key(key_descriptor!( + EfiKey::LShift, + '\0', + '\0', + '\0', + '\0', + r_efi::protocols::hii_database::LEFT_SHIFT_MODIFIER, + 0 + )), + HiiKey::Key(key_descriptor!(EfiKey::C1, 'a', 'A', 'ä', 'Ä', 0, AFFECTED_BY_STANDARD_SHIFT)), + ], + ..crate::keyboard::layout::get_default_keyboard_layout() + }; + kq.set_layout(Some(layout)); + + press_key(&mut kq, usage_lshift()); + press_key(&mut kq, Usage::from(0x000700E6u32)); // right alt + press_key(&mut kq, usage_a()); + + let key_data = kq.pop_key().unwrap(); + assert_eq!(key_data.key.unicode_char, 'Ä' as u16); + } + + // --- modifier key up removes from active state --- + + #[test] + fn modifier_key_up_clears_shift_state() { + let mut kq = key_queue_with_default_layout(); + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + + press_key(&mut kq, usage_lshift()); + let state = kq.init_key_state(); + assert_ne!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED, 0); + + release_key(&mut kq, usage_lshift()); + let state = kq.init_key_state(); + assert_eq!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED, 0); + } + + // --- keystroke without layout does nothing --- + + #[test] + fn keystroke_without_layout_does_nothing() { + let mut kq = KeyQueue::default(); + // No layout set + press_key(&mut kq, usage_a()); + assert!(kq.pop_key().is_none()); + } + + // --- unsupported usage --- + + #[test] + fn unsupported_usage_is_ignored() { + let mut kq = key_queue_with_default_layout(); + // Usage 0x00070066 is in the "not used" range + press_key(&mut kq, Usage::from(0x00070066u32)); + assert!(kq.pop_key().is_none()); + } + + // --- key not found in layout --- + + #[test] + fn key_not_in_layout_is_ignored() { + let mut kq = KeyQueue::default(); + // Set a layout with only one key + let layout = crate::keyboard::layout::HiiKeyboardLayout { + keys: alloc::vec![HiiKey::Key(key_descriptor!( + EfiKey::C1, + 'a', + 'A', + '\0', + '\0', + 0, + AFFECTED_BY_STANDARD_SHIFT + ))], + ..crate::keyboard::layout::get_default_keyboard_layout() + }; + kq.set_layout(Some(layout)); + // Press a key that maps to EfiKey::E1 which is not in this minimal layout + press_key(&mut kq, usage_1()); + assert!(kq.pop_key().is_none()); + } + + // --- is_registered_key returns false when char mismatch --- + + #[test] + fn matches_registered_key_returns_false_on_char_mismatch() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key.unicode_char = 'b' as u16; + assert!(!OrdKeyData(key_data).matches_registered_key(&OrdKeyData(reg_key))); + } + + #[test] + fn matches_registered_key_returns_false_on_shift_mismatch() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + key_data.key_state.key_shift_state = protocols::simple_text_input_ex::SHIFT_STATE_VALID; + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key.unicode_char = 'a' as u16; + reg_key.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED; + assert!(!OrdKeyData(key_data).matches_registered_key(&OrdKeyData(reg_key))); + } + + #[test] + fn matches_registered_key_returns_false_on_toggle_mismatch() { + let mut key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + key_data.key.unicode_char = 'a' as u16; + key_data.key_state.key_toggle_state = protocols::simple_text_input_ex::TOGGLE_STATE_VALID; + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key.unicode_char = 'a' as u16; + reg_key.key_state.key_toggle_state = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::NUM_LOCK_ACTIVE; + assert!(!OrdKeyData(key_data).matches_registered_key(&OrdKeyData(reg_key))); + } + + // --- layout getter/setter --- + + #[test] + fn layout_getter_returns_set_layout() { + let mut kq = KeyQueue::default(); + assert!(kq.layout().is_none()); + let layout = crate::keyboard::layout::get_default_keyboard_layout(); + kq.set_layout(Some(layout.clone())); + assert!(kq.layout().is_some()); + } +} diff --git a/uefi_hid/src/keyboard/layout.rs b/uefi_hid/src/keyboard/layout.rs new file mode 100644 index 0000000..b55c71a --- /dev/null +++ b/uefi_hid/src/keyboard/layout.rs @@ -0,0 +1,726 @@ +//! HII Keyboard Layout support. +//! +//! Provides UEFI HII Keyboard Layout structures and serialization. +//! Folded in from the `HiiKeyboardLayout` crate in mu_plus. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::{format, string::String, vec, vec::Vec}; +use core::mem; + +use num_enum::TryFromPrimitive; + +use r_efi::{ + efi, + hii::{self, PACKAGE_END}, + protocols::hii_database::*, +}; +use scroll::{Pread, Pwrite, ctx}; + +/// GUID for default keyboard layout. +pub const DEFAULT_KEYBOARD_LAYOUT_GUID: efi::Guid = + efi::Guid::from_fields(0x3a4d7a7c, 0x18a, 0x4b42, 0x81, 0xb3, &[0xdc, 0x10, 0xe3, 0xb5, 0x91, 0xbd]); + +/// HII Keyboard Package List +/// Refer to UEFI spec version 2.10 section 33.3.1.2 which defines the generic header structure. This implementation +/// only supports HII Keyboard Packages; other HII package types (or mixes) are not supported. +#[derive(Debug, PartialEq, Eq)] +pub struct HiiKeyboardPkgList { + /// The GUID associated with this package list. + pub package_list_guid: efi::Guid, + /// The HiiKeyboardPkg contained in this package list. + pub package: HiiKeyboardPkg, +} + +/// HII Keyboard Package +/// Refer to UEFI spec version 2.10 section 33.3.9 which defines the keyboard package structure. +#[derive(Debug, PartialEq, Eq)] +pub struct HiiKeyboardPkg { + /// The list of keyboard layouts in this package. + pub layouts: Vec, +} + +/// HII Keyboard Layout +/// Refer to UEFI spec version 2.10 section 34.8.10 which defines the keyboard layout structure. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HiiKeyboardLayout { + /// The unique ID associated with this keyboard layout. + pub guid: efi::Guid, + /// A list of key descriptors + pub keys: Vec, + /// A list of descriptions for this keyboard layout. + pub descriptions: Vec, +} + +/// HII Key descriptor +/// Refer to UEFI spec version 2.10 section 34.10.10 which defines the key descriptor structure. +#[derive(Debug, Pread, Pwrite, PartialEq, Eq, Clone, Copy)] +#[repr(C)] +pub struct HiiKeyDescriptor { + /// Describes the physical key on the keyboard. + pub key: EfiKey, + /// Unicode character for the key (note: UEFI only supports UCS-2 encoding). + pub unicode: u16, + /// Unicode character for the key with the shift key being held down. + pub shifted_unicode: u16, + /// Unicode character for the key with the Alt-GR being held down. + pub alt_gr_unicode: u16, + /// Unicode character for the key with the Alt-GR and shift keys being held down. + pub shifted_alt_gr_unicode: u16, + /// Modifier keys are defined to allow for special functionality that is not necessarily accomplished by a printable + /// character. Many of these modifier keys are flags to toggle certain state bits on and off inside of a keyboard + /// driver. See [`r_efi::protocols::hii_database`] for modifier definitions. + pub modifier: u16, + /// Indicates what modifiers affect this key. See [`r_efi::protocols::hii_database`] for "affected by" definitions. + pub affected_attribute: u16, +} + +/// Non-Spacing HII Key Descriptor variant. Used for "non-spacing" keys. +/// Refer to discussion in UEFI spec version 2.10 section 33.2.4.3 for information on "non-spacing" keys and how they +/// are used. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct HiiNsKeyDescriptor { + /// The descriptor for the "non-spacing key" itself. + pub descriptor: HiiKeyDescriptor, + /// The list of descriptors that are active if the "non-spacing" key has been pressed. + pub dependent_keys: Vec, +} + +/// HII Key descriptor enumeration. +/// HII spec allows for two types of key descriptors - normal and "non-spacing". +/// Refer to UEFI spec version 2.10 section 33.2.4.3 +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HiiKey { + /// A standard key descriptor. + Key(HiiKeyDescriptor), + /// A non-spacing key descriptor. + NsKey(HiiNsKeyDescriptor), +} + +/// Enumeration of physical keys. +/// Refer to UEFI spec version 2.10 section 34.8.10 and section 33.2.4.1. +#[allow(missing_docs)] +#[rustfmt::skip] +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +//Note: UEFI specifies this as an C enum. That means the size is a bit ambiguous; but most compilers +//will make it 32-bit, so that's what this implementation assumes. +#[repr(u32)] +pub enum EfiKey { + LCtrl = 0, A0 = 1, LAlt = 2, SpaceBar = 3, A2 = 4, A3 = 5, A4 = 6, RCtrl = 7, LeftArrow = 8, DownArrow = 9, + RightArrow = 10, Zero = 11, Period = 12, Enter = 13, LShift = 14, B0 = 15, B1 = 16, B2 = 17, B3 = 18, B4 = 19, + B5 = 20, B6 = 21, B7 = 22, B8 = 23, B9 = 24, B10 = 25, RShift = 26, UpArrow = 27, One = 28, Two = 29, Three = 30, + CapsLock = 31, C1 = 32, C2 = 33, C3 = 34, C4 = 35, C5 = 36, C6 = 37, C7 = 38, C8 = 39, C9 = 40, C10 = 41, + C11 = 42, C12 = 43, Four = 44, Five = 45, Six = 46, Plus = 47, Tab = 48, D1 = 49, D2 = 50, D3 = 51, D4 = 52, + D5 = 53, D6 = 54, D7 = 55, D8 = 56, D9 = 57, D10 = 58, D11 = 59, D12 = 60, D13 = 61, Del = 62, End = 63, + PgDn = 64, Seven = 65, Eight = 66, Nine = 67, E0 = 68, E1 = 69, E2 = 70, E3 = 71, E4 = 72, E5 = 73, E6 = 74, + E7 = 75, E8 = 76, E9 = 77, E10 = 78, E11 = 79, E12 = 80, BackSpace = 81, Ins = 82, Home = 83, PgUp = 84, + NLck = 85, Slash = 86, Asterisk = 87, Minus = 88, Esc = 89, F1 = 90, F2 = 91, F3 = 92, F4 = 93, F5 = 94, + F6 = 95, F7 = 96, F8 = 97, F9 = 98, F10 = 99, F11 = 100, F12 = 101, Print = 102, SLck = 103, Pause = 104, + Intl0 = 105, Intl1 = 106, Intl2 = 107, Intl3 = 108, Intl4 = 109, Intl5 = 110, Intl6 = 111, Intl7 = 112, + Intl8 = 113, Intl9 = 114, +} + +/// Description for a keyboard layout. +/// Refer to UEFI spec version 2.10 section 34.8.10 +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HiiKeyboardDescription { + /// The language code for the description (e.g. "en-US") + pub language: String, + /// The description (e.g. "English Keyboard") + pub description: String, +} + +fn gread_guid(src: &[u8], offset: &mut usize) -> Result { + Ok(efi::Guid::from_fields( + src.gread(offset)?, + src.gread(offset)?, + src.gread(offset)?, + src.gread(offset)?, + src.gread(offset)?, + src.gread_with::<&[u8]>(offset, 6)? + .try_into() + .map_err(|_| scroll::Error::BadInput { size: 0, msg: "GUID node6 must be 6 bytes" })?, + )) +} + +impl ctx::TryFromCtx<'_> for HiiKeyboardPkgList { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> { + //Note: This is not a general purpose HII package list reader: it only supports a package list with a single + //keyboard layout package in it. + let offset = &mut 0; + //EFI_HII_PACKAGE_LIST_HEADER::PackageListGuid + let guid = gread_guid(src, offset)?; + //EFI_HII_PACKAGE_LIST_HEADER::PackageLength + let _package_length: u32 = src.gread(offset)?; + + //Read HiiKeyboard Pkg + let hii_keyboard_pkg: HiiKeyboardPkg = src.gread(offset)?; + + //Read EFI_HHI_PACKAGE_END package + let _pkg_end_length_type: u32 = src.gread(offset)?; + + Ok((HiiKeyboardPkgList { package_list_guid: guid, package: hii_keyboard_pkg }, *offset)) + } +} + +impl ctx::TryIntoCtx for &HiiKeyboardPkgList { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: ()) -> Result { + let offset = &mut 0; + //EFI_HII_PACKAGE_LIST_HEADER::PackageListGuid + dest.gwrite(&self.package_list_guid.as_bytes()[..], offset)?; + + //EFI_HII_PACKAGE_LIST_HEADER::PackageLength will be updated at the end. + let mut package_length_offset = *offset; + *offset += 4; + + //Write HiiKeyboardPkg + dest.gwrite(&self.package, offset)?; + + //EFI_HII_PACKAGE_END + let length_type: u32 = 4 | ((PACKAGE_END as u32) << 24); + dest.gwrite(length_type, offset)?; + + //go back and update EFI_HII_PACKAGE_LIST_HEADER::PackageLength + dest.gwrite(*offset as u32, &mut package_length_offset)?; + + Ok(*offset) + } +} + +impl ctx::TryFromCtx<'_> for HiiKeyboardPkg { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], ctx: ()) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + //EFI_HII_KEYBOARD_PACKAGE_HDR::Header (bitfield as single u32) + let length_type: u32 = src.gread(offset)?; + let pkg_type = (length_type >> 24) as u8; + if pkg_type != hii::PACKAGE_KEYBOARD_LAYOUT { + return Err(scroll::Error::BadInput { size: 0, msg: "Unsupported Pkg Type" }); + } + //EFI_HII_KEYBOARD_PACKAGE_HDR::LayoutCount + let layout_count: u16 = src.gread(offset)?; + + //EFI_HII_KEYBOARD_PACKAGE_HDR::Layout[] array into vector. + let mut layouts = Vec::with_capacity(layout_count as usize); + for _ in 0..layout_count { + layouts.push(src.gread_with(offset, ctx)?); + } + + Ok((HiiKeyboardPkg { layouts }, *offset)) + } +} + +impl ctx::TryIntoCtx for &HiiKeyboardPkg { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: ()) -> Result { + let offset = &mut 0; + //EFI_HII_KEYBOARD_PKG_HDR::Header::Length will be updated at the end. + *offset += 4; + //EFI_HII_KEYBOARD_PKG_HDR::LayoutCount + dest.gwrite(self.layouts.len() as u16, offset)?; + //EFI_HII_KEYBOARD_PKG_HDR::Layout[] + for layout in &self.layouts { + dest.gwrite(layout, offset)?; + } + //update EFI_HII_KEYBOARD_PKG_HEADER at offset zero. + let length = *offset; + let length_type: u32 = (hii::PACKAGE_KEYBOARD_LAYOUT as u32) << 24; + let length_type = length_type | (length & 0xFFFFFF) as u32; + dest.gwrite(length_type, &mut 0)?; + + Ok(*offset) + } +} + +impl ctx::TryFromCtx<'_> for HiiKeyboardLayout { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + //EFI_HII_KEYBOARD_LAYOUT::LayoutLength + let _layout_length: u16 = src.gread(offset)?; + //EFI_HII_KEYBOARD_LAYOUT::Guid + let guid = gread_guid(src, offset)?; + //EFI_HII_KEYBOARD_LAYOUT::LayoutDescriptorStringOffset + let layout_descriptor_string_offset: u32 = src.gread(offset)?; + //EFI_HII_KEYBOARD_LAYOUT::DescriptorCount + let descriptor_count: u8 = src.gread(offset)?; + + //EFI_HII_KEYBOARD_LAYOUT::Descriptors[] array into vector. Note: descriptor_count is not used for iteration + //since ns_keys may consume multiple descriptors which are included in the count, resulting in a vector of "real" + //descriptors that is smaller than the descriptor_count. + let descriptor_start = *offset; + let mut descriptors = vec![]; + while *offset < layout_descriptor_string_offset as usize { + descriptors.push(src.gread(offset)?); + } + let parsed_count = (*offset - descriptor_start) / mem::size_of::(); + if parsed_count != descriptor_count as usize { + return Err(scroll::Error::BadInput { + size: 0, + msg: "parsed descriptor byte count does not match declared descriptor_count", + }); + } + + //EFI_DESCRIPTION_STRING_BUNDLE::DescriptionCount + let description_count: u16 = src.gread(offset)?; + let mut descriptions = Vec::with_capacity(description_count as usize); + //EFI_DESCRIPTION_STRING_BUNDLE::DescriptionString[] + for _ in 0..description_count { + descriptions.push(src.gread(offset)?); + } + + Ok((HiiKeyboardLayout { guid, keys: descriptors, descriptions }, *offset)) + } +} + +impl ctx::TryIntoCtx for &HiiKeyboardLayout { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: ()) -> Result { + let offset = &mut 0; + //EFI_HII_KEYBOARD_LAYOUT::LayoutLength will be updated at the end. + *offset += 2; + //EFI_HII_KEYBOARD_LAYOUT::Guid + dest.gwrite(&self.guid.as_bytes()[..], offset)?; + //EFI_HII_KEYBOARD_LAYOUT::LayoutDescriptorStringOffset will be updated after writing out the descriptors. + let mut descriptor_string_offset = *offset; + *offset += 4; + + //EFI_HII_KEYBOARD_LAYOUT::DescriptorCount will be updated after writing out the descriptors. + let mut descriptor_count_offset = *offset; + *offset += 1; + + let descriptor_start = *offset; + //EFI_HII_KEYBOARD_LAYOUT::Descriptors[] + for descriptor in &self.keys { + //Note: may expand to more than one descriptor due to non-spacing keys. + dest.gwrite(descriptor, offset)?; + } + + //Go back and update EFI_HII_KEYBOARD_LAYOUT::DescriptorCount + let descriptor_count = (*offset - descriptor_start) / mem::size_of::(); + dest.gwrite(descriptor_count as u8, &mut descriptor_count_offset)?; + + //Go back and update EFI_HII_KEYBOARD_LAYOUT::LayoutDescriptorStringOffset. + dest.gwrite(*offset as u32, &mut descriptor_string_offset)?; + + //EFI_DESCRIPTION_STRING_BUNDLE::DescriptionCount + dest.gwrite(self.descriptions.len() as u16, offset)?; + + //EFI_DESCRIPTION_STRING_BUNDLE::DescriptionString[] + for description in &self.descriptions { + dest.gwrite(description, offset)?; + } + + //Go back and update EFI_HII_KEYBOARD_LAYOUT::LayoutLength + dest.gwrite(*offset as u16, &mut 0)?; + + Ok(*offset) + } +} + +impl ctx::TryFromCtx<'_> for HiiKey { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + let descriptor: HiiKeyDescriptor = src.gread(offset)?; + if descriptor.modifier == NS_KEY_MODIFIER { + //For Non-Spacing keys, consume descriptors until we find one without EFI_NS_KEY_DEPENDENCY_MODIFIER or run out. + //Refer to UEFI spec 2.10 section 33.2.4.3 for details. + let mut dependent_keys = vec![]; + while let Ok(dependent_key) = src.pread::(*offset) { + if dependent_key.modifier == NS_KEY_DEPENDENCY_MODIFIER { + //found a dependent descriptor. Re-read it with gread to update offset. + dependent_keys.push(src.gread(offset)?); + } else { + //found a descriptor without EFI_NS_KEY_DEPENDENCY_MODIFIER + break; + } + } + Ok((HiiKey::NsKey(HiiNsKeyDescriptor { descriptor, dependent_keys }), *offset)) + } else { + Ok((HiiKey::Key(descriptor), *offset)) + } + } +} + +impl ctx::TryIntoCtx for &HiiKey { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: ()) -> Result { + let offset = &mut 0; + match self { + HiiKey::Key(descriptor) => { + dest.gwrite(descriptor, offset)?; + } + HiiKey::NsKey(ns_descriptor) => { + dest.gwrite(&ns_descriptor.descriptor, offset)?; + for descriptor in &ns_descriptor.dependent_keys { + dest.gwrite(descriptor, offset)?; + } + } + } + Ok(*offset) + } +} + +impl ctx::TryFromCtx<'_> for HiiKeyboardDescription { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + //consume u16 characters until NULL. + let mut desc_chars = vec![]; + loop { + let desc_char: u16 = src.gread(offset)?; + if desc_char == 0 { + break; + } + desc_chars.push(desc_char); + } + //convert to string. Note: UEFI spec uses UCS-2 encoding, so all valid inputs should translate to UTF-16 without + //error. + let desc_string = String::from_utf16(&desc_chars) + .map_err(|_| scroll::Error::BadInput { size: 0, msg: "Invalid string in keyboard description." })?; + + //split the resulting string on the first space - this gives us language and description. + if let Some((lang, desc)) = desc_string.split_once(' ') { + Ok((HiiKeyboardDescription { language: String::from(lang), description: String::from(desc) }, *offset)) + } else { + Err(scroll::Error::BadInput { size: 0, msg: "No space in keyboard description." }) + } + } +} + +impl ctx::TryIntoCtx for &HiiKeyboardDescription { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: ()) -> Result { + let offset = &mut 0; + //Format as EFI_DESCRIPTION_STRING per UEFI spec 2.10 section 34.8.10. + let desc_string = format!("{} {}", self.language, self.description); + let mut characters: Vec = desc_string.encode_utf16().collect(); + characters.push(0); + for character in characters { + dest.gwrite(character, offset)?; + } + Ok(*offset) + } +} + +impl ctx::TryFromCtx<'_, scroll::Endian> for EfiKey { + type Error = scroll::Error; + fn try_from_ctx(src: &'_ [u8], _ctx: scroll::Endian) -> Result<(Self, usize), Self::Error> { + let offset = &mut 0; + let efi_key = EfiKey::try_from(src.gread::(offset)?) + .map_err(|_| scroll::Error::BadInput { size: 0, msg: "Invalid EfiKey enum value" })?; + Ok((efi_key, *offset)) + } +} + +impl ctx::TryIntoCtx for &EfiKey { + type Error = scroll::Error; + fn try_into_ctx(self, dest: &mut [u8], _ctx: scroll::Endian) -> Result { + let offset = &mut 0; + dest.gwrite(*self as u32, offset)?; + Ok(*offset) + } +} + +// Convenience macro for defining HiiKey::Key structures. +macro_rules! key { + ($key:expr, $unicode:literal, $shifted:literal, $alt_gr:literal, $shifted_alt_gr:literal, $modifier:expr, $affected:expr ) => { + HiiKey::Key(key_descriptor!($key, $unicode, $shifted, $alt_gr, $shifted_alt_gr, $modifier, $affected)) + }; +} + +// convenience macro for defining HiiKeyDescriptor structures. +// note: for unicode characters, these are encoded as u16 for compliance with UEFI spec. UEFI only supports UCS-2 +// encoding - so unicode characters that require more than two bytes under UTF-16 are not supported (and will panic). +macro_rules! key_descriptor { + ($key:expr, $unicode:literal, $shifted:literal, $alt_gr:literal, $shifted_alt_gr:literal, $modifier:expr, $affected:expr ) => { + HiiKeyDescriptor { + key: $key, + unicode: $unicode.encode_utf16(&mut [0u16; 1])[0], + shifted_unicode: $shifted.encode_utf16(&mut [0u16; 1])[0], + alt_gr_unicode: $alt_gr.encode_utf16(&mut [0u16; 1])[0], + shifted_alt_gr_unicode: $shifted_alt_gr.encode_utf16(&mut [0u16; 1])[0], + modifier: $modifier, + affected_attribute: $affected, + } + }; +} + +/// Returns a default HiiKeyboardLayout (which is a standard US-104 layout) +#[rustfmt::skip] +pub fn get_default_keyboard_layout() -> HiiKeyboardLayout { + HiiKeyboardLayout { + guid: DEFAULT_KEYBOARD_LAYOUT_GUID, + keys: vec![ + key!(EfiKey::C1, 'a', 'A', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B5, 'b', 'B', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B3, 'c', 'C', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C3, 'd', 'D', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D3, 'e', 'E', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C4, 'f', 'F', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C5, 'g', 'G', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C6, 'h', 'H', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D8, 'i', 'I', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C7, 'j', 'J', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C8, 'k', 'K', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C9, 'l', 'L', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B7, 'm', 'M', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B6, 'n', 'N', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D9, 'o', 'O', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D10, 'p', 'P', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D1, 'q', 'Q', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D4, 'r', 'R', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::C2, 's', 'S', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D5, 't', 'T', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D7, 'u', 'U', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B4, 'v', 'V', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D2, 'w', 'W', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B2, 'x', 'X', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::D6, 'y', 'Y', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::B1, 'z', 'Z', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_CAPS_LOCK), + key!(EfiKey::E1, '1', '!', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E2, '2', '@', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E3, '3', '#', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E4, '4', '$', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E5, '5', '%', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E6, '6', '^', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E7, '7', '&', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E8, '8', '*', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E9, '9', '(', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E10, '0', ')', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::Enter, '\x0d', '\x0d', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Esc, '\x1b', '\x1b', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::BackSpace, '\x08', '\x08', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Tab, '\x09', '\x09', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::SpaceBar, ' ', ' ', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::E11, '-', '_', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E12, '=', '+', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::D11, '[', '{', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::D12, ']', '}', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::D13, '\\', '|', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::C12, '\\', '|', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::C10, ';', ':', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::C11, '\'', '"', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::E0, '`', '~', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::B8, ',', '<', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::B9, '.', '>', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::B10, '/', '?', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT ), + key!(EfiKey::CapsLock, '\0', '\0', '\0', '\0', CAPS_LOCK_MODIFIER, 0 ), + key!(EfiKey::F1, '\0', '\0', '\0', '\0', FUNCTION_KEY_ONE_MODIFIER, 0 ), + key!(EfiKey::F2, '\0', '\0', '\0', '\0', FUNCTION_KEY_TWO_MODIFIER, 0 ), + key!(EfiKey::F3, '\0', '\0', '\0', '\0', FUNCTION_KEY_THREE_MODIFIER, 0 ), + key!(EfiKey::F4, '\0', '\0', '\0', '\0', FUNCTION_KEY_FOUR_MODIFIER, 0 ), + key!(EfiKey::F5, '\0', '\0', '\0', '\0', FUNCTION_KEY_FIVE_MODIFIER, 0 ), + key!(EfiKey::F6, '\0', '\0', '\0', '\0', FUNCTION_KEY_SIX_MODIFIER, 0 ), + key!(EfiKey::F7, '\0', '\0', '\0', '\0', FUNCTION_KEY_SEVEN_MODIFIER, 0 ), + key!(EfiKey::F8, '\0', '\0', '\0', '\0', FUNCTION_KEY_EIGHT_MODIFIER, 0 ), + key!(EfiKey::F9, '\0', '\0', '\0', '\0', FUNCTION_KEY_NINE_MODIFIER, 0 ), + key!(EfiKey::F10, '\0', '\0', '\0', '\0', FUNCTION_KEY_TEN_MODIFIER, 0 ), + key!(EfiKey::F11, '\0', '\0', '\0', '\0', FUNCTION_KEY_ELEVEN_MODIFIER, 0 ), + key!(EfiKey::F12, '\0', '\0', '\0', '\0', FUNCTION_KEY_TWELVE_MODIFIER, 0 ), + key!(EfiKey::Print, '\0', '\0', '\0', '\0', PRINT_MODIFIER, 0 ), + key!(EfiKey::SLck, '\0', '\0', '\0', '\0', SCROLL_LOCK_MODIFIER, 0 ), + key!(EfiKey::Pause, '\0', '\0', '\0', '\0', PAUSE_MODIFIER, 0 ), + key!(EfiKey::Ins, '\0', '\0', '\0', '\0', INSERT_MODIFIER, 0 ), + key!(EfiKey::Home, '\0', '\0', '\0', '\0', HOME_MODIFIER, 0 ), + key!(EfiKey::PgUp, '\0', '\0', '\0', '\0', PAGE_UP_MODIFIER, 0 ), + key!(EfiKey::Del, '\0', '\0', '\0', '\0', DELETE_MODIFIER, 0 ), + key!(EfiKey::End, '\0', '\0', '\0', '\0', END_MODIFIER, 0 ), + key!(EfiKey::PgDn, '\0', '\0', '\0', '\0', PAGE_DOWN_MODIFIER, 0 ), + key!(EfiKey::RightArrow, '\0', '\0', '\0', '\0', RIGHT_ARROW_MODIFIER, 0 ), + key!(EfiKey::LeftArrow, '\0', '\0', '\0', '\0', LEFT_ARROW_MODIFIER, 0 ), + key!(EfiKey::DownArrow, '\0', '\0', '\0', '\0', DOWN_ARROW_MODIFIER, 0 ), + key!(EfiKey::UpArrow, '\0', '\0', '\0', '\0', UP_ARROW_MODIFIER, 0 ), + key!(EfiKey::NLck, '\0', '\0', '\0', '\0', NUM_LOCK_MODIFIER, 0 ), + key!(EfiKey::Slash, '/', '/', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Asterisk, '*', '*', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Minus, '-', '-', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Plus, '+', '+', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::Enter, '\x0d', '\x0d', '\0', '\0', NULL_MODIFIER, 0 ), + key!(EfiKey::One, '1', '1', '\0', '\0', END_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Two, '2', '2', '\0', '\0', DOWN_ARROW_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Three, '3', '3', '\0', '\0', PAGE_DOWN_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Four, '4', '4', '\0', '\0', LEFT_ARROW_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Five, '5', '5', '\0', '\0', NULL_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Six, '6', '6', '\0', '\0', RIGHT_ARROW_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Seven, '7', '7', '\0', '\0', HOME_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Eight, '8', '8', '\0', '\0', UP_ARROW_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Nine, '9', '9', '\0', '\0', PAGE_UP_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Zero, '0', '0', '\0', '\0', INSERT_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::Period, '.', '.', '\0', '\0', DELETE_MODIFIER, AFFECTED_BY_STANDARD_SHIFT | AFFECTED_BY_NUM_LOCK ), + key!(EfiKey::A4, '\0', '\0', '\0', '\0', MENU_MODIFIER, 0 ), + key!(EfiKey::LCtrl, '\0', '\0', '\0', '\0', LEFT_CONTROL_MODIFIER, 0 ), + key!(EfiKey::LShift, '\0', '\0', '\0', '\0', LEFT_SHIFT_MODIFIER, 0 ), + key!(EfiKey::LAlt, '\0', '\0', '\0', '\0', LEFT_ALT_MODIFIER, 0 ), + key!(EfiKey::A0, '\0', '\0', '\0', '\0', LEFT_LOGO_MODIFIER, 0 ), + key!(EfiKey::RCtrl, '\0', '\0', '\0', '\0', RIGHT_CONTROL_MODIFIER, 0 ), + key!(EfiKey::RShift, '\0', '\0', '\0', '\0', RIGHT_SHIFT_MODIFIER, 0 ), + key!(EfiKey::A2, '\0', '\0', '\0', '\0', RIGHT_ALT_MODIFIER, 0 ), + key!(EfiKey::A3, '\0', '\0', '\0', '\0', RIGHT_LOGO_MODIFIER, 0 ), + ], + descriptions: vec![ + HiiKeyboardDescription { + language: String::from("en-US"), + description: String::from("English Keyboard") + } + ] + } +} + +/// Returns a default keyboard layout package. +pub fn get_default_keyboard_pkg() -> HiiKeyboardPkg { + HiiKeyboardPkg { layouts: vec![get_default_keyboard_layout()] } +} + +/// Returns a default keyboard layout package list. +pub fn get_default_keyboard_pkg_list() -> HiiKeyboardPkgList { + HiiKeyboardPkgList { + package_list_guid: efi::Guid::from_fields( + 0xc0f3b43, + 0x44de, + 0x4907, + 0xb4, + 0x78, + &[0x22, 0x5f, 0x6f, 0x62, 0x89, 0xdc], + ), + package: get_default_keyboard_pkg(), + } +} + +/// Returns a default keyboard layout package list as a byte vector. +/// +/// This is suitable for use with [`r_efi::protocols::hii_database::ProtocolNewPackageList`]. +/// +/// # Panics +/// Panics if the built-in default keyboard layout fails to serialize. This should never occur +/// as the layout is a compile-time constant. +pub fn get_default_keyboard_pkg_list_buffer() -> Vec { + // Upper bound: ~86 keys × 16 bytes + headers + descriptions. Buffer is resized to actual size after writing. + let mut buffer = vec![0u8; 4096]; + + let result = buffer.pwrite(&get_default_keyboard_pkg_list(), 0); + if let Ok(buffer_size) = result { + buffer.resize(buffer_size, 0); + buffer + } else { + panic!("Unexpected error serializing HII Keyboard Package List: {:?}", result); + } +} + +/// Defines errors that can occur while parsing keyboard layout. +#[derive(Debug)] +pub enum LayoutError { + /// Malformed key buffer + ParseError(scroll::Error), +} + +/// Returns a HiiKeyboardLayout structure parsed from the given buffer. +pub fn keyboard_layout_from_buffer(buffer: &[u8]) -> Result { + let layout = buffer.pread::(0).map_err(LayoutError::ParseError)?; + log::trace!( + "keyboard_layout_from_buffer: parsed layout with {:?} keys, {:?} descriptions", + layout.keys.len(), + layout.descriptions.len(), + ); + Ok(layout) +} + +#[cfg(test)] +mod tests { + use scroll::{Pread, Pwrite}; + + use super::*; + + #[test] + fn hii_keyboard_package_serialize_deserialize_should_produce_consistent_results() { + let mut buffer = [0u8; 4096]; + + let package = get_default_keyboard_pkg(); + buffer.pwrite(&package, 0).unwrap(); + + let package2: HiiKeyboardPkg = buffer.pread(0).unwrap(); + assert_eq!(package, package2); + } + + #[test] + fn efi_key_from_u32_valid_values() { + assert_eq!(EfiKey::try_from(0u32), Ok(EfiKey::LCtrl)); + assert_eq!(EfiKey::try_from(32u32), Ok(EfiKey::C1)); + assert_eq!(EfiKey::try_from(114u32), Ok(EfiKey::Intl9)); + } + + #[test] + fn efi_key_from_u32_invalid_value() { + assert!(EfiKey::try_from(115u32).is_err()); + assert!(EfiKey::try_from(u32::MAX).is_err()); + } + + #[test] + fn efi_key_roundtrip_through_u32() { + for val in 0..=114u32 { + let key = EfiKey::try_from(val).unwrap(); + assert_eq!(key as u32, val); + } + } + + #[test] + fn hii_keyboard_pkg_list_roundtrip() { + let pkg_list = get_default_keyboard_pkg_list(); + let buffer = get_default_keyboard_pkg_list_buffer(); + let parsed: HiiKeyboardPkgList = buffer.pread(0).unwrap(); + assert_eq!(pkg_list, parsed); + } + + #[test] + fn keyboard_layout_from_buffer_valid() { + let layout = get_default_keyboard_layout(); + let mut buffer = [0u8; 4096]; + let size = buffer.pwrite(&layout, 0).unwrap(); + let parsed = keyboard_layout_from_buffer(&buffer[..size]).unwrap(); + assert_eq!(layout, parsed); + } + + #[test] + fn keyboard_layout_from_buffer_empty_fails() { + let result = keyboard_layout_from_buffer(&[]); + assert!(matches!(result, Err(LayoutError::ParseError(_)))); + } + + #[test] + fn keyboard_layout_from_buffer_truncated_fails() { + let result = keyboard_layout_from_buffer(&[0u8; 4]); + assert!(matches!(result, Err(LayoutError::ParseError(_)))); + } + + #[test] + fn default_layout_has_expected_key_count() { + let layout = get_default_keyboard_layout(); + // US-104 default layout: 105 key entries (includes numpad duplicates and modifier keys) + assert_eq!(layout.keys.len(), 105); + } + + #[test] + fn default_layout_has_description() { + let layout = get_default_keyboard_layout(); + assert_eq!(layout.descriptions.len(), 1); + assert_eq!(layout.descriptions[0].language, "en-US"); + } + + #[test] + fn hii_keyboard_pkg_invalid_type_fails() { + let mut buffer = [0u8; 64]; + // Write a package header with wrong type (type = 0x01 instead of PACKAGE_KEYBOARD_LAYOUT) + let bad_length_type: u32 = 0x01_000006; // type=1, length=6 + buffer.pwrite(bad_length_type, 0).unwrap(); + let result = buffer.pread::(0); + assert!(result.is_err()); + } +} diff --git a/uefi_hid/src/keyboard/mod.rs b/uefi_hid/src/keyboard/mod.rs new file mode 100644 index 0000000..9b7c2b0 --- /dev/null +++ b/uefi_hid/src/keyboard/mod.rs @@ -0,0 +1,2190 @@ +//! Provides Keyboard HID support. +//! +//! This module handles the core logic for processing keystrokes from HID +//! devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +//! +pub mod key_queue; +pub mod layout; +pub(crate) mod simple_text_in; +pub(crate) mod simple_text_in_ex; + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + vec, + vec::Vec, +}; +use core::ptr; + +use r_efi::{efi, protocols}; + +use hidparser::{ + ArrayField, ReportDescriptor, ReportField, VariableField, + report_data_types::{ReportId, Usage}, +}; + +use patina::{ + boot_services::{ + BootServices, + c_ptr::PtrMetadata, + event::{EventTimerType, EventType}, + tpl::Tpl, + }, + tpl_mutex::TplMutex, +}; + +use crate::hid_io::{HidIo, HidReportReceiver}; + +#[cfg(feature = "ctrl-alt-del")] +use r_efi::protocols::simple_text_input_ex::{ + LEFT_ALT_PRESSED, LEFT_CONTROL_PRESSED, RIGHT_ALT_PRESSED, RIGHT_CONTROL_PRESSED, SHIFT_STATE_VALID, +}; + +use self::key_queue::OrdKeyData; + +// Repeat key delay: 500ms in 100ns units. +const REPEAT_KEY_DELAY: u64 = 5_000_000; +// Repeat key rate: 20ms in 100ns units (~50 keys/sec). +const REPEAT_KEY_RATE: u64 = 200_000; + +// usages supported by this module +const KEYBOARD_MODIFIER_USAGE_MIN: u32 = 0x000700E0; +const KEYBOARD_MODIFIER_USAGE_MAX: u32 = 0x000700E7; +const KEYBOARD_USAGE_MIN: u32 = 0x00070001; +const KEYBOARD_USAGE_MAX: u32 = 0x00070065; +const LED_USAGE_MIN: u32 = 0x00080001; +const LED_USAGE_MAX: u32 = 0x00080005; + +// maps a given field to a routine that handles input from it. +struct KeyInputFieldHandler { + field: F, + report_handler: fn(current_keys: &mut BTreeSet, field: &F, report: &[u8]), +} + +// maps a given field to a routine that builds output reports for it. +struct KeyOutputFieldBuilder { + field: VariableField, + field_builder: fn(led_state: &BTreeSet, field: &VariableField, report: &mut [u8]), +} + +// Defines an input report and the fields of interest in it. +#[derive(Default)] +struct KeyboardInputReportSpec { + report_id: Option, + report_size: usize, + relevant_variable_fields: Vec>, + relevant_array_fields: Vec>, +} + +// Defines an output report and the fields of interest in it. +#[derive(Default)] +struct KeyboardOutputReportSpec { + report_id: Option, + report_size: usize, + relevant_variable_fields: Vec, +} + +// Result of processing a single HID report. +struct ProcessReportResult { + should_signal_notify: bool, + output_reports: Vec<(Option, Vec)>, + pressed_keys: Vec, + released_keys: Vec, +} + +// Core keyboard data processing logic, independent of UEFI boot services. +struct KeyboardProcessor { + input_reports: BTreeMap, KeyboardInputReportSpec>, + output_builders: Vec, + report_id_present: bool, + last_keys: BTreeSet, + current_keys: BTreeSet, + led_state: BTreeSet, + notification_callbacks: BTreeMap, + next_notify_handle: usize, +} + +impl KeyboardProcessor { + // Creates a new processor with default state. + fn new() -> Self { + Self { + input_reports: BTreeMap::new(), + output_builders: Vec::new(), + report_id_present: false, + last_keys: BTreeSet::new(), + current_keys: BTreeSet::new(), + led_state: BTreeSet::new(), + notification_callbacks: BTreeMap::new(), + next_notify_handle: 0, + } + } + + // Parses a report descriptor and registers input/output field handlers. + fn process_descriptor(&mut self, descriptor: ReportDescriptor) -> Result<(), efi::Status> { + let multiple_reports = + descriptor.input_reports.len() > 1 || descriptor.output_reports.len() > 1 || descriptor.features.len() > 1; + + for report in &descriptor.input_reports { + let mut report_data_spec = KeyboardInputReportSpec { report_id: report.report_id, ..Default::default() }; + + self.report_id_present = report.report_id.is_some(); + + if multiple_reports && !self.report_id_present { + return Err(efi::Status::DEVICE_ERROR); + } + + report_data_spec.report_size = report.size_in_bits.div_ceil(8); + + for field in &report.fields { + match field { + ReportField::Variable(field) => { + if let KEYBOARD_MODIFIER_USAGE_MIN..=KEYBOARD_MODIFIER_USAGE_MAX = field.usage.into() { + report_data_spec.relevant_variable_fields.push(KeyInputFieldHandler { + field: field.clone(), + report_handler: handle_variable_key, + }); + } + } + ReportField::Array(field) => { + for usage_list in &field.usage_list { + if usage_list.contains(Usage::from(KEYBOARD_USAGE_MIN)) + || usage_list.contains(Usage::from(KEYBOARD_USAGE_MAX)) + { + report_data_spec.relevant_array_fields.push(KeyInputFieldHandler { + field: field.clone(), + report_handler: handle_array_key, + }); + break; + } + } + } + ReportField::Padding(_) => (), + } + } + if !(report_data_spec.relevant_variable_fields.is_empty() + && report_data_spec.relevant_array_fields.is_empty()) + { + self.input_reports.insert(report_data_spec.report_id, report_data_spec); + } + } + + for report in &descriptor.output_reports { + let mut report_builder = KeyboardOutputReportSpec { report_id: report.report_id, ..Default::default() }; + + self.report_id_present = report.report_id.is_some(); + + if multiple_reports && !self.report_id_present { + return Err(efi::Status::DEVICE_ERROR); + } + + report_builder.report_size = usize::div_ceil(report.size_in_bits, 8); + + for field in &report.fields { + match field { + ReportField::Variable(field) => { + if let LED_USAGE_MIN..=LED_USAGE_MAX = field.usage.into() { + report_builder + .relevant_variable_fields + .push(KeyOutputFieldBuilder { field: field.clone(), field_builder: build_led_report }) + } + } + ReportField::Array(_) | ReportField::Padding(_) => (), + } + } + if !report_builder.relevant_variable_fields.is_empty() { + self.output_builders.push(report_builder); + } + } + + if self.input_reports.is_empty() && self.output_builders.is_empty() { + log::trace!("process_descriptor: no relevant keyboard fields found"); + Err(efi::Status::UNSUPPORTED) + } else { + log::trace!( + "process_descriptor: {:?} input report(s) with {:?} variable/{:?} array fields, {:?} output builder(s)", + self.input_reports.len(), + self.input_reports.values().map(|r| r.relevant_variable_fields.len()).sum::(), + self.input_reports.values().map(|r| r.relevant_array_fields.len()).sum::(), + self.output_builders.len(), + ); + Ok(()) + } + } + + // Resets key tracking state and the key queue. + fn reset(&mut self, kq: &mut key_queue::KeyQueue, extended_verification: bool) { + self.last_keys.clear(); + self.current_keys.clear(); + kq.reset(extended_verification); + if extended_verification { + self.led_state.clear(); + } + } + + // Builds output reports reflecting the current LED state. + fn build_led_output_reports(&self) -> Vec<(Option, Vec)> { + let mut output_vec = Vec::new(); + for output_builder in &self.output_builders { + let mut report_buffer = vec![0u8; output_builder.report_size]; + for field_builder in &output_builder.relevant_variable_fields { + (field_builder.field_builder)(&self.led_state, &field_builder.field, report_buffer.as_mut_slice()); + } + output_vec.push((output_builder.report_id, report_buffer)); + } + output_vec + } + + // Processes an incoming HID report, queuing keystrokes and building LED output. + fn process_report(&mut self, report: &[u8], kq: &mut key_queue::KeyQueue) -> ProcessReportResult { + let mut result = ProcessReportResult { + should_signal_notify: false, + output_reports: Vec::new(), + pressed_keys: Vec::new(), + released_keys: Vec::new(), + }; + if report.is_empty() { + return result; + } + let (report_id, report) = match self.report_id_present { + true => (Some(ReportId::from(&report[0..1])), &report[1..]), + false => (None, &report[0..]), + }; + + if report.is_empty() { + return result; + } + + if let Some(report_data) = self.input_reports.get(&report_id) { + if report.len() != report_data.report_size { + log::trace!( + "receive_report: unexpected report length for report_id: {:?}. expected {:?}, actual {:?}", + report_id, + report_data.report_size, + report.len() + ); + } + + self.current_keys.clear(); + + for field in &report_data.relevant_variable_fields { + (field.report_handler)(&mut self.current_keys, &field.field, report); + } + + for field in &report_data.relevant_array_fields { + (field.report_handler)(&mut self.current_keys, &field.field, report); + } + + if self.last_keys != self.current_keys { + let mut released_keys = Vec::new(); + let mut pressed_keys = Vec::new(); + // XOR the key sets to find keys that changed state between reports. + for changed_key in (&self.last_keys ^ &self.current_keys).into_iter().rev() { + if self.last_keys.contains(&changed_key) { + released_keys.push(changed_key); + } else { + pressed_keys.push(changed_key); + } + } + + log::trace!( + "process_report: {:?} key(s) released, {:?} key(s) pressed", + released_keys.len(), + pressed_keys.len(), + ); + + for key in &released_keys { + kq.keystroke(*key, key_queue::KeyAction::KeyUp); + } + for key in &pressed_keys { + kq.keystroke(*key, key_queue::KeyAction::KeyDown); + } + + if kq.peek_notify_key().is_some() { + result.should_signal_notify = true; + } + + // Only send LED output reports when LED state actually changes. + let current_leds: BTreeSet = kq.active_leds().iter().cloned().collect(); + if current_leds != self.led_state { + log::trace!("process_report: LED state changed, generating output reports"); + self.led_state = current_leds; + result.output_reports = self.build_led_output_reports(); + } + + result.pressed_keys = pressed_keys; + result.released_keys = released_keys; + } + + core::mem::swap(&mut self.last_keys, &mut self.current_keys); + } + result + } + + // Returns a reference to the set of currently held keys (valid after process_report). + fn held_keys(&self) -> &BTreeSet { + &self.last_keys + } + + // Registers a key notification callback, returning its handle. + fn insert_key_notify_callback( + &mut self, + key_data: protocols::simple_text_input_ex::KeyData, + key_notification_function: protocols::simple_text_input_ex::KeyNotifyFunction, + kq: &mut key_queue::KeyQueue, + ) -> usize { + let key_data = OrdKeyData(key_data); + for (handle, entry) in &self.notification_callbacks { + if entry.0 == key_data && ptr::fn_addr_eq(entry.1, key_notification_function) { + return *handle; + } + } + self.next_notify_handle += 1; + self.notification_callbacks.insert(self.next_notify_handle, (key_data.clone(), key_notification_function)); + kq.add_notify_key(key_data); + self.next_notify_handle + } + + // Unregisters a key notification callback by handle. + fn remove_key_notify_callback( + &mut self, + notification_handle: usize, + kq: &mut key_queue::KeyQueue, + ) -> Result<(), efi::Status> { + if let Some(entry) = self.notification_callbacks.remove(¬ification_handle) { + let removed_key = entry.0; + if !self.notification_callbacks.values().any(|(key, _)| *key == removed_key) { + kq.remove_notify_key(&removed_key); + } + Ok(()) + } else { + Err(efi::Status::INVALID_PARAMETER) + } + } + + // Returns the next pending notify key and its matching callbacks. + fn pending_callbacks( + &self, + kq: &mut key_queue::KeyQueue, + ) -> (Option, Vec) + { + if let Some(pending_notify_key) = kq.pop_notify_key() { + let mut pending_callbacks = Vec::new(); + for (key, callback) in self.notification_callbacks.values() { + if OrdKeyData(pending_notify_key).matches_registered_key(key) { + pending_callbacks.push(*callback); + } + } + (Some(pending_notify_key), pending_callbacks) + } else { + (None, Vec::new()) + } + } +} + +// Context passed to the keyboard layout change event callback. +pub(crate) struct LayoutChangeContext { + boot_services: &'static T, + keyboard_handler: *mut KeyboardHidHandler, +} + +// Context passed to the key repeat timer event callback. +#[repr(C)] +struct RepeatTimerContext { + keyboard_handler: *mut KeyboardHidHandler, +} + +/// Keyboard HID handler that processes reports and produces UEFI SimpleTextIn keystrokes. +pub struct KeyboardHidHandler { + boot_services: &'static T, + controller: efi::Handle, + hid_io: Option<*const dyn HidIo>, + simple_text_in_key: Option>>>, + simple_text_in_ex_key: Option>>>, + processor: KeyboardProcessor, + pub(crate) state: TplMutex, + pub(crate) key_notify_event: efi::Event, + layout_change_event: efi::Event, + layout_context: *mut LayoutChangeContext, + repeat_timer_event: efi::Event, + repeat_context: *mut RepeatTimerContext, + pub(crate) repeat_key: Option, +} + +impl KeyboardHidHandler { + /// Creates a fully initialized Keyboard HID handler for the given controller. + /// + /// Returns a boxed handler because protocol installation stores raw pointers to `self`. + /// Boxing first ensures those pointers remain valid (the handler is never moved after boxing). + pub fn new( + boot_services: &'static T, + controller: efi::Handle, + hid_io: &dyn HidIo, + ) -> Result, efi::Status> { + let mut processor = KeyboardProcessor::new(); + let descriptor = hid_io.get_report_descriptor()?; + processor.process_descriptor(descriptor)?; + + let mut handler = Box::new(Self { + boot_services, + controller, + // SAFETY: hid_io is valid for the device lifetime (backed by a BY_DRIVER protocol + // reference in UefiHidIo). Transmute erases the borrow lifetime for raw pointer storage. + hid_io: Some(unsafe { + core::mem::transmute::<*const dyn HidIo, *const dyn HidIo>(hid_io as *const dyn HidIo) + }), + simple_text_in_key: None, + simple_text_in_ex_key: None, + processor, + state: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, key_queue::KeyQueue::default()), + key_notify_event: core::ptr::null_mut(), + layout_change_event: core::ptr::null_mut(), + layout_context: core::ptr::null_mut(), + repeat_timer_event: core::ptr::null_mut(), + repeat_context: core::ptr::null_mut(), + repeat_key: None, + }); + + handler.reset(true); + handler.install_protocol_interfaces()?; + handler.initialize_keyboard_layout()?; + handler.install_repeat_timer()?; + + // Register Ctrl-Alt-Delete handler to reset the system. Register only for DEL scan code; + // CTRL-ALT shift state is validated in the callback to handle any left/right combination. + #[cfg(feature = "ctrl-alt-del")] + { + let reset_key_data = protocols::simple_text_input_ex::KeyData { + key: protocols::simple_text_input::InputKey { scan_code: key_queue::SCAN_DELETE, unicode_char: 0 }, + key_state: protocols::simple_text_input_ex::KeyState { key_toggle_state: 0, key_shift_state: 0 }, + }; + handler.insert_key_notify_callback(reset_key_data, reset_notification_function); + } + + Ok(handler) + } + + // Installs SimpleTextIn and SimpleTextInEx protocol interfaces. + fn install_protocol_interfaces(&mut self) -> Result<(), efi::Status> { + let sti_key = simple_text_in::SimpleTextInFfi::install(self.boot_services, self.controller, self)?; + let sti_ex_key = match simple_text_in_ex::SimpleTextInExFfi::install(self.boot_services, self.controller, self) + { + Ok(key) => key, + Err(status) => { + let _ = simple_text_in::SimpleTextInFfi::::uninstall(self.boot_services, self.controller, sti_key); + return Err(status); + } + }; + self.simple_text_in_key = Some(sti_key); + self.simple_text_in_ex_key = Some(sti_ex_key); + Ok(()) + } + + // Creates and registers the keyboard layout change event. + fn install_layout_change_event(&mut self) -> Result<(), efi::Status> { + let context = LayoutChangeContext { boot_services: self.boot_services, keyboard_handler: self as *mut Self }; + let context_ptr = Box::into_raw(Box::new(context)); + + // SAFETY: context_ptr is valid from Box::into_raw and will remain valid for the event lifetime. + let layout_change_event = unsafe { + self.boot_services.create_event_ex_unchecked( + EventType::NOTIFY_SIGNAL, + Tpl::NOTIFY, + Some(Self::on_layout_update), + context_ptr, + &protocols::hii_database::SET_KEYBOARD_LAYOUT_EVENT_GUID, + ) + }; + + match layout_change_event { + Ok(event) => { + self.layout_change_event = event; + self.layout_context = context_ptr; + Ok(()) + } + Err(status) => { + // SAFETY: context_ptr was created via Box::into_raw above and is being reclaimed on the error path. + drop(unsafe { Box::from_raw(context_ptr) }); + Err(status) + } + } + } + + // Closes the layout change event and frees its context. + fn uninstall_layout_change_event(&mut self) -> Result<(), efi::Status> { + if !self.layout_change_event.is_null() { + let layout_change_event = self.layout_change_event; + if let Err(status) = self.boot_services.close_event(layout_change_event) { + log::error!("Failed to close layout_change_event event, status: {:x?}", status); + // SAFETY: layout_context is valid while self exists; nulling the handler prevents stale callbacks. + unsafe { + (*self.layout_context).keyboard_handler = ptr::null_mut(); + } + return Err(status); + } + // SAFETY: layout_context was created via Box::into_raw during install_layout_change_event. + drop(unsafe { Box::from_raw(self.layout_context) }); + self.layout_context = ptr::null_mut(); + self.layout_change_event = ptr::null_mut(); + } + Ok(()) + } + + // Creates the repeat key timer event used for keystroke repeat when a key is held. + fn install_repeat_timer(&mut self) -> Result<(), efi::Status> { + let context = RepeatTimerContext { keyboard_handler: self as *mut Self }; + let context_ptr = Box::into_raw(Box::new(context)); + + // SAFETY: context_ptr is valid from Box::into_raw and will remain valid for the event lifetime. + let repeat_timer = unsafe { + self.boot_services.create_event_unchecked( + EventType::TIMER | EventType::NOTIFY_SIGNAL, + Tpl::NOTIFY, + Some(Self::on_repeat_timer), + context_ptr, + ) + }; + + match repeat_timer { + Ok(event) => { + self.repeat_timer_event = event; + self.repeat_context = context_ptr; + Ok(()) + } + Err(status) => { + // SAFETY: context_ptr was created via Box::into_raw above and is being reclaimed on the error path. + drop(unsafe { Box::from_raw(context_ptr) }); + Err(status) + } + } + } + + // Closes the repeat timer event and frees the context. + fn uninstall_repeat_timer(&mut self) -> Result<(), efi::Status> { + if !self.repeat_timer_event.is_null() { + self.repeat_key = None; + if let Err(status) = self.boot_services.set_timer(self.repeat_timer_event, EventTimerType::Cancel, 0) { + log::error!("Failed to cancel repeat_timer, status: {:x?}", status); + // SAFETY: repeat_context is valid while self exists; nulling the handler prevents stale callbacks. + unsafe { + (*self.repeat_context).keyboard_handler = ptr::null_mut(); + } + return Err(status); + } + + if let Err(status) = self.boot_services.close_event(self.repeat_timer_event) { + log::error!("Failed to close repeat_timer event, status: {:x?}", status); + // SAFETY: repeat_context is valid while self exists; nulling the handler prevents stale callbacks. + unsafe { + (*self.repeat_context).keyboard_handler = ptr::null_mut(); + } + return Err(status); + } + // SAFETY: repeat_context was created via Box::into_raw during install_repeat_timer. + drop(unsafe { Box::from_raw(self.repeat_context) }); + self.repeat_context = ptr::null_mut(); + self.repeat_timer_event = ptr::null_mut(); + } + Ok(()) + } + + // Installs the default US keyboard layout via the HII database. + fn install_default_layout(&mut self) -> Result<(), efi::Status> { + // SAFETY: We locate the HII database protocol and call its methods per UEFI spec. + let hii_database_protocol_ptr = unsafe { + self.boot_services.locate_protocol_unchecked(&protocols::hii_database::PROTOCOL_GUID, ptr::null_mut()) + }; + + let hii_database_protocol_ptr = match hii_database_protocol_ptr { + Ok(ptr) => ptr as *mut protocols::hii_database::Protocol, + Err(status) => { + log::error!("keyboard::install_default_layout: Could not locate hii_database protocol: {:x?}", status); + return Err(status); + } + }; + + // SAFETY: Dereferencing the protocol pointer returned from locate_protocol; null handled by the else branch. + let Some(hii_database_protocol) = (unsafe { hii_database_protocol_ptr.as_mut() }) else { + log::error!("keyboard::install_default_layout: locate_protocol returned null pointer."); + return Err(efi::Status::NOT_FOUND); + }; + + let mut hii_handle: r_efi::hii::Handle = ptr::null_mut(); + let status = (hii_database_protocol.new_package_list)( + hii_database_protocol_ptr, + layout::get_default_keyboard_pkg_list_buffer().as_ptr() as *const r_efi::hii::PackageListHeader, + ptr::null_mut(), + &mut hii_handle as *mut r_efi::hii::Handle, + ); + + if status.is_error() { + log::error!("keyboard::install_default_layout: Failed to install keyboard layout package: {:x?}", status); + return Err(status); + } + + let status = (hii_database_protocol.set_keyboard_layout)( + hii_database_protocol_ptr, + &layout::DEFAULT_KEYBOARD_LAYOUT_GUID as *const efi::Guid as *mut efi::Guid, + ); + if status.is_error() { + log::error!("keyboard::install_default_layout: Failed to set keyboard layout: {:x?}", status); + return Err(status); + } + + Ok(()) + } + + // Initializes the keyboard layout from the HII database or installs a default. + fn initialize_keyboard_layout(&mut self) -> Result<(), efi::Status> { + log::trace!("initialize_keyboard_layout: setting up keyboard layout"); + self.install_layout_change_event()?; + + // fake signal event to pick up any existing layout + Self::on_layout_update(self.layout_change_event, self.layout_context); + + // install a default layout if no layout is installed. + // Note: the guard must be dropped before install_default_layout to avoid + // re-entrant TplMutex acquisition if a notification fires during installation. + let needs_default = self.state.lock().layout().is_none(); + if needs_default { + log::trace!("initialize_keyboard_layout: no existing layout found, installing default"); + self.install_default_layout()?; + } + Ok(()) + } + + // Sends HID output reports to the device. + pub(crate) fn send_output_reports( + &mut self, + hid_io: &dyn HidIo, + output_reports: Vec<(Option, Vec)>, + ) -> Result<(), efi::Status> { + if !output_reports.is_empty() { + log::trace!("send_output_reports: sending {:?} report(s)", output_reports.len()); + } + for (id, output_report) in output_reports { + let result = hid_io.set_output_report(id.map(|x| u32::from(x) as u8), &output_report); + if let Err(result) = result { + log::error!("send_output_reports: unexpected error sending output report: {:?}", result); + return Err(result); + } + } + Ok(()) + } + + /// Resets the keyboard driver state. + pub fn reset(&mut self, extended_verification: bool) { + let mut kq = self.state.lock(); + self.processor.reset(&mut kq, extended_verification); + // Cancel any active repeat timer. + self.repeat_key = None; + if !self.repeat_timer_event.is_null() + && let Err(status) = self.boot_services.set_timer(self.repeat_timer_event, EventTimerType::Cancel, 0) + { + log::error!("Failed to cancel repeat_timer during reset, status: {:x?}", status); + } + } + + /// Returns a clone of the keystroke at the front of the keystroke queue. + pub fn peek_key(&self) -> Option { + self.state.lock().peek_key() + } + + /// Removes and returns the keystroke at the front of the keystroke queue. + pub fn pop_key(&self) -> Option { + self.state.lock().pop_key() + } + + /// Returns the current key state (i.e. the SHIFT and TOGGLE state). + pub fn get_key_state(&self) -> protocols::simple_text_input_ex::KeyState { + self.state.lock().init_key_state() + } + + /// Sets the keyboard toggle state and sends updated LED reports to the device. + pub fn set_key_toggle_state(&mut self, toggle_state: u8) { + let current_leds: BTreeSet = { + let mut kq = self.state.lock(); + kq.set_key_toggle_state(toggle_state); + kq.active_leds().iter().cloned().collect() + }; + if current_leds != self.processor.led_state { + self.processor.led_state = current_leds; + let output_reports = self.processor.build_led_output_reports(); + if let Some(hid_io_ptr) = self.hid_io { + // SAFETY: hid_io is valid for the device lifetime, set during construction. + let hid_io = unsafe { &*hid_io_ptr }; + if let Err(e) = self.send_output_reports(hid_io, output_reports) { + log::error!("set_key_toggle_state: failed to send LED reports: {:?}", e); + } + } + } + } + + /// Registers a new key notify callback function. + pub fn insert_key_notify_callback( + &mut self, + key_data: protocols::simple_text_input_ex::KeyData, + key_notification_function: protocols::simple_text_input_ex::KeyNotifyFunction, + ) -> usize { + let mut kq = self.state.lock(); + self.processor.insert_key_notify_callback(key_data, key_notification_function, &mut kq) + } + + /// Unregisters a previously registered key notify callback function. + pub fn remove_key_notify_callback(&mut self, notification_handle: usize) -> Result<(), efi::Status> { + let mut kq = self.state.lock(); + self.processor.remove_key_notify_callback(notification_handle, &mut kq) + } + + /// Returns the set of keys that have pending callbacks. + pub fn pending_callbacks( + &mut self, + ) -> (Option, Vec) + { + let mut kq = self.state.lock(); + self.processor.pending_callbacks(&mut kq) + } + + /// Returns the controller associated with this KeyboardHidHandler. + pub fn controller(&self) -> efi::Handle { + self.controller + } + + #[cfg(test)] + pub fn new_for_test(boot_services: &'static T) -> Self { + Self { + boot_services, + controller: core::ptr::null_mut(), + hid_io: None, + simple_text_in_key: None, + simple_text_in_ex_key: None, + processor: KeyboardProcessor::new(), + state: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, key_queue::KeyQueue::default()), + key_notify_event: core::ptr::null_mut(), + layout_change_event: core::ptr::null_mut(), + layout_context: core::ptr::null_mut(), + repeat_timer_event: core::ptr::null_mut(), + repeat_context: core::ptr::null_mut(), + repeat_key: None, + } + } + + #[cfg(test)] + pub fn set_layout(&mut self, layout: Option) { + self.state.lock().set_layout(layout) + } + + #[cfg(test)] + pub fn set_notify_event(&mut self, event: efi::Event) { + self.key_notify_event = event; + } + + #[cfg(test)] + pub fn process_descriptor(&mut self, descriptor: ReportDescriptor) -> Result<(), efi::Status> { + self.processor.process_descriptor(descriptor) + } + + // Handles the repeat timer event. When a repeatable key is held, this fires after the initial + // delay (and then at the repeat rate) to re-enqueue the held key into the key queue. + extern "efiapi" fn on_repeat_timer(_event: efi::Event, context: *mut RepeatTimerContext) { + // SAFETY: context was set during event registration via Box::into_raw and remains valid for the event lifetime. + let Some(context) = (unsafe { context.as_mut() }) else { + log::error!("on_repeat_timer invoked with null context pointer"); + return; + }; + + // SAFETY: keyboard_handler is set during install_repeat_timer and remains valid until uninstall. + let Some(keyboard_handler) = (unsafe { context.keyboard_handler.as_mut() }) else { + log::error!("on_repeat_timer invoked with invalid handler"); + return; + }; + + let Some(repeat_usage) = keyboard_handler.repeat_key else { + return; + }; + + // Re-process the held key as a new KeyDown event. This picks up the current modifier/toggle state. + { + let mut kq = keyboard_handler.state.lock(); + kq.keystroke(repeat_usage, key_queue::KeyAction::KeyDown); + } + + // Signal key notify event if the repeated keystroke matched a registered callback. + { + let kq = keyboard_handler.state.lock(); + if kq.peek_notify_key().is_some() { + let _ = keyboard_handler.boot_services.signal_event(keyboard_handler.key_notify_event); + } + } + + // Re-arm the timer at the repeat rate for the next repeat. + if let Err(status) = keyboard_handler.boot_services.set_timer( + keyboard_handler.repeat_timer_event, + EventTimerType::Relative, + REPEAT_KEY_RATE, + ) { + log::error!("on_repeat_timer: failed to re-arm repeat timer, status: {:x?}", status); + } + } + + // Handles keyboard layout change events from the HII database. + extern "efiapi" fn on_layout_update(_event: efi::Event, context: *mut LayoutChangeContext) { + log::trace!("on_layout_update: keyboard layout change event received"); + // SAFETY: context was set during event registration via Box::into_raw and remains valid for the event lifetime. + let context = unsafe { context.as_mut() }.expect("bad context pointer"); + + if context.keyboard_handler.is_null() { + log::error!("on_layout_update invoked with invalid handler"); + return; + } + + // SAFETY: keyboard_handler is null-checked above. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }.expect("bad keyboard handler"); + + // SAFETY: We locate the HII database protocol per UEFI spec. + let hii_database_protocol_ptr = unsafe { + context.boot_services.locate_protocol_unchecked(&protocols::hii_database::PROTOCOL_GUID, ptr::null_mut()) + }; + + let Ok(hii_database_protocol_ptr) = + hii_database_protocol_ptr.map(|p| p as *mut protocols::hii_database::Protocol) + else { + return; + }; + + // SAFETY: Dereferencing the protocol pointer returned from locate_protocol; null handled by the else branch. + let Some(hii_database_protocol) = (unsafe { hii_database_protocol_ptr.as_mut() }) else { + log::error!("on_layout_update: locate_protocol returned null pointer."); + return; + }; + + // retrieve keyboard layout size + let mut layout_buffer_len: u16 = 0; + match (hii_database_protocol.get_keyboard_layout)( + hii_database_protocol_ptr, + ptr::null_mut(), + &mut layout_buffer_len as *mut u16, + ptr::null_mut(), + ) { + efi::Status::NOT_FOUND => return, + status if status != efi::Status::BUFFER_TOO_SMALL => { + log::error!( + "on_layout_update: unexpected return from get_keyboard_layout when determining length: {:x?}", + status + ); + return; + } + _ => (), + } + + let mut keyboard_layout_buffer = vec![0u8; layout_buffer_len as usize]; + let status = (hii_database_protocol.get_keyboard_layout)( + hii_database_protocol_ptr, + ptr::null_mut(), + &mut layout_buffer_len as *mut u16, + keyboard_layout_buffer.as_mut_ptr() as *mut protocols::hii_database::KeyboardLayout<0>, + ); + + if status.is_error() { + log::error!("Unexpected return from get_keyboard_layout: {:x?}", status); + return; + } + + match layout::keyboard_layout_from_buffer(&keyboard_layout_buffer) { + Ok(keyboard_layout) => { + log::trace!("on_layout_update: successfully parsed layout with {:?} keys", keyboard_layout.keys.len()); + keyboard_handler.state.lock().set_layout(Some(keyboard_layout)); + } + Err(_) => { + log::error!("keyboard::on_layout_update: Could not parse keyboard layout buffer."); + } + } + } +} + +// Inserts the field's usage into the active key set if the field value is nonzero. +fn handle_variable_key(current_keys: &mut BTreeSet, field: &VariableField, report: &[u8]) { + match field.field_value(report) { + Some(x) if x != 0 => _ = current_keys.insert(field.usage), + _ => (), + } +} + +// Resolves an array field value to a usage and inserts it into the active key set. +fn handle_array_key(current_keys: &mut BTreeSet, field: &ArrayField, report: &[u8]) { + match field.field_value(report) { + Some(index) if index != 0 => { + let mut index = (index as u32 - u32::from(field.logical_minimum)) as usize; + let usage = field.usage_list.iter().find_map(|x| { + let range_size = (x.end() - x.start()) as usize; + if index <= range_size { + x.range().nth(index) + } else { + index -= range_size; + None + } + }); + if let Some(usage) = usage { + current_keys.insert(Usage::from(usage)); + } + } + _ => (), + } +} + +// Sets a single LED field in the output report buffer. +fn build_led_report(led_state: &BTreeSet, field: &VariableField, report: &mut [u8]) { + let status = field.set_field_value(led_state.contains(&field.usage).into(), report); + if status.is_err() { + log::warn!("build_led_report: failed to set field value: {:?}", status); + } +} + +// Notification function called when Ctrl-Alt-Delete is pressed. +// Any DEL key press triggers this callback; CTRL-ALT qualification is checked here to handle +// arbitrary left/right modifier combinations without needing separate registrations. +#[cfg(feature = "ctrl-alt-del")] +extern "efiapi" fn reset_notification_function(key_data: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + if key_data.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: null-checked above, using read_unaligned to avoid any alignment issues. + let key_data = unsafe { key_data.read_unaligned() }; + if key_data.key.scan_code != key_queue::SCAN_DELETE { + return efi::Status::SUCCESS; + } + + // Check that DEL is qualified with valid CTRL-ALT state. + if key_data.key_state.key_shift_state & SHIFT_STATE_VALID == 0 + || key_data.key_state.key_shift_state & (LEFT_CONTROL_PRESSED | RIGHT_CONTROL_PRESSED) == 0 + || key_data.key_state.key_shift_state & (LEFT_ALT_PRESSED | RIGHT_ALT_PRESSED) == 0 + { + return efi::Status::SUCCESS; + } + + log::warn!("Ctrl-Alt-Del pressed, resetting system."); + let rt_ptr = crate::RUNTIME_SERVICES.load(core::sync::atomic::Ordering::SeqCst); + // SAFETY: rt_ptr is loaded from a global atomic; null is handled by the if-let. + if let Some(runtime_services) = unsafe { rt_ptr.as_ref() } { + (runtime_services.reset_system)(efi::RESET_COLD, efi::Status::SUCCESS, 0, core::ptr::null_mut()); + } + // reset_system should not return; if it does, there is nothing useful to do. + efi::Status::SUCCESS +} + +impl HidReportReceiver for KeyboardHidHandler { + fn receive_report(&mut self, report: &[u8], hid_io: &dyn HidIo) { + log::trace!("keyboard::receive_report: {:?} bytes", report.len()); + let result = { + let mut kq = self.state.lock(); + self.processor.process_report(report, &mut kq) + }; + + // Handle key repeat logic for released and pressed keys. This is done here rather than + // in process_report to avoid exposing the processor to boot services and maintain a clean + // separation between pure key processing logic and UEFI timer management. + if !result.released_keys.is_empty() || !result.pressed_keys.is_empty() { + let kq = self.state.lock(); + + // Determine the new repeat candidate: prefer a newly pressed repeatable key, + // otherwise hand off to a still-held repeatable key if the current repeat key + // was released. + let new_repeat_key = result.pressed_keys.iter().rev().find(|k| kq.is_repeatable_key(**k)).copied(); + + let repeat_key_released = + self.repeat_key.is_some_and(|repeat_usage| result.released_keys.contains(&repeat_usage)); + + let next_repeat = if let Some(usage) = new_repeat_key { + Some(usage) + } else if repeat_key_released { + self.processor.held_keys().iter().rev().find(|k| kq.is_repeatable_key(**k)).copied() + } else { + None + }; + + // Apply the repeat state change with a single timer operation. + if next_repeat.is_some() || repeat_key_released { + self.repeat_key = next_repeat; + let (timer_type, trigger_time) = match next_repeat { + Some(_) => (EventTimerType::Relative, REPEAT_KEY_DELAY), + None => (EventTimerType::Cancel, 0), + }; + if let Err(status) = self.boot_services.set_timer(self.repeat_timer_event, timer_type, trigger_time) { + log::error!("receive_report: failed to set repeat timer, status: {:x?}", status); + } + } + } + + if result.should_signal_notify { + let _ = self.boot_services.signal_event(self.key_notify_event); + } + if let Err(e) = self.send_output_reports(hid_io, result.output_reports) { + log::error!("unexpected error sending output report: {:?}", e); + } + } +} + +impl Drop for KeyboardHidHandler { + fn drop(&mut self) { + // Close repeat timer first — its callback may reference protocol events. + if let Err(status) = self.uninstall_repeat_timer() { + log::error!("KeyboardHidHandler::drop: Failed to close repeat_timer: {:?}", status); + } + if let Some(key) = self.simple_text_in_key.take() + && let Err(status) = + simple_text_in::SimpleTextInFfi::::uninstall(self.boot_services, self.controller, key) + { + log::error!("KeyboardHidHandler::drop: Failed to uninstall simple_text_in: {:?}", status); + } + if let Some(key) = self.simple_text_in_ex_key.take() + && let Err(status) = + simple_text_in_ex::SimpleTextInExFfi::::uninstall(self.boot_services, self.controller, key) + { + log::error!("KeyboardHidHandler::drop: Failed to uninstall simple_text_in_ex: {:?}", status); + } + if let Err(status) = self.uninstall_layout_change_event() { + log::error!("KeyboardHidHandler::drop: Failed to close layout_change_event: {:?}", status); + } + } +} + +#[cfg(test)] +mod test { + use alloc::vec; + + use hidparser::{ + ReportDescriptor, ReportField, VariableField, + report_data_types::{ReportAttributes, Usage}, + }; + use r_efi::{efi, protocols}; + + use super::*; + + fn modifier_field(usage: u32, bit: u32) -> VariableField { + VariableField { + bits: bit..bit + 1, + usage: Usage::from(usage), + logical_minimum: 0.into(), + logical_maximum: 1.into(), + attributes: ReportAttributes::default(), + ..Default::default() + } + } + + fn led_field(usage: u32, bit: u32) -> VariableField { + VariableField { + bits: bit..bit + 1, + usage: Usage::from(usage), + logical_minimum: 0.into(), + logical_maximum: 1.into(), + attributes: ReportAttributes::default(), + ..Default::default() + } + } + + fn modifier_only_descriptor() -> ReportDescriptor { + ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 1, 1)), + ], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + } + } + + fn keyboard_with_leds_descriptor() -> ReportDescriptor { + ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ + ReportField::Variable(led_field(LED_USAGE_MIN, 0)), + ReportField::Variable(led_field(LED_USAGE_MIN + 1, 1)), + ReportField::Variable(led_field(LED_USAGE_MIN + 2, 2)), + ], + }], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + } + } + + fn empty_descriptor() -> ReportDescriptor { + ReportDescriptor { + input_reports: vec![hidparser::Report { report_id: None, size_in_bits: 0, fields: vec![] }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + } + } + + fn led_only_descriptor() -> ReportDescriptor { + ReportDescriptor { + input_reports: vec![], + output_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(led_field(LED_USAGE_MIN, 0))], + }], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + } + } + + // --- processor defaults --- + + #[test] + fn new_processor_has_defaults() { + let processor = KeyboardProcessor::new(); + assert!(processor.input_reports.is_empty()); + assert!(processor.output_builders.is_empty()); + assert!(processor.last_keys.is_empty()); + assert!(processor.current_keys.is_empty()); + assert!(processor.led_state.is_empty()); + assert!(processor.notification_callbacks.is_empty()); + assert_eq!(processor.next_notify_handle, 0); + } + + // --- process_descriptor --- + + #[test] + fn process_descriptor_with_modifier_fields_succeeds() { + let mut processor = KeyboardProcessor::new(); + assert_eq!(processor.process_descriptor(modifier_only_descriptor()), Ok(())); + assert_eq!(processor.input_reports.len(), 1); + let report_data = processor.input_reports.values().next().unwrap(); + assert_eq!(report_data.relevant_variable_fields.len(), 2); + assert!(report_data.relevant_array_fields.is_empty()); + } + + #[test] + fn process_descriptor_with_leds_creates_output_builders() { + let mut processor = KeyboardProcessor::new(); + assert_eq!(processor.process_descriptor(keyboard_with_leds_descriptor()), Ok(())); + assert_eq!(processor.input_reports.len(), 1); + assert_eq!(processor.output_builders.len(), 1); + assert_eq!(processor.output_builders[0].relevant_variable_fields.len(), 3); + } + + #[test] + fn process_descriptor_with_no_relevant_fields_returns_unsupported() { + let mut processor = KeyboardProcessor::new(); + assert_eq!(processor.process_descriptor(empty_descriptor()), Err(efi::Status::UNSUPPORTED)); + } + + #[test] + fn process_descriptor_with_only_led_output_fields_succeeds() { + let mut processor = KeyboardProcessor::new(); + assert_eq!(processor.process_descriptor(led_only_descriptor()), Ok(())); + assert!(processor.input_reports.is_empty()); + assert_eq!(processor.output_builders.len(), 1); + } + + // --- reset --- + + #[test] + fn reset_clears_keys_and_state() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + processor.last_keys.insert(Usage::from(0x00070004)); + processor.current_keys.insert(Usage::from(0x00070005)); + processor.led_state.insert(Usage::from(LED_USAGE_MIN)); + + processor.reset(&mut kq, true); + + assert!(processor.last_keys.is_empty()); + assert!(processor.current_keys.is_empty()); + assert!(processor.led_state.is_empty()); + } + + #[test] + fn reset_non_extended_preserves_led_state() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + processor.led_state.insert(Usage::from(LED_USAGE_MIN)); + + processor.reset(&mut kq, false); + + assert!(!processor.led_state.is_empty()); + } + + // --- key queue --- + + #[test] + fn peek_key_returns_none_on_empty_queue() { + let kq = key_queue::KeyQueue::default(); + assert!(kq.peek_key().is_none()); + } + + #[test] + fn pop_key_returns_none_on_empty_queue() { + let mut kq = key_queue::KeyQueue::default(); + assert!(kq.pop_key().is_none()); + } + + #[test] + fn get_key_state_returns_initial_state() { + let kq = key_queue::KeyQueue::default(); + let state = kq.init_key_state(); + assert_eq!(state.key_shift_state, protocols::simple_text_input_ex::SHIFT_STATE_VALID); + assert_eq!(state.key_toggle_state, protocols::simple_text_input_ex::TOGGLE_STATE_VALID); + } + + #[test] + fn set_key_toggle_state_updates_state() { + let mut kq = key_queue::KeyQueue::default(); + let toggle = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE; + kq.set_key_toggle_state(toggle); + let state = kq.init_key_state(); + assert_ne!(state.key_toggle_state & protocols::simple_text_input_ex::CAPS_LOCK_ACTIVE, 0); + } + + // --- key notify callbacks --- + + extern "efiapi" fn dummy_callback(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + + #[test] + fn insert_key_notify_returns_handle() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + let handle = processor.insert_key_notify_callback(key_data, dummy_callback, &mut kq); + assert!(handle > 0); + assert_eq!(processor.notification_callbacks.len(), 1); + } + + #[test] + fn insert_duplicate_notify_returns_same_handle() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + let handle1 = processor.insert_key_notify_callback(key_data, dummy_callback, &mut kq); + let handle2 = processor.insert_key_notify_callback(key_data, dummy_callback, &mut kq); + assert_eq!(handle1, handle2); + assert_eq!(processor.notification_callbacks.len(), 1); + } + + #[test] + fn remove_key_notify_succeeds() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + let handle = processor.insert_key_notify_callback(key_data, dummy_callback, &mut kq); + assert_eq!(processor.remove_key_notify_callback(handle, &mut kq), Ok(())); + assert!(processor.notification_callbacks.is_empty()); + } + + #[test] + fn remove_key_notify_invalid_returns_error() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + assert_eq!(processor.remove_key_notify_callback(42, &mut kq), Err(efi::Status::INVALID_PARAMETER)); + } + + #[test] + fn pending_callbacks_returns_none_when_empty() { + let processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let (key, callbacks) = processor.pending_callbacks(&mut kq); + assert!(key.is_none()); + assert!(callbacks.is_empty()); + } + + // --- LED reports --- + + #[test] + fn build_led_output_reports_reflects_led_state() { + let mut processor = KeyboardProcessor::new(); + processor.process_descriptor(keyboard_with_leds_descriptor()).unwrap(); + + // Set num_lock LED active + processor.led_state.insert(Usage::from(LED_USAGE_MIN)); + + let reports = processor.build_led_output_reports(); + assert_eq!(reports.len(), 1); + let (id, report) = &reports[0]; + assert!(id.is_none()); + assert_eq!(report.len(), 1); + assert_eq!(report[0] & 0x01, 0x01); // bit 0 = num lock + } + + #[test] + fn build_led_output_reports_empty_without_descriptor() { + let processor = KeyboardProcessor::new(); + let reports = processor.build_led_output_reports(); + assert!(reports.is_empty()); + } + + #[test] + fn process_report_skips_led_output_when_led_state_unchanged() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + processor.process_descriptor(keyboard_with_leds_descriptor()).unwrap(); + + // Press modifier (key state changes, but LED state stays empty). + let result = processor.process_report(&[0x01], &mut kq); + assert!(result.output_reports.is_empty()); + + // Release modifier (key state changes again, LED state still empty). + let result = processor.process_report(&[0x00], &mut kq); + assert!(result.output_reports.is_empty()); + } + + #[test] + fn process_report_sends_led_output_when_led_state_changes() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + processor.process_descriptor(keyboard_with_leds_descriptor()).unwrap(); + + // Simulate a prior LED state (e.g. num lock was on). + processor.led_state.insert(Usage::from(LED_USAGE_MIN)); + + // Press modifier — key state changes and kq reports no active LEDs, + // which differs from the pre-set led_state, so output reports are sent. + let result = processor.process_report(&[0x01], &mut kq); + assert!(!result.output_reports.is_empty()); + + // LED state is now synchronized; same key change should not resend. + let result = processor.process_report(&[0x00], &mut kq); + assert!(result.output_reports.is_empty()); + } + + // --- process_descriptor edge cases --- + + #[test] + fn process_descriptor_multiple_input_reports_without_ids_returns_error() { + let mut processor = KeyboardProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: vec![ + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }, + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 1, 0))], + }, + ], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Err(efi::Status::DEVICE_ERROR)); + } + + #[test] + fn process_descriptor_multiple_output_reports_without_ids_returns_error() { + let mut processor = KeyboardProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![ + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(led_field(LED_USAGE_MIN, 0))], + }, + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(led_field(LED_USAGE_MIN + 1, 0))], + }, + ], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Err(efi::Status::DEVICE_ERROR)); + } + + #[test] + fn process_descriptor_with_array_fields_succeeds() { + let mut processor = KeyboardProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Array(hidparser::ArrayField { + bits: 0..8, + usage_list: vec![hidparser::report_data_types::UsageRange::from( + KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX, + )], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + })], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Ok(())); + assert_eq!(processor.input_reports.len(), 1); + let report_data = processor.input_reports.values().next().unwrap(); + assert_eq!(report_data.relevant_array_fields.len(), 1); + } + + // --- process_report edge cases --- + + #[test] + fn process_report_empty_returns_no_output() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + let result = processor.process_report(&[], &mut kq); + assert!(result.output_reports.is_empty()); + assert!(!result.should_signal_notify); + } + + #[test] + fn process_report_unregistered_report_id_is_ignored() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + // Send a report with a different report ID (0x02) + let result = processor.process_report(&[0x02, 0x00], &mut kq); + assert!(result.output_reports.is_empty()); + } + + #[test] + fn process_report_key_press_and_release_queues_keystrokes() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + // Press left ctrl (bit 0) + processor.process_report(&[0x01], &mut kq); + // Release left ctrl + processor.process_report(&[0x00], &mut kq); + + // Key state should reflect ctrl was pressed then released. + let state = kq.init_key_state(); + assert_eq!(state.key_shift_state, protocols::simple_text_input_ex::SHIFT_STATE_VALID); + } + + #[test] + fn process_report_same_report_twice_does_not_re_queue() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + // Press key + processor.process_report(&[0x01], &mut kq); + // Same report again — no change in state + let result = processor.process_report(&[0x01], &mut kq); + assert!(result.output_reports.is_empty()); + assert!(!result.should_signal_notify); + } + + #[test] + fn process_report_with_notify_key_sets_signal() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + // Enable partial key support so modifier-only keys get enqueued + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + + // Register a notification for left ctrl + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED; + processor.insert_key_notify_callback(reg_key, dummy_callback, &mut kq); + + // Press left ctrl + let result = processor.process_report(&[0x01], &mut kq); + assert!(result.should_signal_notify); + } + + // --- pending_callbacks --- + + #[test] + fn pending_callbacks_returns_matching_callbacks() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + + let mut reg_key: protocols::simple_text_input_ex::KeyData = Default::default(); + reg_key.key_state.key_shift_state = + protocols::simple_text_input_ex::SHIFT_STATE_VALID | protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED; + processor.insert_key_notify_callback(reg_key, dummy_callback, &mut kq); + + // Press left ctrl + processor.process_report(&[0x01], &mut kq); + + let (key, callbacks) = processor.pending_callbacks(&mut kq); + assert!(key.is_some()); + assert_eq!(callbacks.len(), 1); + } + + // --- remove_key_notify with shared key --- + + #[test] + fn remove_key_notify_keeps_key_if_other_callback_remains() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let key_data: protocols::simple_text_input_ex::KeyData = Default::default(); + let handle1 = processor.insert_key_notify_callback(key_data, dummy_callback, &mut kq); + let handle2 = processor.insert_key_notify_callback(key_data, dummy_callback2, &mut kq); + assert_ne!(handle1, handle2); + // Remove first callback — key should remain since second still exists + assert_eq!(processor.remove_key_notify_callback(handle1, &mut kq), Ok(())); + assert_eq!(processor.notification_callbacks.len(), 1); + } + + extern "efiapi" fn dummy_callback2(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + + // --- handle_variable_key tests --- + + #[test] + fn handle_variable_key_inserts_usage_when_nonzero() { + let mut keys = BTreeSet::new(); + let field = modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0); + // Byte 0 bit 0 set → nonzero → usage inserted + handle_variable_key(&mut keys, &field, &[0x01]); + assert!(keys.contains(&Usage::from(KEYBOARD_MODIFIER_USAGE_MIN))); + } + + #[test] + fn handle_variable_key_does_nothing_when_zero() { + let mut keys = BTreeSet::new(); + let field = modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0); + handle_variable_key(&mut keys, &field, &[0x00]); + assert!(keys.is_empty()); + } + + // --- handle_array_key tests --- + + #[test] + fn handle_array_key_inserts_usage_for_valid_index() { + let mut keys = BTreeSet::new(); + let field = hidparser::ArrayField { + bits: 0..8, + usage_list: vec![hidparser::report_data_types::UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }; + // Report value 4 → index 4 from logical_minimum 0 → usage KEYBOARD_USAGE_MIN + 4 + handle_array_key(&mut keys, &field, &[0x04]); + assert_eq!(keys.len(), 1); + assert!(keys.contains(&Usage::from(KEYBOARD_USAGE_MIN + 4))); + } + + #[test] + fn handle_array_key_does_nothing_when_zero() { + let mut keys = BTreeSet::new(); + let field = hidparser::ArrayField { + bits: 0..8, + usage_list: vec![hidparser::report_data_types::UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }; + handle_array_key(&mut keys, &field, &[0x00]); + assert!(keys.is_empty()); + } + + // --- build_led_report tests --- + + #[test] + fn build_led_report_sets_field_when_usage_present() { + let mut led_state = BTreeSet::new(); + led_state.insert(Usage::from(LED_USAGE_MIN)); + let field = led_field(LED_USAGE_MIN, 0); + let mut report = [0u8; 1]; + build_led_report(&led_state, &field, &mut report); + assert_ne!(report[0] & 0x01, 0); + } + + #[test] + fn build_led_report_clears_field_when_usage_absent() { + let led_state = BTreeSet::new(); + let field = led_field(LED_USAGE_MIN, 0); + let mut report = [0xFFu8; 1]; + build_led_report(&led_state, &field, &mut report); + assert_eq!(report[0] & 0x01, 0); + } + + // --- process_report with array fields --- + + #[test] + fn process_report_with_array_field_handles_key_press() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Array(hidparser::ArrayField { + bits: 0..8, + usage_list: vec![hidparser::report_data_types::UsageRange::from( + KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX, + )], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + })], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + // Report value 0x04 → index 4 in usage range → produces a keystroke + processor.process_report(&[0x04], &mut kq); + let key_data = kq.pop_key().unwrap(); + assert_ne!(key_data.key.unicode_char, 0); + } + + #[test] + fn process_report_key_release_produces_key_up() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + // Press left shift + processor.process_report(&[0x02], &mut kq); + // Release it + processor.process_report(&[0x00], &mut kq); + + // After release, the init_key_state should show shift no longer pressed + let state = kq.init_key_state(); + assert_eq!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_SHIFT_PRESSED, 0); + } + + #[test] + fn process_report_report_length_mismatch_still_processes() { + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + processor.process_descriptor(modifier_only_descriptor()).unwrap(); + + // Send 2 bytes when descriptor expects 1 — should still process without error + processor.process_report(&[0x01, 0x00], &mut kq); + } + + #[test] + fn process_report_with_report_id_strips_id_byte() { + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + // Report with ID byte + modifier data + processor.process_report(&[0x01, 0x01], &mut kq); + let state = kq.init_key_state(); + assert_ne!(state.key_shift_state & protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED, 0); + } + + #[test] + fn process_report_with_report_id_and_empty_data_is_no_op() { + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let mut processor = KeyboardProcessor::new(); + let mut kq = key_queue::KeyQueue::default(); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + // Only report ID byte, no data + let result = processor.process_report(&[0x01], &mut kq); + assert!(!result.should_signal_notify); + } + + #[test] + fn process_descriptor_with_padding_field_ignores_it() { + let mut processor = KeyboardProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 16, + fields: vec![ + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0)), + ReportField::Padding(hidparser::PaddingField { bits: 8..16 }), + ], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Ok(())); + let report_data = processor.input_reports.values().next().unwrap(); + assert_eq!(report_data.relevant_variable_fields.len(), 1); + } + + #[test] + fn process_descriptor_output_array_field_ignored() { + let mut processor = KeyboardProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0))], + }], + output_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: vec![ReportField::Array(hidparser::ArrayField { bits: 0..8, ..Default::default() })], + }], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Ok(())); + assert!(processor.output_builders.is_empty()); + } + + #[test] + fn build_led_output_reports_produces_correct_output() { + let mut processor = KeyboardProcessor::new(); + processor.process_descriptor(keyboard_with_leds_descriptor()).unwrap(); + // Set LED state to include num lock LED + processor.led_state.insert(Usage::from(LED_USAGE_MIN)); + + let reports = processor.build_led_output_reports(); + assert!(!reports.is_empty()); + let (_, report_data) = &reports[0]; + // First bit should be set for the LED + assert_ne!(report_data[0] & 0x01, 0); + } + + #[cfg(feature = "ctrl-alt-del")] + mod ctrl_alt_del_tests { + use super::*; + use core::sync::atomic::{AtomicBool, Ordering}; + use r_efi::protocols::simple_text_input_ex::{ + KeyData, KeyState, LEFT_ALT_PRESSED, LEFT_CONTROL_PRESSED, RIGHT_ALT_PRESSED, RIGHT_CONTROL_PRESSED, + SHIFT_STATE_VALID, + }; + + static RESET_CALLED: AtomicBool = AtomicBool::new(false); + static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + extern "efiapi" fn mock_reset_system( + _reset_type: efi::ResetType, + _status: efi::Status, + _data_size: usize, + _data: *mut core::ffi::c_void, + ) { + RESET_CALLED.store(true, Ordering::SeqCst); + } + + fn make_key_data(scan_code: u16, shift_state: u32) -> KeyData { + KeyData { + key: protocols::simple_text_input::InputKey { scan_code, unicode_char: 0 }, + key_state: KeyState { key_shift_state: shift_state, key_toggle_state: 0 }, + } + } + + // Acquires the test lock, installs a mock RuntimeServices, runs the closure, + // and cleans up. Serializes access to the shared RUNTIME_SERVICES and RESET_CALLED globals. + fn with_mock_runtime_services(f: F) { + let _guard = TEST_LOCK.lock().unwrap(); + let rt = Box::leak(Box::new(core::mem::MaybeUninit::::zeroed())); + // SAFETY: Only reset_system is accessed by the code under test. + unsafe { + (*rt.as_mut_ptr()).reset_system = mock_reset_system; + } + crate::RUNTIME_SERVICES.store(rt.as_mut_ptr(), Ordering::SeqCst); + RESET_CALLED.store(false, Ordering::SeqCst); + f(); + } + + #[test] + fn reset_notification_returns_invalid_parameter_on_null() { + let status = reset_notification_function(core::ptr::null_mut()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn reset_notification_ignores_non_delete_scan_code() { + with_mock_runtime_services(|| { + let mut key_data = + make_key_data(key_queue::SCAN_F1, SHIFT_STATE_VALID | LEFT_CONTROL_PRESSED | LEFT_ALT_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(!RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_ignores_delete_without_shift_state_valid() { + with_mock_runtime_services(|| { + let mut key_data = make_key_data(key_queue::SCAN_DELETE, LEFT_CONTROL_PRESSED | LEFT_ALT_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(!RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_ignores_delete_without_ctrl() { + with_mock_runtime_services(|| { + let mut key_data = make_key_data(key_queue::SCAN_DELETE, SHIFT_STATE_VALID | LEFT_ALT_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(!RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_ignores_delete_without_alt() { + with_mock_runtime_services(|| { + let mut key_data = make_key_data(key_queue::SCAN_DELETE, SHIFT_STATE_VALID | LEFT_CONTROL_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(!RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_triggers_on_left_ctrl_left_alt_delete() { + with_mock_runtime_services(|| { + let mut key_data = + make_key_data(key_queue::SCAN_DELETE, SHIFT_STATE_VALID | LEFT_CONTROL_PRESSED | LEFT_ALT_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_triggers_on_right_ctrl_right_alt_delete() { + with_mock_runtime_services(|| { + let mut key_data = make_key_data( + key_queue::SCAN_DELETE, + SHIFT_STATE_VALID | RIGHT_CONTROL_PRESSED | RIGHT_ALT_PRESSED, + ); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(RESET_CALLED.load(Ordering::SeqCst)); + }); + } + + #[test] + fn reset_notification_triggers_on_mixed_ctrl_alt_delete() { + with_mock_runtime_services(|| { + let mut key_data = + make_key_data(key_queue::SCAN_DELETE, SHIFT_STATE_VALID | LEFT_CONTROL_PRESSED | RIGHT_ALT_PRESSED); + let status = reset_notification_function(&mut key_data); + assert_eq!(status, efi::Status::SUCCESS); + assert!(RESET_CALLED.load(Ordering::SeqCst)); + }); + } + } + + // --- key repeat --- + + mod repeat_tests { + use super::*; + use crate::hid_io::MockHidIo; + use hidparser::report_data_types::UsageRange; + use patina::boot_services::MockBootServices; + + fn mock_boot_services_for_repeat() -> &'static MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + mock.expect_set_timer().returning(|_, _, _| Ok(())); + mock.expect_signal_event().returning(|_| Ok(())); + mock.expect_close_event().returning(|_| Ok(())); + // SAFETY: Leaked to obtain 'static lifetime for test use; never freed. + unsafe { &*Box::into_raw(Box::new(mock)) } + } + + fn boot_keyboard_descriptor() -> ReportDescriptor { + // 8-byte report: 1 byte modifiers (8 variable bits) + 1 byte reserved + 6 array key slots + ReportDescriptor { + input_reports: vec![hidparser::Report { + report_id: None, + size_in_bits: 64, + fields: vec![ + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN, 0)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 1, 1)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 2, 2)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 3, 3)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 4, 4)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 5, 5)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 6, 6)), + ReportField::Variable(modifier_field(KEYBOARD_MODIFIER_USAGE_MIN + 7, 7)), + ReportField::Array(hidparser::ArrayField { + bits: 16..24, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ReportField::Array(hidparser::ArrayField { + bits: 24..32, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ReportField::Array(hidparser::ArrayField { + bits: 32..40, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ReportField::Array(hidparser::ArrayField { + bits: 40..48, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ReportField::Array(hidparser::ArrayField { + bits: 48..56, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ReportField::Array(hidparser::ArrayField { + bits: 56..64, + usage_list: vec![UsageRange::from(KEYBOARD_USAGE_MIN..=KEYBOARD_USAGE_MAX)], + logical_minimum: 0.into(), + logical_maximum: 0x65.into(), + ..Default::default() + }), + ], + }], + output_reports: vec![], + bad_input_reports: vec![], + bad_output_reports: vec![], + features: vec![], + bad_features: vec![], + } + } + + fn setup_handler() -> KeyboardHidHandler { + let boot_services = mock_boot_services_for_repeat(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + handler.process_descriptor(boot_keyboard_descriptor()).unwrap(); + handler.set_layout(Some(layout::get_default_keyboard_layout())); + // Set a non-null timer event and a valid context so repeat logic is exercised. + handler.repeat_timer_event = 0x42 as efi::Event; + let context = Box::into_raw(Box::new(RepeatTimerContext { keyboard_handler: &mut handler as *mut _ })); + handler.repeat_context = context; + handler + } + + fn mock_hid_io() -> MockHidIo { + let mut hid_io = MockHidIo::new(); + hid_io.expect_set_output_report().returning(|_, _| Ok(())); + hid_io + } + + #[test] + fn pressing_repeatable_key_sets_repeat_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press key (report value 0x04 → usage KEYBOARD_USAGE_MIN + 4 = 0x00070005) + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + + assert!(handler.repeat_key.is_some()); + assert_eq!(handler.repeat_key.unwrap(), Usage::from(KEYBOARD_USAGE_MIN + 4)); + } + + #[test] + fn releasing_repeat_key_clears_repeat_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press key + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert!(handler.repeat_key.is_some()); + + // Release key + let report: &[u8] = &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert!(handler.repeat_key.is_none()); + } + + #[test] + fn modifier_key_does_not_set_repeat_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press left shift (modifier bit 1, usage 0xE1) + let report: &[u8] = &[0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert!(handler.repeat_key.is_none()); + } + + #[test] + fn releasing_repeat_key_hands_off_to_remaining_held_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press first key (report 0x04 → usage 0x00070005) + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert_eq!(handler.repeat_key, Some(Usage::from(KEYBOARD_USAGE_MIN + 4))); + + // Press second key (report 0x05 → usage 0x00070006) while first held + let report: &[u8] = &[0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + // Second key is the newest pressed, becomes the repeat candidate + assert_eq!(handler.repeat_key, Some(Usage::from(KEYBOARD_USAGE_MIN + 5))); + + // Release second key, first still held → repeat should hand off to first + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert_eq!(handler.repeat_key, Some(Usage::from(KEYBOARD_USAGE_MIN + 4))); + } + + #[test] + fn new_key_replaces_repeat_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press first key + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert_eq!(handler.repeat_key, Some(Usage::from(KEYBOARD_USAGE_MIN + 4))); + + // Press second key while first still held + let report: &[u8] = &[0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert_eq!(handler.repeat_key, Some(Usage::from(KEYBOARD_USAGE_MIN + 5))); + } + + #[test] + fn on_repeat_timer_enqueues_keystroke() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press key (report 0x04 → usage 0x00070005 = 'b') + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + + // Pop the initial keystroke + let _initial = handler.pop_key(); + + // Simulate timer callback + let mut context = RepeatTimerContext { keyboard_handler: &mut handler as *mut _ }; + KeyboardHidHandler::on_repeat_timer(ptr::null_mut(), &mut context); + + // Should have enqueued a repeat keystroke + let repeat_key = handler.pop_key(); + assert!(repeat_key.is_some()); + assert_eq!(repeat_key.unwrap().key.unicode_char, 'b' as u16); + } + + #[test] + fn reset_clears_repeat_key() { + let mut handler = setup_handler(); + let hid_io = mock_hid_io(); + + // Press 'a' + let report: &[u8] = &[0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + handler.receive_report(report, &hid_io); + assert!(handler.repeat_key.is_some()); + + // Reset should cancel repeat + handler.reset(false); + assert!(handler.repeat_key.is_none()); + } + } +} diff --git a/uefi_hid/src/keyboard/simple_text_in.rs b/uefi_hid/src/keyboard/simple_text_in.rs new file mode 100644 index 0000000..c52a3ab --- /dev/null +++ b/uefi_hid/src/keyboard/simple_text_in.rs @@ -0,0 +1,576 @@ +//! Simple Text In Protocol FFI Support. +//! +//! This module manages the Simple Text In FFI. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::boxed::Box; +use core::ptr; + +use r_efi::{efi, protocols}; + +use patina::{ + boot_services::{ + BootServices, + c_ptr::PtrMetadata, + event::{EventNotifyCallback, EventType}, + tpl::Tpl, + }, + uefi_protocol::ProtocolInterface, +}; + +use super::KeyboardHidHandler; + +/// FFI context for SimpleTextInput protocol. +/// +/// A pointer to KeyboardHidHandler is included in the context so that it can be reclaimed in the simple_text_in API +/// implementations. Mutual exclusion on the KeyboardHidHandler is provided by the TplMutex on its key_queue state. +/// +/// The simple_text_in protocol element must be the first element in the structure so that the full structure can be +/// recovered by simple casting. +#[repr(C)] +pub(crate) struct SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol, + pub(crate) boot_services: &'static T, + pub(crate) keyboard_handler: *mut KeyboardHidHandler, +} + +// SAFETY: SimpleTextInFfi is #[repr(C)] with protocols::simple_text_input::Protocol as its +// first field, so a pointer to SimpleTextInFfi is a valid pointer to Protocol per the +// first-field casting pattern. +unsafe impl ProtocolInterface for SimpleTextInFfi { + const PROTOCOL_GUID: patina::BinaryGuid = patina::BinaryGuid(protocols::simple_text_input::PROTOCOL_GUID); +} + +impl SimpleTextInFfi { + /// Installs the simple text in protocol. Returns the key required to uninstall. + pub(crate) fn install( + boot_services: &'static T, + controller: efi::Handle, + keyboard_handler: &mut KeyboardHidHandler, + ) -> Result>, efi::Status> { + let mut ctx = Box::new(SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol { + reset: Self::simple_text_in_reset, + read_key_stroke: Self::simple_text_in_read_key_stroke, + wait_for_key: ptr::null_mut(), + }, + boot_services, + keyboard_handler: keyboard_handler as *mut KeyboardHidHandler, + }); + + let ctx_ptr: *mut Self = &mut *ctx; + + let wait_for_key_event = boot_services.create_event( + EventType::NOTIFY_WAIT, + Tpl::NOTIFY, + Some(Self::simple_text_in_wait_for_key as EventNotifyCallback<*mut Self>), + ctx_ptr, + )?; + + ctx.simple_text_in.wait_for_key = wait_for_key_event; + + match boot_services.install_protocol_interface(Some(controller), ctx) { + Ok((_handle, key)) => Ok(key), + Err(status) => { + // install_protocol_interface reconstructs and drops the Box on failure, + // but doesn't know about our event. + let _ = boot_services.close_event(wait_for_key_event); + Err(status) + } + } + } + + /// Uninstalls the simple text in protocol using the key from install. + pub(crate) fn uninstall( + boot_services: &'static T, + controller: efi::Handle, + key: PtrMetadata<'static, Box>, + ) -> Result<(), efi::Status> { + // Save the raw pointer before consuming key, in case uninstall fails. + let raw_ptr = key.ptr_value as *mut SimpleTextInFfi; + + let ctx = match boot_services.uninstall_protocol_interface(controller, key) { + Ok(ctx) => ctx, + Err(status) => { + log::error!("Failed to uninstall simple_text_in interface, status: {:x?}", status); + // Protocol is still installed. Null keyboard_handler so callbacks don't access a + // dropped KeyboardHidHandler. + // SAFETY: raw_ptr was saved from the protocol key before uninstall was attempted; + // since uninstall failed the protocol is still installed and the memory is valid. + unsafe { + if let Some(ctx) = raw_ptr.as_mut() { + ctx.keyboard_handler = ptr::null_mut(); + } + } + return Err(status); + } + }; + + if let Err(status) = boot_services.close_event(ctx.simple_text_in.wait_for_key) { + log::error!("Failed to close simple_text_in.wait_for_key event, status: {:x?}", status); + // Leak ctx so the still-live event callback doesn't use freed memory. + core::mem::forget(ctx); + return Err(status); + } + + // ctx drops here, freeing the SimpleTextInFfi allocation. + Ok(()) + } + + // Resets the keyboard state — part of the simple_text_in protocol interface. + extern "efiapi" fn simple_text_in_reset( + this: *mut protocols::simple_text_input::Protocol, + extended_verification: efi::Boolean, + ) -> efi::Status { + // SAFETY: `this` points to the first field of Self (#[repr(C)]), so the cast recovers + // the full SimpleTextInFfi context. Null is handled by the check below. + let context = unsafe { (this as *mut Self).as_mut() }; + let Some(context) = context else { return efi::Status::INVALID_PARAMETER }; + // SAFETY: keyboard_handler pointer validity is ensured by the protocol lifecycle; + // null is handled by the check below. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_reset invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + keyboard_handler.reset(extended_verification.into()); + log::trace!("simple_text_in_reset: extended_verification={:?}", bool::from(extended_verification)); + efi::Status::SUCCESS + } + + // Reads a key stroke — part of the simple_text_in protocol interface. + extern "efiapi" fn simple_text_in_read_key_stroke( + this: *mut protocols::simple_text_input::Protocol, + key: *mut protocols::simple_text_input::InputKey, + ) -> efi::Status { + if this.is_null() || key.is_null() { + return efi::Status::INVALID_PARAMETER; + } + // SAFETY: `this` points to the first field of Self (#[repr(C)]), so the cast recovers + // the full SimpleTextInFfi context. Null is checked by the preceding guard. + let Some(context) = (unsafe { (this as *mut Self).as_mut() }) else { + return efi::Status::INVALID_PARAMETER; + }; + // SAFETY: keyboard_handler pointer validity is ensured by the protocol lifecycle; + // null is handled by the check below. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_read_key_stroke invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + + let mut kq = keyboard_handler.state.lock(); + loop { + if let Some(mut key_data) = kq.pop_key() { + // skip partials + if key_data.key.unicode_char == 0 && key_data.key.scan_code == 0 { + continue; + } + const CONTROL_PRESSED: u32 = protocols::simple_text_input_ex::RIGHT_CONTROL_PRESSED + | protocols::simple_text_input_ex::LEFT_CONTROL_PRESSED; + const LOWERCASE_A: u16 = 0x0061; + const LOWERCASE_Z: u16 = 0x007a; + const UPPERCASE_A: u16 = 0x0041; + const UPPERCASE_Z: u16 = 0x005a; + if (key_data.key_state.key_shift_state & CONTROL_PRESSED) != 0 { + if key_data.key.unicode_char >= LOWERCASE_A && key_data.key.unicode_char <= LOWERCASE_Z { + key_data.key.unicode_char = (key_data.key.unicode_char - LOWERCASE_A) + 1; + } + if key_data.key.unicode_char >= UPPERCASE_A && key_data.key.unicode_char <= UPPERCASE_Z { + key_data.key.unicode_char = (key_data.key.unicode_char - UPPERCASE_A) + 1; + } + } + // SAFETY: key output pointer was null-checked at the top of this function; + // write_unaligned is used to avoid any alignment requirements. + unsafe { key.write_unaligned(key_data.key) } + log::trace!( + "simple_text_in_read_key_stroke: unicode=0x{:04X} scan=0x{:04X}", + key_data.key.unicode_char, + key_data.key.scan_code, + ); + return efi::Status::SUCCESS; + } else { + return efi::Status::NOT_READY; + } + } + } + + // Handles the wait_for_key event — part of the simple_text_in protocol interface. + extern "efiapi" fn simple_text_in_wait_for_key(event: efi::Event, context: *mut Self) { + // SAFETY: context pointer was provided by the UEFI event system from Box::into_raw; + // null is handled by the else branch below. + let Some(context) = (unsafe { context.as_mut() }) else { + log::error!("simple_text_in_wait_for_key invoked with invalid context"); + return; + }; + // SAFETY: keyboard_handler pointer validity is ensured by the protocol lifecycle; + // null is handled by the check below. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { return }; + let mut kq = keyboard_handler.state.lock(); + while let Some(key_data) = kq.peek_key() { + if key_data.key.unicode_char == 0 && key_data.key.scan_code == 0 { + let _ = kq.pop_key(); + continue; + } else { + let _ = context.boot_services.signal_event(event); + break; + } + } + } +} + +#[cfg(test)] +mod test { + use core::ptr; + + use alloc::boxed::Box; + use r_efi::{efi, protocols}; + + use patina::boot_services::{ + MockBootServices, + c_ptr::{CPtr, PtrMetadata}, + }; + + use super::*; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| patina::boot_services::tpl::Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked to obtain 'static lifetime for test use; never freed. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + type StiPtr = SimpleTextInFfi; + + fn test_context( + boot_services: &'static MockBootServices, + handler: &mut KeyboardHidHandler, + ) -> SimpleTextInFfi { + SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol { + reset: SimpleTextInFfi::::simple_text_in_reset, + read_key_stroke: SimpleTextInFfi::::simple_text_in_read_key_stroke, + wait_for_key: ptr::null_mut(), + }, + boot_services, + keyboard_handler: handler as *mut KeyboardHidHandler, + } + } + + fn leaked_context_key( + boot_services: &'static MockBootServices, + handler: &mut KeyboardHidHandler, + ) -> PtrMetadata<'static, Box> { + let ctx = Box::new(SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol { + reset: SimpleTextInFfi::::simple_text_in_reset, + read_key_stroke: SimpleTextInFfi::::simple_text_in_read_key_stroke, + wait_for_key: 0x42 as efi::Event, + }, + boot_services, + keyboard_handler: handler as *mut _, + }); + let key = ctx.metadata(); + let _ = Box::into_raw(ctx); + key + } + + // --- install/uninstall --- + + #[test] + fn install_succeeds() { + let boot_services = mock_boot_services(); + boot_services.expect_create_event::<*mut StiPtr>().times(1).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services.expect_install_protocol_interface::>().times(1).returning( + |_, protocol_interface| { + let key = protocol_interface.metadata(); + let _ = Box::into_raw(protocol_interface); + Ok((0x1 as efi::Handle, key)) + }, + ); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert!(result.is_ok()); + + let key = result.unwrap(); + // SAFETY: Reclaiming the Box leaked by the mock install_protocol_interface. + drop(unsafe { Box::from_raw(key.ptr_value as *mut StiPtr) }); + } + + #[test] + fn install_returns_error_on_create_event_failure() { + let boot_services = mock_boot_services(); + boot_services + .expect_create_event::<*mut StiPtr>() + .times(1) + .returning(|_, _, _, _| Err(efi::Status::OUT_OF_RESOURCES)); + boot_services.expect_install_protocol_interface::>().never(); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::OUT_OF_RESOURCES)); + } + + #[test] + fn install_cleans_up_event_on_install_protocol_failure() { + let boot_services = mock_boot_services(); + boot_services.expect_create_event::<*mut StiPtr>().times(1).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services + .expect_install_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + boot_services.expect_close_event().times(1).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::ACCESS_DENIED)); + } + + #[test] + fn uninstall_succeeds() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + // SAFETY: Reclaiming the Box from the key, mirroring the real uninstall_protocol_interface. + .returning(|_, key| Ok(unsafe { Box::from_raw(key.ptr_value as *mut StiPtr) })); + boot_services.expect_close_event().times(1).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + let result = SimpleTextInFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Ok(())); + } + + #[test] + fn uninstall_failure_nulls_keyboard_handler() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + let raw_ptr = key.ptr_value as *mut StiPtr; + + let result = SimpleTextInFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Err(efi::Status::ACCESS_DENIED)); + + // SAFETY: raw_ptr points to the still-live leaked context (uninstall failed). + let ctx = unsafe { raw_ptr.as_ref() }.unwrap(); + assert!(ctx.keyboard_handler.is_null()); + + // SAFETY: Reclaiming the leaked context Box for cleanup. + drop(unsafe { Box::from_raw(raw_ptr) }); + } + + #[test] + fn uninstall_returns_error_on_close_event_failure() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + // SAFETY: Reclaiming the Box from the key, mirroring the real uninstall_protocol_interface. + .returning(|_, key| Ok(unsafe { Box::from_raw(key.ptr_value as *mut StiPtr) })); + boot_services.expect_close_event().times(1).returning(|_| Err(efi::Status::INVALID_PARAMETER)); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + let result = SimpleTextInFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Err(efi::Status::INVALID_PARAMETER)); + } + + // --- FFI callbacks --- + + #[test] + fn reset_returns_invalid_parameter_on_null() { + let status = SimpleTextInFfi::::simple_text_in_reset(ptr::null_mut(), false.into()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn reset_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol { + reset: SimpleTextInFfi::::simple_text_in_reset, + read_key_stroke: SimpleTextInFfi::::simple_text_in_read_key_stroke, + wait_for_key: ptr::null_mut(), + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let status = + SimpleTextInFfi::::simple_text_in_reset(&mut ctx.simple_text_in as *mut _, false.into()); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_key_stroke_returns_not_ready_on_empty_queue() { + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut key = protocols::simple_text_input::InputKey::default(); + let status = SimpleTextInFfi::::simple_text_in_read_key_stroke( + &mut ctx.simple_text_in as *mut _, + &mut key, + ); + assert_eq!(status, efi::Status::NOT_READY); + } + + #[test] + fn read_key_stroke_returns_invalid_parameter_on_null() { + let status = + SimpleTextInFfi::::simple_text_in_read_key_stroke(ptr::null_mut(), ptr::null_mut()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn wait_for_key_handles_null_context() { + let event = 0x1234 as efi::Event; + // Should log an error but not panic. + SimpleTextInFfi::::simple_text_in_wait_for_key(event, ptr::null_mut()); + } + + #[test] + fn reset_succeeds_with_valid_handler() { + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let status = + SimpleTextInFfi::::simple_text_in_reset(&mut ctx.simple_text_in as *mut _, false.into()); + assert_eq!(status, efi::Status::SUCCESS); + } + + #[test] + fn read_key_stroke_returns_success_when_key_available() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + + { + let mut kq = handler.state.lock(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + kq.keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + } + + let mut ctx = test_context(boot_services, &mut handler); + let mut key = protocols::simple_text_input::InputKey::default(); + let status = SimpleTextInFfi::::simple_text_in_read_key_stroke( + &mut ctx.simple_text_in as *mut _, + &mut key, + ); + assert_eq!(status, efi::Status::SUCCESS); + assert!(key.scan_code != 0 || key.unicode_char != 0); + } + + #[test] + fn read_key_stroke_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInFfi { + simple_text_in: protocols::simple_text_input::Protocol { + reset: SimpleTextInFfi::::simple_text_in_reset, + read_key_stroke: SimpleTextInFfi::::simple_text_in_read_key_stroke, + wait_for_key: ptr::null_mut(), + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let mut key = protocols::simple_text_input::InputKey::default(); + let status = SimpleTextInFfi::::simple_text_in_read_key_stroke( + &mut ctx.simple_text_in as *mut _, + &mut key, + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_key_stroke_skips_partial_keys() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + + { + let mut kq = handler.state.lock(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + // Enable partial key support so modifier keys are enqueued. + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID + | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + // Press left ctrl — produces a partial key (scan=0, unicode=0). + kq.keystroke(Usage::from(0x000700E0u32), super::super::key_queue::KeyAction::KeyDown); + // Press Enter — produces a real key (unicode=0x0D). + kq.keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + } + + let mut ctx = test_context(boot_services, &mut handler); + let mut key = protocols::simple_text_input::InputKey::default(); + let status = SimpleTextInFfi::::simple_text_in_read_key_stroke( + &mut ctx.simple_text_in as *mut _, + &mut key, + ); + assert_eq!(status, efi::Status::SUCCESS); + assert!(key.scan_code != 0 || key.unicode_char != 0); + } + + #[test] + fn wait_for_key_signals_event_when_key_available() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + boot_services.expect_signal_event().times(1).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + + { + let mut kq = handler.state.lock(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + kq.keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + } + + let mut ctx = test_context(boot_services, &mut handler); + let event = 0x1234 as efi::Event; + SimpleTextInFfi::::simple_text_in_wait_for_key(event, &mut ctx as *mut _); + } + + #[test] + fn wait_for_key_skips_partial_and_signals_for_real_key() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + boot_services.expect_signal_event().times(1).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + + { + let mut kq = handler.state.lock(); + kq.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + // Enable partial key support so modifier keys are enqueued. + kq.set_key_toggle_state( + protocols::simple_text_input_ex::TOGGLE_STATE_VALID + | protocols::simple_text_input_ex::KEY_STATE_EXPOSED, + ); + // Press left ctrl — produces a partial key (scan=0, unicode=0). + kq.keystroke(Usage::from(0x000700E0u32), super::super::key_queue::KeyAction::KeyDown); + // Press Enter — produces a real key (unicode=0x0D). + kq.keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + } + + let mut ctx = test_context(boot_services, &mut handler); + let event = 0x1234 as efi::Event; + SimpleTextInFfi::::simple_text_in_wait_for_key(event, &mut ctx as *mut _); + } +} diff --git a/uefi_hid/src/keyboard/simple_text_in_ex.rs b/uefi_hid/src/keyboard/simple_text_in_ex.rs new file mode 100644 index 0000000..3504d41 --- /dev/null +++ b/uefi_hid/src/keyboard/simple_text_in_ex.rs @@ -0,0 +1,826 @@ +//! Simple Text In Ex Protocol FFI Support. +//! +//! This module manages the Simple Text In Ex FFI. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::boxed::Box; +use core::{ffi::c_void, ptr}; + +use r_efi::{efi, protocols}; + +use patina::{ + boot_services::{ + BootServices, + c_ptr::PtrMetadata, + event::{EventNotifyCallback, EventType}, + tpl::Tpl, + }, + uefi_protocol::ProtocolInterface, +}; + +use super::KeyboardHidHandler; + +/// FFI context for SimpleTextInputEx protocol. +/// +/// A pointer to KeyboardHidHandler is included in the context so that it can be reclaimed in the simple_text_in_ex API +/// implementations. Mutual exclusion on the KeyboardHidHandler is provided by the TplMutex on its key_queue state. +/// +/// The simple_text_in_ex protocol element must be the first element in the structure so that the full structure can be +/// recovered by simple casting. +#[repr(C)] +pub(crate) struct SimpleTextInExFfi { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol, + pub(crate) boot_services: &'static T, + pub(crate) keyboard_handler: *mut KeyboardHidHandler, +} + +// SAFETY: SimpleTextInExFfi is #[repr(C)] with protocols::simple_text_input_ex::Protocol as its +// first field, so a pointer to SimpleTextInExFfi is a valid pointer to Protocol per the +// first-field casting pattern. +unsafe impl ProtocolInterface for SimpleTextInExFfi { + const PROTOCOL_GUID: patina::BinaryGuid = patina::BinaryGuid(protocols::simple_text_input_ex::PROTOCOL_GUID); +} + +impl SimpleTextInExFfi { + /// Installs the simple text in ex protocol. Returns the key required to uninstall. + pub(crate) fn install( + boot_services: &'static T, + controller: efi::Handle, + keyboard_handler: &mut KeyboardHidHandler, + ) -> Result>, efi::Status> { + let mut ctx = Box::new(SimpleTextInExFfi { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: Self::simple_text_in_ex_reset, + read_key_stroke_ex: Self::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: Self::simple_text_in_ex_set_state, + register_key_notify: Self::simple_text_in_ex_register_key_notify, + unregister_key_notify: Self::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: keyboard_handler as *mut KeyboardHidHandler, + }); + + let ctx_ptr: *mut Self = &mut *ctx; + + let wait_for_key_event = boot_services.create_event( + EventType::NOTIFY_WAIT, + Tpl::NOTIFY, + Some(Self::simple_text_in_ex_wait_for_key as EventNotifyCallback<*mut Self>), + ctx_ptr, + )?; + + ctx.simple_text_in_ex.wait_for_key_ex = wait_for_key_event; + + // Key notifies dispatch at TPL_CALLBACK per UEFI spec 2.10 section 12.2.5. + let key_notify_event = match boot_services.create_event( + EventType::NOTIFY_SIGNAL, + Tpl::CALLBACK, + Some(Self::process_key_notifies as EventNotifyCallback<*mut Self>), + ctx_ptr, + ) { + Ok(event) => event, + Err(status) => { + let _ = boot_services.close_event(wait_for_key_event); + return Err(status); + } + }; + + match boot_services.install_protocol_interface(Some(controller), ctx) { + Ok((_handle, key)) => { + keyboard_handler.key_notify_event = key_notify_event; + Ok(key) + } + Err(status) => { + // install_protocol_interface reconstructs and drops the Box on failure, + // but doesn't know about our events. + let _ = boot_services.close_event(wait_for_key_event); + let _ = boot_services.close_event(key_notify_event); + Err(status) + } + } + } + + /// Uninstalls the simple text in ex protocol using the key from install. + pub(crate) fn uninstall( + boot_services: &'static T, + controller: efi::Handle, + key: PtrMetadata<'static, Box>, + ) -> Result<(), efi::Status> { + // Save the raw pointer before consuming key, in case uninstall fails. + let raw_ptr = key.ptr_value as *mut SimpleTextInExFfi; + + let ctx = match boot_services.uninstall_protocol_interface(controller, key) { + Ok(ctx) => ctx, + Err(status) => { + log::error!("Failed to uninstall simple_text_in_ex interface, status: {:x?}", status); + // Protocol is still installed. Null keyboard_handler so callbacks don't access a + // dropped KeyboardHidHandler. + // SAFETY: raw_ptr was saved before consuming key; protocol is still installed so memory is valid. + unsafe { + if let Some(ctx) = raw_ptr.as_mut() { + ctx.keyboard_handler = ptr::null_mut(); + } + } + return Err(status); + } + }; + + if let Err(status) = boot_services.close_event(ctx.simple_text_in_ex.wait_for_key_ex) { + log::error!("Failed to close simple_text_in_ex.wait_for_key_ex event, status: {:x?}", status); + // Leak ctx so the still-live event callback doesn't use freed memory. + core::mem::forget(ctx); + return Err(status); + } + + // Close key_notify_event stored on keyboard_handler. + // SAFETY: dereferencing keyboard_handler pointer; valid while ctx exists because uninstall succeeded. + if let Some(keyboard_handler) = unsafe { ctx.keyboard_handler.as_mut() } { + let key_notify_event = keyboard_handler.key_notify_event; + keyboard_handler.key_notify_event = ptr::null_mut(); + if !key_notify_event.is_null() + && let Err(status) = boot_services.close_event(key_notify_event) + { + log::error!("Failed to close key_notify_event, status: {:x?}", status); + // Leak ctx so the still-live event callback doesn't use freed memory. + core::mem::forget(ctx); + return Err(status); + } + } + // ctx drops here, freeing the SimpleTextInExFfi allocation. + Ok(()) + } + + // Resets the keyboard state — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_reset( + this: *mut protocols::simple_text_input_ex::Protocol, + extended_verification: efi::Boolean, + ) -> efi::Status { + // SAFETY: casting `this` (first field of #[repr(C)] struct) to recover full context; null handled below. + let context = unsafe { (this as *mut Self).as_mut() }; + let Some(context) = context else { return efi::Status::INVALID_PARAMETER }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_ex_reset invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + keyboard_handler.reset(extended_verification.into()); + log::trace!("simple_text_in_ex_reset: extended_verification={:?}", bool::from(extended_verification)); + efi::Status::SUCCESS + } + + // Reads a key stroke — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_read_key_stroke( + this: *mut protocols::simple_text_input_ex::Protocol, + key_data: *mut protocols::simple_text_input_ex::KeyData, + ) -> efi::Status { + if this.is_null() || key_data.is_null() { + return efi::Status::INVALID_PARAMETER; + } + // SAFETY: casting `this` (first field of #[repr(C)] struct) to recover full context; null handled below. + let Some(context) = (unsafe { (this as *mut Self).as_mut() }) else { + return efi::Status::INVALID_PARAMETER; + }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_ex_read_key_stroke invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + + let mut kq = keyboard_handler.state.lock(); + if let Some(key) = kq.pop_key() { + log::trace!( + "simple_text_in_ex_read_key_stroke: unicode=0x{:04X} scan=0x{:04X}", + key.key.unicode_char, + key.key.scan_code, + ); + // SAFETY: writing through output pointer; null-checked at top of function. + unsafe { key_data.write_unaligned(key) } + efi::Status::SUCCESS + } else { + let key = protocols::simple_text_input_ex::KeyData { key_state: kq.init_key_state(), ..Default::default() }; + // SAFETY: writing through output pointer; null-checked at top of function. + unsafe { key_data.write_unaligned(key) }; + efi::Status::NOT_READY + } + } + + // Sets the keyboard toggle state — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_set_state( + this: *mut protocols::simple_text_input_ex::Protocol, + key_toggle_state: *mut protocols::simple_text_input_ex::KeyToggleState, + ) -> efi::Status { + if this.is_null() || key_toggle_state.is_null() { + return efi::Status::INVALID_PARAMETER; + } + // SAFETY: casting `this` (first field of #[repr(C)] struct) to recover full context; null handled below. + let Some(context) = (unsafe { (this as *mut Self).as_mut() }) else { + return efi::Status::INVALID_PARAMETER; + }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_ex_set_state invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + // SAFETY: reading through pointer; null-checked at top of function. + keyboard_handler.set_key_toggle_state(unsafe { key_toggle_state.read() }); + log::trace!("simple_text_in_ex_set_state: toggle state set"); + efi::Status::SUCCESS + } + + // Registers a key notification callback — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_register_key_notify( + this: *mut protocols::simple_text_input_ex::Protocol, + key_data_ptr: *mut protocols::simple_text_input_ex::KeyData, + key_notification_function: protocols::simple_text_input_ex::KeyNotifyFunction, + notify_handle: *mut *mut c_void, + ) -> efi::Status { + if this.is_null() + || key_data_ptr.is_null() + || notify_handle.is_null() + || key_notification_function as usize == 0 + { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: casting `this` (first field of #[repr(C)] struct) to recover full context; null handled below. + let Some(context) = (unsafe { (this as *mut Self).as_mut() }) else { + return efi::Status::INVALID_PARAMETER; + }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_ex_register_key_notify invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + + // SAFETY: reading through pointer; null-checked at top of function. + let key_data = unsafe { key_data_ptr.read() }; + let handle = keyboard_handler.insert_key_notify_callback(key_data, key_notification_function); + log::trace!("simple_text_in_ex_register_key_notify: handle=0x{:X}", handle); + // SAFETY: writing through output pointer; null-checked at top of function. + unsafe { notify_handle.write(handle as *mut c_void) }; + efi::Status::SUCCESS + } + + // Unregisters a key notification callback — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_unregister_key_notify( + this: *mut protocols::simple_text_input_ex::Protocol, + notification_handle: *mut c_void, + ) -> efi::Status { + if this.is_null() { + return efi::Status::INVALID_PARAMETER; + } + // SAFETY: casting `this` (first field of #[repr(C)] struct) to recover full context; null handled below. + let Some(context) = (unsafe { (this as *mut Self).as_mut() }) else { + return efi::Status::INVALID_PARAMETER; + }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("simple_text_in_ex_unregister_key_notify invoked after keyboard dropped."); + return efi::Status::DEVICE_ERROR; + }; + match keyboard_handler.remove_key_notify_callback(notification_handle as usize) { + Ok(()) => { + log::trace!("simple_text_in_ex_unregister_key_notify: handle=0x{:X}", notification_handle as usize); + efi::Status::SUCCESS + } + Err(status) => status, + } + } + + // Handles the wait_for_key_ex event — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn simple_text_in_ex_wait_for_key(event: efi::Event, context: *mut Self) { + // SAFETY: dereferencing context pointer from UEFI event; null handled by else branch. + let Some(context) = (unsafe { context.as_mut() }) else { + log::error!("simple_text_in_ex_wait_for_key invoked with invalid context"); + return; + }; + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { return }; + let mut kq = keyboard_handler.state.lock(); + while let Some(key_data) = kq.peek_key() { + if key_data.key.unicode_char == 0 && key_data.key.scan_code == 0 { + let _ = kq.pop_key(); + continue; + } else { + let _ = context.boot_services.signal_event(event); + break; + } + } + } + + // Dispatches registered key notification callbacks — part of the simple_text_in_ex protocol interface. + extern "efiapi" fn process_key_notifies(_event: efi::Event, context: *mut Self) { + // SAFETY: dereferencing context pointer from UEFI event; null handled by else branch. + let Some(context) = (unsafe { context.as_mut() }) else { + return; + }; + loop { + // SAFETY: dereferencing keyboard_handler raw pointer; validity ensured by protocol lifecycle. + let keyboard_handler = unsafe { context.keyboard_handler.as_mut() }; + let Some(keyboard_handler) = keyboard_handler else { + log::error!("process_key_notifies event called without a valid keyboard_handler"); + return; + }; + let (pending_key, pending_callbacks) = keyboard_handler.pending_callbacks(); + if let Some(mut pending_key) = pending_key { + let key_ptr = &mut pending_key as *mut protocols::simple_text_input_ex::KeyData; + for callback in pending_callbacks { + let _ = callback(key_ptr); + } + } else { + break; + } + } + } +} + +#[cfg(test)] +mod test { + use core::ptr; + + use alloc::boxed::Box; + use r_efi::{efi, protocols}; + + use patina::boot_services::{ + MockBootServices, + c_ptr::{CPtr, PtrMetadata}, + }; + + use super::*; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| patina::boot_services::tpl::Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked to obtain 'static lifetime for test use; never freed. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + type StiExPtr = SimpleTextInExFfi; + + fn test_context( + boot_services: &'static MockBootServices, + handler: &mut KeyboardHidHandler, + ) -> SimpleTextInExFfi { + SimpleTextInExFfi { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: handler as *mut KeyboardHidHandler, + } + } + + fn leaked_context_key( + boot_services: &'static MockBootServices, + handler: &mut KeyboardHidHandler, + ) -> PtrMetadata<'static, Box> { + let ctx = Box::new(SimpleTextInExFfi { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: 0x42 as efi::Event, + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: handler as *mut _, + }); + let key = ctx.metadata(); + let _ = Box::into_raw(ctx); + key + } + + // --- install/uninstall --- + + #[test] + fn install_succeeds() { + let boot_services = mock_boot_services(); + // Two create_event calls: wait_for_key and key_notify + boot_services.expect_create_event::<*mut StiExPtr>().times(2).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services.expect_install_protocol_interface::>().times(1).returning( + |_, protocol_interface| { + let key = protocol_interface.metadata(); + let _ = Box::into_raw(protocol_interface); + Ok((0x1 as efi::Handle, key)) + }, + ); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInExFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert!(result.is_ok()); + assert!(!handler.key_notify_event.is_null()); + + let key = result.unwrap(); + // SAFETY: Reclaiming the Box leaked by the mock install_protocol_interface. + drop(unsafe { Box::from_raw(key.ptr_value as *mut StiExPtr) }); + } + + #[test] + fn install_returns_error_on_first_create_event_failure() { + let boot_services = mock_boot_services(); + boot_services + .expect_create_event::<*mut StiExPtr>() + .times(1) + .returning(|_, _, _, _| Err(efi::Status::OUT_OF_RESOURCES)); + boot_services.expect_install_protocol_interface::>().never(); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInExFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::OUT_OF_RESOURCES)); + } + + #[test] + fn install_cleans_up_events_on_install_protocol_failure() { + let boot_services = mock_boot_services(); + boot_services.expect_create_event::<*mut StiExPtr>().times(2).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services + .expect_install_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + // Both events cleaned up on failure. + boot_services.expect_close_event().times(2).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let result = SimpleTextInExFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::ACCESS_DENIED)); + } + + #[test] + fn uninstall_succeeds() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + // SAFETY: Reclaiming the Box from the key, mirroring the real uninstall_protocol_interface. + .returning(|_, key| Ok(unsafe { Box::from_raw(key.ptr_value as *mut StiExPtr) })); + // Two close_event calls: wait_for_key_ex and key_notify_event. + boot_services.expect_close_event().times(2).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + handler.key_notify_event = 0x99 as efi::Event; + let key = leaked_context_key(boot_services, &mut handler); + let result = SimpleTextInExFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Ok(())); + assert!(handler.key_notify_event.is_null()); + } + + #[test] + fn uninstall_failure_nulls_keyboard_handler() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + let raw_ptr = key.ptr_value as *mut StiExPtr; + + let result = SimpleTextInExFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Err(efi::Status::ACCESS_DENIED)); + + // SAFETY: raw_ptr points to the still-live leaked context (uninstall failed). + let ctx = unsafe { raw_ptr.as_ref() }.unwrap(); + assert!(ctx.keyboard_handler.is_null()); + + // SAFETY: Reclaiming the leaked context Box for cleanup. + drop(unsafe { Box::from_raw(raw_ptr) }); + } + + // --- FFI callbacks --- + + #[test] + fn reset_returns_invalid_parameter_on_null() { + let status = SimpleTextInExFfi::::simple_text_in_ex_reset(ptr::null_mut(), false.into()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn reset_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let status = SimpleTextInExFfi::::simple_text_in_ex_reset( + &mut ctx.simple_text_in_ex as *mut _, + false.into(), + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_key_stroke_returns_not_ready_on_empty_queue() { + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let status = SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + ); + assert_eq!(status, efi::Status::NOT_READY); + } + + #[test] + fn read_key_stroke_returns_invalid_parameter_on_null() { + let status = + SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke(ptr::null_mut(), ptr::null_mut()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn set_state_returns_invalid_parameter_on_null() { + let status = + SimpleTextInExFfi::::simple_text_in_ex_set_state(ptr::null_mut(), ptr::null_mut()); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn register_key_notify_returns_invalid_parameter_on_null_this() { + // All parameters must be non-null. Testing with null `this` is sufficient + // since the first check short-circuits. + extern "efiapi" fn noop(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let mut handle: *mut c_void = ptr::null_mut(); + let status = SimpleTextInExFfi::::simple_text_in_ex_register_key_notify( + ptr::null_mut(), + &mut key_data, + noop, + &mut handle, + ); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn unregister_key_notify_returns_invalid_parameter_on_null() { + let status = SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify( + ptr::null_mut(), + ptr::null_mut(), + ); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn wait_for_key_handles_null_context() { + let event = 0x1234 as efi::Event; + // Should log an error but not panic. + SimpleTextInExFfi::::simple_text_in_ex_wait_for_key(event, ptr::null_mut()); + } + + #[test] + fn reset_succeeds_with_valid_handler() { + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let status = SimpleTextInExFfi::::simple_text_in_ex_reset( + &mut ctx.simple_text_in_ex as *mut _, + false.into(), + ); + assert_eq!(status, efi::Status::SUCCESS); + } + + #[test] + fn read_key_stroke_returns_success_when_key_available() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + handler.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + // Push an Enter keystroke (usage 0x00070028). + handler.state.lock().keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + let mut ctx = test_context(boot_services, &mut handler); + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let status = SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + ); + assert_eq!(status, efi::Status::SUCCESS); + assert!(key_data.key.unicode_char != 0 || key_data.key.scan_code != 0); + } + + #[test] + fn read_key_stroke_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi:: { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let status = SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn set_state_succeeds_with_valid_handler() { + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut toggle_state: protocols::simple_text_input_ex::KeyToggleState = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID; + let status = SimpleTextInExFfi::::simple_text_in_ex_set_state( + &mut ctx.simple_text_in_ex as *mut _, + &mut toggle_state, + ); + assert_eq!(status, efi::Status::SUCCESS); + } + + #[test] + fn set_state_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi:: { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let mut toggle_state: protocols::simple_text_input_ex::KeyToggleState = + protocols::simple_text_input_ex::TOGGLE_STATE_VALID; + let status = SimpleTextInExFfi::::simple_text_in_ex_set_state( + &mut ctx.simple_text_in_ex as *mut _, + &mut toggle_state, + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn register_key_notify_succeeds() { + extern "efiapi" fn noop(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let mut handle: *mut c_void = ptr::null_mut(); + let status = SimpleTextInExFfi::::simple_text_in_ex_register_key_notify( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + noop, + &mut handle, + ); + assert_eq!(status, efi::Status::SUCCESS); + assert!(!handle.is_null()); + } + + #[test] + fn register_key_notify_returns_device_error_when_handler_null() { + extern "efiapi" fn noop(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi:: { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let mut handle: *mut c_void = ptr::null_mut(); + let status = SimpleTextInExFfi::::simple_text_in_ex_register_key_notify( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + noop, + &mut handle, + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn unregister_key_notify_returns_device_error_when_handler_null() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi:: { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let status = SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify( + &mut ctx.simple_text_in_ex as *mut _, + core::ptr::dangling_mut::(), + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn unregister_key_notify_succeeds_after_register() { + extern "efiapi" fn noop(_key: *mut protocols::simple_text_input_ex::KeyData) -> efi::Status { + efi::Status::SUCCESS + } + + let boot_services = mock_boot_services(); + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut key_data = protocols::simple_text_input_ex::KeyData::default(); + let mut handle: *mut c_void = ptr::null_mut(); + let status = SimpleTextInExFfi::::simple_text_in_ex_register_key_notify( + &mut ctx.simple_text_in_ex as *mut _, + &mut key_data, + noop, + &mut handle, + ); + assert_eq!(status, efi::Status::SUCCESS); + + let status = SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify( + &mut ctx.simple_text_in_ex as *mut _, + handle, + ); + assert_eq!(status, efi::Status::SUCCESS); + } + + #[test] + fn wait_for_key_signals_event_when_key_available() { + use hidparser::report_data_types::Usage; + + let boot_services = mock_boot_services(); + boot_services.expect_signal_event().times(1).returning(|_| Ok(())); + + let mut handler = KeyboardHidHandler::new_for_test(boot_services); + handler.set_layout(Some(crate::keyboard::layout::get_default_keyboard_layout())); + handler.state.lock().keystroke(Usage::from(0x00070028u32), super::super::key_queue::KeyAction::KeyDown); + let mut ctx = test_context(boot_services, &mut handler); + let event = 0x1234 as efi::Event; + SimpleTextInExFfi::::simple_text_in_ex_wait_for_key(event, &mut ctx); + } + + #[test] + fn process_key_notifies_handles_null_handler() { + let boot_services = mock_boot_services(); + let mut ctx = SimpleTextInExFfi:: { + simple_text_in_ex: protocols::simple_text_input_ex::Protocol { + reset: SimpleTextInExFfi::::simple_text_in_ex_reset, + read_key_stroke_ex: SimpleTextInExFfi::::simple_text_in_ex_read_key_stroke, + wait_for_key_ex: ptr::null_mut(), + set_state: SimpleTextInExFfi::::simple_text_in_ex_set_state, + register_key_notify: SimpleTextInExFfi::::simple_text_in_ex_register_key_notify, + unregister_key_notify: SimpleTextInExFfi::::simple_text_in_ex_unregister_key_notify, + }, + boot_services, + keyboard_handler: ptr::null_mut(), + }; + let event = 0x1234 as efi::Event; + // Should return without panic when keyboard_handler is null. + SimpleTextInExFfi::::process_key_notifies(event, &mut ctx); + } +} diff --git a/uefi_hid/src/lib.rs b/uefi_hid/src/lib.rs new file mode 100644 index 0000000..bf65734 --- /dev/null +++ b/uefi_hid/src/lib.rs @@ -0,0 +1,180 @@ +//! UEFI HID - Human Interface Device support as a Patina component. +//! +//! This crate provides a Patina component that consumes the HidIo protocol and +//! produces UEFI input protocols (SimpleTextInput, SimpleTextInputEx, +//! AbsolutePointer) for keyboard and pointer HID devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +#![cfg_attr(not(test), no_std)] +#![feature(coverage_attribute)] + +extern crate alloc; + +pub mod hid; +pub mod hid_io; +pub mod keyboard; +pub mod pointer; + +#[cfg(test)] +pub(crate) mod test_stubs; + +use alloc::boxed::Box; + +use r_efi::efi; + +use patina::{ + BinaryGuid, + boot_services::{BootServices, StandardBootServices}, + component::{component, params}, + driver_binding::UefiDriverBinding, + error::Result, + uefi_protocol::ProtocolInterface, +}; + +#[cfg(feature = "ctrl-alt-del")] +use core::sync::atomic::AtomicPtr; + +#[cfg(feature = "ctrl-alt-del")] +use patina::runtime_services::StandardRuntimeServices; + +/// Global pointer to UEFI Runtime Services, used by the Ctrl-Alt-Del reset callback. +#[cfg(feature = "ctrl-alt-del")] +pub static RUNTIME_SERVICES: AtomicPtr = AtomicPtr::new(core::ptr::null_mut()); + +/// Zero-sized marker protocol used to create a dedicated driver binding handle. +#[repr(C)] +struct UefiHidMarker; + +// SAFETY: UefiHidMarker is a ZST whose GUID uniquely identifies this component. +unsafe impl ProtocolInterface for UefiHidMarker { + const PROTOCOL_GUID: BinaryGuid = BinaryGuid::from_string("122ffcfd-f8f8-46d6-81de-333e2419ebcb"); +} + +/// UEFI HID Patina component. +/// +/// When dispatched, installs a UEFI Driver Binding that consumes HidIo +/// protocol instances and produces keyboard and pointer input protocols. +pub struct UefiHidComponent; + +#[component] +impl UefiHidComponent { + #[cfg(feature = "ctrl-alt-del")] + fn entry_point( + self, + boot_services: StandardBootServices, + image_handle: params::Handle, + runtime_services: StandardRuntimeServices, + ) -> Result<()> { + RUNTIME_SERVICES.store(runtime_services.as_mut_ptr(), core::sync::atomic::Ordering::SeqCst); + let boot_services: &'static StandardBootServices = Box::leak(Box::new(boot_services)); + install_hid_driver_binding(boot_services, *image_handle) + } + + #[cfg(not(feature = "ctrl-alt-del"))] + fn entry_point(self, boot_services: StandardBootServices, image_handle: params::Handle) -> Result<()> { + let boot_services: &'static StandardBootServices = Box::leak(Box::new(boot_services)); + install_hid_driver_binding(boot_services, *image_handle) + } +} + +/// Installs the HID driver binding using the provided boot services. +/// +/// Separated from the component entry point to allow testing with +/// `MockBootServices`. +fn install_hid_driver_binding( + boot_services: &'static T, + image_handle: efi::Handle, +) -> Result<()> { + // Patina component model has a single image handle. Create a separate driver_binding handle for the driver binding + // to avoid conflict on the image handle. + let (driver_binding_handle, _marker_key) = + boot_services.install_protocol_interface(None, Box::new(UefiHidMarker))?; + + let driver_binding = hid::HidDriver::new(boot_services, driver_binding_handle); + + let mut driver_binding = + UefiDriverBinding::new_with_driver_handle(driver_binding, image_handle, driver_binding_handle, boot_services); + + driver_binding.install().map_err(patina::error::EfiError::from)?; + // driver_binding is intentionally leaked by install() — it lives forever. + + Ok(()) +} + +#[cfg(test)] +mod test { + use patina::boot_services::{MockBootServices, c_ptr::CPtr}; + + use super::*; + + #[test] + fn install_hid_binding_should_install_a_binding() { + let boot_services = MockBootServices::new(); + let boot_services = Box::leak(Box::new(boot_services)); + + boot_services.expect_install_protocol_interface::>().returning( + |handle, protocol_interface| { + assert_eq!(handle, None, "Expected no handle for marker protocol installation"); + Ok((0x5678 as efi::Handle, protocol_interface.metadata())) + }, + ); + + boot_services.expect_install_protocol_interface_unchecked().returning(|handle, protocol, interface| { + if protocol == &efi::protocols::driver_binding::PROTOCOL_GUID { + assert!( + handle.is_some_and(|handle| handle as usize == 0x5678), + "Expected correct handle for driver binding protocol" + ); + assert!(!interface.is_null(), "Expected non-null interface for driver binding protocol"); + return Ok(0x9abc as efi::Handle); + } + panic!("Unexpected protocol installation: {:?}", protocol); + }); + + let mock_image_handle = 0x1234 as efi::Handle; + install_hid_driver_binding(boot_services, mock_image_handle).expect("install should succeed"); + } + + #[test] + fn install_hid_binding_handles_failure() { + let boot_services = Box::leak(Box::new(MockBootServices::new())); + + boot_services + .expect_install_protocol_interface::>() + .returning(|_, _| Err(efi::Status::OUT_OF_RESOURCES)); + + let mock_image_handle = 0x1234 as efi::Handle; + assert_eq!( + install_hid_driver_binding(boot_services, mock_image_handle), + Err(efi::Status::OUT_OF_RESOURCES.into()) + ); + + let boot_services = Box::leak(Box::new(MockBootServices::new())); + + boot_services + .expect_install_protocol_interface::>() + .returning(|_, protocol_interface| Ok((0x5678 as efi::Handle, protocol_interface.metadata()))); + + boot_services.expect_install_protocol_interface_unchecked().returning(|handle, protocol, interface| { + if protocol == &efi::protocols::driver_binding::PROTOCOL_GUID { + assert!( + handle.is_some_and(|handle| handle as usize == 0x5678), + "Expected correct handle for driver binding protocol" + ); + assert!(!interface.is_null(), "Expected non-null interface for driver binding protocol"); + return Err(efi::Status::OUT_OF_RESOURCES); + } + panic!("Unexpected protocol installation: {:?}", protocol); + }); + + assert_eq!( + install_hid_driver_binding(boot_services, mock_image_handle), + Err(efi::Status::OUT_OF_RESOURCES.into()) + ); + } +} diff --git a/uefi_hid/src/pointer/absolute_pointer.rs b/uefi_hid/src/pointer/absolute_pointer.rs new file mode 100644 index 0000000..862bfa7 --- /dev/null +++ b/uefi_hid/src/pointer/absolute_pointer.rs @@ -0,0 +1,682 @@ +//! Absolute Pointer Protocol FFI Support. +//! +//! This module manages the Absolute Pointer Protocol FFI. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use alloc::boxed::Box; +use core::ptr; + +use r_efi::{efi, protocols}; + +use hidparser::report_data_types::Usage; + +use patina::{ + boot_services::{ + BootServices, + c_ptr::PtrMetadata, + event::{EventNotifyCallback, EventType}, + tpl::Tpl, + }, + uefi_protocol::ProtocolInterface, +}; + +use super::{ + AXIS_RESOLUTION, BUTTON_MAX, BUTTON_MIN, DIGITIZER_SWITCH_MAX, DIGITIZER_SWITCH_MIN, GENERIC_DESKTOP_WHEEL, + GENERIC_DESKTOP_X, GENERIC_DESKTOP_Y, GENERIC_DESKTOP_Z, PointerHidHandler, +}; + +/// FFI context +/// # Safety +/// A pointer to PointerHidHandler is included in the context so that it can be reclaimed in the absolute_pointer +/// API implementation. Care must be taken to ensure that rust invariants are respected when accessing the +/// PointerHidHandler. In particular, the design must ensure mutual exclusion on the PointerHidHandler between +/// callbacks running at different TPL; this is accomplished by ensuring all access to the structure is at TPL_NOTIFY +/// once initialization is complete. +/// +/// The absolute_pointer element must be the first element in the structure so that the full structure can be +/// recovered by simple casting. +#[repr(C)] +pub(crate) struct AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol, + boot_services: &'static T, + pointer_handler: *mut PointerHidHandler, +} + +// SAFETY: AbsolutePointerFfi is #[repr(C)] with protocols::absolute_pointer::Protocol as its +// first field, so a pointer to AbsolutePointerFfi is a valid pointer to Protocol per the +// first-field casting pattern. +unsafe impl ProtocolInterface for AbsolutePointerFfi { + const PROTOCOL_GUID: patina::BinaryGuid = patina::BinaryGuid(protocols::absolute_pointer::PROTOCOL_GUID); +} + +impl Drop for AbsolutePointerFfi { + fn drop(&mut self) { + if !self.absolute_pointer.mode.is_null() { + // SAFETY: mode was created via Box::into_raw during install and is non-null (checked above). + drop(unsafe { Box::from_raw(self.absolute_pointer.mode) }); + } + } +} + +impl AbsolutePointerFfi { + /// Installs the absolute pointer protocol. If successful, returns the key required to uninstall. + pub(crate) fn install( + boot_services: &'static T, + controller: efi::Handle, + pointer_handler: &mut PointerHidHandler, + ) -> Result>, efi::Status> { + let mut pointer_ctx = Box::new(AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol { + reset: Self::absolute_pointer_reset, + get_state: Self::absolute_pointer_get_state, + mode: Box::into_raw(Box::new(Self::initialize_mode(pointer_handler))), + wait_for_input: ptr::null_mut(), + }, + boot_services, + pointer_handler: pointer_handler as *mut PointerHidHandler, + }); + + let ctx_ptr: *mut Self = &mut *pointer_ctx; + + let wait_for_input_event = boot_services.create_event( + EventType::NOTIFY_WAIT, + Tpl::NOTIFY, + Some(Self::absolute_pointer_wait_for_input as EventNotifyCallback<*mut Self>), + ctx_ptr, + )?; + + pointer_ctx.absolute_pointer.wait_for_input = wait_for_input_event; + + match boot_services.install_protocol_interface(Some(controller), pointer_ctx) { + Ok((_handle, key)) => Ok(key), + Err(status) => { + // install_protocol_interface reconstructs and drops the Box on failure (freeing mode + // via AbsolutePointerFfi::Drop), but doesn't know about our event. + let _ = boot_services.close_event(wait_for_input_event); + Err(status) + } + } + } + + // Initializes the absolute_pointer mode structure. + fn initialize_mode(pointer_handler: &PointerHidHandler) -> protocols::absolute_pointer::Mode { + let mut mode: protocols::absolute_pointer::Mode = Default::default(); + + if pointer_handler.processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_X)) { + mode.absolute_max_x = AXIS_RESOLUTION; + mode.absolute_min_x = 0; + } else { + log::warn!("No x-axis usages found in the report descriptor."); + } + + if pointer_handler.processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_Y)) { + mode.absolute_max_y = AXIS_RESOLUTION; + mode.absolute_min_y = 0; + } else { + log::warn!("No y-axis usages found in the report descriptor."); + } + + if pointer_handler.processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_Z)) + || pointer_handler.processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_WHEEL)) + { + mode.absolute_max_z = AXIS_RESOLUTION; + mode.absolute_min_z = 0; + } + + let has_multiple_buttons = pointer_handler + .processor + .supported_usages + .iter() + .filter(|x| matches!((**x).into(), BUTTON_MIN..=BUTTON_MAX | DIGITIZER_SWITCH_MIN..=DIGITIZER_SWITCH_MAX)) + .nth(1) + .is_some(); + + if has_multiple_buttons { + mode.attributes |= protocols::absolute_pointer::SUPPORTS_ALT_ACTIVE; + } + + mode + } + + /// Uninstalls the absolute pointer protocol + pub(crate) fn uninstall( + boot_services: &'static T, + controller: efi::Handle, + key: PtrMetadata<'static, Box>, + ) -> Result<(), efi::Status> { + // Save the raw pointer before consuming key, in case uninstall fails. + let raw_ptr = key.ptr_value as *mut AbsolutePointerFfi; + + let pointer_ctx = match boot_services.uninstall_protocol_interface(controller, key) { + Ok(ctx) => ctx, + Err(status) => { + log::error!("Failed to uninstall absolute_pointer interface, status: {:x?}", status); + // Protocol is still installed. Null pointer_handler so callbacks don't access a + // dropped PointerHidHandler. + // SAFETY: raw_ptr was saved before uninstall consumed the key; protocol is still installed so memory is valid. + unsafe { + if let Some(ctx) = raw_ptr.as_mut() { + ctx.pointer_handler = ptr::null_mut(); + } + } + return Err(status); + } + }; + + if let Err(status) = boot_services.close_event(pointer_ctx.absolute_pointer.wait_for_input) { + log::error!("Failed to close absolute_pointer.wait_for_input event, status: {:x?}", status); + // Leak pointer_ctx so the still-live event callback doesn't use freed memory. + core::mem::forget(pointer_ctx); + return Err(status); + } + + // pointer_ctx drops here, freeing the mode allocation via AbsolutePointerFfi::Drop. + Ok(()) + } + + // Handles the wait_for_input event — part of the absolute_pointer protocol interface. + extern "efiapi" fn absolute_pointer_wait_for_input(event: efi::Event, context: *mut Self) { + if context.is_null() { + log::error!("absolute_pointer_wait_for_input invoked with invalid context"); + return; + } + // SAFETY: context pointer is non-null (checked above) and valid for the lifetime of the UEFI event callback. + let context = unsafe { context.as_ref() }.expect("context pointer should not be null"); + // SAFETY: pointer_handler validity is ensured by the protocol lifecycle. + if let Some(pointer_handler) = unsafe { context.pointer_handler.as_ref() } { + if pointer_handler.state.lock().state_changed { + let _ = context.boot_services.signal_event(event); + } + } else { + log::error!("absolute_pointer_wait_for_input invoked after pointer dropped."); + } + } + + // Resets the pointer state — part of the absolute_pointer protocol interface. + extern "efiapi" fn absolute_pointer_reset( + this: *mut protocols::absolute_pointer::Protocol, + _extended_verification: bool, + ) -> efi::Status { + if this.is_null() { + return efi::Status::INVALID_PARAMETER; + } + // SAFETY: this is the first field of #[repr(C)] AbsolutePointerFfi, non-null (checked above). + let context = + unsafe { (this as *const AbsolutePointerFfi).as_ref() }.expect("context pointer should not be null"); + // SAFETY: pointer_handler validity is ensured by the protocol lifecycle. + if let Some(pointer_handler) = unsafe { context.pointer_handler.as_ref() } { + pointer_handler.state.lock().reset(); + efi::Status::SUCCESS + } else { + log::error!("absolute_pointer_reset invoked after pointer dropped."); + efi::Status::DEVICE_ERROR + } + } + + // Returns the current pointer state — part of the absolute_pointer protocol interface. + extern "efiapi" fn absolute_pointer_get_state( + this: *mut protocols::absolute_pointer::Protocol, + state: *mut protocols::absolute_pointer::State, + ) -> efi::Status { + if this.is_null() || state.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this is the first field of #[repr(C)] AbsolutePointerFfi, non-null (checked above). + let context = + unsafe { (this as *const AbsolutePointerFfi).as_ref() }.expect("context pointer should not be null"); + // SAFETY: pointer_handler validity is ensured by the protocol lifecycle. + if let Some(pointer_handler) = unsafe { context.pointer_handler.as_ref() } { + let mut pointer_state = pointer_handler.state.lock(); + if pointer_state.state_changed { + // SAFETY: state is non-null (checked above), using write_unaligned to avoid any alignment issues. + unsafe { + state.write_unaligned(pointer_state.current_state); + } + pointer_state.state_changed = false; + efi::Status::SUCCESS + } else { + efi::Status::NOT_READY + } + } else { + log::error!("absolute_pointer_get_state invoked after pointer dropped."); + efi::Status::DEVICE_ERROR + } + } +} + +#[cfg(test)] +mod test { + use core::ptr; + + use alloc::boxed::Box; + use r_efi::{efi, protocols}; + + use hidparser::report_data_types::Usage; + use patina::boot_services::MockBootServices; + + use super::*; + use crate::pointer::{ + AXIS_RESOLUTION, BUTTON_MIN, CENTER, GENERIC_DESKTOP_WHEEL, GENERIC_DESKTOP_X, GENERIC_DESKTOP_Y, + GENERIC_DESKTOP_Z, PointerHidHandler, + }; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| patina::boot_services::tpl::Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked mock for test use with 'static lifetime requirement. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + /// Builds a test AbsolutePointerFfi context wired to the given handler. The caller must ensure + /// `handler` outlives the returned context. The returned context has a null `wait_for_input` + /// event (not needed for callback unit tests) and a heap-allocated mode. + fn test_context( + boot_services: &'static MockBootServices, + handler: &mut PointerHidHandler, + ) -> AbsolutePointerFfi { + AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol { + reset: AbsolutePointerFfi::::absolute_pointer_reset, + get_state: AbsolutePointerFfi::::absolute_pointer_get_state, + mode: Box::into_raw(Box::new(AbsolutePointerFfi::initialize_mode(handler))), + wait_for_input: ptr::null_mut(), + }, + boot_services, + pointer_handler: handler as *mut PointerHidHandler, + } + } + + // --- initialize_mode tests --- + + #[test] + fn initialize_mode_sets_xy_axes() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.processor.supported_usages.insert(Usage::from(GENERIC_DESKTOP_X)); + handler.processor.supported_usages.insert(Usage::from(GENERIC_DESKTOP_Y)); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_eq!(mode.absolute_max_x, AXIS_RESOLUTION); + assert_eq!(mode.absolute_min_x, 0); + assert_eq!(mode.absolute_max_y, AXIS_RESOLUTION); + assert_eq!(mode.absolute_min_y, 0); + assert_eq!(mode.absolute_max_z, 0); + assert_eq!(mode.attributes, 0); + } + + #[test] + fn initialize_mode_sets_z_axis_for_z_usage() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.processor.supported_usages.insert(Usage::from(GENERIC_DESKTOP_Z)); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_eq!(mode.absolute_max_z, AXIS_RESOLUTION); + assert_eq!(mode.absolute_min_z, 0); + } + + #[test] + fn initialize_mode_sets_z_axis_for_wheel_usage() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.processor.supported_usages.insert(Usage::from(GENERIC_DESKTOP_WHEEL)); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_eq!(mode.absolute_max_z, AXIS_RESOLUTION); + assert_eq!(mode.absolute_min_z, 0); + } + + #[test] + fn initialize_mode_sets_alt_active_for_multiple_buttons() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.processor.supported_usages.insert(Usage::from(BUTTON_MIN)); + handler.processor.supported_usages.insert(Usage::from(BUTTON_MIN + 1)); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_ne!(mode.attributes & protocols::absolute_pointer::SUPPORTS_ALT_ACTIVE, 0); + } + + #[test] + fn initialize_mode_no_alt_active_for_single_button() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.processor.supported_usages.insert(Usage::from(BUTTON_MIN)); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_eq!(mode.attributes & protocols::absolute_pointer::SUPPORTS_ALT_ACTIVE, 0); + } + + #[test] + fn initialize_mode_no_usages_returns_zeroed_mode() { + let boot_services = mock_boot_services(); + let handler = PointerHidHandler::new_for_test(boot_services); + + let mode = AbsolutePointerFfi::initialize_mode(&handler); + + assert_eq!(mode.absolute_max_x, 0); + assert_eq!(mode.absolute_max_y, 0); + assert_eq!(mode.absolute_max_z, 0); + assert_eq!(mode.attributes, 0); + } + + // --- FFI callback tests --- + + #[test] + fn reset_returns_invalid_parameter_on_null() { + let status = AbsolutePointerFfi::::absolute_pointer_reset(ptr::null_mut(), false); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn reset_clears_state() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + // Set up some non-default state. + { + let mut state = handler.state.lock(); + state.current_state.current_x = 100; + state.state_changed = true; + } + + let mut ctx = test_context(boot_services, &mut handler); + let status = + AbsolutePointerFfi::::absolute_pointer_reset(&mut ctx.absolute_pointer as *mut _, false); + + assert_eq!(status, efi::Status::SUCCESS); + let state = handler.state.lock(); + assert_eq!(state.current_state.current_x, CENTER); + assert_eq!(state.current_state.current_y, CENTER); + assert!(!state.state_changed); + } + + #[test] + fn reset_returns_device_error_after_pointer_dropped() { + let boot_services = mock_boot_services(); + let mut ctx = AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol { + reset: AbsolutePointerFfi::::absolute_pointer_reset, + get_state: AbsolutePointerFfi::::absolute_pointer_get_state, + mode: ptr::null_mut(), + wait_for_input: ptr::null_mut(), + }, + boot_services, + pointer_handler: ptr::null_mut(), + }; + + let status = + AbsolutePointerFfi::::absolute_pointer_reset(&mut ctx.absolute_pointer as *mut _, false); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn get_state_returns_invalid_parameter_on_null_this() { + let status = AbsolutePointerFfi::::absolute_pointer_get_state( + ptr::null_mut(), + &mut protocols::absolute_pointer::State::default(), + ); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn get_state_returns_invalid_parameter_on_null_state() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + + let status = AbsolutePointerFfi::::absolute_pointer_get_state( + &mut ctx.absolute_pointer as *mut _, + ptr::null_mut(), + ); + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn get_state_returns_not_ready_when_unchanged() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + let mut ctx = test_context(boot_services, &mut handler); + let mut state = protocols::absolute_pointer::State::default(); + + let status = AbsolutePointerFfi::::absolute_pointer_get_state( + &mut ctx.absolute_pointer as *mut _, + &mut state, + ); + assert_eq!(status, efi::Status::NOT_READY); + } + + #[test] + fn get_state_returns_state_and_clears_flag() { + let boot_services = mock_boot_services(); + let mut handler = PointerHidHandler::new_for_test(boot_services); + { + let mut state = handler.state.lock(); + state.current_state.current_x = 200; + state.current_state.current_y = 300; + state.current_state.active_buttons = 1; + state.state_changed = true; + } + + let mut ctx = test_context(boot_services, &mut handler); + let mut out_state = protocols::absolute_pointer::State::default(); + + let status = AbsolutePointerFfi::::absolute_pointer_get_state( + &mut ctx.absolute_pointer as *mut _, + &mut out_state, + ); + + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(out_state.current_x, 200); + assert_eq!(out_state.current_y, 300); + assert_eq!(out_state.active_buttons, 1); + // Flag should be cleared after read. + assert!(!handler.state.lock().state_changed); + } + + #[test] + fn get_state_returns_device_error_after_pointer_dropped() { + let boot_services = mock_boot_services(); + let mut ctx = AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol { + reset: AbsolutePointerFfi::::absolute_pointer_reset, + get_state: AbsolutePointerFfi::::absolute_pointer_get_state, + mode: ptr::null_mut(), + wait_for_input: ptr::null_mut(), + }, + boot_services, + pointer_handler: ptr::null_mut(), + }; + + let status = AbsolutePointerFfi::::absolute_pointer_get_state( + &mut ctx.absolute_pointer as *mut _, + &mut protocols::absolute_pointer::State::default(), + ); + assert_eq!(status, efi::Status::DEVICE_ERROR); + } + + #[test] + fn wait_for_input_signals_event_when_state_changed() { + let boot_services = mock_boot_services(); + boot_services.expect_signal_event().times(1).returning(|_| Ok(())); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + handler.state.lock().state_changed = true; + + let mut ctx = test_context(boot_services, &mut handler); + let event = 0x1234 as efi::Event; + + AbsolutePointerFfi::::absolute_pointer_wait_for_input(event, &mut ctx); + } + + #[test] + fn wait_for_input_does_not_signal_when_unchanged() { + let boot_services = mock_boot_services(); + boot_services.expect_signal_event().never(); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + // state_changed defaults to false. + + let mut ctx = test_context(boot_services, &mut handler); + let event = 0x1234 as efi::Event; + + AbsolutePointerFfi::::absolute_pointer_wait_for_input(event, &mut ctx); + } + + #[test] + fn wait_for_input_handles_null_context() { + let event = 0x1234 as efi::Event; + // Should log an error but not panic. + AbsolutePointerFfi::::absolute_pointer_wait_for_input(event, ptr::null_mut()); + } + + // --- install/uninstall tests --- + + /// Creates a leaked `AbsolutePointerFfi` box and returns its `PtrMetadata` key, + /// simulating the state after a successful `install`. The caller is responsible for + /// ensuring the leaked allocation is cleaned up (either via `uninstall_protocol_interface` + /// reconstructing the Box, or manual `Box::from_raw`). + fn leaked_context_key( + boot_services: &'static MockBootServices, + handler: &mut PointerHidHandler, + ) -> PtrMetadata<'static, Box>> { + use patina::boot_services::c_ptr::CPtr; + + let ctx = Box::new(AbsolutePointerFfi { + absolute_pointer: protocols::absolute_pointer::Protocol { + reset: AbsolutePointerFfi::::absolute_pointer_reset, + get_state: AbsolutePointerFfi::::absolute_pointer_get_state, + mode: Box::into_raw(Box::new(protocols::absolute_pointer::Mode::default())), + wait_for_input: 0x42 as efi::Event, + }, + boot_services, + pointer_handler: handler as *mut _, + }); + let key = ctx.metadata(); + let _ = Box::into_raw(ctx); + key + } + + type AbsPtr = AbsolutePointerFfi; + + #[test] + fn install_succeeds() { + let boot_services = mock_boot_services(); + boot_services.expect_create_event::<*mut AbsPtr>().times(1).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services.expect_install_protocol_interface::>().times(1).returning( + |_, protocol_interface| { + use patina::boot_services::c_ptr::CPtr; + let key = protocol_interface.metadata(); + let _ = Box::into_raw(protocol_interface); + Ok((0x1 as efi::Handle, key)) + }, + ); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let result = AbsolutePointerFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert!(result.is_ok()); + + // Clean up the leaked box. + let key = result.unwrap(); + // SAFETY: Reclaiming the Box leaked by the mock install_protocol_interface. + drop(unsafe { Box::from_raw(key.ptr_value as *mut AbsPtr) }); + } + + #[test] + fn install_returns_error_on_create_event_failure() { + let boot_services = mock_boot_services(); + boot_services + .expect_create_event::<*mut AbsPtr>() + .times(1) + .returning(|_, _, _, _| Err(efi::Status::OUT_OF_RESOURCES)); + boot_services.expect_install_protocol_interface::>().never(); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let result = AbsolutePointerFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::OUT_OF_RESOURCES)); + } + + #[test] + fn install_cleans_up_event_on_install_protocol_failure() { + let boot_services = mock_boot_services(); + boot_services.expect_create_event::<*mut AbsPtr>().times(1).returning(|_, _, _, _| Ok(0x42 as efi::Event)); + boot_services + .expect_install_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + // close_event must be called to clean up the event on install failure. + boot_services.expect_close_event().times(1).returning(|_| Ok(())); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let result = AbsolutePointerFfi::install(boot_services, 0x2 as efi::Handle, &mut handler); + assert_eq!(result.err(), Some(efi::Status::ACCESS_DENIED)); + } + + #[test] + fn uninstall_succeeds() { + let boot_services = mock_boot_services(); + boot_services.expect_uninstall_protocol_interface::>().times(1).returning(|_, key| { + // SAFETY: Reclaiming the Box from the key, mirroring the real uninstall_protocol_interface. + Ok(unsafe { Box::from_raw(key.ptr_value as *mut AbsPtr) }) + }); + boot_services.expect_close_event().times(1).returning(|_| Ok(())); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + + let result = AbsolutePointerFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Ok(())); + } + + #[test] + fn uninstall_failure_nulls_pointer_handler() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + .returning(|_, _| Err(efi::Status::ACCESS_DENIED)); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + let raw_ptr = key.ptr_value as *mut AbsPtr; + + let result = AbsolutePointerFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Err(efi::Status::ACCESS_DENIED)); + + // Verify pointer_handler was nulled as a safety measure. + // SAFETY: raw_ptr points to the still-live leaked context (uninstall failed). + let ctx = unsafe { raw_ptr.as_ref() }.unwrap(); + assert!(ctx.pointer_handler.is_null()); + + // SAFETY: Reclaiming the leaked context Box for cleanup. + drop(unsafe { Box::from_raw(raw_ptr) }); + } + + #[test] + fn uninstall_returns_error_on_close_event_failure() { + let boot_services = mock_boot_services(); + boot_services + .expect_uninstall_protocol_interface::>() + .times(1) + // SAFETY: Reclaiming the Box from the key, mirroring the real uninstall_protocol_interface. + .returning(|_, key| Ok(unsafe { Box::from_raw(key.ptr_value as *mut AbsPtr) })); + boot_services.expect_close_event().times(1).returning(|_| Err(efi::Status::INVALID_PARAMETER)); + + let mut handler = PointerHidHandler::new_for_test(boot_services); + let key = leaked_context_key(boot_services, &mut handler); + + let result = AbsolutePointerFfi::uninstall(boot_services, 0x2 as efi::Handle, key); + assert_eq!(result, Err(efi::Status::INVALID_PARAMETER)); + } +} diff --git a/uefi_hid/src/pointer/mod.rs b/uefi_hid/src/pointer/mod.rs new file mode 100644 index 0000000..7a4acef --- /dev/null +++ b/uefi_hid/src/pointer/mod.rs @@ -0,0 +1,794 @@ +//! Provides Pointer HID support. +//! +//! This module handles the core logic for processing pointer input from HID +//! devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +pub(crate) mod absolute_pointer; + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + vec::Vec, +}; + +use r_efi::{efi, protocols}; + +use hidparser::{ + ReportDescriptor, ReportField, VariableField, + report_data_types::{ReportId, Usage}, +}; + +use patina::{ + boot_services::{BootServices, c_ptr::PtrMetadata, tpl::Tpl}, + tpl_mutex::TplMutex, +}; + +use crate::hid_io::{HidIo, HidReportReceiver}; + +use self::absolute_pointer::AbsolutePointerFfi; + +// Usages supported by this module. +const GENERIC_DESKTOP_X: u32 = 0x00010030; +const GENERIC_DESKTOP_Y: u32 = 0x00010031; +const GENERIC_DESKTOP_Z: u32 = 0x00010032; +const GENERIC_DESKTOP_WHEEL: u32 = 0x00010038; +const BUTTON_MIN: u32 = 0x00090001; +const BUTTON_MAX: u32 = 0x00090020; +const DIGITIZER_SWITCH_MIN: u32 = 0x000d0042; +const DIGITIZER_SWITCH_MAX: u32 = 0x000d0046; +const DIGITIZER_CONTACT_COUNT: u32 = 0x000d0054; + +// Number of points on the X/Y axis for this implementation. +const AXIS_RESOLUTION: u64 = 1024; +const CENTER: u64 = AXIS_RESOLUTION / 2; + +/// Mutable pointer state updated during report processing. +#[derive(Debug)] +pub(crate) struct PointerState { + pub(crate) state_changed: bool, + pub(crate) current_state: protocols::absolute_pointer::State, + contact_count: Option, +} + +impl PointerState { + // Creates a new PointerState centered at the default position. + fn new() -> Self { + let mut state = Self { state_changed: false, current_state: Default::default(), contact_count: None }; + state.reset(); + state + } + + /// Resets the pointer state to the default values. + pub(crate) fn reset(&mut self) { + self.current_state = Default::default(); + self.current_state.current_x = CENTER; + self.current_state.current_y = CENTER; + self.state_changed = false; + self.contact_count = None; + } + + // Helper routine that handles projecting relative and absolute axis reports onto the fixed + // absolute report axis that this driver produces. + fn resolve_axis(current_value: u64, field: &VariableField, report: &[u8]) -> Option { + if field.attributes.relative { + let new_value = current_value as i64 + field.field_value(report)?; + Some(new_value.clamp(0, AXIS_RESOLUTION as i64) as u64) + } else { + let mut new_value = field.field_value(report)?; + new_value = new_value.checked_sub(i32::from(field.logical_minimum) as i64)?; + new_value = (new_value * AXIS_RESOLUTION as i64 * 1000) / (field.field_range()? as i64 * 1000); + Some(new_value.clamp(0, AXIS_RESOLUTION as i64) as u64) + } + } + + // Updates the axis value from the given report field. + fn axis_handler(&mut self, field: &VariableField, report: &[u8]) { + let current_value = match field.usage.into() { + GENERIC_DESKTOP_X => &mut self.current_state.current_x, + GENERIC_DESKTOP_Y => &mut self.current_state.current_y, + GENERIC_DESKTOP_Z | GENERIC_DESKTOP_WHEEL => &mut self.current_state.current_z, + _ => return, + }; + if let Some(new_value) = Self::resolve_axis(*current_value, field, report) + && *current_value != new_value + { + *current_value = new_value; + self.state_changed = true; + } + } + + // Updates button state from the given report field. + fn button_handler(&mut self, field: &VariableField, report: &[u8]) { + let shift = match field.usage.into() { + x @ BUTTON_MIN..=BUTTON_MAX => x - BUTTON_MIN, + x @ DIGITIZER_SWITCH_MIN..=DIGITIZER_SWITCH_MAX => x - DIGITIZER_SWITCH_MIN, + _ => return, + }; + + if let Some(button_value) = field.field_value(report) { + let button_value = button_value as u32; + + if shift > u32::BITS { + return; + } + let button_value = button_value << shift; + + let new_buttons = self.current_state.active_buttons + & !(1 << shift) // zero the relevant bit in the button state field. + | button_value; // or in the current button state into that bit position. + + if new_buttons != self.current_state.active_buttons { + self.current_state.active_buttons = new_buttons; + self.state_changed = true; + } + } + } + + // Updates the contact count from the given report field. + fn contact_count_handler(&mut self, field: &VariableField, report: &[u8]) { + if let Some(contact_count) = field.field_value(report) { + if let Ok(contact_count) = usize::try_from(contact_count) { + self.contact_count = Some(contact_count); + } else { + log::debug!("Ignoring negative contact_count: {}", contact_count); + } + } + } +} + +// Function pointer type for per-field report processing. +type ReportHandler = fn(&mut PointerState, field: &VariableField, report: &[u8]); + +// Maps a given HID report field to a routine that handles input from it. +struct PointerInputFieldHandler { + field: VariableField, + report_handler: ReportHandler, +} + +// Defines a report and the fields of interest within it. +#[derive(Default)] +struct PointerInputReportSpec { + report_id: Option, + report_size: usize, + relevant_fields: Vec, +} + +// Defines counters for determining how many contact points need to be handled +#[derive(Default, Clone)] +struct UsageUpdateCounter { + x: usize, + y: usize, + z: usize, + button: usize, + switch: usize, +} + +// Core pointer data processing logic, independent of UEFI boot services. +pub(crate) struct PointerProcessor { + input_reports: BTreeMap, PointerInputReportSpec>, + pub(crate) supported_usages: BTreeSet, + report_id_present: bool, +} + +impl PointerProcessor { + // Creates a new processor with empty report maps. + fn new() -> Self { + Self { input_reports: BTreeMap::new(), supported_usages: BTreeSet::new(), report_id_present: false } + } + + // Parses a report descriptor and registers relevant field handlers. + fn process_descriptor(&mut self, descriptor: ReportDescriptor) -> Result<(), efi::Status> { + let multiple_reports = descriptor.input_reports.len() > 1; + log::trace!("pointer::process_descriptor: {:?} input report(s) in descriptor", descriptor.input_reports.len(),); + + for report in &descriptor.input_reports { + let mut report_data = PointerInputReportSpec { report_id: report.report_id, ..Default::default() }; + + self.report_id_present = report.report_id.is_some(); + + if multiple_reports && !self.report_id_present { + return Err(efi::Status::DEVICE_ERROR); + } + + report_data.report_size = report.size_in_bits.div_ceil(8); + + for field in &report.fields { + if let ReportField::Variable(field) = field { + let handler: Option<(ReportHandler, bool)> = match field.usage.into() { + // Contact count is processed first to ensure that if counters are present, we can enforce + // them when processing the rest of the fields in the report. + DIGITIZER_CONTACT_COUNT => Some((PointerState::contact_count_handler, true)), + GENERIC_DESKTOP_X | GENERIC_DESKTOP_Y | GENERIC_DESKTOP_Z | GENERIC_DESKTOP_WHEEL => { + Some((PointerState::axis_handler, false)) + } + BUTTON_MIN..=BUTTON_MAX => Some((PointerState::button_handler, false)), + DIGITIZER_SWITCH_MIN..=DIGITIZER_SWITCH_MAX => Some((PointerState::button_handler, false)), + _ => None, + }; + if let Some((report_handler, insert_first)) = handler { + let entry = PointerInputFieldHandler { field: field.clone(), report_handler }; + if insert_first { + report_data.relevant_fields.insert(0, entry); + } else { + report_data.relevant_fields.push(entry); + } + self.supported_usages.insert(field.usage); + } + } + } + + if !report_data.relevant_fields.is_empty() { + self.input_reports.insert(report_data.report_id, report_data); + } + } + if !self.input_reports.is_empty() { + log::trace!( + "pointer::process_descriptor: {:?} usable report(s) with {:?} supported usage(s)", + self.input_reports.len(), + self.supported_usages.len(), + ); + Ok(()) + } else { + Err(efi::Status::UNSUPPORTED) + } + } + + // Processes an incoming HID report and updates pointer state. + fn process_report(&self, report: &[u8], state: &mut PointerState) { + if report.is_empty() { + return; + } + + log::trace!("pointer::process_report: {:?} bytes", report.len()); + + let (report_id, report) = match self.report_id_present { + true => (Some(ReportId::from(&report[0..1])), &report[1..]), + false => (None, &report[0..]), + }; + + if report.is_empty() { + return; + } + + if let Some(report_data) = self.input_reports.get(&report_id) { + state.contact_count = None; + let mut counters = UsageUpdateCounter::default(); + + if report.len() != report_data.report_size { + log::trace!( + "receive_report: unexpected report length for report_id: {:?}. expected {:?}, actual {:?}", + report_id, + report_data.report_size, + report.len() + ); + } + + for field in &report_data.relevant_fields { + let counter = match field.field.usage.into() { + DIGITIZER_CONTACT_COUNT => None, + GENERIC_DESKTOP_X => Some(&mut counters.x), + GENERIC_DESKTOP_Y => Some(&mut counters.y), + GENERIC_DESKTOP_Z | GENERIC_DESKTOP_WHEEL => Some(&mut counters.z), + BUTTON_MIN..=BUTTON_MAX => Some(&mut counters.button), + DIGITIZER_SWITCH_MIN..=DIGITIZER_SWITCH_MAX => Some(&mut counters.switch), + _ => continue, + }; + + if let Some(counter) = counter { + if state.contact_count.is_some_and(|c| *counter >= c) { + continue; + } + *counter += 1; + } + + (field.report_handler)(state, &field.field, report); + } + } + } +} + +/// Pointer HID handler that processes reports and produces UEFI AbsolutePointer state. +pub struct PointerHidHandler { + boot_services: &'static T, + controller: efi::Handle, + absolute_pointer_key: Option>>>, + pub(crate) processor: PointerProcessor, + pub(crate) state: TplMutex, +} + +impl PointerHidHandler { + /// Creates a fully initialized Pointer HID handler for the given controller. + /// + /// Returns a boxed handler because protocol installation stores raw pointers to `self`. + /// Boxing first ensures those pointers remain valid (the handler is never moved after boxing). + pub fn new( + boot_services: &'static T, + controller: efi::Handle, + hid_io: &dyn HidIo, + ) -> Result, efi::Status> { + let mut processor = PointerProcessor::new(); + let descriptor = hid_io.get_report_descriptor()?; + processor.process_descriptor(descriptor)?; + + let mut handler = Box::new(Self { + boot_services, + controller, + absolute_pointer_key: None, + processor, + state: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, PointerState::new()), + }); + + let key = AbsolutePointerFfi::install(boot_services, controller, &mut *handler)?; + handler.absolute_pointer_key = Some(key); + Ok(handler) + } + + #[cfg(test)] + pub fn new_for_test(boot_services: &'static T) -> Self { + Self { + boot_services, + controller: core::ptr::null_mut(), + absolute_pointer_key: None, + processor: PointerProcessor::new(), + state: TplMutex::new((*boot_services).clone(), Tpl::NOTIFY, PointerState::new()), + } + } +} + +impl HidReportReceiver for PointerHidHandler { + fn receive_report(&mut self, report: &[u8], _hid_io: &dyn HidIo) { + log::trace!("pointer::receive_report: {:?} bytes", report.len()); + self.processor.process_report(report, &mut self.state.lock()); + } +} + +impl Drop for PointerHidHandler { + fn drop(&mut self) { + if let Some(key) = self.absolute_pointer_key.take() + && let Err(status) = AbsolutePointerFfi::::uninstall(self.boot_services, self.controller, key) + { + log::error!("PointerHidHandler::drop: Failed to uninstall absolute_pointer: {:?}", status); + } + } +} + +#[cfg(test)] +mod test { + use hidparser::{ + ReportDescriptor, ReportField, VariableField, + report_data_types::{ReportAttributes, Usage}, + }; + use r_efi::efi; + + use super::*; + + // Creates an absolute VariableField at the given bit range with the given usage. + fn absolute_field(usage: u32, bits: core::ops::Range, logical_max: i32) -> VariableField { + VariableField { + bits, + usage: Usage::from(usage), + logical_minimum: 0.into(), + logical_maximum: logical_max.into(), + attributes: ReportAttributes { relative: false, ..Default::default() }, + ..Default::default() + } + } + + // Creates a relative VariableField at the given bit range with the given usage. + fn relative_field(usage: u32, bits: core::ops::Range, logical_min: i32, logical_max: i32) -> VariableField { + VariableField { + bits, + usage: Usage::from(usage), + logical_minimum: logical_min.into(), + logical_maximum: logical_max.into(), + attributes: ReportAttributes { relative: true, ..Default::default() }, + ..Default::default() + } + } + + // Builds a minimal ReportDescriptor with the given fields, no report ID. + fn descriptor_with_fields(fields: Vec, size_in_bits: usize) -> ReportDescriptor { + ReportDescriptor { + input_reports: alloc::vec![hidparser::Report { report_id: None, size_in_bits, fields }], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + } + } + + // --- PointerState handler tests --- + + #[test] + fn reset_sets_center_and_clears_state() { + let mut state = PointerState::new(); + state.current_state.current_x = 100; + state.state_changed = true; + state.contact_count = Some(5); + + state.reset(); + + assert_eq!(state.current_state.current_x, CENTER); + assert_eq!(state.current_state.current_y, CENTER); + assert!(!state.state_changed); + assert_eq!(state.contact_count, None); + } + + #[test] + fn axis_handler_updates_absolute_x() { + let mut state = PointerState::new(); + // 8-bit field at bits 0..8, logical range 0..255 + let field = absolute_field(GENERIC_DESKTOP_X, 0..8, 255); + // Report value 128 → projects to 514 on 0..1024 axis (128/255 * 1024 ≈ 514) + let report = [128u8]; + + state.axis_handler(&field, &report); + + assert!(state.state_changed); + assert_eq!(state.current_state.current_x, 514); + } + + #[test] + fn axis_handler_updates_absolute_y() { + let mut state = PointerState::new(); + let field = absolute_field(GENERIC_DESKTOP_Y, 0..8, 255); + let report = [0u8]; + + state.axis_handler(&field, &report); + + assert!(state.state_changed); + assert_eq!(state.current_state.current_y, 0); + } + + #[test] + fn axis_handler_updates_relative_z() { + let mut state = PointerState::new(); + // 8-bit signed relative field, range -127..127 + let field = relative_field(GENERIC_DESKTOP_Z, 0..8, -127, 127); + // Report value 10 (relative) + let report = [10u8]; + + state.axis_handler(&field, &report); + + assert!(state.state_changed); + assert_eq!(state.current_state.current_z, 10); + } + + #[test] + fn same_value_does_not_set_state_changed() { + let mut state = PointerState::new(); + let field = absolute_field(GENERIC_DESKTOP_X, 0..8, 255); + // Value that maps to CENTER (512) + let report = [128u8]; + state.axis_handler(&field, &report); + assert!(state.state_changed); + + // Reset flag and send same value again. + state.state_changed = false; + state.axis_handler(&field, &report); + assert!(!state.state_changed); + } + + #[test] + fn button_handler_sets_button_bit() { + let mut state = PointerState::new(); + // Button 1 (usage 0x00090001): 1-bit field at bit 0 + let field = absolute_field(BUTTON_MIN, 0..1, 1); + let report = [0x01u8]; + + state.button_handler(&field, &report); + + assert!(state.state_changed); + assert_eq!(state.current_state.active_buttons, 1); + } + + #[test] + fn button_handler_clears_button_bit() { + let mut state = PointerState::new(); + state.current_state.active_buttons = 1; + let field = absolute_field(BUTTON_MIN, 0..1, 1); + let report = [0x00u8]; + + state.button_handler(&field, &report); + + assert!(state.state_changed); + assert_eq!(state.current_state.active_buttons, 0); + } + + #[test] + fn contact_count_handler_sets_contact_count() { + let mut state = PointerState::new(); + let field = absolute_field(DIGITIZER_CONTACT_COUNT, 0..8, 255); + let report = [3u8]; + + state.contact_count_handler(&field, &report); + + assert_eq!(state.contact_count, Some(3)); + } + + // --- process_descriptor tests --- + + #[test] + fn process_descriptor_with_x_and_y_succeeds() { + let mut processor = PointerProcessor::new(); + + let fields = alloc::vec![ + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255)), + ReportField::Variable(absolute_field(GENERIC_DESKTOP_Y, 8..16, 255)), + ]; + let descriptor = descriptor_with_fields(fields, 16); + + assert_eq!(processor.process_descriptor(descriptor), Ok(())); + assert!(processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_X))); + assert!(processor.supported_usages.contains(&Usage::from(GENERIC_DESKTOP_Y))); + assert_eq!(processor.input_reports.len(), 1); + } + + #[test] + fn process_descriptor_with_no_relevant_fields_returns_unsupported() { + let mut processor = PointerProcessor::new(); + + let descriptor = descriptor_with_fields(alloc::vec![], 0); + + assert_eq!(processor.process_descriptor(descriptor), Err(efi::Status::UNSUPPORTED)); + } + + #[test] + fn process_descriptor_places_contact_count_first() { + let mut processor = PointerProcessor::new(); + + let fields = alloc::vec![ + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255)), + ReportField::Variable(absolute_field(DIGITIZER_CONTACT_COUNT, 8..16, 255)), + ]; + let descriptor = descriptor_with_fields(fields, 16); + + processor.process_descriptor(descriptor).unwrap(); + + let report_data = processor.input_reports.values().next().unwrap(); + let first_usage: u32 = report_data.relevant_fields[0].field.usage.into(); + assert_eq!(first_usage, DIGITIZER_CONTACT_COUNT); + } + + #[test] + fn process_descriptor_multiple_reports_without_ids_returns_error() { + let mut processor = PointerProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: alloc::vec![ + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))], + }, + hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_Y, 0..8, 255))], + }, + ], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + }; + assert_eq!(processor.process_descriptor(descriptor), Err(efi::Status::DEVICE_ERROR)); + } + + #[test] + fn process_descriptor_ignores_array_fields() { + let mut processor = PointerProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: alloc::vec![hidparser::Report { + report_id: None, + size_in_bits: 8, + fields: alloc::vec![ReportField::Array(hidparser::ArrayField { bits: 0..8, ..Default::default() })], + }], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + }; + // No relevant variable fields → UNSUPPORTED + assert_eq!(processor.process_descriptor(descriptor), Err(efi::Status::UNSUPPORTED)); + } + + #[test] + fn process_descriptor_with_button_and_switch_fields() { + let mut processor = PointerProcessor::new(); + let fields = alloc::vec![ + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255)), + ReportField::Variable(absolute_field(BUTTON_MIN, 8..9, 1)), + ReportField::Variable(absolute_field(DIGITIZER_SWITCH_MIN, 9..10, 1)), + ]; + let descriptor = descriptor_with_fields(fields, 10); + assert_eq!(processor.process_descriptor(descriptor), Ok(())); + assert!(processor.supported_usages.contains(&Usage::from(BUTTON_MIN))); + assert!(processor.supported_usages.contains(&Usage::from(DIGITIZER_SWITCH_MIN))); + } + + // --- process_report tests --- + + #[test] + fn process_report_empty_report_is_no_op() { + let mut processor = PointerProcessor::new(); + let fields = alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))]; + let descriptor = descriptor_with_fields(fields, 8); + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + processor.process_report(&[], &mut state); + assert!(!state.state_changed); + } + + #[test] + fn process_report_updates_pointer_state() { + let mut processor = PointerProcessor::new(); + let fields = alloc::vec![ + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255)), + ReportField::Variable(absolute_field(GENERIC_DESKTOP_Y, 8..16, 255)), + ]; + let descriptor = descriptor_with_fields(fields, 16); + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + processor.process_report(&[128, 64], &mut state); + assert!(state.state_changed); + assert_ne!(state.current_state.current_x, CENTER); + } + + #[test] + fn process_report_unregistered_report_id_is_ignored() { + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let mut processor = PointerProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: alloc::vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))], + }], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + // Report with wrong report ID (0x02) + processor.process_report(&[0x02, 128], &mut state); + assert!(!state.state_changed); + } + + #[test] + fn process_report_report_length_mismatch_still_processes() { + let mut processor = PointerProcessor::new(); + let fields = alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))]; + let descriptor = descriptor_with_fields(fields, 8); + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + // Send 2 bytes when only 1 expected — should still process + processor.process_report(&[128, 0], &mut state); + assert!(state.state_changed); + } + + #[test] + fn process_report_contact_count_limits_axis_updates() { + let mut processor = PointerProcessor::new(); + // Two X fields and a contact count field + let fields = alloc::vec![ + ReportField::Variable(absolute_field(DIGITIZER_CONTACT_COUNT, 0..8, 255)), + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 8..16, 255)), + ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 16..24, 255)), + ]; + let descriptor = descriptor_with_fields(fields, 24); + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + // Contact count = 1 → only first X field should be processed + processor.process_report(&[1, 200, 50], &mut state); + assert!(state.state_changed); + } + + #[test] + fn process_report_with_report_id_strips_id_byte() { + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let mut processor = PointerProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: alloc::vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))], + }], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + // Report with correct report ID (0x01) + data + processor.process_report(&[0x01, 200], &mut state); + assert!(state.state_changed); + } + + #[test] + fn process_report_with_report_id_and_empty_data_is_no_op() { + let report_id = hidparser::report_data_types::ReportId::from(&[0x01][..]); + let mut processor = PointerProcessor::new(); + let descriptor = ReportDescriptor { + input_reports: alloc::vec![hidparser::Report { + report_id: Some(report_id), + size_in_bits: 8, + fields: alloc::vec![ReportField::Variable(absolute_field(GENERIC_DESKTOP_X, 0..8, 255))], + }], + output_reports: alloc::vec![], + bad_input_reports: alloc::vec![], + bad_output_reports: alloc::vec![], + features: alloc::vec![], + bad_features: alloc::vec![], + }; + processor.process_descriptor(descriptor).unwrap(); + + let mut state = PointerState::new(); + // Only report ID byte, no data + processor.process_report(&[0x01], &mut state); + assert!(!state.state_changed); + } + + #[test] + fn axis_handler_ignores_unknown_usage() { + let mut state = PointerState::new(); + // Usage that's not X, Y, Z, or Wheel + let field = absolute_field(0x00010099, 0..8, 255); + state.axis_handler(&field, &[128]); + assert!(!state.state_changed); + } + + #[test] + fn axis_handler_wheel_updates_z() { + let mut state = PointerState::new(); + let field = relative_field(GENERIC_DESKTOP_WHEEL, 0..8, -127, 127); + state.axis_handler(&field, &[5]); + assert!(state.state_changed); + assert_eq!(state.current_state.current_z, 5); + } + + #[test] + fn button_handler_digitizer_switch_sets_bit() { + let mut state = PointerState::new(); + let field = absolute_field(DIGITIZER_SWITCH_MIN, 0..1, 1); + state.button_handler(&field, &[0x01]); + assert!(state.state_changed); + assert_eq!(state.current_state.active_buttons, 1); + } + + #[test] + fn button_handler_ignores_unknown_usage() { + let mut state = PointerState::new(); + let field = absolute_field(0x00990001, 0..1, 1); + state.button_handler(&field, &[0x01]); + assert!(!state.state_changed); + } + + #[test] + fn button_handler_shift_exceeding_u32_bits_is_ignored() { + let mut state = PointerState::new(); + // BUTTON_MAX has a shift of BUTTON_MAX - BUTTON_MIN which could be large + let field = absolute_field(BUTTON_MIN + 33, 0..1, 1); + state.button_handler(&field, &[0x01]); + assert!(!state.state_changed); + } +} diff --git a/uefi_hid/src/test_stubs.rs b/uefi_hid/src/test_stubs.rs new file mode 100644 index 0000000..c9604f5 --- /dev/null +++ b/uefi_hid/src/test_stubs.rs @@ -0,0 +1,71 @@ +//! Test stubs for protocol types whose `stub()` methods are not exposed +//! from the upstream patina crate (they are `#[cfg(test)]` internal to patina). +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::boxed::Box; +use core::ffi::c_void; + +use r_efi::efi; + +use patina::vendor_protocols::hid_io::{HidIoProtocol, HidIoReportCallback, HidReportType}; + +/// Creates a stub `HidIoProtocol` with no-op function pointers for testing. +#[coverage(off)] +pub fn hid_io_stub() -> &'static mut HidIoProtocol { + unsafe extern "efiapi" fn get_report_descriptor( + _this: *const HidIoProtocol, + report_descriptor_size: *mut usize, + _report_descriptor_buffer: *mut c_void, + ) -> efi::Status { + // SAFETY: report_descriptor_size is a valid pointer provided by the caller in the test stub. + unsafe { *report_descriptor_size = 0 }; + efi::Status::BUFFER_TOO_SMALL + } + unsafe extern "efiapi" fn get_report( + _this: *const HidIoProtocol, + _report_id: u8, + _report_type: HidReportType, + _report_buffer_size: usize, + _report_buffer: *mut c_void, + ) -> efi::Status { + efi::Status::SUCCESS + } + unsafe extern "efiapi" fn set_report( + _this: *const HidIoProtocol, + _report_id: u8, + _report_type: HidReportType, + _report_buffer_size: usize, + _report_buffer: *mut c_void, + ) -> efi::Status { + efi::Status::SUCCESS + } + unsafe extern "efiapi" fn register_report_callback( + _this: *const HidIoProtocol, + _callback: HidIoReportCallback, + _context: *mut c_void, + ) -> efi::Status { + efi::Status::SUCCESS + } + unsafe extern "efiapi" fn unregister_report_callback( + _this: *const HidIoProtocol, + _callback: HidIoReportCallback, + ) -> efi::Status { + efi::Status::SUCCESS + } + + let protocol = HidIoProtocol { + get_report_descriptor, + get_report, + set_report, + register_report_callback, + unregister_report_callback, + }; + // SAFETY: Leaked for 'static lifetime in tests. + unsafe { Box::into_raw(Box::new(protocol)).as_mut().unwrap() } +} From 06fd93d40648269c628afb73e216d3dacda73087 Mon Sep 17 00:00:00 2001 From: John Schock Date: Mon, 4 May 2026 13:36:51 -0700 Subject: [PATCH 2/3] Add usb_hid component Add the USB HID component which consumes EFI_USB_IO_PROTOCOL on USB HID device controllers and produces the HidIo protocol for each managed device. This serves as the HidIo producer consumed by uefi_hid. USB HID class-specific constants and descriptor structures are defined locally in usb_hid_defs.rs rather than imported from patina. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- Cargo.toml | 2 +- usb_hid/Cargo.toml | 21 + usb_hid/README.md | 78 +++ usb_hid/src/control_transfers.rs | 376 ++++++++++++++ usb_hid/src/descriptors.rs | 711 ++++++++++++++++++++++++++ usb_hid/src/device.rs | 106 ++++ usb_hid/src/driver.rs | 261 ++++++++++ usb_hid/src/hid_io_impl.rs | 768 +++++++++++++++++++++++++++++ usb_hid/src/interrupt_transfers.rs | 626 +++++++++++++++++++++++ usb_hid/src/lib.rs | 129 +++++ usb_hid/src/test_stubs.rs | 141 ++++++ usb_hid/src/usb_hid_defs.rs | 71 +++ 12 files changed, 3289 insertions(+), 1 deletion(-) create mode 100644 usb_hid/Cargo.toml create mode 100644 usb_hid/README.md create mode 100644 usb_hid/src/control_transfers.rs create mode 100644 usb_hid/src/descriptors.rs create mode 100644 usb_hid/src/device.rs create mode 100644 usb_hid/src/driver.rs create mode 100644 usb_hid/src/hid_io_impl.rs create mode 100644 usb_hid/src/interrupt_transfers.rs create mode 100644 usb_hid/src/lib.rs create mode 100644 usb_hid/src/test_stubs.rs create mode 100644 usb_hid/src/usb_hid_defs.rs diff --git a/Cargo.toml b/Cargo.toml index 8918ae2..87a1d55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "3" -members = ["uefi_hid"] +members = ["uefi_hid", "usb_hid"] [workspace.package] version = "0.0.1" diff --git a/usb_hid/Cargo.toml b/usb_hid/Cargo.toml new file mode 100644 index 0000000..35792fa --- /dev/null +++ b/usb_hid/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "usb_hid" +description = "USB HID driver that produces the HidIo protocol as a Patina component." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +repository.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +log = { workspace = true } +patina = { workspace = true } +r-efi = { workspace = true } + +[dev-dependencies] +mockall = { workspace = true } +patina = { workspace = true, features = ["mockall"] } diff --git a/usb_hid/README.md b/usb_hid/README.md new file mode 100644 index 0000000..c60e6a3 --- /dev/null +++ b/usb_hid/README.md @@ -0,0 +1,78 @@ + +# USB HID + +## Overview + +This Patina component provides USB Human Interface Device support for UEFI by consuming the +`EFI_USB_IO_PROTOCOL` on USB HID device controllers and producing the +[HidIo](https://github.com/microsoft/mu_plus/blob/release/202502/HidPkg/Include/Protocol/HidIo.h) protocol for +each managed device. + +The HidIo protocol is then consumed by downstream components (e.g. `uefi_hid`) to provide keyboard, pointer, and +other HID input support. + +## Architecture + +The component installs a UEFI Driver Binding that manages USB HID device instances. The driver follows the standard +UEFI Driver Model: + +1. **Supported** — checks if a controller has USB IO with HID interface class. +2. **Start** — reads USB descriptors, configures report protocol mode for boot devices, and installs the HidIo + protocol on the controller handle. +3. **Stop** — shuts down async transfers, uninstalls the protocol, and frees resources. + +Asynchronous input reports are delivered via USB interrupt-in transfers. A timer-based delayed recovery mechanism +handles USB transfer errors. + +## Modules + +| Module | Description | +| --- | --- | +| `control_transfers` | USB control transfer helpers for HID devices (set protocol, set/get report). | +| `descriptors` | USB descriptor reading for HID devices (HID descriptor, report descriptor). | +| `device` | Per-device state for USB HID devices. | +| `driver` | Driver binding implementation that manages USB HID device instances on controllers. | +| `hid_io_impl` | HidIoProtocol function pointer implementations — delegates to USB IO operations. | +| `interrupt_transfers` | Async interrupt transfer management and error recovery. | + +## Dependencies + +Key crate dependencies (see `Cargo.toml` for the full list): + +- [`patina`](https://crates.io/crates/patina) — Patina component SDK (boot services, driver binding, protocol interfaces). +- [`r-efi`](https://crates.io/crates/r-efi) — Rust UEFI type definitions. + +## Platform Integration + +To include `usb_hid` in a Patina binary, add the crate as a dependency and register the component +in the platform's `ComponentInfo` implementation. + +1. Add the dependency to the binary crate's `Cargo.toml`: + + ```toml + [dependencies] + usb_hid = { version = "20" } + ``` + +2. Register the component in the `components` function: + + ```rust + impl ComponentInfo for MyPlatform { + fn components(mut add: Add) { + // ...other components... + add.component(usb_hid::UsbHidComponent); + } + } + ``` + +The driver binding will automatically attach to any controller that exposes the USB IO protocol with a HID interface +class. A HidIo consumer (e.g. [`uefi_hid`](../uefi_hid)) must be present in the platform firmware for the produced +HidIo protocol to be functional. + +## Testing + +Unit tests use `mockall` and `patina`'s mock boot services: + +```sh +cargo test -p usb_hid +``` diff --git a/usb_hid/src/control_transfers.rs b/usb_hid/src/control_transfers.rs new file mode 100644 index 0000000..46d2647 --- /dev/null +++ b/usb_hid/src/control_transfers.rs @@ -0,0 +1,376 @@ +//! USB control transfer helpers for HID devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use core::ffi::c_void; + +use r_efi::efi; + +use patina::uefi_protocol::usb_io::{EfiUsbIoProtocol, types::*}; + +use crate::usb_hid_defs::*; + +/// Sends a USB HID SET_PROTOCOL request to switch the device to report mode. +pub fn set_protocol_request(usb_io: &EfiUsbIoProtocol, interface_number: u8, protocol: u8) -> Result<(), efi::Status> { + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_CLASS_INTERFACE_OUT, + request: USB_HID_SET_PROTOCOL_REQUEST, + value: protocol as u16, + index: interface_number as u16, + length: 0, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::NoData, + USB_TRANSFER_TIMEOUT_MS, + core::ptr::null_mut(), + 0, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + Ok(()) +} + +/// Sends a USB HID GET_REPORT class-specific request. +pub fn usb_get_report_request( + usb_io: &EfiUsbIoProtocol, + interface_number: u8, + report_id: u8, + report_type: u8, + report_len: u16, + report: *mut u8, +) -> Result<(), efi::Status> { + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_CLASS_INTERFACE_IN, + request: USB_HID_GET_REPORT_REQUEST, + value: (report_type as u16) << 8 | report_id as u16, + index: interface_number as u16, + length: report_len, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request, report buffer, and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::DataIn, + USB_TRANSFER_TIMEOUT_MS, + report as *mut c_void, + report_len as usize, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + Ok(()) +} + +/// Sends a USB HID SET_REPORT class-specific request. +pub fn usb_set_report_request( + usb_io: &EfiUsbIoProtocol, + interface_number: u8, + report_id: u8, + report_type: u8, + report_len: u16, + report: *const u8, +) -> Result<(), efi::Status> { + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_CLASS_INTERFACE_OUT, + request: USB_HID_SET_REPORT_REQUEST, + value: (report_type as u16) << 8 | report_id as u16, + index: interface_number as u16, + length: report_len, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request, report buffer, and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::DataOut, + USB_TRANSFER_TIMEOUT_MS, + report as *mut c_void, + report_len as usize, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + Ok(()) +} + +/// Sends a USB CLEAR_FEATURE(ENDPOINT_HALT) request. +pub fn usb_clear_endpoint_halt(usb_io: &EfiUsbIoProtocol, endpoint_address: u8) -> Result<(), efi::Status> { + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_STANDARD_ENDPOINT_OUT, + request: USB_REQ_CLEAR_FEATURE, + value: USB_FEATURE_ENDPOINT_HALT, + index: endpoint_address as u16, + length: 0, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::NoData, + USB_TRANSFER_TIMEOUT_MS, + core::ptr::null_mut(), + 0, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + Ok(()) +} + +/// Reads the report descriptor from the device via GET_DESCRIPTOR. +pub fn usb_get_report_descriptor( + usb_io: &EfiUsbIoProtocol, + interface_number: u8, + descriptor_length: u16, + descriptor_buffer: *mut u8, +) -> Result<(), efi::Status> { + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_STANDARD_DEVICE_IN | 0x01, // Interface recipient + request: USB_REQ_GET_DESCRIPTOR, + value: (USB_DESC_TYPE_REPORT as u16) << 8, + index: interface_number as u16, + length: descriptor_length, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request, descriptor buffer, and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::DataIn, + USB_TRANSFER_TIMEOUT_MS, + descriptor_buffer as *mut c_void, + descriptor_length as usize, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use core::cell::Cell; + + // ---- Mock USB IO ---- + + /// Captured parameters from the most recent `usb_control_transfer` call. + #[derive(Clone, Copy, Default)] + struct CapturedRequest { + request_type: u8, + request: u8, + value: u16, + index: u16, + length: u16, + direction: u32, + data_length: usize, + } + + /// Mock USB IO context. `protocol` must be the first field so the extern + /// mock function can recover mock state from the `this` pointer. + #[repr(C)] + struct MockUsbIo { + protocol: EfiUsbIoProtocol, + status: efi::Status, + captured: Cell, + } + + impl MockUsbIo { + /// # Safety + /// `this` must point to the `protocol` field of a valid `MockUsbIo`. + unsafe fn from_this(this: *const EfiUsbIoProtocol) -> &'static Self { + // SAFETY: MockUsbIo is #[repr(C)] with protocol as first field. + unsafe { &*(this as *const MockUsbIo) } + } + } + + extern "efiapi" fn mock_control_transfer( + this: *const EfiUsbIoProtocol, + request: *const EfiUsbDeviceRequest, + direction: EfiUsbDataDirection, + _timeout: u32, + _data: *mut c_void, + data_length: usize, + _status: *mut u32, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + // SAFETY: request is a valid pointer from the caller. + let req = unsafe { &*request }; + mock.captured.set(CapturedRequest { + request_type: req.request_type, + request: req.request, + value: req.value, + index: req.index, + length: req.length, + direction: direction as u32, + data_length, + }); + mock.status + } + + fn make_mock(status: efi::Status) -> MockUsbIo { + let mut protocol = crate::test_stubs::usb_io_stub(); + protocol.usb_control_transfer = mock_control_transfer; + MockUsbIo { protocol, status, captured: Cell::new(CapturedRequest::default()) } + } + + // ---- set_protocol_request tests ---- + + #[test] + fn set_protocol_request_builds_correct_request() { + let mock = make_mock(efi::Status::SUCCESS); + assert!(set_protocol_request(&mock.protocol, 2, REPORT_PROTOCOL).is_ok()); + let cap = mock.captured.get(); + assert_eq!(cap.request_type, USB_REQ_TYPE_CLASS_INTERFACE_OUT); + assert_eq!(cap.request, USB_HID_SET_PROTOCOL_REQUEST); + assert_eq!(cap.value, REPORT_PROTOCOL as u16); + assert_eq!(cap.index, 2); + assert_eq!(cap.length, 0); + assert_eq!(cap.direction, EfiUsbDataDirection::NoData as u32); + assert_eq!(cap.data_length, 0); + } + + #[test] + fn set_protocol_request_returns_error_on_failure() { + let mock = make_mock(efi::Status::DEVICE_ERROR); + assert_eq!(set_protocol_request(&mock.protocol, 0, 0), Err(efi::Status::DEVICE_ERROR)); + } + + // ---- usb_get_report_request tests ---- + + #[test] + fn get_report_request_builds_correct_request() { + let mock = make_mock(efi::Status::SUCCESS); + let mut buffer = [0u8; 16]; + assert!(usb_get_report_request(&mock.protocol, 1, 0x03, 0x01, 16, buffer.as_mut_ptr()).is_ok()); + let cap = mock.captured.get(); + assert_eq!(cap.request_type, USB_REQ_TYPE_CLASS_INTERFACE_IN); + assert_eq!(cap.request, USB_HID_GET_REPORT_REQUEST); + assert_eq!(cap.value, (0x01u16 << 8) | 0x03); + assert_eq!(cap.index, 1); + assert_eq!(cap.length, 16); + assert_eq!(cap.direction, EfiUsbDataDirection::DataIn as u32); + assert_eq!(cap.data_length, 16); + } + + #[test] + fn get_report_request_returns_error_on_failure() { + let mock = make_mock(efi::Status::DEVICE_ERROR); + let mut buffer = [0u8; 8]; + assert_eq!( + usb_get_report_request(&mock.protocol, 0, 0, 0x01, 8, buffer.as_mut_ptr()), + Err(efi::Status::DEVICE_ERROR), + ); + } + + // ---- usb_set_report_request tests ---- + + #[test] + fn set_report_request_builds_correct_request() { + let mock = make_mock(efi::Status::SUCCESS); + let report = [0xAAu8; 4]; + assert!(usb_set_report_request(&mock.protocol, 0, 0x01, 0x02, 4, report.as_ptr()).is_ok()); + let cap = mock.captured.get(); + assert_eq!(cap.request_type, USB_REQ_TYPE_CLASS_INTERFACE_OUT); + assert_eq!(cap.request, USB_HID_SET_REPORT_REQUEST); + assert_eq!(cap.value, (0x02u16 << 8) | 0x01); + assert_eq!(cap.index, 0); + assert_eq!(cap.length, 4); + assert_eq!(cap.direction, EfiUsbDataDirection::DataOut as u32); + assert_eq!(cap.data_length, 4); + } + + #[test] + fn set_report_request_returns_error_on_failure() { + let mock = make_mock(efi::Status::DEVICE_ERROR); + let report = [0u8; 4]; + assert_eq!( + usb_set_report_request(&mock.protocol, 0, 0, 0x02, 4, report.as_ptr()), + Err(efi::Status::DEVICE_ERROR), + ); + } + + // ---- usb_clear_endpoint_halt tests ---- + + #[test] + fn clear_endpoint_halt_builds_correct_request() { + let mock = make_mock(efi::Status::SUCCESS); + assert!(usb_clear_endpoint_halt(&mock.protocol, 0x81).is_ok()); + let cap = mock.captured.get(); + assert_eq!(cap.request_type, USB_REQ_TYPE_STANDARD_ENDPOINT_OUT); + assert_eq!(cap.request, USB_REQ_CLEAR_FEATURE); + assert_eq!(cap.value, USB_FEATURE_ENDPOINT_HALT); + assert_eq!(cap.index, 0x81); + assert_eq!(cap.length, 0); + assert_eq!(cap.direction, EfiUsbDataDirection::NoData as u32); + assert_eq!(cap.data_length, 0); + } + + #[test] + fn clear_endpoint_halt_returns_error_on_failure() { + let mock = make_mock(efi::Status::DEVICE_ERROR); + assert_eq!(usb_clear_endpoint_halt(&mock.protocol, 0x81), Err(efi::Status::DEVICE_ERROR)); + } + + // ---- usb_get_report_descriptor tests ---- + + #[test] + fn get_report_descriptor_builds_correct_request() { + let mock = make_mock(efi::Status::SUCCESS); + let mut buffer = [0u8; 64]; + assert!(usb_get_report_descriptor(&mock.protocol, 0, 64, buffer.as_mut_ptr()).is_ok()); + let cap = mock.captured.get(); + assert_eq!(cap.request_type, USB_REQ_TYPE_STANDARD_DEVICE_IN | 0x01); + assert_eq!(cap.request, USB_REQ_GET_DESCRIPTOR); + assert_eq!(cap.value, (USB_DESC_TYPE_REPORT as u16) << 8); + assert_eq!(cap.index, 0); + assert_eq!(cap.length, 64); + assert_eq!(cap.direction, EfiUsbDataDirection::DataIn as u32); + assert_eq!(cap.data_length, 64); + } + + #[test] + fn get_report_descriptor_uses_interface_number_as_index() { + let mock = make_mock(efi::Status::SUCCESS); + let mut buffer = [0u8; 32]; + assert!(usb_get_report_descriptor(&mock.protocol, 3, 32, buffer.as_mut_ptr()).is_ok()); + assert_eq!(mock.captured.get().index, 3); + } + + #[test] + fn get_report_descriptor_returns_error_on_failure() { + let mock = make_mock(efi::Status::DEVICE_ERROR); + let mut buffer = [0u8; 64]; + assert_eq!( + usb_get_report_descriptor(&mock.protocol, 0, 64, buffer.as_mut_ptr()), + Err(efi::Status::DEVICE_ERROR), + ); + } +} diff --git a/usb_hid/src/descriptors.rs b/usb_hid/src/descriptors.rs new file mode 100644 index 0000000..baa371a --- /dev/null +++ b/usb_hid/src/descriptors.rs @@ -0,0 +1,711 @@ +//! USB descriptor reading for HID devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use alloc::vec; +use core::{ffi::c_void, mem::size_of}; + +use r_efi::efi; + +use crate::{control_transfers, device::UsbHidDescriptors}; +use patina::uefi_protocol::usb_io::{EfiUsbIoProtocol, types::*}; + +use crate::usb_hid_defs::*; + +/// Owned wrapper around the variable-length USB HID descriptor. +/// +/// The backing `Vec` holds the complete descriptor bytes (fixed header + +/// trailing `HidClassDescriptor` entries). Automatically freed on drop. +#[derive(Debug)] +struct HidDescriptor { + data: alloc::vec::Vec, +} + +impl HidDescriptor { + fn header(&self) -> &EfiUsbHidDescriptor { + // SAFETY: data was copied from a valid HID descriptor at least size_of::() bytes. + unsafe { &*(self.data.as_ptr() as *const EfiUsbHidDescriptor) } + } + + fn class_descriptors(&self) -> &[HidClassDescriptor] { + let count = self.header().num_descriptors as usize; + let available = + self.data.len().saturating_sub(size_of::()) / size_of::(); + let count = count.min(available); + // SAFETY: Construction validates data.len() >= header + at least `count` class descriptors. + unsafe { + let base = self.data.as_ptr().add(size_of::()) as *const HidClassDescriptor; + core::slice::from_raw_parts(base, count) + } + } +} + +/// Reads interface, endpoint, and HID descriptors from the device. +/// +/// Returns a [`UsbHidDescriptors`] containing the interface, interrupt-in +/// endpoint, and report descriptors. Returns an error if the interrupt-in +/// endpoint is not found. +pub fn read_descriptors(usb_io: &EfiUsbIoProtocol) -> Result { + let usb_io_ptr = usb_io as *const EfiUsbIoProtocol; + + let mut interface_descriptor = EfiUsbInterfaceDescriptor::default(); + // SAFETY: usb_io and interface_descriptor are valid. + let status = unsafe { (usb_io.usb_get_interface_descriptor)(usb_io_ptr, &mut interface_descriptor) }; + if status != efi::Status::SUCCESS { + return Err(status); + } + + log::trace!( + "USB HID: interface class: 0x{:x}, subclass: 0x{:x}, protocol: 0x{:x}", + interface_descriptor.interface_class, + interface_descriptor.interface_sub_class, + interface_descriptor.interface_protocol, + ); + + let mut int_in_endpoint_descriptor = EfiUsbEndpointDescriptor::default(); + for index in 0..interface_descriptor.num_endpoints { + let mut endpoint = EfiUsbEndpointDescriptor::default(); + // SAFETY: usb_io and endpoint descriptor are valid. + let status = unsafe { (usb_io.usb_get_endpoint_descriptor)(usb_io_ptr, index, &mut endpoint) }; + if status != efi::Status::SUCCESS { + return Err(status); + } + + if (endpoint.attributes & USB_ENDPOINT_XFER_TYPE_MASK) == USB_ENDPOINT_INTERRUPT + && (endpoint.endpoint_address & USB_ENDPOINT_DIR_IN) != 0 + { + int_in_endpoint_descriptor = endpoint; + break; + } + } + + // Interrupt-in endpoint must be found. + if int_in_endpoint_descriptor.length == 0 { + return Err(efi::Status::DEVICE_ERROR); + } + + let hid_descriptor = get_full_hid_descriptor(usb_io, &interface_descriptor)?; + let report_descriptor = read_report_descriptor(usb_io, &interface_descriptor, &hid_descriptor)?; + + Ok(UsbHidDescriptors { interface_descriptor, int_in_endpoint_descriptor, report_descriptor }) +} + +/// Reads the report descriptor from the device using the HID descriptor's +/// class descriptor entries to determine the length. +fn read_report_descriptor( + usb_io: &EfiUsbIoProtocol, + interface_descriptor: &EfiUsbInterfaceDescriptor, + hid_descriptor: &HidDescriptor, +) -> Result, efi::Status> { + let report_entry = hid_descriptor.class_descriptors().iter().find(|d| d.descriptor_type == USB_DESC_TYPE_REPORT); + let descriptor_length = match report_entry { + Some(entry) => entry.descriptor_length as usize, + None => return Err(efi::Status::NOT_FOUND), + }; + + let mut buffer = vec![0u8; descriptor_length]; + + control_transfers::usb_get_report_descriptor( + usb_io, + interface_descriptor.interface_number, + descriptor_length as u16, + buffer.as_mut_ptr(), + )?; + + Ok(buffer) +} + +/// Retrieves the full HID descriptor for the given interface by parsing the +/// configuration descriptor. +fn get_full_hid_descriptor( + usb_io: &EfiUsbIoProtocol, + interface_descriptor: &EfiUsbInterfaceDescriptor, +) -> Result { + let mut config_desc = EfiUsbConfigDescriptor::default(); + // SAFETY: usb_io and config_desc are valid. + let status = unsafe { (usb_io.usb_get_config_descriptor)(usb_io as *const EfiUsbIoProtocol, &mut config_desc) }; + if status != efi::Status::SUCCESS { + return Err(status); + } + + let total_length = config_desc.total_length as usize; + let mut buffer = vec![0u8; total_length]; + + // Read the full configuration descriptor using GET_DESCRIPTOR control transfer. + let descriptor_value = + (USB_DESC_TYPE_CONFIG as u16) << 8 | (config_desc.configuration_value.wrapping_sub(1)) as u16; + let request = EfiUsbDeviceRequest { + request_type: USB_REQ_TYPE_STANDARD_DEVICE_IN, + request: USB_REQ_GET_DESCRIPTOR, + value: descriptor_value, + index: 0, + length: total_length as u16, + }; + let mut transfer_status: u32 = 0; + // SAFETY: usb_io is valid; request, buffer, and status pointers are valid. + let status = unsafe { + (usb_io.usb_control_transfer)( + usb_io as *const EfiUsbIoProtocol, + &request, + EfiUsbDataDirection::DataIn, + USB_TRANSFER_TIMEOUT_MS, + buffer.as_mut_ptr() as *mut c_void, + total_length, + &mut transfer_status, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + + find_hid_descriptor_in_config(&buffer, interface_descriptor) +} + +/// Searches the configuration descriptor buffer for the HID descriptor that +/// immediately follows the matching interface descriptor. +fn find_hid_descriptor_in_config( + buffer: &[u8], + interface_descriptor: &EfiUsbInterfaceDescriptor, +) -> Result { + let mut cursor: usize = 0; + + while cursor + size_of::() <= buffer.len() { + // SAFETY: bounds check above ensures UsbDescHead fits at cursor. + let header = unsafe { &*(buffer.as_ptr().add(cursor) as *const UsbDescHead) }; + + if header.len == 0 { + log::error!("USB HID: descriptor length is 0 at offset {cursor}"); + break; + } + + if header.desc_type == USB_DESC_TYPE_INTERFACE { + if cursor + size_of::() > buffer.len() { + break; + } + // SAFETY: bounds check above ensures EfiUsbInterfaceDescriptor fits at cursor. + let interface = unsafe { &*(buffer.as_ptr().add(cursor) as *const EfiUsbInterfaceDescriptor) }; + if interface.interface_number == interface_descriptor.interface_number + && interface.alternate_setting == interface_descriptor.alternate_setting + { + // The HID descriptor must immediately follow the interface descriptor. + let next_offset = cursor + header.len as usize; + if next_offset + size_of::() <= buffer.len() { + // SAFETY: bounds check above ensures UsbDescHead fits at next_offset. + let next_header = unsafe { &*(buffer.as_ptr().add(next_offset) as *const UsbDescHead) }; + if next_header.desc_type == USB_DESC_TYPE_HID { + let len = next_header.len as usize; + if next_offset + len > buffer.len() { + log::error!("USB HID: HID descriptor length overflows config buffer"); + return Err(efi::Status::DEVICE_ERROR); + } + let min_size = size_of::() + size_of::(); + if len < min_size { + log::error!("USB HID: HID descriptor too short for header + class descriptor"); + return Err(efi::Status::DEVICE_ERROR); + } + return Ok(HidDescriptor { data: buffer[next_offset..next_offset + len].to_vec() }); + } + } + // HID descriptor not found at expected position. + break; + } + } + + cursor += header.len as usize; + } + + Err(efi::Status::UNSUPPORTED) +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::{vec, vec::Vec}; + use core::{cell::Cell, ffi::c_void}; + + // ---- Descriptor byte builders ---- + + fn interface_bytes(number: u8, alt: u8, num_endpoints: u8) -> Vec { + vec![9, USB_DESC_TYPE_INTERFACE, number, alt, num_endpoints, CLASS_HID, 0, 0, 0] + } + + fn hid_bytes(report_desc_len: u16) -> Vec { + let total = (size_of::() + size_of::()) as u8; + vec![ + total, + USB_DESC_TYPE_HID, + 0x11, + 0x01, // bcd_hid + 0, // country_code + 1, // num_descriptors + USB_DESC_TYPE_REPORT, + (report_desc_len & 0xFF) as u8, + (report_desc_len >> 8) as u8, + ] + } + + fn endpoint_bytes(address: u8, attributes: u8) -> Vec { + vec![7, 5, address, attributes, 8, 0, 10] + } + + fn config_header_bytes(total_length: u16) -> Vec { + vec![9, USB_DESC_TYPE_CONFIG, (total_length & 0xFF) as u8, (total_length >> 8) as u8, 1, 1, 0, 0x80, 50] + } + + /// Builds a complete config descriptor buffer with correct total_length header. + fn build_config_buffer(descs: &[&[u8]]) -> Vec { + let payload_len: usize = descs.iter().map(|d| d.len()).sum(); + let total_length = (9 + payload_len) as u16; + let mut buffer = config_header_bytes(total_length); + for desc in descs { + buffer.extend_from_slice(desc); + } + buffer + } + + fn concat(slices: &[&[u8]]) -> Vec { + slices.iter().flat_map(|s| s.iter().copied()).collect() + } + + fn make_interface(number: u8, alt: u8, num_endpoints: u8) -> EfiUsbInterfaceDescriptor { + EfiUsbInterfaceDescriptor { + length: 9, + descriptor_type: USB_DESC_TYPE_INTERFACE, + interface_number: number, + alternate_setting: alt, + num_endpoints, + interface_class: CLASS_HID, + ..Default::default() + } + } + + fn make_endpoint(address: u8, attributes: u8) -> EfiUsbEndpointDescriptor { + EfiUsbEndpointDescriptor { + length: 7, + descriptor_type: 5, + endpoint_address: address, + attributes, + max_packet_size: 8, + interval: 10, + } + } + + // ---- HidDescriptor tests ---- + + #[test] + fn hid_descriptor_header_returns_correct_fields() { + let data = hid_bytes(64); + let hid = HidDescriptor { data }; + let header = *hid.header(); + assert_eq!(header.descriptor_type, USB_DESC_TYPE_HID); + let bcd = header.bcd_hid; + assert_eq!(bcd, 0x0111); + assert_eq!(header.country_code, 0); + assert_eq!(header.num_descriptors, 1); + } + + #[test] + fn hid_descriptor_class_descriptors_returns_entries() { + let data = hid_bytes(256); + let hid = HidDescriptor { data }; + let class_descs = hid.class_descriptors(); + assert_eq!(class_descs.len(), 1); + let entry = class_descs[0]; + assert_eq!(entry.descriptor_type, USB_DESC_TYPE_REPORT); + let len = entry.descriptor_length; + assert_eq!(len, 256); + } + + #[test] + fn hid_descriptor_class_descriptors_clamps_inflated_count() { + // Build a HID descriptor that claims 4 class descriptors but only has space for 1. + let mut data = hid_bytes(64); + // Overwrite num_descriptors (offset 5 in the HID descriptor) to claim 4. + data[5] = 4; + let hid = HidDescriptor { data }; + // Should be clamped to the 1 entry that actually fits. + let class_descs = hid.class_descriptors(); + assert_eq!(class_descs.len(), 1); + } + + #[test] + fn find_hid_desc_fails_when_interface_truncated_in_buffer() { + let interface = make_interface(0, 0, 1); + // Buffer has the UsbDescHead for an interface (2 bytes match) but is too short + // for a full EfiUsbInterfaceDescriptor. + let buffer = vec![9, USB_DESC_TYPE_INTERFACE, 0, 0]; + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + // ---- find_hid_descriptor_in_config tests ---- + + #[test] + fn find_hid_desc_succeeds_for_matching_interface() { + let interface = make_interface(0, 0, 1); + let buffer = + concat(&[&interface_bytes(0, 0, 1), &hid_bytes(64), &endpoint_bytes(0x81, USB_ENDPOINT_INTERRUPT)]); + let result = find_hid_descriptor_in_config(&buffer, &interface).unwrap(); + let header = *result.header(); + assert_eq!(header.descriptor_type, USB_DESC_TYPE_HID); + assert_eq!(header.num_descriptors, 1); + let entry = result.class_descriptors()[0]; + assert_eq!(entry.descriptor_type, USB_DESC_TYPE_REPORT); + let len = entry.descriptor_length; + assert_eq!(len, 64); + } + + #[test] + fn find_hid_desc_fails_for_wrong_interface_number() { + let interface = make_interface(1, 0, 1); + let buffer = concat(&[&interface_bytes(0, 0, 1), &hid_bytes(64)]); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_fails_for_wrong_alternate_setting() { + let interface = make_interface(0, 1, 1); + let buffer = concat(&[&interface_bytes(0, 0, 1), &hid_bytes(64)]); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_fails_when_non_hid_follows_interface() { + let interface = make_interface(0, 0, 1); + let buffer = concat(&[&interface_bytes(0, 0, 1), &endpoint_bytes(0x81, USB_ENDPOINT_INTERRUPT)]); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_fails_on_empty_buffer() { + let interface = make_interface(0, 0, 1); + assert_eq!(find_hid_descriptor_in_config(&[], &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_handles_zero_length_descriptor() { + let interface = make_interface(0, 0, 1); + let buffer = vec![0, USB_DESC_TYPE_INTERFACE]; + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_fails_when_buffer_truncated_after_interface() { + let interface = make_interface(0, 0, 1); + let buffer = interface_bytes(0, 0, 1); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn find_hid_desc_fails_when_hid_length_overflows_buffer() { + let interface = make_interface(0, 0, 1); + // HID descriptor claims length of 9 bytes but buffer only has 4 bytes after the interface. + let mut buffer = interface_bytes(0, 0, 1); + buffer.extend_from_slice(&[9, USB_DESC_TYPE_HID, 0x11, 0x01]); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn find_hid_desc_fails_when_hid_descriptor_too_short() { + let interface = make_interface(0, 0, 1); + // HID descriptor with length smaller than the minimum (header + one class descriptor). + let min_size = size_of::() + size_of::(); + let short_len = (min_size - 1) as u8; + let mut buffer = interface_bytes(0, 0, 1); + // Pad to ensure the buffer is long enough that the bounds check passes, + // but the length field is too short for a valid HID descriptor. + let mut hid = vec![short_len, USB_DESC_TYPE_HID]; + hid.resize(short_len as usize, 0); + buffer.extend_from_slice(&hid); + // Pad buffer to avoid the overflow check triggering first. + buffer.resize(buffer.len() + 16, 0); + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn find_hid_desc_selects_correct_interface_among_multiple() { + let interface = make_interface(1, 0, 1); + let buffer = concat(&[ + &interface_bytes(0, 0, 1), + &hid_bytes(32), + &endpoint_bytes(0x81, USB_ENDPOINT_INTERRUPT), + &interface_bytes(1, 0, 1), + &hid_bytes(128), + &endpoint_bytes(0x82, USB_ENDPOINT_INTERRUPT), + ]); + let result = find_hid_descriptor_in_config(&buffer, &interface).unwrap(); + let entry = result.class_descriptors()[0]; + let len = entry.descriptor_length; + assert_eq!(len, 128); + } + + #[test] + fn find_hid_desc_buffer_too_small_for_header() { + let interface = make_interface(0, 0, 1); + let buffer = vec![9]; // 1 byte, can't fit UsbDescHead (2 bytes) + assert_eq!(find_hid_descriptor_in_config(&buffer, &interface).unwrap_err(), efi::Status::UNSUPPORTED); + } + + // ---- Mock EfiUsbIoProtocol ---- + + /// Test wrapper containing an `EfiUsbIoProtocol` as the first field so that + /// extern "efiapi" mock functions can recover the mock data via the `this` + /// pointer (same containing-record pattern used by production code). + #[repr(C)] + struct MockUsbIo { + protocol: EfiUsbIoProtocol, + interface_desc: EfiUsbInterfaceDescriptor, + interface_status: efi::Status, + endpoints: Vec, + config_desc: EfiUsbConfigDescriptor, + config_status: efi::Status, + config_buffer: Vec, + report_descriptor: Vec, + control_call_count: Cell, + control_statuses: Vec, + } + + impl MockUsbIo { + /// # Safety + /// `this` must point to the `protocol` field of a valid `MockUsbIo`. + unsafe fn from_this(this: *const EfiUsbIoProtocol) -> &'static Self { + // SAFETY: MockUsbIo is #[repr(C)] with protocol as first field. + unsafe { &*(this as *const MockUsbIo) } + } + } + + extern "efiapi" fn mock_get_interface_descriptor( + this: *const EfiUsbIoProtocol, + desc: *mut EfiUsbInterfaceDescriptor, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + if mock.interface_status == efi::Status::SUCCESS { + // SAFETY: desc is a valid output pointer from the caller. + unsafe { + *desc = mock.interface_desc; + } + } + mock.interface_status + } + + extern "efiapi" fn mock_get_endpoint_descriptor( + this: *const EfiUsbIoProtocol, + index: u8, + desc: *mut EfiUsbEndpointDescriptor, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + let idx = index as usize; + if idx < mock.endpoints.len() { + // SAFETY: desc is a valid output pointer from the caller. + unsafe { + *desc = mock.endpoints[idx]; + } + efi::Status::SUCCESS + } else { + efi::Status::INVALID_PARAMETER + } + } + + extern "efiapi" fn mock_get_config_descriptor( + this: *const EfiUsbIoProtocol, + desc: *mut EfiUsbConfigDescriptor, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + if mock.config_status == efi::Status::SUCCESS { + // SAFETY: desc is a valid output pointer from the caller. + unsafe { + *desc = mock.config_desc; + } + } + mock.config_status + } + + extern "efiapi" fn mock_control_transfer( + this: *const EfiUsbIoProtocol, + _request: *const EfiUsbDeviceRequest, + _direction: EfiUsbDataDirection, + _timeout: u32, + data: *mut c_void, + data_length: usize, + _status: *mut u32, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + let call_idx = mock.control_call_count.get(); + mock.control_call_count.set(call_idx + 1); + + let status = mock.control_statuses.get(call_idx).copied().unwrap_or(efi::Status::SUCCESS); + if status != efi::Status::SUCCESS { + return status; + } + + // Call 0: config descriptor read; Call 1: report descriptor read. + let source = if call_idx == 0 { &mock.config_buffer } else { &mock.report_descriptor }; + let copy_len = data_length.min(source.len()); + if copy_len > 0 && !data.is_null() { + // SAFETY: data is a valid buffer of data_length bytes, source is at least copy_len bytes. + unsafe { + core::ptr::copy_nonoverlapping(source.as_ptr(), data as *mut u8, copy_len); + } + } + efi::Status::SUCCESS + } + + fn build_mock( + interface: EfiUsbInterfaceDescriptor, + endpoints: Vec, + config_buffer: Vec, + report_descriptor: Vec, + ) -> MockUsbIo { + let total_length = config_buffer.len() as u16; + let mut protocol = crate::test_stubs::usb_io_stub(); + protocol.usb_control_transfer = mock_control_transfer; + protocol.usb_get_config_descriptor = mock_get_config_descriptor; + protocol.usb_get_interface_descriptor = mock_get_interface_descriptor; + protocol.usb_get_endpoint_descriptor = mock_get_endpoint_descriptor; + MockUsbIo { + protocol, + interface_desc: interface, + interface_status: efi::Status::SUCCESS, + endpoints, + config_desc: EfiUsbConfigDescriptor { + length: 9, + descriptor_type: USB_DESC_TYPE_CONFIG, + total_length, + num_interfaces: 1, + configuration_value: 1, + ..Default::default() + }, + config_status: efi::Status::SUCCESS, + config_buffer, + report_descriptor, + control_call_count: Cell::new(0), + control_statuses: vec![], + } + } + + /// Builds a standard config buffer with one HID interface + interrupt-in endpoint. + fn standard_config_buffer(interface_number: u8, report_desc_len: u16) -> Vec { + build_config_buffer(&[ + &interface_bytes(interface_number, 0, 1), + &hid_bytes(report_desc_len), + &endpoint_bytes(0x81, USB_ENDPOINT_INTERRUPT), + ]) + } + + // ---- read_descriptors integration tests ---- + + #[test] + fn read_descriptors_succeeds() { + let report_data = vec![0x05, 0x01, 0x09, 0x06]; + let mock = build_mock( + make_interface(0, 0, 1), + vec![make_endpoint(0x81, USB_ENDPOINT_INTERRUPT)], + standard_config_buffer(0, report_data.len() as u16), + report_data.clone(), + ); + + let result = read_descriptors(&mock.protocol).unwrap(); + assert_eq!(result.interface_descriptor.interface_class, CLASS_HID); + assert_eq!(result.int_in_endpoint_descriptor.endpoint_address, 0x81); + assert_eq!(result.report_descriptor, report_data); + } + + #[test] + fn read_descriptors_fails_on_interface_error() { + let mut mock = build_mock(make_interface(0, 0, 1), vec![], vec![], vec![]); + mock.interface_status = efi::Status::DEVICE_ERROR; + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_when_no_interrupt_in_endpoint() { + // Endpoint is interrupt OUT, not IN. + let mock = + build_mock(make_interface(0, 0, 1), vec![make_endpoint(0x01, USB_ENDPOINT_INTERRUPT)], vec![], vec![]); + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_when_no_endpoints() { + let mock = build_mock(make_interface(0, 0, 0), vec![], vec![], vec![]); + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_on_config_descriptor_error() { + let mut mock = + build_mock(make_interface(0, 0, 1), vec![make_endpoint(0x81, USB_ENDPOINT_INTERRUPT)], vec![], vec![]); + mock.config_status = efi::Status::DEVICE_ERROR; + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_on_config_transfer_error() { + let mut mock = build_mock( + make_interface(0, 0, 1), + vec![make_endpoint(0x81, USB_ENDPOINT_INTERRUPT)], + standard_config_buffer(0, 4), + vec![], + ); + mock.control_statuses = vec![efi::Status::DEVICE_ERROR]; + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_on_report_descriptor_transfer_error() { + let mut mock = build_mock( + make_interface(0, 0, 1), + vec![make_endpoint(0x81, USB_ENDPOINT_INTERRUPT)], + standard_config_buffer(0, 4), + vec![], + ); + // First control transfer (config read) succeeds, second (report read) fails. + mock.control_statuses = vec![efi::Status::SUCCESS, efi::Status::DEVICE_ERROR]; + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::DEVICE_ERROR); + } + + #[test] + fn read_descriptors_fails_when_hid_descriptor_not_in_config() { + // Config buffer has interface but no HID descriptor after it. + let config_buffer = + build_config_buffer(&[&interface_bytes(0, 0, 1), &endpoint_bytes(0x81, USB_ENDPOINT_INTERRUPT)]); + let mock = build_mock( + make_interface(0, 0, 1), + vec![make_endpoint(0x81, USB_ENDPOINT_INTERRUPT)], + config_buffer, + vec![], + ); + + assert_eq!(read_descriptors(&mock.protocol).unwrap_err(), efi::Status::UNSUPPORTED); + } + + #[test] + fn read_descriptors_skips_non_interrupt_endpoints() { + // First endpoint is bulk OUT, second is interrupt IN. + let mock = build_mock( + make_interface(0, 0, 2), + vec![ + make_endpoint(0x02, 0x02), // bulk OUT + make_endpoint(0x81, USB_ENDPOINT_INTERRUPT), // interrupt IN + ], + standard_config_buffer(0, 4), + vec![0x05, 0x01, 0x09, 0x06], + ); + + let result = read_descriptors(&mock.protocol).unwrap(); + assert_eq!(result.int_in_endpoint_descriptor.endpoint_address, 0x81); + } +} diff --git a/usb_hid/src/device.rs b/usb_hid/src/device.rs new file mode 100644 index 0000000..aa548e7 --- /dev/null +++ b/usb_hid/src/device.rs @@ -0,0 +1,106 @@ +//! Per-device state for USB HID devices. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use alloc::vec::Vec; +use core::ffi::c_void; + +use r_efi::efi; + +use patina::vendor_protocols::hid_io::{HidIoProtocol, HidIoReportCallback}; + +use patina::uefi_protocol::usb_io::{ + EfiUsbIoProtocol, + types::{EfiUsbEndpointDescriptor, EfiUsbInterfaceDescriptor}, +}; + +use crate::interrupt_transfers::TransferRecoveryTimer; + +/// USB HID descriptor set read from the device during initialization. +#[derive(Debug)] +pub struct UsbHidDescriptors { + pub interface_descriptor: EfiUsbInterfaceDescriptor, + pub int_in_endpoint_descriptor: EfiUsbEndpointDescriptor, + pub report_descriptor: Vec, +} + +/// Registered callback state for asynchronous input report notifications. +#[derive(Default)] +pub struct ReportCallbackState { + pub callback: Option, + pub context: *mut c_void, +} + +/// Per-device context for a USB HID device managed by this driver. +/// +/// Allocated on the heap during `driver_binding_start` and freed during +/// `driver_binding_stop`. The `hid_io` field is installed as a protocol +/// interface on the controller handle. +#[repr(C)] +pub struct UsbHidDevice { + // Note: a direct cast is used to recover the UsbHidDevice pointer from the HidIoProtocol pointer, so hid_io must be + // the first field. + pub hid_io: HidIoProtocol, + pub usb_io: *const EfiUsbIoProtocol, + pub descriptors: UsbHidDescriptors, + pub report_callback: ReportCallbackState, + /// Boot services timer interface for delayed error recovery. + pub(crate) timer_services: &'static dyn TransferRecoveryTimer, + /// Timer event armed by the interrupt callback on transfer errors. The event's + /// notify function re-submits the async interrupt transfer after a delay. + pub(crate) recovery_event: efi::Event, +} + +impl UsbHidDevice { + /// Recovers a raw pointer to the `UsbHidDevice` from a pointer to its `hid_io` field. + /// + /// This is a pure pointer cast (no dereference) and is therefore safe. The + /// caller is responsible for ensuring the returned pointer is valid before + /// dereferencing it — `hid_io_ptr` must point to the `hid_io` field of a + /// heap-allocated `UsbHidDevice`, and the `hid_io` field must be the first + /// field in the `#[repr(C)]` layout. + pub fn from_hid_io_protocol(hid_io_ptr: *const HidIoProtocol) -> *mut Self { + hid_io_ptr as *mut UsbHidDevice + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::hid_io_impl; + + #[test] + fn from_hid_io_protocol_recovers_device() { + struct NoopTimer; + impl crate::interrupt_transfers::TransferRecoveryTimer for NoopTimer { + fn arm_recovery_timer(&self, _: efi::Event, _: u64) -> Result<(), efi::Status> { + Ok(()) + } + } + static NOOP: NoopTimer = NoopTimer; + + let device = Box::new(UsbHidDevice { + hid_io: hid_io_impl::new_hid_io_protocol(), + usb_io: core::ptr::null(), + descriptors: UsbHidDescriptors { + interface_descriptor: EfiUsbInterfaceDescriptor::default(), + int_in_endpoint_descriptor: EfiUsbEndpointDescriptor::default(), + report_descriptor: Vec::new(), + }, + report_callback: ReportCallbackState::default(), + timer_services: &NOOP, + recovery_event: core::ptr::null_mut(), + }); + + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let recovered = UsbHidDevice::from_hid_io_protocol(hid_io_ptr); + assert_eq!(recovered as usize, &*device as *const _ as usize); + + // Prevent drop from double-freeing the leaked box. + core::mem::forget(device); + } +} diff --git a/usb_hid/src/driver.rs b/usb_hid/src/driver.rs new file mode 100644 index 0000000..6e81bb9 --- /dev/null +++ b/usb_hid/src/driver.rs @@ -0,0 +1,261 @@ +//! USB HID driver binding implementation. +//! +//! The [`UsbHidDriver`] implements [`patina::driver_binding::DriverBinding`] to +//! manage USB HID devices. It consumes USB IO and produces HidIo. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use alloc::boxed::Box; +use core::{ffi::c_void, ptr::NonNull}; + +use r_efi::{efi, protocols::device_path::Protocol as EfiDevicePathProtocol}; + +use patina::{boot_services::BootServices, driver_binding::DriverBinding}; + +use patina::{ + uefi_protocol::usb_io::{EfiUsbIoProtocol, USB_IO_PROTOCOL_GUID, types::*}, + vendor_protocols::hid_io, +}; + +use crate::{control_transfers, descriptors, device::UsbHidDevice, hid_io_impl, interrupt_transfers, usb_hid_defs::*}; +use patina::boot_services::event::EventTimerType; + +/// USB HID driver that implements [`DriverBinding`]. +pub struct UsbHidDriver { + agent: efi::Handle, +} + +impl UsbHidDriver { + /// Creates a new USB HID driver bound to the given agent handle. + pub fn new(agent: efi::Handle) -> Self { + Self { agent } + } +} + +/// Checks whether the controller has USB IO protocol with HID interface class. +fn is_usb_hid(usb_io: &EfiUsbIoProtocol) -> bool { + let mut interface_descriptor = EfiUsbInterfaceDescriptor::default(); + // SAFETY: usb_io and interface_descriptor are valid. + let status = + unsafe { (usb_io.usb_get_interface_descriptor)(usb_io as *const EfiUsbIoProtocol, &mut interface_descriptor) }; + if status != efi::Status::SUCCESS { + return false; + } + + interface_descriptor.interface_class == CLASS_HID +} + +// efi::Handle is an opaque *mut c_void that is never actually dereferenced as a pointer. +#[allow(clippy::not_unsafe_ptr_arg_deref)] +impl DriverBinding for UsbHidDriver { + /// Tests if the given controller has USB IO with HID interface class. + #[coverage(off)] + fn driver_binding_supported( + &self, + boot_services: &'static U, + controller: efi::Handle, + _remaining_device_path: Option>, + ) -> Result { + // SAFETY: EfiUsbIoProtocol layout matches the USB IO GUID. + let usb_io = match unsafe { + boot_services.open_protocol::( + controller, + self.agent, + controller, + efi::OPEN_PROTOCOL_BY_DRIVER, + ) + } { + Ok(usb_io) => usb_io, + Err(_) => return Ok(false), + }; + + let result = is_usb_hid(usb_io); + + boot_services.close_protocol(controller, USB_IO_PROTOCOL_GUID.as_efi_guid(), self.agent, controller).ok(); + + Ok(result) + } + + /// Starts USB HID support for the given controller. + fn driver_binding_start( + &mut self, + boot_services: &'static U, + controller: efi::Handle, + _remaining_device_path: Option>, + ) -> Result<(), efi::Status> { + log::trace!("USB HID: driver_binding_start on controller {:?}", controller); + + // Open USB IO BY_DRIVER for exclusive access. + // SAFETY: EfiUsbIoProtocol layout matches the USB IO GUID. + let usb_io = unsafe { + boot_services.open_protocol::( + controller, + self.agent, + controller, + efi::OPEN_PROTOCOL_BY_DRIVER, + ) + }?; + + // Read descriptors from the device. + let descriptors = match descriptors::read_descriptors(usb_io) { + Ok(d) => d, + Err(status) => { + log::error!("USB HID: failed to read descriptors: {status:x?}"); + self.close_usb_io(boot_services, controller); + return Err(status); + } + }; + + // Boot devices: explicitly set report protocol mode. + if descriptors.interface_descriptor.interface_sub_class == SUBCLASS_BOOT + && let Err(status) = control_transfers::set_protocol_request( + usb_io, + descriptors.interface_descriptor.interface_number, + REPORT_PROTOCOL, + ) + { + log::warn!("USB HID: failed to set report protocol: {status:x?}"); + } + + // Build the device context and leak it for UEFI protocol ownership. + let device_ptr = Box::into_raw(Box::new(UsbHidDevice { + hid_io: hid_io_impl::new_hid_io_protocol(), + usb_io: usb_io as *const EfiUsbIoProtocol, + descriptors, + report_callback: crate::device::ReportCallbackState::default(), + timer_services: boot_services as &'static dyn interrupt_transfers::TransferRecoveryTimer, + recovery_event: core::ptr::null_mut(), + })); + + // Create a recovery timer event for delayed re-submission on transfer errors. + // SAFETY: device_ptr is a valid heap-allocated UsbHidDevice that will outlive + // the event (closed in stop before the device is freed). + match unsafe { interrupt_transfers::create_recovery_event(boot_services, device_ptr) } { + Ok(event) => { + // SAFETY: device_ptr is valid; setting the pre-allocated field. + unsafe { (*device_ptr).recovery_event = event }; + } + Err(status) => { + log::error!("USB HID: failed to create recovery event: {status:x?}"); + // SAFETY: Reclaiming the Box we leaked above. + drop(unsafe { Box::from_raw(device_ptr) }); + self.close_usb_io(boot_services, controller); + return Err(status); + } + } + + // Install HidIo protocol on the controller. + // hid_io is the first field in #[repr(C)] UsbHidDevice, so device_ptr == hid_io_ptr. + // SAFETY: Installing HidIo protocol interface on the controller handle. + if let Err(status) = unsafe { + boot_services.install_protocol_interface_unchecked( + Some(controller), + &hid_io::HID_IO_PROTOCOL_GUID, + device_ptr as *mut c_void, + ) + } { + log::error!("USB HID: failed to install HidIo protocol: {status:x?}"); + // SAFETY: Reclaiming the Box we leaked above; close recovery event first. + let device = unsafe { Box::from_raw(device_ptr) }; + let _ = boot_services.close_event(device.recovery_event); + drop(device); + self.close_usb_io(boot_services, controller); + return Err(status); + } + + Ok(()) + } + + /// Stops USB HID support for the given controller. + fn driver_binding_stop( + &mut self, + boot_services: &'static U, + controller: efi::Handle, + _number_of_children: usize, + _child_handle_buffer: Option>, + ) -> Result<(), efi::Status> { + log::trace!("USB HID: driver_binding_stop on controller {:?}", controller); + + // Retrieve the HidIo protocol to recover the device. + // SAFETY: HidIo protocol was installed by start. + let hid_io_ptr = unsafe { + boot_services.open_protocol_unchecked( + controller, + &hid_io::HID_IO_PROTOCOL_GUID, + self.agent, + controller, + efi::OPEN_PROTOCOL_GET_PROTOCOL, + ) + }? as *const hid_io::HidIoProtocol; + + // SAFETY: hid_io_ptr points into a valid heap-allocated UsbHidDevice. + let device = unsafe { &mut *UsbHidDevice::from_hid_io_protocol(hid_io_ptr) }; + + // Shutdown async transfers first so that no further interrupt callbacks can + // fire and attempt to arm the recovery timer after we close it. + let _ = interrupt_transfers::shutdown_async_interrupt_input_transfers(device); + + // Now that no callbacks can fire, cancel and close the recovery timer event. + let _ = boot_services.set_timer(device.recovery_event, EventTimerType::Cancel, 0); + let _ = boot_services.close_event(device.recovery_event); + + // Uninstall HidIo protocol. + // SAFETY: Uninstalling the HidIo protocol interface. + if let Err(status) = unsafe { + boot_services.uninstall_protocol_interface_unchecked( + controller, + &hid_io::HID_IO_PROTOCOL_GUID, + hid_io_ptr as *mut c_void, + ) + } { + log::error!("USB HID: failed to uninstall HidIo: {status:x?}"); + return Err(status); + } + + // SAFETY: device was created via Box::into_raw in start. + drop(unsafe { Box::from_raw(device as *mut UsbHidDevice) }); + + // Close USB IO protocol. + self.close_usb_io(boot_services, controller); + + Ok(()) + } +} + +impl UsbHidDriver { + fn close_usb_io(&self, boot_services: &impl BootServices, controller: efi::Handle) { + if let Err(status) = + boot_services.close_protocol(controller, USB_IO_PROTOCOL_GUID.as_efi_guid(), self.agent, controller) + { + log::error!("USB HID: error closing USB IO protocol: {status:x?}"); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use patina::boot_services::MockBootServices; + + fn mock_boot_services() -> &'static mut MockBootServices { + let mut mock = MockBootServices::new(); + mock.expect_raise_tpl().returning(|_| patina::boot_services::tpl::Tpl::APPLICATION); + mock.expect_restore_tpl().returning(|_| ()); + // SAFETY: Leaked mock for test use with 'static lifetime requirement. + unsafe { Box::into_raw(Box::new(mock)).as_mut().unwrap() } + } + + #[test] + fn supported_returns_false_when_no_usb_io() { + let boot_services = mock_boot_services(); + boot_services.expect_open_protocol::().returning(|_, _, _, _| Err(efi::Status::NOT_FOUND)); + + let driver = UsbHidDriver::new(0x1 as efi::Handle); + assert_eq!(driver.driver_binding_supported(boot_services, 0x2 as efi::Handle, None), Ok(false)); + } +} diff --git a/usb_hid/src/hid_io_impl.rs b/usb_hid/src/hid_io_impl.rs new file mode 100644 index 0000000..732cca4 --- /dev/null +++ b/usb_hid/src/hid_io_impl.rs @@ -0,0 +1,768 @@ +//! HidIoProtocol function pointer implementations for USB HID devices. +//! +//! Each function recovers the `UsbHidDevice` from the protocol pointer, +//! then delegates to USB IO operations. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use core::ffi::c_void; + +use r_efi::efi; + +use patina::vendor_protocols::hid_io::{HidIoProtocol, HidIoReportCallback, HidReportType}; + +use crate::{control_transfers, device::UsbHidDevice, interrupt_transfers}; + +/// Creates a new `HidIoProtocol` populated with this module's function pointers. +pub fn new_hid_io_protocol() -> HidIoProtocol { + HidIoProtocol { + get_report_descriptor: hid_get_report_descriptor, + get_report: hid_get_report, + set_report: hid_set_report, + register_report_callback: hid_register_report_callback, + unregister_report_callback: hid_unregister_report_callback, + } +} + +/// Retrieves the HID report descriptor from the device. +/// +/// # Safety +/// +/// `this` must point to the `hid_io` field of a valid, heap-allocated +/// [`UsbHidDevice`]. `report_descriptor_size` must be a valid pointer. +/// If the buffer is large enough, `report_descriptor_buffer` must be valid +/// for `*report_descriptor_size` bytes. +unsafe extern "efiapi" fn hid_get_report_descriptor( + this: *const HidIoProtocol, + report_descriptor_size: *mut usize, + report_descriptor_buffer: *mut c_void, +) -> efi::Status { + if this.is_null() || report_descriptor_size.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this points to the hid_io field of a valid UsbHidDevice. + let device = unsafe { &mut *UsbHidDevice::from_hid_io_protocol(this) }; + + if device.descriptors.report_descriptor.is_empty() { + return efi::Status::NOT_FOUND; + } + + // SAFETY: report_descriptor_size is checked non-null above. + let requested_size = unsafe { *report_descriptor_size }; + if requested_size < device.descriptors.report_descriptor.len() { + // SAFETY: report_descriptor_size is checked non-null above. + unsafe { *report_descriptor_size = device.descriptors.report_descriptor.len() }; + return efi::Status::BUFFER_TOO_SMALL; + } + + if report_descriptor_buffer.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: Both pointers are valid for report_descriptor.len() bytes. + unsafe { + *report_descriptor_size = device.descriptors.report_descriptor.len(); + core::ptr::copy_nonoverlapping( + device.descriptors.report_descriptor.as_ptr(), + report_descriptor_buffer as *mut u8, + device.descriptors.report_descriptor.len(), + ); + } + + efi::Status::SUCCESS +} + +/// Retrieves a single report from the device via USB GET_REPORT request. +/// +/// # Safety +/// +/// `this` must point to the `hid_io` field of a valid, heap-allocated +/// [`UsbHidDevice`]. `report_buffer` must be valid for `report_buffer_size` +/// bytes. +unsafe extern "efiapi" fn hid_get_report( + this: *const HidIoProtocol, + report_id: u8, + report_type: HidReportType, + report_buffer_size: usize, + report_buffer: *mut c_void, +) -> efi::Status { + if this.is_null() || report_buffer_size == 0 || report_buffer_size > u16::MAX as usize || report_buffer.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // Only support Get_Report for Input or Feature reports. + if report_type != HidReportType::InputReport && report_type != HidReportType::Feature { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this points to the hid_io field of a valid UsbHidDevice. + let device = unsafe { &*UsbHidDevice::from_hid_io_protocol(this) }; + // SAFETY: usb_io is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + + match control_transfers::usb_get_report_request( + usb_io, + device.descriptors.interface_descriptor.interface_number, + report_id, + report_type as u8, + report_buffer_size as u16, + report_buffer as *mut u8, + ) { + Ok(()) => efi::Status::SUCCESS, + Err(status) => status, + } +} + +/// Sends a single report to the device via USB SET_REPORT request. +/// +/// # Safety +/// +/// `this` must point to the `hid_io` field of a valid, heap-allocated +/// [`UsbHidDevice`]. `report_buffer` must be valid for `report_buffer_size` +/// bytes. +unsafe extern "efiapi" fn hid_set_report( + this: *const HidIoProtocol, + report_id: u8, + report_type: HidReportType, + report_buffer_size: usize, + report_buffer: *mut c_void, +) -> efi::Status { + if this.is_null() || report_buffer_size == 0 || report_buffer_size > u16::MAX as usize || report_buffer.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // Only support Set_Report for Output or Feature reports. + if report_type != HidReportType::OutputReport && report_type != HidReportType::Feature { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this points to the hid_io field of a valid UsbHidDevice. + let device = unsafe { &*UsbHidDevice::from_hid_io_protocol(this) }; + // SAFETY: usb_io is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + + match control_transfers::usb_set_report_request( + usb_io, + device.descriptors.interface_descriptor.interface_number, + report_id, + report_type as u8, + report_buffer_size as u16, + report_buffer as *const u8, + ) { + Ok(()) => efi::Status::SUCCESS, + Err(status) => status, + } +} + +/// Registers a callback function to receive asynchronous input reports. +/// +/// # Safety +/// +/// `this` must point to the `hid_io` field of a valid, heap-allocated +/// [`UsbHidDevice`]. `context` must remain valid for the lifetime of the +/// registration (until the callback is unregistered). +unsafe extern "efiapi" fn hid_register_report_callback( + this: *const HidIoProtocol, + callback: HidIoReportCallback, + context: *mut c_void, +) -> efi::Status { + if this.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this points to the hid_io field of a valid UsbHidDevice. + let device = unsafe { &mut *UsbHidDevice::from_hid_io_protocol(this) }; + + if device.report_callback.callback.is_some() { + return efi::Status::ALREADY_STARTED; + } + + device.report_callback.callback = Some(callback); + device.report_callback.context = context; + + match interrupt_transfers::initiate_async_interrupt_input_transfers(device) { + Ok(()) => efi::Status::SUCCESS, + Err(status) => { + device.report_callback.callback = None; + status + } + } +} + +/// Unregisters a previously registered callback function. +/// +/// # Safety +/// +/// `this` must point to the `hid_io` field of a valid, heap-allocated +/// [`UsbHidDevice`]. +unsafe extern "efiapi" fn hid_unregister_report_callback( + this: *const HidIoProtocol, + callback: HidIoReportCallback, +) -> efi::Status { + if this.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: this points to the hid_io field of a valid UsbHidDevice. + let device = unsafe { &mut *UsbHidDevice::from_hid_io_protocol(this) }; + + // Verify the callback matches the registered one. + match device.report_callback.callback { + Some(registered) if core::ptr::fn_addr_eq(registered, callback) => {} + _ => return efi::Status::NOT_STARTED, + } + + match interrupt_transfers::shutdown_async_interrupt_input_transfers(device) { + Ok(()) => {} + Err(status) => { + log::error!("USB HID: error shutting down transfers during unregister: {status:x?}"); + } + } + + device.report_callback.callback = None; + + efi::Status::SUCCESS +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::{boxed::Box, vec, vec::Vec}; + use core::cell::Cell; + + use patina::uefi_protocol::usb_io::{EfiAsyncUsbTransferCallback, EfiUsbIoProtocol, types::*}; + + use crate::{ + device::{ReportCallbackState, UsbHidDescriptors, UsbHidDevice}, + interrupt_transfers::TransferRecoveryTimer, + }; + + struct NoopTransferRecoveryTimer; + impl TransferRecoveryTimer for NoopTransferRecoveryTimer { + fn arm_recovery_timer(&self, _event: efi::Event, _delay: u64) -> Result<(), efi::Status> { + Ok(()) + } + } + static NOOP_RECOVERY_TIMER: NoopTransferRecoveryTimer = NoopTransferRecoveryTimer; + + // ---- Mock USB IO ---- + + /// Mock USB IO context. `protocol` must be the first field so the extern + /// mock functions can recover the mock state from the `this` pointer. + #[repr(C)] + struct MockUsbIo { + protocol: EfiUsbIoProtocol, + control_transfer_status: efi::Status, + async_transfer_status: efi::Status, + control_call_count: Cell, + } + + impl MockUsbIo { + /// # Safety + /// `this` must point to the `protocol` field of a valid `MockUsbIo`. + unsafe fn from_this(this: *const EfiUsbIoProtocol) -> &'static Self { + // SAFETY: MockUsbIo is #[repr(C)] with protocol as first field. + unsafe { &*(this as *const MockUsbIo) } + } + } + + extern "efiapi" fn mock_control_transfer( + this: *const EfiUsbIoProtocol, + _request: *const EfiUsbDeviceRequest, + _direction: EfiUsbDataDirection, + _timeout: u32, + _data: *mut c_void, + _data_length: usize, + _status: *mut u32, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + mock.control_call_count.set(mock.control_call_count.get() + 1); + mock.control_transfer_status + } + + extern "efiapi" fn mock_async_interrupt_transfer( + this: *const EfiUsbIoProtocol, + _endpoint: u8, + _is_new_transfer: efi::Boolean, + _polling_interval: usize, + _data_length: usize, + _callback: Option, + _context: *mut c_void, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + mock.async_transfer_status + } + + fn make_mock_usb_io(control_status: efi::Status, async_status: efi::Status) -> MockUsbIo { + let mut protocol = crate::test_stubs::usb_io_stub(); + protocol.usb_control_transfer = mock_control_transfer; + protocol.usb_async_interrupt_transfer = mock_async_interrupt_transfer; + MockUsbIo { + protocol, + control_transfer_status: control_status, + async_transfer_status: async_status, + control_call_count: Cell::new(0), + } + } + + // ---- Test device builder ---- + + fn make_device(usb_io: &MockUsbIo, report_descriptor: Vec) -> Box { + Box::new(UsbHidDevice { + hid_io: new_hid_io_protocol(), + usb_io: &usb_io.protocol as *const EfiUsbIoProtocol, + descriptors: UsbHidDescriptors { + interface_descriptor: EfiUsbInterfaceDescriptor::default(), + int_in_endpoint_descriptor: EfiUsbEndpointDescriptor { + endpoint_address: 0x81, + interval: 10, + max_packet_size: 8, + ..Default::default() + }, + report_descriptor, + }, + report_callback: ReportCallbackState::default(), + timer_services: &NOOP_RECOVERY_TIMER, + recovery_event: core::ptr::null_mut(), + }) + } + + unsafe extern "efiapi" fn test_callback( + _report_buffer_size: u16, + _report_buffer: *mut c_void, + _context: *mut c_void, + ) { + } + + unsafe extern "efiapi" fn other_callback( + _report_buffer_size: u16, + _report_buffer: *mut c_void, + _context: *mut c_void, + ) { + } + + // ---- get_report_descriptor tests ---- + + #[test] + fn get_report_descriptor_null_this_returns_invalid_parameter() { + let mut size = 64usize; + let mut buffer = [0u8; 64]; + // SAFETY: testing null this handling. + let status = + unsafe { hid_get_report_descriptor(core::ptr::null(), &mut size, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn get_report_descriptor_null_size_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05, 0x01]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 64]; + // SAFETY: device is a valid UsbHidDevice; testing null size handling. + let status = + unsafe { hid_get_report_descriptor(hid_io_ptr, core::ptr::null_mut(), buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_descriptor_empty_descriptor_returns_not_found() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut size = 64usize; + let mut buffer = [0u8; 64]; + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_get_report_descriptor(hid_io_ptr, &mut size, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::NOT_FOUND); + core::mem::forget(device); + } + + #[test] + fn get_report_descriptor_buffer_too_small_updates_size() { + let descriptor = vec![0x05, 0x01, 0x09, 0x06, 0xA1, 0x01]; + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, descriptor.clone()); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut size = 2usize; // smaller than descriptor + let mut buffer = [0u8; 2]; + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_get_report_descriptor(hid_io_ptr, &mut size, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::BUFFER_TOO_SMALL); + assert_eq!(size, descriptor.len()); + core::mem::forget(device); + } + + #[test] + fn get_report_descriptor_null_buffer_with_sufficient_size_returns_invalid_parameter() { + let descriptor = vec![0x05, 0x01]; + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, descriptor); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut size = 64usize; + // SAFETY: device is a valid UsbHidDevice; testing null buffer handling. + let status = unsafe { hid_get_report_descriptor(hid_io_ptr, &mut size, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_descriptor_succeeds() { + let descriptor = vec![0x05, 0x01, 0x09, 0x06]; + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, descriptor.clone()); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut size = 64usize; + let mut buffer = [0u8; 64]; + // SAFETY: device is a valid UsbHidDevice; buffer is large enough. + let status = unsafe { hid_get_report_descriptor(hid_io_ptr, &mut size, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(size, descriptor.len(), "size should be updated to actual descriptor length on success"); + assert_eq!(&buffer[..descriptor.len()], &descriptor); + core::mem::forget(device); + } + + #[test] + fn get_report_descriptor_exact_size_succeeds() { + let descriptor = vec![0xAA, 0xBB, 0xCC]; + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, descriptor.clone()); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut size = descriptor.len(); + let mut buffer = vec![0u8; descriptor.len()]; + // SAFETY: device is a valid UsbHidDevice; buffer is exact size. + let status = unsafe { hid_get_report_descriptor(hid_io_ptr, &mut size, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(buffer, descriptor); + core::mem::forget(device); + } + + // ---- get_report tests ---- + + #[test] + fn get_report_null_this_returns_invalid_parameter() { + let mut buffer = [0u8; 8]; + // SAFETY: testing null this handling. + let status = unsafe { + hid_get_report(core::ptr::null(), 0, HidReportType::InputReport, 8, buffer.as_mut_ptr() as *mut c_void) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn get_report_null_buffer_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice; testing null buffer. + let status = unsafe { hid_get_report(hid_io_ptr, 0, HidReportType::InputReport, 8, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_zero_size_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing zero size. + let status = + unsafe { hid_get_report(hid_io_ptr, 0, HidReportType::InputReport, 0, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_size_exceeds_u16_max_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing oversized request. + let status = unsafe { + hid_get_report( + hid_io_ptr, + 0, + HidReportType::InputReport, + u16::MAX as usize + 1, + buffer.as_mut_ptr() as *mut c_void, + ) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_output_type_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing invalid report type. + let status = unsafe { + hid_get_report(hid_io_ptr, 0, HidReportType::OutputReport, 8, buffer.as_mut_ptr() as *mut c_void) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn get_report_input_type_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_get_report(hid_io_ptr, 1, HidReportType::InputReport, 8, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(mock_usb.control_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn get_report_feature_type_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_get_report(hid_io_ptr, 1, HidReportType::Feature, 8, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(mock_usb.control_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn get_report_usb_error_is_propagated() { + let mock_usb = make_mock_usb_io(efi::Status::DEVICE_ERROR, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let mut buffer = [0u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_get_report(hid_io_ptr, 0, HidReportType::InputReport, 8, buffer.as_mut_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + core::mem::forget(device); + } + + // ---- set_report tests ---- + + #[test] + fn set_report_null_this_returns_invalid_parameter() { + let buffer = [0x01u8; 8]; + // SAFETY: testing null this handling. + let status = unsafe { + hid_set_report(core::ptr::null(), 0, HidReportType::OutputReport, 8, buffer.as_ptr() as *mut c_void) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn set_report_null_buffer_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice; testing null buffer. + let status = unsafe { hid_set_report(hid_io_ptr, 0, HidReportType::OutputReport, 8, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn set_report_zero_size_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing zero size. + let status = + unsafe { hid_set_report(hid_io_ptr, 0, HidReportType::OutputReport, 0, buffer.as_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn set_report_size_exceeds_u16_max_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing oversized request. + let status = unsafe { + hid_set_report( + hid_io_ptr, + 0, + HidReportType::OutputReport, + u16::MAX as usize + 1, + buffer.as_ptr() as *mut c_void, + ) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn set_report_input_type_returns_invalid_parameter() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice; testing invalid report type. + let status = + unsafe { hid_set_report(hid_io_ptr, 0, HidReportType::InputReport, 8, buffer.as_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + core::mem::forget(device); + } + + #[test] + fn set_report_output_type_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_set_report(hid_io_ptr, 1, HidReportType::OutputReport, 8, buffer.as_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(mock_usb.control_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn set_report_feature_type_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_set_report(hid_io_ptr, 1, HidReportType::Feature, 8, buffer.as_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(mock_usb.control_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn set_report_usb_error_is_propagated() { + let mock_usb = make_mock_usb_io(efi::Status::DEVICE_ERROR, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + let buffer = [0x01u8; 8]; + // SAFETY: device is a valid UsbHidDevice. + let status = + unsafe { hid_set_report(hid_io_ptr, 0, HidReportType::OutputReport, 8, buffer.as_ptr() as *mut c_void) }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + core::mem::forget(device); + } + + // ---- register_report_callback tests ---- + + #[test] + fn register_callback_null_this_returns_invalid_parameter() { + // SAFETY: testing null this handling. + let status = unsafe { hid_register_report_callback(core::ptr::null(), test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn register_callback_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::SUCCESS); + assert!(device.report_callback.callback.is_some()); + core::mem::forget(device); + } + + #[test] + fn register_callback_already_started() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // Register first callback. + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::SUCCESS); + // Second registration should fail. + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::ALREADY_STARTED); + core::mem::forget(device); + } + + #[test] + fn register_callback_clears_callback_on_transfer_failure() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::DEVICE_ERROR); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + assert!(device.report_callback.callback.is_none()); + core::mem::forget(device); + } + + // ---- unregister_report_callback tests ---- + + #[test] + fn unregister_callback_null_this_returns_invalid_parameter() { + // SAFETY: testing null this handling. + let status = unsafe { hid_unregister_report_callback(core::ptr::null(), test_callback) }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn unregister_callback_not_started() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice; no callback registered. + let status = unsafe { hid_unregister_report_callback(hid_io_ptr, test_callback) }; + assert_eq!(status, efi::Status::NOT_STARTED); + core::mem::forget(device); + } + + #[test] + fn unregister_callback_wrong_callback_returns_not_started() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // Register one callback, then try to unregister a different one. + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::SUCCESS); + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_unregister_report_callback(hid_io_ptr, other_callback) }; + assert_eq!(status, efi::Status::NOT_STARTED); + assert!(device.report_callback.callback.is_some()); + core::mem::forget(device); + } + + #[test] + fn unregister_callback_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let device = make_device(&mock_usb, vec![0x05]); + let hid_io_ptr = &device.hid_io as *const HidIoProtocol; + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_register_report_callback(hid_io_ptr, test_callback, core::ptr::null_mut()) }; + assert_eq!(status, efi::Status::SUCCESS); + // SAFETY: device is a valid UsbHidDevice. + let status = unsafe { hid_unregister_report_callback(hid_io_ptr, test_callback) }; + assert_eq!(status, efi::Status::SUCCESS); + assert!(device.report_callback.callback.is_none()); + core::mem::forget(device); + } +} diff --git a/usb_hid/src/interrupt_transfers.rs b/usb_hid/src/interrupt_transfers.rs new file mode 100644 index 0000000..926c5f1 --- /dev/null +++ b/usb_hid/src/interrupt_transfers.rs @@ -0,0 +1,626 @@ +//! Async interrupt transfer management for USB HID. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +use core::ffi::c_void; + +use r_efi::efi; + +use patina::boot_services::{ + BootServices, + event::{EventTimerType, EventType}, + tpl::Tpl, +}; + +use crate::{control_transfers, device::UsbHidDevice}; +use patina::uefi_protocol::usb_io::types::*; + +/// Delay in 100ns units before re-submitting after a transfer error. +/// 100ms matches the standard EDKII `EFI_USB_INTERRUPT_DELAY`. +const RECOVERY_DELAY_100NS: u64 = 1_000_000; + +/// Object-safe subset of [`BootServices`] for timer operations. +/// +/// The interrupt completion callback only has access to [`UsbHidDevice`] via a +/// raw context pointer. This trait provides a narrow, object-safe interface so +/// the callback can arm a recovery timer without requiring the full (non-object-safe) +/// `BootServices` trait. +pub(crate) trait TransferRecoveryTimer { + /// Arms a one-shot timer event to fire after `delay_100ns` units of 100ns. + fn arm_recovery_timer(&self, event: efi::Event, delay_100ns: u64) -> Result<(), efi::Status>; +} + +impl TransferRecoveryTimer for T { + fn arm_recovery_timer(&self, event: efi::Event, delay_100ns: u64) -> Result<(), efi::Status> { + self.set_timer(event, EventTimerType::Relative, delay_100ns) + } +} + +/// Creates a recovery timer event whose notify function re-submits the async +/// interrupt transfer on the given device. +/// +/// # Safety +/// +/// `device_ptr` must point to a valid, heap-allocated [`UsbHidDevice`] that will +/// remain valid for the lifetime of the returned event. The caller is responsible +/// for closing the event (via [`BootServices::close_event`]) before the device is +/// freed. +pub unsafe fn create_recovery_event( + boot_services: &U, + device_ptr: *mut UsbHidDevice, +) -> Result { + // SAFETY: Caller guarantees device_ptr is valid for the event's lifetime. + // The event fires at TPL_CALLBACK, below the typical interrupt transfer TPL. + unsafe { + boot_services.create_event_unchecked::( + EventType::TIMER | EventType::NOTIFY_SIGNAL, + Tpl::CALLBACK, + Some(recovery_timer_notify), + device_ptr, + ) + } +} + +/// Timer notify function invoked after the recovery delay to re-submit async +/// interrupt transfers following an error. +extern "efiapi" fn recovery_timer_notify(_event: efi::Event, context: *mut UsbHidDevice) { + if context.is_null() { + return; + } + // SAFETY: context is a valid UsbHidDevice pointer set during event creation. + let device = unsafe { &mut *context }; + + // If the callback was unregistered while the timer was armed, do not re-submit. + if device.report_callback.callback.is_none() { + return; + } + + if let Err(status) = initiate_async_interrupt_input_transfers(device) { + log::warn!("USB HID: recovery re-submit failed: {status:x?}"); + } +} + +/// Initiates input reports from the endpoint by scheduling an async interrupt +/// transfer to poll the device. +pub fn initiate_async_interrupt_input_transfers(device: &mut UsbHidDevice) -> Result<(), efi::Status> { + // SAFETY: usb_io was opened BY_DRIVER and is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + + // SAFETY: usb_io was opened BY_DRIVER and is valid; transfer parameters are valid. + let status = unsafe { + (usb_io.usb_async_interrupt_transfer)( + device.usb_io, + device.descriptors.int_in_endpoint_descriptor.endpoint_address, + true.into(), + device.descriptors.int_in_endpoint_descriptor.interval as usize, + device.descriptors.int_in_endpoint_descriptor.max_packet_size as usize, + Some(on_report_interrupt_complete), + device as *mut UsbHidDevice as *mut c_void, + ) + }; + if status != efi::Status::SUCCESS { + return Err(status); + } + + Ok(()) +} + +/// Shuts down async interrupt transfers. +pub fn shutdown_async_interrupt_input_transfers(device: &mut UsbHidDevice) -> Result<(), efi::Status> { + // SAFETY: usb_io is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + + // Cancel the async interrupt transfer. + // SAFETY: usb_io is valid; cancellation parameters are valid. + let status = unsafe { + (usb_io.usb_async_interrupt_transfer)( + device.usb_io, + device.descriptors.int_in_endpoint_descriptor.endpoint_address, + false.into(), + 0, + 0, + None, + core::ptr::null_mut(), + ) + }; + if status != efi::Status::SUCCESS && status != efi::Status::NOT_FOUND { + log::warn!("USB HID: unexpected error shutting down async transfer: {status:x?}"); + } + + Ok(()) +} + +/// Interrupt completion callback. Invoked by the USB bus driver when data +/// arrives on the interrupt-in endpoint (or on error). +/// +/// # Safety +/// +/// `context` must be a valid pointer to a [`UsbHidDevice`] that was set +/// during transfer initiation. `data` must be valid for `data_length` bytes +/// when `result` indicates success. +unsafe extern "efiapi" fn on_report_interrupt_complete( + data: *mut c_void, + data_length: usize, + context: *mut c_void, + result: u32, +) -> efi::Status { + if context.is_null() { + return efi::Status::INVALID_PARAMETER; + } + + // SAFETY: context is a pointer to UsbHidDevice set during transfer initiation. + let device = unsafe { &mut *(context as *mut UsbHidDevice) }; + + if result != EFI_USB_NOERROR { + // Handle stall by clearing the endpoint halt. + if (result & EFI_USB_ERR_STALL) != 0 { + // SAFETY: usb_io is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + let _ = control_transfers::usb_clear_endpoint_halt( + usb_io, + device.descriptors.int_in_endpoint_descriptor.endpoint_address, + ); + } + + // Cancel the current async transfer and re-submit. + // SAFETY: usb_io is valid for the device's lifetime. + let usb_io = unsafe { &*device.usb_io }; + // SAFETY: usb_io is valid; cancelling the current async transfer. + let _ = unsafe { + (usb_io.usb_async_interrupt_transfer)( + device.usb_io, + device.descriptors.int_in_endpoint_descriptor.endpoint_address, + false.into(), + 0, + 0, + None, + core::ptr::null_mut(), + ) + }; + + // Arm the recovery timer for delayed re-submission. + let _ = device.timer_services.arm_recovery_timer(device.recovery_event, RECOVERY_DELAY_100NS); + + return efi::Status::DEVICE_ERROR; + } + + if data_length > u16::MAX as usize { + return efi::Status::DEVICE_ERROR; + } + + if data_length == 0 || data.is_null() { + return efi::Status::SUCCESS; + } + + if let Some(callback) = device.report_callback.callback { + // SAFETY: callback was registered via HidIo protocol; data and context are valid. + unsafe { callback(data_length as u16, data, device.report_callback.context) }; + } + + efi::Status::SUCCESS +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use core::{ + cell::Cell, + sync::atomic::{AtomicU16, AtomicUsize, Ordering}, + }; + + use patina::uefi_protocol::usb_io::{ + EfiAsyncUsbTransferCallback, EfiUsbIoProtocol, + types::{EfiUsbDataDirection, EfiUsbDeviceRequest}, + }; + + use crate::{ + device::{ReportCallbackState, UsbHidDescriptors, UsbHidDevice}, + hid_io_impl, + }; + + // ---- Mock USB IO ---- + + /// Mock USB IO context. `protocol` must be the first field so extern mock + /// functions can recover mock state from the `this` pointer. + #[repr(C)] + struct MockUsbIo { + protocol: EfiUsbIoProtocol, + control_transfer_status: efi::Status, + async_transfer_status: efi::Status, + async_call_count: Cell, + control_call_count: Cell, + } + + impl MockUsbIo { + /// # Safety + /// `this` must point to the `protocol` field of a valid `MockUsbIo`. + unsafe fn from_this(this: *const EfiUsbIoProtocol) -> &'static Self { + // SAFETY: MockUsbIo is #[repr(C)] with protocol as first field. + unsafe { &*(this as *const MockUsbIo) } + } + } + + extern "efiapi" fn mock_control_transfer( + this: *const EfiUsbIoProtocol, + _request: *const EfiUsbDeviceRequest, + _direction: EfiUsbDataDirection, + _timeout: u32, + _data: *mut c_void, + _data_length: usize, + _status: *mut u32, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + mock.control_call_count.set(mock.control_call_count.get() + 1); + mock.control_transfer_status + } + + extern "efiapi" fn mock_async_interrupt_transfer( + this: *const EfiUsbIoProtocol, + _endpoint: u8, + _is_new_transfer: efi::Boolean, + _polling_interval: usize, + _data_length: usize, + _callback: Option, + _context: *mut c_void, + ) -> efi::Status { + // SAFETY: this points to a valid MockUsbIo on the test stack. + let mock = unsafe { MockUsbIo::from_this(this) }; + mock.async_call_count.set(mock.async_call_count.get() + 1); + mock.async_transfer_status + } + + fn make_mock_usb_io(control_status: efi::Status, async_status: efi::Status) -> MockUsbIo { + let mut protocol = crate::test_stubs::usb_io_stub(); + protocol.usb_control_transfer = mock_control_transfer; + protocol.usb_async_interrupt_transfer = mock_async_interrupt_transfer; + MockUsbIo { + protocol, + control_transfer_status: control_status, + async_transfer_status: async_status, + async_call_count: Cell::new(0), + control_call_count: Cell::new(0), + } + } + + // ---- No-op timer for tests that don't exercise recovery ---- + + struct NoopTransferRecoveryTimer; + impl TransferRecoveryTimer for NoopTransferRecoveryTimer { + fn arm_recovery_timer(&self, _event: efi::Event, _delay_100ns: u64) -> Result<(), efi::Status> { + Ok(()) + } + } + static NOOP_RECOVERY_TIMER: NoopTransferRecoveryTimer = NoopTransferRecoveryTimer; + + fn make_device(usb_io: &MockUsbIo) -> Box { + Box::new(UsbHidDevice { + hid_io: hid_io_impl::new_hid_io_protocol(), + usb_io: &usb_io.protocol as *const EfiUsbIoProtocol, + descriptors: UsbHidDescriptors { + interface_descriptor: EfiUsbInterfaceDescriptor::default(), + int_in_endpoint_descriptor: EfiUsbEndpointDescriptor { + endpoint_address: 0x81, + interval: 10, + max_packet_size: 8, + ..Default::default() + }, + report_descriptor: vec![0x05, 0x01], + }, + report_callback: ReportCallbackState::default(), + timer_services: &NOOP_RECOVERY_TIMER, + recovery_event: core::ptr::null_mut(), + }) + } + + // ---- initiate tests ---- + + #[test] + fn initiate_transfers_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + assert!(initiate_async_interrupt_input_transfers(&mut device).is_ok()); + assert_eq!(mock_usb.async_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn initiate_transfers_returns_error_on_failure() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::DEVICE_ERROR); + let mut device = make_device(&mock_usb); + assert_eq!(initiate_async_interrupt_input_transfers(&mut device), Err(efi::Status::DEVICE_ERROR),); + core::mem::forget(device); + } + + // ---- shutdown tests ---- + + #[test] + fn shutdown_transfers_succeeds() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + assert!(shutdown_async_interrupt_input_transfers(&mut device).is_ok()); + assert_eq!(mock_usb.async_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn shutdown_transfers_tolerates_not_found() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::NOT_FOUND); + let mut device = make_device(&mock_usb); + assert!(shutdown_async_interrupt_input_transfers(&mut device).is_ok()); + core::mem::forget(device); + } + + #[test] + fn shutdown_transfers_tolerates_unexpected_error() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::DEVICE_ERROR); + let mut device = make_device(&mock_usb); + // shutdown always returns Ok even on unexpected errors (it just logs a warning). + assert!(shutdown_async_interrupt_input_transfers(&mut device).is_ok()); + core::mem::forget(device); + } + + // ---- on_report_interrupt_complete tests ---- + + #[test] + fn callback_null_context_returns_invalid_parameter() { + let report = [0u8; 8]; + // SAFETY: testing null context handling. + let status = unsafe { + on_report_interrupt_complete( + report.as_ptr() as *mut c_void, + report.len(), + core::ptr::null_mut(), + EFI_USB_NOERROR, + ) + }; + assert_eq!(status, efi::Status::INVALID_PARAMETER); + } + + #[test] + fn callback_usb_error_returns_device_error_and_arms_timer() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete( + core::ptr::null_mut(), + 0, + device_ptr as *mut c_void, + 0x01, // non-zero = error, but not stall + ) + }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + // Should have called async transfer once (cancel only; re-submit is via recovery timer). + assert_eq!(mock_usb.async_call_count.get(), 1); + // Should not have called control transfer (no stall). + assert_eq!(mock_usb.control_call_count.get(), 0); + core::mem::forget(device); + } + + #[test] + fn callback_stall_error_clears_halt_and_arms_timer() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete(core::ptr::null_mut(), 0, device_ptr as *mut c_void, EFI_USB_ERR_STALL) + }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + // Should have called control transfer to clear endpoint halt. + assert_eq!(mock_usb.control_call_count.get(), 1); + // Should have called async transfer once (cancel only; re-submit is via recovery timer). + assert_eq!(mock_usb.async_call_count.get(), 1); + core::mem::forget(device); + } + + #[test] + fn callback_data_length_exceeds_u16_max_returns_device_error() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete( + 0x1000 as *mut c_void, // non-null + u16::MAX as usize + 1, + device_ptr as *mut c_void, + EFI_USB_NOERROR, + ) + }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + core::mem::forget(device); + } + + #[test] + fn callback_zero_length_returns_success() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete(core::ptr::null_mut(), 0, device_ptr as *mut c_void, EFI_USB_NOERROR) + }; + assert_eq!(status, efi::Status::SUCCESS); + core::mem::forget(device); + } + + #[test] + fn callback_null_data_returns_success() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete(core::ptr::null_mut(), 8, device_ptr as *mut c_void, EFI_USB_NOERROR) + }; + assert_eq!(status, efi::Status::SUCCESS); + core::mem::forget(device); + } + + #[test] + fn callback_no_registered_callback_returns_success() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + let device_ptr = &mut *device as *mut UsbHidDevice; + let report = [0xAAu8; 4]; + // No callback registered — should silently succeed. + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete( + report.as_ptr() as *mut c_void, + report.len(), + device_ptr as *mut c_void, + EFI_USB_NOERROR, + ) + }; + assert_eq!(status, efi::Status::SUCCESS); + core::mem::forget(device); + } + + // Shared atomic counters for the test callback below. + static CALLBACK_INVOCATIONS: AtomicUsize = AtomicUsize::new(0); + static CALLBACK_REPORT_SIZE: AtomicU16 = AtomicU16::new(0); + + unsafe extern "efiapi" fn counting_callback( + report_buffer_size: u16, + _report_buffer: *mut c_void, + _context: *mut c_void, + ) { + CALLBACK_INVOCATIONS.fetch_add(1, Ordering::SeqCst); + CALLBACK_REPORT_SIZE.store(report_buffer_size, Ordering::SeqCst); + } + + #[test] + fn callback_invokes_registered_callback_with_report() { + CALLBACK_INVOCATIONS.store(0, Ordering::SeqCst); + CALLBACK_REPORT_SIZE.store(0, Ordering::SeqCst); + + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + device.report_callback.callback = Some(counting_callback); + let device_ptr = &mut *device as *mut UsbHidDevice; + let report = [0x10u8, 0x20, 0x30]; + // SAFETY: device_ptr is a valid UsbHidDevice; report is valid. + let status = unsafe { + on_report_interrupt_complete( + report.as_ptr() as *mut c_void, + report.len(), + device_ptr as *mut c_void, + EFI_USB_NOERROR, + ) + }; + assert_eq!(status, efi::Status::SUCCESS); + assert_eq!(CALLBACK_INVOCATIONS.load(Ordering::SeqCst), 1); + assert_eq!(CALLBACK_REPORT_SIZE.load(Ordering::SeqCst), 3); + core::mem::forget(device); + } + + // ---- Timer-based recovery tests ---- + + /// Mock timer services for testing delayed recovery. + struct MockTransferRecoveryTimer { + arm_called: Cell, + arm_event: Cell>, + arm_delay: Cell, + } + + impl MockTransferRecoveryTimer { + fn new() -> Self { + Self { arm_called: Cell::new(false), arm_event: Cell::new(None), arm_delay: Cell::new(0) } + } + } + + impl TransferRecoveryTimer for MockTransferRecoveryTimer { + fn arm_recovery_timer(&self, event: efi::Event, delay_100ns: u64) -> Result<(), efi::Status> { + self.arm_called.set(true); + self.arm_event.set(Some(event)); + self.arm_delay.set(delay_100ns); + Ok(()) + } + } + + fn make_device_with_timer(usb_io: &MockUsbIo, timer_services: &MockTransferRecoveryTimer) -> Box { + // SAFETY: The timer_services reference is transmuted to 'static for storage in + // UsbHidDevice. The test ensures the MockTransferRecoveryTimer outlives the device. + let timer_ref: &'static dyn TransferRecoveryTimer = + unsafe { core::mem::transmute(timer_services as &dyn TransferRecoveryTimer) }; + let sentinel_event = 0xBEEF as efi::Event; + Box::new(UsbHidDevice { + hid_io: hid_io_impl::new_hid_io_protocol(), + usb_io: &usb_io.protocol as *const EfiUsbIoProtocol, + descriptors: UsbHidDescriptors { + interface_descriptor: EfiUsbInterfaceDescriptor::default(), + int_in_endpoint_descriptor: EfiUsbEndpointDescriptor { + endpoint_address: 0x81, + interval: 10, + max_packet_size: 8, + ..Default::default() + }, + report_descriptor: vec![0x05, 0x01], + }, + report_callback: ReportCallbackState::default(), + timer_services: timer_ref, + recovery_event: sentinel_event, + }) + } + + #[test] + fn callback_error_arms_recovery_timer_instead_of_immediate_resubmit() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mock_timer = MockTransferRecoveryTimer::new(); + let mut device = make_device_with_timer(&mock_usb, &mock_timer); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { on_report_interrupt_complete(core::ptr::null_mut(), 0, device_ptr as *mut c_void, 0x01) }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + // Should have called async transfer once (cancel only, no immediate re-submit). + assert_eq!(mock_usb.async_call_count.get(), 1); + // Recovery timer should have been armed. + assert!(mock_timer.arm_called.get()); + assert_eq!(mock_timer.arm_event.get(), Some(0xBEEF as efi::Event)); + assert_eq!(mock_timer.arm_delay.get(), RECOVERY_DELAY_100NS); + core::mem::forget(device); + } + + #[test] + fn callback_stall_error_with_timer_clears_halt_and_arms_timer() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mock_timer = MockTransferRecoveryTimer::new(); + let mut device = make_device_with_timer(&mock_usb, &mock_timer); + let device_ptr = &mut *device as *mut UsbHidDevice; + // SAFETY: device_ptr is a valid UsbHidDevice. + let status = unsafe { + on_report_interrupt_complete(core::ptr::null_mut(), 0, device_ptr as *mut c_void, EFI_USB_ERR_STALL) + }; + assert_eq!(status, efi::Status::DEVICE_ERROR); + // Should have cleared endpoint halt via control transfer. + assert_eq!(mock_usb.control_call_count.get(), 1); + // Should have called async transfer once (cancel only). + assert_eq!(mock_usb.async_call_count.get(), 1); + // Recovery timer should have been armed. + assert!(mock_timer.arm_called.get()); + core::mem::forget(device); + } + + #[test] + fn recovery_timer_does_not_resubmit_after_callback_unregistered() { + let mock_usb = make_mock_usb_io(efi::Status::SUCCESS, efi::Status::SUCCESS); + let mut device = make_device(&mock_usb); + // No callback registered — simulates unregister having cleared it. + assert!(device.report_callback.callback.is_none()); + let device_ptr = &mut *device as *mut UsbHidDevice; + // Simulate the recovery timer firing. + recovery_timer_notify(core::ptr::null_mut(), device_ptr); + // Should NOT have attempted to re-submit async transfers. + assert_eq!(mock_usb.async_call_count.get(), 0); + core::mem::forget(device); + } +} diff --git a/usb_hid/src/lib.rs b/usb_hid/src/lib.rs new file mode 100644 index 0000000..2da6003 --- /dev/null +++ b/usb_hid/src/lib.rs @@ -0,0 +1,129 @@ +//! USB HID Driver — produces the HidIo protocol on USB HID devices. +//! +//! This crate implements a UEFI Driver Binding that consumes the USB IO protocol +//! and produces the HidIo protocol for each USB HID device it manages. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! +#![cfg_attr(not(test), no_std)] +#![feature(coverage_attribute)] + +extern crate alloc; + +pub(crate) mod control_transfers; +pub(crate) mod descriptors; +pub(crate) mod device; +pub(crate) mod driver; +pub(crate) mod hid_io_impl; +pub(crate) mod interrupt_transfers; +pub(crate) mod usb_hid_defs; + +#[cfg(test)] +pub(crate) mod test_stubs; + +use alloc::boxed::Box; + +use r_efi::efi; + +use patina::{ + BinaryGuid, + boot_services::{BootServices, StandardBootServices}, + component::{component, params}, + driver_binding::UefiDriverBinding, + error::Result, + uefi_protocol::ProtocolInterface, +}; + +/// Zero-sized marker protocol used to create a dedicated driver binding handle. +#[repr(C)] +struct UsbHidMarker; + +// SAFETY: UsbHidMarker is a ZST whose GUID uniquely identifies this component. +unsafe impl ProtocolInterface for UsbHidMarker { + const PROTOCOL_GUID: BinaryGuid = BinaryGuid::from_string("a7f36d52-8e3b-4f1a-9c5d-7b2e4a6f8d01"); +} + +/// USB HID Patina component. +/// +/// When dispatched, installs a UEFI Driver Binding that consumes USB IO +/// protocol instances on HID devices and produces the HidIo protocol. +pub struct UsbHidComponent; + +#[component] +impl UsbHidComponent { + fn entry_point(self, boot_services: StandardBootServices, image_handle: params::Handle) -> Result<()> { + let boot_services: &'static StandardBootServices = Box::leak(Box::new(boot_services)); + install_usb_hid_driver_binding(boot_services, *image_handle) + } +} + +/// Installs the USB HID driver binding using the provided boot services. +fn install_usb_hid_driver_binding( + boot_services: &'static T, + image_handle: efi::Handle, +) -> Result<()> { + let (driver_binding_handle, _marker_key) = + boot_services.install_protocol_interface(None, Box::new(UsbHidMarker))?; + + let driver = driver::UsbHidDriver::new(driver_binding_handle); + + let mut driver_binding = + UefiDriverBinding::new_with_driver_handle(driver, image_handle, driver_binding_handle, boot_services); + + driver_binding.install().map_err(patina::error::EfiError::from)?; + + Ok(()) +} + +#[cfg(test)] +mod test { + use patina::boot_services::{MockBootServices, c_ptr::CPtr}; + + use super::*; + + #[test] + fn install_usb_hid_binding_should_install_a_binding() { + let boot_services = Box::leak(Box::new(MockBootServices::new())); + + boot_services.expect_install_protocol_interface::>().returning( + |handle, protocol_interface| { + assert_eq!(handle, None, "Expected no handle for marker protocol installation"); + Ok((0x5678 as efi::Handle, protocol_interface.metadata())) + }, + ); + + boot_services.expect_install_protocol_interface_unchecked().returning(|handle, protocol, interface| { + if protocol == &efi::protocols::driver_binding::PROTOCOL_GUID { + assert!( + handle.is_some_and(|handle| handle as usize == 0x5678), + "Expected correct handle for driver binding protocol" + ); + assert!(!interface.is_null(), "Expected non-null interface for driver binding protocol"); + return Ok(0x9abc as efi::Handle); + } + panic!("Unexpected protocol installation: {:?}", protocol); + }); + + let mock_image_handle = 0x1234 as efi::Handle; + install_usb_hid_driver_binding(boot_services, mock_image_handle).expect("install should succeed"); + } + + #[test] + fn install_usb_hid_binding_handles_marker_failure() { + let boot_services = Box::leak(Box::new(MockBootServices::new())); + + boot_services + .expect_install_protocol_interface::>() + .returning(|_, _| Err(efi::Status::OUT_OF_RESOURCES)); + + let mock_image_handle = 0x1234 as efi::Handle; + assert_eq!( + install_usb_hid_driver_binding(boot_services, mock_image_handle), + Err(efi::Status::OUT_OF_RESOURCES.into()) + ); + } +} diff --git a/usb_hid/src/test_stubs.rs b/usb_hid/src/test_stubs.rs new file mode 100644 index 0000000..0f5dfc7 --- /dev/null +++ b/usb_hid/src/test_stubs.rs @@ -0,0 +1,141 @@ +//! Test stubs for protocol types whose `stub()` methods are not exposed +//! from the upstream patina crate. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +use core::ffi::c_void; + +use r_efi::efi; + +use patina::uefi_protocol::usb_io::{EfiAsyncUsbTransferCallback, EfiUsbIoProtocol, types::*}; + +/// Creates a stub `EfiUsbIoProtocol` with panicking function pointers. +/// +/// Callers should replace the function pointers they need before use. +#[coverage(off)] +pub fn usb_io_stub() -> EfiUsbIoProtocol { + unsafe extern "efiapi" fn stub_control_transfer( + _this: *const EfiUsbIoProtocol, + _request: *const EfiUsbDeviceRequest, + _direction: EfiUsbDataDirection, + _timeout: u32, + _data: *mut c_void, + _data_length: usize, + _status: *mut u32, + ) -> efi::Status { + panic!("unexpected call to usb_control_transfer") + } + unsafe extern "efiapi" fn stub_bulk_transfer( + _this: *const EfiUsbIoProtocol, + _device_endpoint: u8, + _data: *mut c_void, + _data_length: *mut usize, + _timeout: usize, + _status: *mut u32, + ) -> efi::Status { + panic!("unexpected call to usb_bulk_transfer") + } + unsafe extern "efiapi" fn stub_async_interrupt_transfer( + _this: *const EfiUsbIoProtocol, + _device_endpoint: u8, + _is_new_transfer: efi::Boolean, + _polling_interval: usize, + _data_length: usize, + _callback: Option, + _context: *mut c_void, + ) -> efi::Status { + panic!("unexpected call to usb_async_interrupt_transfer") + } + unsafe extern "efiapi" fn stub_sync_interrupt_transfer( + _this: *const EfiUsbIoProtocol, + _device_endpoint: u8, + _data: *mut c_void, + _data_length: *mut usize, + _timeout: usize, + _status: *mut u32, + ) -> efi::Status { + panic!("unexpected call to usb_sync_interrupt_transfer") + } + unsafe extern "efiapi" fn stub_isochronous_transfer( + _this: *const EfiUsbIoProtocol, + _device_endpoint: u8, + _data: *mut c_void, + _data_length: usize, + _status: *mut u32, + ) -> efi::Status { + panic!("unexpected call to usb_isochronous_transfer") + } + unsafe extern "efiapi" fn stub_async_isochronous_transfer( + _this: *const EfiUsbIoProtocol, + _device_endpoint: u8, + _data: *mut c_void, + _data_length: usize, + _isochronous_callback: EfiAsyncUsbTransferCallback, + _context: *mut c_void, + ) -> efi::Status { + panic!("unexpected call to usb_async_isochronous_transfer") + } + unsafe extern "efiapi" fn stub_get_device_descriptor( + _this: *const EfiUsbIoProtocol, + _device_descriptor: *mut EfiUsbDeviceDescriptor, + ) -> efi::Status { + panic!("unexpected call to usb_get_device_descriptor") + } + unsafe extern "efiapi" fn stub_get_config_descriptor( + _this: *const EfiUsbIoProtocol, + _config_descriptor: *mut EfiUsbConfigDescriptor, + ) -> efi::Status { + panic!("unexpected call to usb_get_config_descriptor") + } + unsafe extern "efiapi" fn stub_get_interface_descriptor( + _this: *const EfiUsbIoProtocol, + _interface_descriptor: *mut EfiUsbInterfaceDescriptor, + ) -> efi::Status { + panic!("unexpected call to usb_get_interface_descriptor") + } + unsafe extern "efiapi" fn stub_get_endpoint_descriptor( + _this: *const EfiUsbIoProtocol, + _endpoint_index: u8, + _endpoint_descriptor: *mut EfiUsbEndpointDescriptor, + ) -> efi::Status { + panic!("unexpected call to usb_get_endpoint_descriptor") + } + unsafe extern "efiapi" fn stub_get_string_descriptor( + _this: *const EfiUsbIoProtocol, + _lang_id: u16, + _string_id: u8, + _string: *mut *mut u16, + ) -> efi::Status { + panic!("unexpected call to usb_get_string_descriptor") + } + unsafe extern "efiapi" fn stub_get_supported_languages( + _this: *const EfiUsbIoProtocol, + _lang_id_table: *mut *mut u16, + _table_size: *mut u16, + ) -> efi::Status { + panic!("unexpected call to usb_get_supported_languages") + } + unsafe extern "efiapi" fn stub_port_reset(_this: *const EfiUsbIoProtocol) -> efi::Status { + panic!("unexpected call to usb_port_reset") + } + EfiUsbIoProtocol { + usb_control_transfer: stub_control_transfer, + usb_bulk_transfer: stub_bulk_transfer, + usb_async_interrupt_transfer: stub_async_interrupt_transfer, + usb_sync_interrupt_transfer: stub_sync_interrupt_transfer, + usb_isochronous_transfer: stub_isochronous_transfer, + usb_async_isochronous_transfer: stub_async_isochronous_transfer, + usb_get_device_descriptor: stub_get_device_descriptor, + usb_get_config_descriptor: stub_get_config_descriptor, + usb_get_interface_descriptor: stub_get_interface_descriptor, + usb_get_endpoint_descriptor: stub_get_endpoint_descriptor, + usb_get_string_descriptor: stub_get_string_descriptor, + usb_get_supported_languages: stub_get_supported_languages, + usb_port_reset: stub_port_reset, + } +} diff --git a/usb_hid/src/usb_hid_defs.rs b/usb_hid/src/usb_hid_defs.rs new file mode 100644 index 0000000..3a08fdf --- /dev/null +++ b/usb_hid/src/usb_hid_defs.rs @@ -0,0 +1,71 @@ +//! USB HID class-specific constants and descriptor structures. +//! +//! These definitions are specific to the USB HID class and are used by this +//! component to communicate with HID devices via the USB IO protocol. +//! +//! ## License +//! +//! Copyright (c) Microsoft Corporation. +//! +//! SPDX-License-Identifier: Apache-2.0 +//! + +/// USB interface class for HID devices. +pub const CLASS_HID: u8 = 3; +/// USB interface subclass for boot devices. +pub const SUBCLASS_BOOT: u8 = 1; + +/// HID report protocol mode. +pub const REPORT_PROTOCOL: u8 = 1; + +/// USB descriptor type for HID. +pub const USB_DESC_TYPE_HID: u8 = 0x21; +/// USB descriptor type for HID report. +pub const USB_DESC_TYPE_REPORT: u8 = 0x22; + +/// USB HID class-specific request: GET_REPORT. +pub const USB_HID_GET_REPORT_REQUEST: u8 = 0x01; +/// USB HID class-specific request: SET_REPORT. +pub const USB_HID_SET_REPORT_REQUEST: u8 = 0x09; +/// USB HID class-specific request: SET_PROTOCOL. +pub const USB_HID_SET_PROTOCOL_REQUEST: u8 = 0x0B; + +/// USB request type: class, interface, host-to-device. +pub const USB_REQ_TYPE_CLASS_INTERFACE_OUT: u8 = 0x21; +/// USB request type: class, interface, device-to-host. +pub const USB_REQ_TYPE_CLASS_INTERFACE_IN: u8 = 0xA1; +/// USB request type: standard, endpoint, host-to-device. +pub const USB_REQ_TYPE_STANDARD_ENDPOINT_OUT: u8 = 0x02; +/// USB request type: standard, device, device-to-host. +pub const USB_REQ_TYPE_STANDARD_DEVICE_IN: u8 = 0x80; + +/// USB standard request: CLEAR_FEATURE. +pub const USB_REQ_CLEAR_FEATURE: u8 = 0x01; +/// USB feature selector: ENDPOINT_HALT. +pub const USB_FEATURE_ENDPOINT_HALT: u16 = 0; + +/// USB standard request: GET_DESCRIPTOR. +pub const USB_REQ_GET_DESCRIPTOR: u8 = 0x06; + +/// Timeout for USB control transfers (in milliseconds). +pub const USB_TRANSFER_TIMEOUT_MS: u32 = 3000; + +/// HID class descriptor entry (type + length pair). +#[derive(Debug, Clone, Copy, Default)] +#[repr(C, packed)] +pub struct HidClassDescriptor { + pub descriptor_type: u8, + pub descriptor_length: u16, +} + +/// USB HID descriptor. +#[derive(Debug, Clone, Copy)] +#[repr(C, packed)] +pub struct EfiUsbHidDescriptor { + pub length: u8, + pub descriptor_type: u8, + pub bcd_hid: u16, + pub country_code: u8, + pub num_descriptors: u8, + // Followed by variable-length array of HidClassDescriptor. +} From e32b8784f7729c4ca3d16e57605c7c90a1e8d0f5 Mon Sep 17 00:00:00 2001 From: John Schock Date: Mon, 4 May 2026 17:47:28 -0700 Subject: [PATCH 3/3] fix: add legitimate technical terms to cspell dictionary Add words from scroll crate (Pread/Pwrite/gread/gwrite/pread/pwrite), USB/UEFI terms (EDKII, NOERROR), keyboard terms (numpad, numlock, lctrl, lshift, unshifted), data structures (Deque), mockall (withf), and common abbreviations (descs, Unregisters). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cspell.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/cspell.yml b/cspell.yml index 9c988a4..7efee5b 100644 --- a/cspell.yml +++ b/cspell.yml @@ -26,8 +26,23 @@ caseSensitive: false allowCompoundWords: true words: - "Depex" + - "Deque" + - "descs" + - "EDKII" - "efiapi" + - "gread" - "guids" + - "gwrite" + - "lctrl" + - "lshift" - "mdbook" + - "NOERROR" + - "numlock" + - "numpad" + - "Pread" + - "pwrite" - "sysreg" + - "unshifted" + - "Unregisters" - "unsignaled" + - "withf"