diff --git a/Cargo.lock b/Cargo.lock index cd2d69a3..4f5c1a89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -404,9 +404,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" dependencies = [ "crc-catalog", ] @@ -1302,6 +1302,7 @@ name = "mctp-rs" version = "0.1.0" dependencies = [ "bit-register", + "crc", "defmt 0.3.100", "embedded-batteries", "espi-device", diff --git a/examples/pico-de-gallo/Cargo.lock b/examples/pico-de-gallo/Cargo.lock index 8ac8715f..375ba9bb 100644 --- a/examples/pico-de-gallo/Cargo.lock +++ b/examples/pico-de-gallo/Cargo.lock @@ -619,7 +619,7 @@ dependencies = [ [[package]] name = "espi-device" version = "0.1.0" -source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#0d21aa34fd8691c6544533e6c72fe41824dc2fa8" +source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#09eda26a729738adbd177231600acdb981690375" dependencies = [ "bit-register", "bitflags", diff --git a/examples/rt685s-evk/Cargo.lock b/examples/rt685s-evk/Cargo.lock index ab677628..342fa1dc 100644 --- a/examples/rt685s-evk/Cargo.lock +++ b/examples/rt685s-evk/Cargo.lock @@ -688,7 +688,7 @@ dependencies = [ [[package]] name = "espi-device" version = "0.1.0" -source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#0d21aa34fd8691c6544533e6c72fe41824dc2fa8" +source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#09eda26a729738adbd177231600acdb981690375" dependencies = [ "bit-register", "bitflags 2.11.1", diff --git a/examples/std/Cargo.lock b/examples/std/Cargo.lock index 16c0a12b..5d622782 100644 --- a/examples/std/Cargo.lock +++ b/examples/std/Cargo.lock @@ -675,7 +675,7 @@ dependencies = [ [[package]] name = "espi-device" version = "0.1.0" -source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#0d21aa34fd8691c6544533e6c72fe41824dc2fa8" +source = "git+https://github.com/OpenDevicePartnership/haf-ec-service#09eda26a729738adbd177231600acdb981690375" dependencies = [ "bit-register", "bitflags 2.11.1", diff --git a/mctp-rs/Cargo.toml b/mctp-rs/Cargo.toml index 7e5d382b..7d16b1c5 100644 --- a/mctp-rs/Cargo.toml +++ b/mctp-rs/Cargo.toml @@ -6,16 +6,18 @@ edition = "2024" [package.metadata.cargo-machete] # Optional deps gated by features — cargo-machete sees them as unused at # default features but they ARE consumed when the relevant feature is on. -ignored = ["embedded-batteries", "espi-device", "uuid"] +ignored = ["embedded-batteries", "espi-device", "uuid", "crc"] [features] default = [] espi = ["dep:espi-device"] defmt = ["dep:defmt", "embedded-batteries/defmt"] +serial = ["dep:crc"] [dependencies] espi-device = { git = "https://github.com/OpenDevicePartnership/haf-ec-service", optional = true } bit-register = { git = "https://github.com/OpenDevicePartnership/odp-utilities", package = "bit-register" } +crc = { version = "3.4", default-features = false, optional = true } num_enum = { version = "0.7.4", default-features = false } smbus-pec = "1.0.1" thiserror = { version = "2.0.16", default-features = false } diff --git a/mctp-rs/src/buffer_encoding.rs b/mctp-rs/src/buffer_encoding.rs new file mode 100644 index 00000000..96908e5a --- /dev/null +++ b/mctp-rs/src/buffer_encoding.rs @@ -0,0 +1,248 @@ +//! Stateless byte-level buffer-encoding transform for MCTP media. +//! +//! Most media (SMBus/eSPI) ship MCTP packets verbatim — wire bytes ARE +//! payload bytes. Some media (DSP0253 serial) need byte-stuffing: an +//! escape character expands certain payload bytes into 2-byte sequences +//! on the wire, and decode reverses that transform. +//! +//! [`BufferEncoding`] is the byte-stuffing layer ONLY. It is stateless: +//! [`write_byte`](BufferEncoding::write_byte) and +//! [`read_byte`](BufferEncoding::read_byte) are associated functions with +//! no `self` and no struct state. Higher-level framing concerns +//! (start/end delimiters, FCS / CRC) live on the medium type, not here. + +use core::marker::PhantomData; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum EncodeError { + /// `wire_buf` did not have room for the encoded bytes (1 for plain, + /// up to 2 for an escape sequence). The caller should advance no + /// cursors and treat the encode as failed. + BufferFull, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum DecodeError { + /// `wire_buf` was empty or ended mid-escape-sequence. Indicates the + /// caller asked to decode past the end of valid wire data. + PrematureEnd, + /// An escape byte was followed by a byte not in the medium's + /// accept-list (strict-XOR rule per RFC1662 §4.2 / DSP0253 §6.4). + /// The caller should reject the entire frame. Reachable via + /// `SerialEncoding` when the byte following an escape (`0x7D`) is + /// neither `0x5E` nor `0x5D`. + InvalidEscape, +} + +/// Stateless byte-stuffing transform. Implementors define how a single +/// logical (payload) byte maps to one or more wire bytes (encode) and +/// how a wire-byte prefix maps back to a single payload byte (decode). +/// +/// All methods are associated functions — there is no `self` and no +/// struct state. Callers own the buffers and the read/write cursors. +pub trait BufferEncoding { + /// Encode one logical payload byte into `wire_buf` starting at + /// index 0. Returns the number of wire bytes written (1 for plain, + /// 2 for an escape sequence). The caller advances their write + /// cursor by the returned count. + fn write_byte(wire_buf: &mut [u8], byte: u8) -> Result; + + /// Decode the next logical payload byte from `wire_buf` starting at + /// index 0. Returns `(decoded_byte, wire_bytes_consumed)`. The + /// caller advances their read cursor by `wire_bytes_consumed`. + fn read_byte(wire_buf: &[u8]) -> Result<(u8, usize), DecodeError>; + + /// Wire-byte footprint of `decoded` under this encoding. Must equal + /// the sum of `write_byte(_, b)` lengths for each `b` in `decoded`. + /// NO default impl: every encoding declares its sizing rule + /// explicitly. + fn wire_size_of(decoded: &[u8]) -> usize; +} + +/// No-op encoding: wire bytes ARE payload bytes. Used by media that do +/// not byte-stuff (SMBus/eSPI, test fixtures). +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct PassthroughEncoding; + +impl BufferEncoding for PassthroughEncoding { + fn write_byte(wire_buf: &mut [u8], byte: u8) -> Result { + match wire_buf.first_mut() { + Some(slot) => { + *slot = byte; + Ok(1) + } + None => Err(EncodeError::BufferFull), + } + } + + fn read_byte(wire_buf: &[u8]) -> Result<(u8, usize), DecodeError> { + match wire_buf.first() { + Some(&byte) => Ok((byte, 1)), + None => Err(DecodeError::PrematureEnd), + } + } + + fn wire_size_of(decoded: &[u8]) -> usize { + decoded.len() + } +} + +/// Stateful cursor over a `&[u8]` wire buffer that reads decoded bytes +/// through `E: BufferEncoding`. Constructed by [`MctpMedium::deserialize`] +/// and handed to higher layers so they cannot bypass the encoding by +/// slicing the underlying buffer directly. +/// +/// [`MctpMedium::deserialize`]: crate::medium::MctpMedium::deserialize +pub struct EncodingDecoder<'buf, E: BufferEncoding> { + buf: &'buf [u8], + wire_pos: usize, + _phantom: PhantomData, +} + +impl<'buf, E: BufferEncoding> EncodingDecoder<'buf, E> { + /// Wrap a wire-byte buffer for stateful encoding-mediated reads. + pub fn new(buf: &'buf [u8]) -> Self { + Self { + buf, + wire_pos: 0, + _phantom: PhantomData, + } + } + + /// Read one decoded byte. Advances the wire cursor by the encoding's + /// per-byte wire footprint. Returns `DecodeError::PrematureEnd` when + /// the wire buffer is exhausted (or ends mid-escape) and + /// `DecodeError::InvalidEscape` for malformed escape sequences. + pub fn read(&mut self) -> Result { + let (byte, n) = E::read_byte(&self.buf[self.wire_pos..])?; + self.wire_pos += n; + Ok(byte) + } +} + +/// Stateful cursor over a `&mut [u8]` wire buffer that writes decoded +/// bytes through `E: BufferEncoding`. Constructed by +/// [`MctpMedium::serialize`] and handed to the caller's `message_writer` +/// closure so the closure cannot bypass the encoding. +/// +/// [`MctpMedium::serialize`]: crate::medium::MctpMedium::serialize +pub struct EncodingEncoder<'buf, E: BufferEncoding> { + buf: &'buf mut [u8], + wire_pos: usize, + _phantom: PhantomData, +} + +impl<'buf, E: BufferEncoding> EncodingEncoder<'buf, E> { + /// Wrap a wire-byte buffer for stateful encoding-mediated writes. + pub fn new(buf: &'buf mut [u8]) -> Self { + Self { + buf, + wire_pos: 0, + _phantom: PhantomData, + } + } + + /// Write one decoded byte. Advances the wire cursor by the encoding's + /// per-byte wire footprint. Returns `EncodeError::BufferFull` when + /// the underlying wire buffer cannot fit the encoded representation. + pub fn write(&mut self, byte: u8) -> Result<(), EncodeError> { + let n = E::write_byte(&mut self.buf[self.wire_pos..], byte)?; + self.wire_pos += n; + Ok(()) + } + + /// Write a contiguous slice of decoded bytes; aborts on the first + /// encode error. Equivalent to a `for &b in bytes { self.write(b)? }` + /// loop, but more concise at call sites that just splat a byte slice. + pub fn write_all(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { + for &b in bytes { + self.write(b)?; + } + Ok(()) + } + + /// Wire bytes written so far (the size of the produced wire frame). + pub fn wire_position(&self) -> usize { + self.wire_pos + } + + /// Wire bytes remaining in the underlying buffer. + pub fn remaining_wire(&self) -> usize { + self.buf.len() - self.wire_pos + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn passthrough_write_byte_writes_one_byte() { + let mut buf = [0u8; 4]; + let n = PassthroughEncoding::write_byte(&mut buf, 0xAB).unwrap(); + assert_eq!(n, 1); + assert_eq!(buf, [0xAB, 0, 0, 0]); + } + + #[test] + fn passthrough_write_byte_full_buffer() { + let mut buf = []; + let err = PassthroughEncoding::write_byte(&mut buf, 0xAB).unwrap_err(); + assert_eq!(err, EncodeError::BufferFull); + } + + #[test] + fn passthrough_read_byte_reads_one_byte() { + let buf = [0xAB, 0xCD]; + let (b, n) = PassthroughEncoding::read_byte(&buf).unwrap(); + assert_eq!(b, 0xAB); + assert_eq!(n, 1); + } + + #[test] + fn passthrough_read_byte_premature_end() { + let buf = []; + let err = PassthroughEncoding::read_byte(&buf).unwrap_err(); + assert_eq!(err, DecodeError::PrematureEnd); + } + + #[test] + fn decoder_reads_all_bytes_via_passthrough() { + let buf = [0xAA, 0xBB, 0xCC, 0xDD]; + let mut decoder = EncodingDecoder::::new(&buf); + assert_eq!(decoder.read().unwrap(), 0xAA); + assert_eq!(decoder.read().unwrap(), 0xBB); + assert_eq!(decoder.read().unwrap(), 0xCC); + assert_eq!(decoder.read().unwrap(), 0xDD); + assert_eq!(decoder.read().unwrap_err(), DecodeError::PrematureEnd); + } + + #[test] + fn encoder_writes_all_bytes_via_passthrough() { + let mut buf = [0u8; 4]; + { + let mut encoder = EncodingEncoder::::new(&mut buf); + assert_eq!(encoder.wire_position(), 0); + assert_eq!(encoder.remaining_wire(), 4); + encoder.write(0x11).unwrap(); + encoder.write(0x22).unwrap(); + encoder.write(0x33).unwrap(); + encoder.write(0x44).unwrap(); + assert_eq!(encoder.wire_position(), 4); + assert_eq!(encoder.remaining_wire(), 0); + assert_eq!(encoder.write(0x55).unwrap_err(), EncodeError::BufferFull); + } + assert_eq!(buf, [0x11, 0x22, 0x33, 0x44]); + } + + #[test] + fn passthrough_wire_size_of_returns_input_len() { + assert_eq!(PassthroughEncoding::wire_size_of(&[]), 0); + assert_eq!(PassthroughEncoding::wire_size_of(&[0xAB]), 1); + let buf = [0u8; 64]; + assert_eq!(PassthroughEncoding::wire_size_of(&buf), 64); + } +} diff --git a/mctp-rs/src/deserialize.rs b/mctp-rs/src/deserialize.rs index 0a0ddd63..f1e0b3f5 100644 --- a/mctp-rs/src/deserialize.rs +++ b/mctp-rs/src/deserialize.rs @@ -1,25 +1,45 @@ use crate::{ - MctpMessageBuffer, MctpPacketError, error::MctpPacketResult, mctp_transport_header::MctpTransportHeader, + MctpMessageBuffer, MctpPacketError, + buffer_encoding::{DecodeError, EncodingDecoder}, + error::MctpPacketResult, + mctp_transport_header::MctpTransportHeader, medium::MctpMedium, }; +pub(crate) fn map_decode_err( + e: DecodeError, + on_premature: &'static str, + on_escape: &'static str, +) -> MctpPacketError { + match e { + DecodeError::PrematureEnd => MctpPacketError::HeaderParseError(on_premature), + DecodeError::InvalidEscape => MctpPacketError::HeaderParseError(on_escape), + } +} + pub(crate) fn parse_transport_header( - packet: &[u8], -) -> MctpPacketResult<(MctpTransportHeader, &[u8]), M> { - if packet.len() < 4 { - return Err(MctpPacketError::HeaderParseError( - "Packet is too small, cannot parse transport header", - )); + decoder: &mut EncodingDecoder<'_, M::Encoding>, +) -> MctpPacketResult { + // Read 4 decoded bytes through the encoding-aware decoder. We do NOT + // pre-check `decoder.remaining_wire() < 4` because for stuffing + // encodings wire length is not decoded length; PrematureEnd from + // `read()` is the canonical "ran out of bytes while decoding the + // header" signal — it correctly handles BOTH the Passthrough case + // (wire < 4) AND the stuffing case (wire >= 4 but yields < 4 decoded + // bytes). + let mut header_bytes = [0u8; 4]; + for slot in header_bytes.iter_mut() { + *slot = decoder.read().map_err(|e| { + map_decode_err::( + e, + "Packet is too small, cannot parse transport header", + "Invalid encoding escape sequence in transport header", + ) + })?; } - let transport_header_value = u32::from_be_bytes( - packet[0..4] - .try_into() - .map_err(|_| MctpPacketError::HeaderParseError("Packet is too small, cannot parse transport header"))?, - ); - let transport_header = MctpTransportHeader::try_from(transport_header_value) - .map_err(|_| MctpPacketError::HeaderParseError("Invalid transport header"))?; - let packet = &packet[4..]; - Ok((transport_header, packet)) + let transport_header_value = u32::from_be_bytes(header_bytes); + MctpTransportHeader::try_from(transport_header_value) + .map_err(|_| MctpPacketError::HeaderParseError("Invalid transport header")) } pub(crate) fn parse_message_body( diff --git a/mctp-rs/src/lib.rs b/mctp-rs/src/lib.rs index 718fc22e..7fa1809b 100644 --- a/mctp-rs/src/lib.rs +++ b/mctp-rs/src/lib.rs @@ -14,10 +14,10 @@ //! # use mctp_rs::*; //! # #[derive(Debug, Clone, Copy)] struct MyMedium { mtu: usize } //! # #[derive(Debug, Clone, Copy)] struct MyMediumFrame { packet_size: usize } -//! # impl MctpMedium for MyMedium { type Frame=MyMediumFrame; type Error=&'static str; type ReplyContext=(); +//! # impl MctpMedium for MyMedium { type Frame=MyMediumFrame; type Error=&'static str; type ReplyContext=(); type Encoding=PassthroughEncoding; //! # fn max_message_body_size(&self)->usize{self.mtu} -//! # fn deserialize<'b>(&self,p:&'b [u8])->MctpPacketResult<(Self::Frame,&'b [u8]),Self>{Ok((MyMediumFrame{packet_size:p.len()},p))} -//! # fn serialize<'b,F>(&self,_:Self::ReplyContext,b:&'b mut [u8],w:F)->MctpPacketResult<&'b [u8],Self> where F: for<'a> FnOnce(&'a mut [u8])->MctpPacketResult{let n=w(b)?;Ok(&b[..n])}} +//! # fn deserialize<'b>(&self,p:&'b [u8])->MctpPacketResult<(Self::Frame,EncodingDecoder<'b,Self::Encoding>),Self>{Ok((MyMediumFrame{packet_size:p.len()},EncodingDecoder::new(p)))} +//! # fn serialize<'b,F>(&self,_:Self::ReplyContext,b:&'b mut [u8],w:F)->MctpPacketResult<&'b [u8],Self> where F: for<'a> FnOnce(&mut EncodingEncoder<'a,Self::Encoding>)->MctpPacketResult<(),Self>{let n={let mut e=EncodingEncoder::::new(b);w(&mut e)?;e.wire_position()};Ok(&b[..n])}} //! # impl MctpMediumFrame for MyMediumFrame { fn packet_size(&self)->usize{self.packet_size} fn reply_context(&self)->(){()}} //! let mut assembly_buffer = [0u8; 1024]; //! let medium = MyMedium { mtu: 256 }; @@ -59,9 +59,9 @@ //! # use mctp_rs::*; //! # #[derive(Debug, Clone, Copy)] struct MyMedium { mtu: usize } //! # #[derive(Debug, Clone, Copy)] struct MyMediumFrame { packet_size: usize } -//! # impl MctpMedium for MyMedium { type Frame=MyMediumFrame; type Error=&'static str; type ReplyContext=(); fn max_message_body_size(&self)->usize{self.mtu} -//! # fn deserialize<'b>(&self,p:&'b [u8])->MctpPacketResult<(Self::Frame,&'b [u8]),Self>{Ok((MyMediumFrame{packet_size:p.len()},p))} -//! # fn serialize<'b,F>(&self,_:Self::ReplyContext,b:&'b mut [u8],w:F)->MctpPacketResult<&'b [u8],Self> where F: for<'a> FnOnce(&'a mut [u8])->MctpPacketResult{let n=w(b)?;Ok(&b[..n])}} +//! # impl MctpMedium for MyMedium { type Frame=MyMediumFrame; type Error=&'static str; type ReplyContext=(); type Encoding=PassthroughEncoding; fn max_message_body_size(&self)->usize{self.mtu} +//! # fn deserialize<'b>(&self,p:&'b [u8])->MctpPacketResult<(Self::Frame,EncodingDecoder<'b,Self::Encoding>),Self>{Ok((MyMediumFrame{packet_size:p.len()},EncodingDecoder::new(p)))} +//! # fn serialize<'b,F>(&self,_:Self::ReplyContext,b:&'b mut [u8],w:F)->MctpPacketResult<&'b [u8],Self> where F: for<'a> FnOnce(&mut EncodingEncoder<'a,Self::Encoding>)->MctpPacketResult<(),Self>{let n={let mut e=EncodingEncoder::::new(b);w(&mut e)?;e.wire_position()};Ok(&b[..n])}} //! # impl MctpMediumFrame for MyMediumFrame { fn packet_size(&self)->usize{self.packet_size} fn reply_context(&self)->(){()}} //! let mut buf = [0u8; 1024]; //! let mut ctx = MctpPacketContext::new(MyMedium { mtu: 64 }, &mut buf); @@ -109,6 +109,7 @@ //! type Frame = MyMediumFrame; //! type Error = &'static str; //! type ReplyContext = (); +//! type Encoding = PassthroughEncoding; //! //! fn max_message_body_size(&self) -> usize { //! self.mtu @@ -117,13 +118,13 @@ //! fn deserialize<'buf>( //! &self, //! packet: &'buf [u8], -//! ) -> MctpPacketResult<(Self::Frame, &'buf [u8]), Self> { +//! ) -> MctpPacketResult<(Self::Frame, EncodingDecoder<'buf, Self::Encoding>), Self> { //! // Strip/validate transport headers as needed for your bus and return MCTP payload slice //! Ok(( //! MyMediumFrame { //! packet_size: packet.len(), //! }, -//! packet, +//! EncodingDecoder::new(packet), //! )) //! } //! @@ -134,11 +135,17 @@ //! message_writer: F, //! ) -> MctpPacketResult<&'buf [u8], Self> //! where -//! F: for<'a> FnOnce(&'a mut [u8]) -> MctpPacketResult, +//! F: for<'a> FnOnce( +//! &mut EncodingEncoder<'a, Self::Encoding>, +//! ) -> MctpPacketResult<(), Self>, //! { //! // Prepend transport headers as needed, then ask the writer to write MCTP payload -//! let message_len = message_writer(buffer)?; -//! Ok(&buffer[..message_len]) +//! let written = { +//! let mut encoder = EncodingEncoder::::new(buffer); +//! message_writer(&mut encoder)?; +//! encoder.wire_position() +//! }; +//! Ok(&buffer[..written]) //! } //! } //! @@ -152,6 +159,7 @@ //! } //! ``` +mod buffer_encoding; mod deserialize; mod endpoint_id; pub mod error; @@ -167,11 +175,16 @@ mod serialize; #[cfg(test)] mod test_util; +pub use buffer_encoding::{ + BufferEncoding, DecodeError, EncodeError, EncodingDecoder, EncodingEncoder, PassthroughEncoding, +}; pub use endpoint_id::EndpointId; pub use error::{MctpPacketError, MctpPacketResult}; pub use mctp_message_tag::MctpMessageTag; pub use mctp_packet_context::{MctpPacketContext, MctpReplyContext}; pub use mctp_sequence_number::MctpSequenceNumber; +#[cfg(feature = "serial")] +pub use medium::serial::{CONST_MTU, EC_EID, MctpSerialMedium, MctpSerialMediumFrame, SP_EID, SerialEncoding}; pub use medium::*; pub use message_type::*; diff --git a/mctp-rs/src/mctp_packet_context.rs b/mctp-rs/src/mctp_packet_context.rs index f22af6fd..893b333c 100644 --- a/mctp-rs/src/mctp_packet_context.rs +++ b/mctp-rs/src/mctp_packet_context.rs @@ -1,6 +1,6 @@ use crate::{ MctpMessage, MctpMessageHeaderTrait, MctpMessageTrait, MctpPacketError, - deserialize::{parse_message_body, parse_transport_header}, + deserialize::{map_decode_err, parse_message_body, parse_transport_header}, endpoint_id::EndpointId, error::{MctpPacketResult, ProtocolError}, mctp_message_tag::MctpMessageTag, @@ -40,8 +40,8 @@ impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> { } pub fn deserialize_packet(&mut self, packet: &[u8]) -> MctpPacketResult>, M> { - let (medium_frame, packet) = self.medium.deserialize(packet)?; - let (transport_header, packet) = parse_transport_header::(packet)?; + let (medium_frame, mut decoder) = self.medium.deserialize(packet)?; + let transport_header = parse_transport_header::(&mut decoder)?; let mut state = match self.assembly_state { AssemblyState::Idle => { @@ -100,16 +100,28 @@ impl<'buf, M: MctpMedium> MctpPacketContext<'buf, M> { )); } let packet_size = packet_size - 4; // to account for the transport header - if packet.len() < packet_size { - return Err(MctpPacketError::HeaderParseError("packet.len() < packet_size")); - } - // Check bounds to prevent buffer overflow + // Check assembly buffer bounds (decoded bytes destination) if buffer_idx + packet_size > self.packet_assembly_buffer.len() { return Err(MctpPacketError::HeaderParseError( "packet assembly buffer overflow - insufficient space", )); } - self.packet_assembly_buffer[buffer_idx..buffer_idx + packet_size].copy_from_slice(&packet[..packet_size]); + // Decode `packet_size` payload bytes from the (possibly stuffed) wire + // buffer into the assembly buffer one byte at a time via the + // medium-supplied decoder. We do NOT pre-check + // `decoder.remaining_wire() < packet_size` because for stuffing + // encodings wire length is not decoded length; PrematureEnd from + // `read()` is the canonical "ran out of bytes while decoding the + // body" signal. + for i in 0..packet_size { + self.packet_assembly_buffer[buffer_idx + i] = decoder.read().map_err(|e| { + map_decode_err::( + e, + "packet body too short to extract expected decoded bytes", + "Invalid encoding escape sequence in packet body", + ) + })?; + } state.packet_assembly_buffer_index += packet_size; let message = if transport_header.end_of_message == 1 { diff --git a/mctp-rs/src/medium/mod.rs b/mctp-rs/src/medium/mod.rs index fb49a9ad..693deda4 100644 --- a/mctp-rs/src/medium/mod.rs +++ b/mctp-rs/src/medium/mod.rs @@ -1,8 +1,14 @@ -use crate::error::MctpPacketResult; +use crate::{ + buffer_encoding::{BufferEncoding, EncodingDecoder, EncodingEncoder}, + error::MctpPacketResult, +}; pub mod smbus_espi; mod util; +#[cfg(feature = "serial")] +pub mod serial; + pub trait MctpMedium: Sized { /// the medium specific header and trailer for the packet type Frame: MctpMediumFrame; @@ -13,14 +19,34 @@ pub trait MctpMedium: Sized { // the type used for the data needed to send a reply to a request type ReplyContext: core::fmt::Debug + Copy + Clone + PartialEq + Eq; + /// the byte-stuffing transform used by this medium when (de)serializing + /// wire bytes. Stateless — see [`crate::buffer_encoding`]. Most media + /// use [`PassthroughEncoding`](crate::buffer_encoding::PassthroughEncoding) + /// (no transform); media that need byte-stuffing (e.g., DSP0253 serial) + /// supply their own impl. + type Encoding: BufferEncoding; + /// the maximum transmission unit for the medium fn max_message_body_size(&self) -> usize; - /// deserialize the packet into the medium specific header and remainder of the packet - - /// this includes the mctp transport header, and mctp packet payload - fn deserialize<'buf>(&self, packet: &'buf [u8]) -> MctpPacketResult<(Self::Frame, &'buf [u8]), Self>; - - /// serialize the packet into the medium specific header and the payload + /// Deserialize a packet into the medium-specific header (frame) and an + /// [`EncodingDecoder`] that wraps the inner stuffed-region bytes. + /// Higher layers (e.g., `parse_transport_header`, the payload copy + /// loop in `MctpPacketContext`) read decoded bytes through the + /// returned decoder and physically cannot bypass the medium's + /// encoding by slicing the underlying buffer directly. + fn deserialize<'buf>( + &self, + packet: &'buf [u8], + ) -> MctpPacketResult<(Self::Frame, EncodingDecoder<'buf, Self::Encoding>), Self>; + + /// Serialize a packet by allowing the caller's `message_writer` + /// closure to write decoded bytes into the medium's stuffed region + /// through an [`EncodingEncoder`]. The medium owns its outer framing + /// (e.g., SMBus header + PEC, DSP0253 start/end flags + FCS) and + /// inspects the encoder's + /// [`wire_position`](EncodingEncoder::wire_position) after the + /// closure returns to size headers/trailers and compute checksums. fn serialize<'buf, F>( &self, reply_context: Self::ReplyContext, @@ -28,7 +54,7 @@ pub trait MctpMedium: Sized { message_writer: F, ) -> MctpPacketResult<&'buf [u8], Self> where - F: for<'a> FnOnce(&'a mut [u8]) -> MctpPacketResult; + F: for<'a> FnOnce(&mut EncodingEncoder<'a, Self::Encoding>) -> MctpPacketResult<(), Self>; } pub trait MctpMediumFrame: Clone + Copy { diff --git a/mctp-rs/src/medium/serial.rs b/mctp-rs/src/medium/serial.rs new file mode 100644 index 00000000..9075ba1f --- /dev/null +++ b/mctp-rs/src/medium/serial.rs @@ -0,0 +1,716 @@ +//! DSP0253 byte-stuffed serial medium for MCTP. +//! +//! Two-layer split: +//! - [`SerialEncoding`]: stateless byte-stuffing (0x7E, 0x7D escape pair). +//! - [`MctpSerialMedium`]: framing (revision byte, byte_count, body, FCS-16, end-flag). +//! +//! Both layers are gated behind the `serial` cargo feature. + +use crate::{ + MctpPacketError, + buffer_encoding::{BufferEncoding, DecodeError, EncodeError, EncodingDecoder, EncodingEncoder}, + error::MctpPacketResult, + medium::{MctpMedium, MctpMediumFrame}, +}; + +/// DSP0253 byte-stuffing transform. Stateless ZST. +/// +/// Encode: `0x7E -> [0x7D, 0x5E]`, `0x7D -> [0x7D, 0x5D]`, any other +/// byte -> `[b]`. +/// Decode: `0x7D 0x5E -> 0x7E`, `0x7D 0x5D -> 0x7D`, `0x7D ` -> +/// `InvalidEscape`. +/// +/// Raw `0x7E` in the wire stream is NOT rejected here — that's a +/// framing concern owned by `MctpSerialMedium::deserialize`, which +/// checks the body region for stray flags. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct SerialEncoding; + +impl BufferEncoding for SerialEncoding { + fn write_byte(wire_buf: &mut [u8], byte: u8) -> Result { + match byte { + 0x7E => { + if wire_buf.len() < 2 { + return Err(EncodeError::BufferFull); + } + wire_buf[0] = 0x7D; + wire_buf[1] = 0x5E; + Ok(2) + } + 0x7D => { + if wire_buf.len() < 2 { + return Err(EncodeError::BufferFull); + } + wire_buf[0] = 0x7D; + wire_buf[1] = 0x5D; + Ok(2) + } + b => match wire_buf.first_mut() { + Some(slot) => { + *slot = b; + Ok(1) + } + None => Err(EncodeError::BufferFull), + }, + } + } + + fn read_byte(wire_buf: &[u8]) -> Result<(u8, usize), DecodeError> { + match wire_buf.first().copied() { + None => Err(DecodeError::PrematureEnd), + Some(0x7D) => match wire_buf.get(1).copied() { + None => Err(DecodeError::PrematureEnd), + Some(0x5E) => Ok((0x7E, 2)), + Some(0x5D) => Ok((0x7D, 2)), + Some(_) => Err(DecodeError::InvalidEscape), + }, + // Raw 0x7E falls through here as a 1-byte read; the framing + // layer (`MctpSerialMedium::deserialize`) rejects bare + // 0x7E inside the body region. + Some(b) => Ok((b, 1)), + } + } + + fn wire_size_of(decoded: &[u8]) -> usize { + decoded + .iter() + .map(|&b| if b == 0x7E || b == 0x7D { 2 } else { 1 }) + .sum() + } +} + +/// SP MCTP endpoint id per CONTEXT D-D-06. +pub const SP_EID: crate::endpoint_id::EndpointId = crate::endpoint_id::EndpointId::Id(0x08); +/// EC MCTP endpoint id per CONTEXT D-D-06. +pub const EC_EID: crate::endpoint_id::EndpointId = crate::endpoint_id::EndpointId::Id(0x0A); +/// Maximum DSP0253 packet body size (DECODED bytes, before stuffing). +pub const CONST_MTU: usize = 251; + +const SERIAL_REVISION: u8 = 0x01; +const END_FLAG: u8 = 0x7E; +/// Header bytes: revision + byte_count (decoded body byte count). +const HEADER_LEN: usize = 2; +/// Worst-case trailer wire bytes: 2 stuffed FCS bytes (each may +/// expand 1 -> 2) + 1 end-flag. +const MAX_TRAILER_WIRE: usize = 5; + +// CRC-16/X-25 per DSP0253 §8 (poly 0x1021, init 0xFFFF, refin/refout, +// xorout 0xFFFF). Algorithm catalog entry locked in CONTEXT D-D-02. +// FCS bytes on the wire are MSB-first per DSP0253 §5.2 (overrides +// RFC1662's LSB-first PPP convention). +const FCS_ALGO: crc::Crc = crc::Crc::::new(&crc::CRC_16_IBM_SDLC); + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct MctpSerialMediumFrame { + pub revision: u8, + /// DECODED body byte count per DSP0253 §6.2 (NOT the wire byte + /// count). Cap = `CONST_MTU` = 251; max u8 = 255, fits comfortably. + pub byte_count: u8, + pub fcs: u16, +} + +impl MctpMediumFrame for MctpSerialMediumFrame { + fn packet_size(&self) -> usize { + // packet_size is the DECODED body byte count — the contract + // used by `MctpPacketContext::deserialize_packet`, which then + // subtracts 4 for the transport header. + self.byte_count as usize + } + + fn reply_context(&self) {} +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct MctpSerialMedium; + +impl MctpMedium for MctpSerialMedium { + type Frame = MctpSerialMediumFrame; + type Error = &'static str; + type ReplyContext = (); + type Encoding = SerialEncoding; + + fn max_message_body_size(&self) -> usize { + CONST_MTU + } + + fn deserialize<'buf>( + &self, + packet: &'buf [u8], + ) -> MctpPacketResult<(Self::Frame, EncodingDecoder<'buf, Self::Encoding>), Self> { + // Minimum frame: 2 header + 0 body + 2 FCS (unstuffed) + 1 end-flag = 5 bytes. + if packet.len() < HEADER_LEN + 3 { + return Err(MctpPacketError::MediumError("packet too short for serial frame")); + } + let revision = packet[0]; + if revision != SERIAL_REVISION { + return Err(MctpPacketError::MediumError("unsupported serial revision")); + } + let byte_count = packet[1]; + if (byte_count as usize) > CONST_MTU { + return Err(MctpPacketError::MediumError("byte_count exceeds MTU")); + } + + // Single forward walk: un-stuff body bytes (count must equal + // `byte_count`), un-stuff 2 FCS bytes, expect end-flag, compare + // CRC. + let body_wire_start = HEADER_LEN; + let mut decoded = [0u8; CONST_MTU]; + let mut decoded_len = 0usize; + let mut wire_pos = 0usize; // offset from body_wire_start + + while decoded_len < byte_count as usize { + let (b, n) = SerialEncoding::read_byte(&packet[body_wire_start + wire_pos..]).map_err(|e| match e { + DecodeError::PrematureEnd => MctpPacketError::MediumError("premature end in body"), + DecodeError::InvalidEscape => MctpPacketError::MediumError("invalid escape in body"), + })?; + if b == END_FLAG && n == 1 { + // Bare (unstuffed) 0x7E inside the body region is a + // protocol error (MEDIUM-05). A decoded 0x7E whose wire + // representation was the stuffed pair `0x7D 0x5E` + // (n==2) is a legitimate payload byte and is kept. + return Err(MctpPacketError::MediumError("unexpected 0x7E in body")); + } + decoded[decoded_len] = b; + decoded_len += 1; + wire_pos += n; + } + let body_wire_end = body_wire_start + wire_pos; + + // Un-stuff 2 FCS bytes (DSP0253 §7.1 stuffing applies to FCS). + let (fcs_msb, n_msb) = SerialEncoding::read_byte(&packet[body_wire_end..]) + .map_err(|_| MctpPacketError::MediumError("invalid escape in fcs"))?; + let (fcs_lsb, n_lsb) = SerialEncoding::read_byte(&packet[body_wire_end + n_msb..]) + .map_err(|_| MctpPacketError::MediumError("invalid escape in fcs"))?; + let trailer_pos = body_wire_end + n_msb + n_lsb; + + if trailer_pos >= packet.len() || packet[trailer_pos] != END_FLAG { + return Err(MctpPacketError::MediumError("missing end flag")); + } + if trailer_pos + 1 != packet.len() { + return Err(MctpPacketError::MediumError("trailing bytes after end flag")); + } + + // FCS-16/X-25 over un-stuffed (revision || byte_count || decoded body). + let mut digest = FCS_ALGO.digest(); + digest.update(&[revision, byte_count]); + digest.update(&decoded[..decoded_len]); + let computed_fcs = digest.finalize(); + // DSP0253 §5.2: MSB first on wire. + let wire_fcs = u16::from_be_bytes([fcs_msb, fcs_lsb]); + if wire_fcs != computed_fcs { + return Err(MctpPacketError::MediumError("fcs mismatch")); + } + + Ok(( + MctpSerialMediumFrame { + revision, + byte_count, + fcs: wire_fcs, + }, + EncodingDecoder::::new(&packet[body_wire_start..body_wire_end]), + )) + } + + fn serialize<'buf, F>( + &self, + _reply_context: Self::ReplyContext, + buffer: &'buf mut [u8], + message_writer: F, + ) -> MctpPacketResult<&'buf [u8], Self> + where + F: for<'a> FnOnce(&mut EncodingEncoder<'a, Self::Encoding>) -> MctpPacketResult<(), Self>, + { + if buffer.len() < HEADER_LEN + MAX_TRAILER_WIRE { + return Err(MctpPacketError::MediumError("buffer too small for serial frame")); + } + let buffer_len = buffer.len(); + + // Run closure over body region (reserve worst-case 5-byte + // trailer). The encoder stuffs body bytes via + // `SerialEncoding::write_byte` automatically. + let body_wire_len = { + let body_buf = &mut buffer[HEADER_LEN..buffer_len - MAX_TRAILER_WIRE]; + let mut encoder = EncodingEncoder::::new(body_buf); + message_writer(&mut encoder)?; + encoder.wire_position() + }; + + // Re-decode body to recover DECODED bytes + decoded count for + // `byte_count` and FCS. CONTEXT D-B-02 acknowledges the + // double-walk; ~250 bytes max, no_std, cheap. + let mut decoded = [0u8; CONST_MTU]; + let mut decoded_len = 0usize; + let mut wire_pos = 0usize; + while wire_pos < body_wire_len { + let (b, n) = SerialEncoding::read_byte(&buffer[HEADER_LEN + wire_pos..HEADER_LEN + body_wire_len]) + .map_err(|_| MctpPacketError::MediumError("internal: failed to re-decode body"))?; + if decoded_len >= CONST_MTU { + return Err(MctpPacketError::MediumError("body exceeds MTU")); + } + decoded[decoded_len] = b; + decoded_len += 1; + wire_pos += n; + } + // Should not fire — `EncodingEncoder::write` returns + // `BufferFull` long before decoded_len could exceed 251. + if decoded_len > u8::MAX as usize { + return Err(MctpPacketError::MediumError("body exceeds byte_count u8 cap")); + } + let byte_count = decoded_len as u8; + + // FCS-16/X-25 over un-stuffed (revision || byte_count || decoded body). + let mut digest = FCS_ALGO.digest(); + digest.update(&[SERIAL_REVISION, byte_count]); + digest.update(&decoded[..decoded_len]); + let fcs = digest.finalize(); + // DSP0253 §5.2: MSB first on wire. + let [fcs_msb, fcs_lsb] = fcs.to_be_bytes(); + + // Header: revision + byte_count emitted directly (NOT stuffed), + // matching `SmbusEspiMedium`'s header pattern. See PLAN + // note for the conformance caveat when byte_count + // happens to equal 0x7E or 0x7D — round-trips cleanly through + // this implementation's deserialize. + buffer[0] = SERIAL_REVISION; + buffer[1] = byte_count; + + // Stuff and write FCS bytes via SerialEncoding (DSP0253 §7.1 + + // CONTEXT D-B-02 — deserialize un-stuffs FCS, so serialize + // must stuff). + let fcs_start = HEADER_LEN + body_wire_len; + let n_msb = SerialEncoding::write_byte(&mut buffer[fcs_start..], fcs_msb) + .map_err(|_| MctpPacketError::MediumError("internal: failed to encode fcs"))?; + let n_lsb = SerialEncoding::write_byte(&mut buffer[fcs_start + n_msb..], fcs_lsb) + .map_err(|_| MctpPacketError::MediumError("internal: failed to encode fcs"))?; + let end_pos = fcs_start + n_msb + n_lsb; + + // End-flag is written directly (flags are NOT stuffed by + // definition). + buffer[end_pos] = END_FLAG; + + Ok(&buffer[..end_pos + 1]) + } +} + +#[cfg(test)] +mod encoding_tests { + use super::*; + use crate::buffer_encoding::EncodingDecoder; + + #[test] + fn write_byte_stuffs_7e() { + let mut buf = [0u8; 4]; + let n = SerialEncoding::write_byte(&mut buf, 0x7E).unwrap(); + assert_eq!(n, 2); + assert_eq!(&buf[..2], &[0x7D, 0x5E]); + } + + #[test] + fn write_byte_stuffs_7d() { + let mut buf = [0u8; 4]; + let n = SerialEncoding::write_byte(&mut buf, 0x7D).unwrap(); + assert_eq!(n, 2); + assert_eq!(&buf[..2], &[0x7D, 0x5D]); + } + + #[test] + fn write_byte_passthrough_plain() { + let mut buf = [0u8; 1]; + let n = SerialEncoding::write_byte(&mut buf, 0x41).unwrap(); + assert_eq!(n, 1); + assert_eq!(buf, [0x41]); + } + + #[test] + fn write_byte_full_buffer_plain() { + let mut buf = []; + assert_eq!( + SerialEncoding::write_byte(&mut buf, 0x41).unwrap_err(), + EncodeError::BufferFull + ); + } + + #[test] + fn write_byte_full_buffer_escape() { + let mut buf = [0u8; 1]; + assert_eq!( + SerialEncoding::write_byte(&mut buf, 0x7E).unwrap_err(), + EncodeError::BufferFull + ); + } + + #[test] + fn read_byte_unstuffs_7e() { + assert_eq!(SerialEncoding::read_byte(&[0x7D, 0x5E]).unwrap(), (0x7E, 2)); + } + + #[test] + fn read_byte_unstuffs_7d() { + assert_eq!(SerialEncoding::read_byte(&[0x7D, 0x5D]).unwrap(), (0x7D, 2)); + } + + #[test] + fn read_byte_passthrough_plain() { + assert_eq!(SerialEncoding::read_byte(&[0x41]).unwrap(), (0x41, 1)); + } + + #[test] + fn read_byte_raw_7e_passes_through() { + // Raw 0x7E is NOT rejected at the encoding layer — framing is + // the framing layer's concern. + assert_eq!(SerialEncoding::read_byte(&[0x7E]).unwrap(), (0x7E, 1)); + } + + #[test] + fn read_byte_premature_end_empty() { + assert_eq!(SerialEncoding::read_byte(&[]).unwrap_err(), DecodeError::PrematureEnd); + } + + #[test] + fn read_byte_premature_end_after_escape() { + assert_eq!( + SerialEncoding::read_byte(&[0x7D]).unwrap_err(), + DecodeError::PrematureEnd + ); + } + + #[test] + fn read_byte_invalid_escape() { + assert_eq!( + SerialEncoding::read_byte(&[0x7D, 0xAA]).unwrap_err(), + DecodeError::InvalidEscape + ); + } + + #[test] + fn wire_size_of_mixed() { + assert_eq!(SerialEncoding::wire_size_of(&[0x41, 0x7E, 0x42, 0x7D, 0x43]), 7); + } + + #[test] + fn wire_size_of_empty() { + assert_eq!(SerialEncoding::wire_size_of(&[]), 0); + } + + #[test] + fn roundtrip_all_byte_values() { + // 256-byte payload of every byte value, encoded into a 512-byte + // wire buffer (worst case is 2x expansion if every byte stuffs; + // actual expansion here is 256 + 2 = 258 wire bytes). + let mut decoded = [0u8; 256]; + for (i, slot) in decoded.iter_mut().enumerate() { + *slot = i as u8; + } + let mut wire = [0u8; 512]; + let mut wpos = 0usize; + for &b in &decoded { + wpos += SerialEncoding::write_byte(&mut wire[wpos..], b).unwrap(); + } + assert_eq!(wpos, SerialEncoding::wire_size_of(&decoded)); + let mut dec = EncodingDecoder::::new(&wire[..wpos]); + for &expected in &decoded { + assert_eq!(dec.read().unwrap(), expected); + } + assert_eq!(dec.read().unwrap_err(), DecodeError::PrematureEnd); + } +} + +#[cfg(test)] +mod fixtures { + //! Hand-authored DSP0253 serial frame fixtures (golden vectors). + //! + //! Layout per fixture (no leading flag — this implementation omits + //! the open `0x7E` per CONTEXT D-D-01; upstream UART layer supplies + //! it in Phase 27): + //! + //! `[REVISION=0x01, byte_count, ...stuffed body..., ...stuffed FCS-MSB..., ...stuffed + //! FCS-LSB..., 0x7E]` + //! + //! - Header bytes (REVISION, byte_count) are NOT stuffed (matches production serialize). + //! - Body bytes are stuffed per `SerialEncoding`. + //! - FCS-16/X-25 computed over un-stuffed `[REVISION, byte_count, ...decoded body...]`, emitted + //! MSB-first on wire (DSP0253 §5.2), each FCS byte then stuffed if equal to 0x7E or 0x7D. + //! - Trailing `0x7E` is the end-flag (not stuffed by definition). + + pub(crate) const FIXTURE_BASIC_RX: &[u8] = &[0x01, 0x04, 0xAA, 0xBB, 0xCC, 0xDD, 0x6D, 0xA1, 0x7E]; + + pub(crate) const FIXTURE_PAYLOAD_CONTAINS_7E: &[u8] = &[0x01, 0x03, 0xAA, 0x7D, 0x5E, 0xCC, 0xFB, 0xE7, 0x7E]; + + pub(crate) const FIXTURE_PAYLOAD_CONTAINS_7D: &[u8] = &[0x01, 0x03, 0xAA, 0x7D, 0x5D, 0xCC, 0xD1, 0x8F, 0x7E]; + + pub(crate) const FIXTURE_PAYLOAD_CONTAINS_BOTH: &[u8] = + &[0x01, 0x03, 0x7D, 0x5E, 0x7D, 0x5D, 0x42, 0x50, 0x97, 0x7E]; + + /// 251-byte body `(0..251)` decoded; wire = 258 bytes after stuffing + /// the lone 0x7D (idx 125) and 0x7E (idx 126) inside the body. + pub(crate) const FIXTURE_MAX_MTU_FRAME: &[u8] = &[ + 0x01, 0xFB, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, + 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, + 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6A, 0x6B, 0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7A, 0x7B, + 0x7C, 0x7D, 0x5D, 0x7D, 0x5E, 0x7F, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A, 0x8B, + 0x8C, 0x8D, 0x8E, 0x8F, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, + 0x9E, 0x9F, 0xA0, 0xA1, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xAC, 0xAD, 0xAE, 0xAF, + 0xB0, 0xB1, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBC, 0xBD, 0xBE, 0xBF, 0xC0, 0xC1, + 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA, 0xCB, 0xCC, 0xCD, 0xCE, 0xCF, 0xD0, 0xD1, 0xD2, 0xD3, + 0xD4, 0xD5, 0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xDB, 0xDC, 0xDD, 0xDE, 0xDF, 0xE0, 0xE1, 0xE2, 0xE3, 0xE4, 0xE5, + 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xEB, 0xEC, 0xED, 0xEE, 0xEF, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, + 0xF8, 0xF9, 0xFA, 0xF6, 0x07, 0x7E, + ]; + + pub(crate) const FIXTURE_EMPTY_PAYLOAD: &[u8] = &[0x01, 0x00, 0x16, 0x9F, 0x7E]; + + pub(crate) const FIXTURE_FCS_VALID: &[u8] = &[0x01, 0x03, 0x10, 0x20, 0x30, 0x76, 0xDB, 0x7E]; + + /// Same body as FCS_VALID but FCS-MSB byte XOR 0xFF (0x76 -> 0x89). + pub(crate) const FIXTURE_FCS_INVALID: &[u8] = &[0x01, 0x03, 0x10, 0x20, 0x30, 0x89, 0xDB, 0x7E]; + + /// byte_count=2 claims 2 decoded bytes; first body wire byte is + /// `0x7D 0xAA` (escape followed by non-`{0x5E,0x5D}`) -> rejected + /// as "invalid escape in body" before reaching FCS. + pub(crate) const FIXTURE_INVALID_ESCAPE: &[u8] = &[0x01, 0x02, 0x7D, 0xAA, 0x00, 0x00, 0x7E]; + + /// byte_count=3, body wire region is `[0xAA, 0x7E, 0xCC]` — the + /// raw 0x7E inside the body region is rejected before FCS. + pub(crate) const FIXTURE_PREMATURE_END_FLAG: &[u8] = &[0x01, 0x03, 0xAA, 0x7E, 0xCC, 0x00, 0x00, 0x7E]; +} + +#[cfg(test)] +mod medium_tests { + use super::{fixtures::*, *}; + + fn drain_decoder(mut dec: EncodingDecoder<'_, SerialEncoding>) -> ([u8; CONST_MTU], usize) { + let mut out = [0u8; CONST_MTU]; + let mut n = 0; + while let Ok(b) = dec.read() { + out[n] = b; + n += 1; + } + (out, n) + } + + #[test] + fn decode_basic_rx_succeeds() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_BASIC_RX).unwrap(); + assert_eq!(frame.revision, 0x01); + assert_eq!(frame.byte_count, 4); + assert_eq!(frame.fcs, 0x6DA1); + let (decoded, n) = drain_decoder(dec); + assert_eq!(&decoded[..n], &[0xAA, 0xBB, 0xCC, 0xDD]); + } + + #[test] + fn decode_payload_contains_7e() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_PAYLOAD_CONTAINS_7E).unwrap(); + assert_eq!(frame.byte_count, 3); + let (decoded, n) = drain_decoder(dec); + assert_eq!(&decoded[..n], &[0xAA, 0x7E, 0xCC]); + } + + #[test] + fn decode_payload_contains_7d() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_PAYLOAD_CONTAINS_7D).unwrap(); + assert_eq!(frame.byte_count, 3); + let (decoded, n) = drain_decoder(dec); + assert_eq!(&decoded[..n], &[0xAA, 0x7D, 0xCC]); + } + + #[test] + fn decode_payload_contains_both() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_PAYLOAD_CONTAINS_BOTH).unwrap(); + assert_eq!(frame.byte_count, 3); + let (decoded, n) = drain_decoder(dec); + assert_eq!(&decoded[..n], &[0x7E, 0x7D, 0x42]); + } + + #[test] + fn decode_max_mtu_frame() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_MAX_MTU_FRAME).unwrap(); + assert_eq!(frame.byte_count as usize, CONST_MTU); + let (decoded, n) = drain_decoder(dec); + assert_eq!(n, CONST_MTU); + for (i, &b) in decoded[..n].iter().enumerate() { + assert_eq!(b, i as u8, "mismatch at idx {i}"); + } + } + + #[test] + fn decode_empty_payload() { + let (frame, dec) = MctpSerialMedium.deserialize(FIXTURE_EMPTY_PAYLOAD).unwrap(); + assert_eq!(frame.byte_count, 0); + let (_, n) = drain_decoder(dec); + assert_eq!(n, 0); + } + + #[test] + fn decode_fcs_valid() { + assert!(MctpSerialMedium.deserialize(FIXTURE_FCS_VALID).is_ok()); + } + + #[test] + fn decode_fcs_invalid_rejects() { + match MctpSerialMedium.deserialize(FIXTURE_FCS_INVALID) { + Err(crate::MctpPacketError::MediumError("fcs mismatch")) => {} + other => panic!("expected MediumError(\"fcs mismatch\"), got {:?}", other.err()), + } + } + + #[test] + fn decode_invalid_escape_rejects() { + match MctpSerialMedium.deserialize(FIXTURE_INVALID_ESCAPE) { + Err(crate::MctpPacketError::MediumError("invalid escape in body")) => {} + other => panic!( + "expected MediumError(\"invalid escape in body\"), got {:?}", + other.err() + ), + } + } + + #[test] + fn decode_premature_end_flag_rejects() { + match MctpSerialMedium.deserialize(FIXTURE_PREMATURE_END_FLAG) { + Err(crate::MctpPacketError::MediumError("unexpected 0x7E in body")) => {} + other => panic!( + "expected MediumError(\"unexpected 0x7E in body\"), got {:?}", + other.err() + ), + } + } + + fn fixture_roundtrip(wire: &[u8]) { + let m = MctpSerialMedium; + let (_frame, dec) = m.deserialize(wire).unwrap(); + let (decoded, n) = drain_decoder(dec); + let mut out = [0u8; 1024]; + let serialized = m + .serialize((), &mut out, |e| { + e.write_all(&decoded[..n]) + .map_err(|_| MctpPacketError::MediumError("write failed")) + }) + .unwrap(); + assert_eq!(serialized, wire); + } + + #[test] + fn fixture_roundtrip_basic_rx() { + fixture_roundtrip(FIXTURE_BASIC_RX); + } + + #[test] + fn fixture_roundtrip_payload_contains_7e() { + fixture_roundtrip(FIXTURE_PAYLOAD_CONTAINS_7E); + } + + #[test] + fn fixture_roundtrip_payload_contains_7d() { + fixture_roundtrip(FIXTURE_PAYLOAD_CONTAINS_7D); + } + + #[test] + fn fixture_roundtrip_payload_contains_both() { + fixture_roundtrip(FIXTURE_PAYLOAD_CONTAINS_BOTH); + } + + #[test] + fn fixture_roundtrip_max_mtu_frame() { + fixture_roundtrip(FIXTURE_MAX_MTU_FRAME); + } + + #[test] + fn fixture_roundtrip_empty_payload() { + fixture_roundtrip(FIXTURE_EMPTY_PAYLOAD); + } + + #[test] + fn fixture_roundtrip_fcs_valid() { + fixture_roundtrip(FIXTURE_FCS_VALID); + } + + #[test] + fn public_api_smoke() { + let _: crate::MctpSerialMedium = crate::MctpSerialMedium; + let _: crate::SerialEncoding = crate::SerialEncoding; + assert_eq!(crate::CONST_MTU, 251); + assert_eq!(crate::SP_EID, crate::EndpointId::Id(0x08)); + assert_eq!(crate::EC_EID, crate::EndpointId::Id(0x0A)); + } + + #[test] + fn packetize_with_stuffing_respects_mtu() { + // 251-byte payload of all 0x7E. Each byte stuffs to 2 wire + // bytes (0x7D 0x5E), so encoded body footprint per packet is + // 2x decoded length. The packet body MTU is 251 wire bytes; + // each MCTP packet also carries a 4-byte transport header + // which itself is `wire_size_of`-measured. Expect the message + // to split across multiple packets and no body region to + // exceed CONST_MTU wire bytes. + use crate::{ + endpoint_id::EndpointId, mctp_message_tag::MctpMessageTag, mctp_packet_context::MctpReplyContext, + mctp_sequence_number::MctpSequenceNumber, serialize::SerializePacketState, + }; + + let payload = [0x7E_u8; 251]; + let mut assembly = [0u8; 1024]; + let medium = MctpSerialMedium; + let reply_context = MctpReplyContext:: { + destination_endpoint_id: EndpointId::Id(0x0A), + source_endpoint_id: EndpointId::Id(0x08), + packet_sequence_number: MctpSequenceNumber::new(0), + message_tag: MctpMessageTag::default(), + medium_context: (), + }; + let mut state = SerializePacketState { + medium: &medium, + reply_context, + current_packet_num: 0, + serialized_message_header: false, + message_buffer: &payload[..], + assembly_buffer: &mut assembly[..], + }; + + let mut total_decoded_body = 0usize; + let mut packet_count = 0usize; + loop { + // We cannot iterate `state.next()` more than once because + // `next` mutably borrows the assembly buffer for each + // returned slice. Take one packet, process it, then break. + let pkt = match state.next() { + Some(Ok(pkt)) => { + let mut tmp = [0u8; 1024]; + tmp[..pkt.len()].copy_from_slice(pkt); + (tmp, pkt.len()) + } + Some(Err(e)) => panic!("serialize error: {e:?}"), + None => break, + }; + packet_count += 1; + // Deserialize the packet to recover the wire body length + // and the decoded body byte count. + let (frame, dec) = medium.deserialize(&pkt.0[..pkt.1]).unwrap(); + // Decoded body byte count INCLUDES the 4 transport-header + // bytes — subtract to get the actual payload bytes. + assert!(frame.byte_count as usize >= 4); + let payload_decoded = frame.byte_count as usize - 4; + total_decoded_body += payload_decoded; + // Wire body region (between header and FCS) MUST be <= + // CONST_MTU under MEDIUM-08 chunk-sizing. + let _ = dec; // decoder discard + let wire_body_len = pkt.1 - 2 /* hdr */ - 1 /* end-flag */; + // Subtract the (possibly stuffed) FCS bytes — they are 2 + // FCS bytes but each may stuff to 2 wire bytes. Worst case + // 4 bytes; lower bound on body wire = wire_body_len - 4. + assert!( + wire_body_len <= CONST_MTU + 4, + "packet {packet_count} body exceeds MTU + worst-case FCS: {wire_body_len}" + ); + } + assert!(packet_count >= 2, "expected multi-packet split, got {packet_count}"); + assert_eq!(total_decoded_body, payload.len()); + } +} diff --git a/mctp-rs/src/medium/smbus_espi.rs b/mctp-rs/src/medium/smbus_espi.rs index a54b9378..eeeb5979 100644 --- a/mctp-rs/src/medium/smbus_espi.rs +++ b/mctp-rs/src/medium/smbus_espi.rs @@ -2,6 +2,7 @@ use bit_register::{NumBytes, TryFromBits, TryIntoBits, bit_register}; use crate::{ MctpPacketError, + buffer_encoding::{EncodingDecoder, EncodingEncoder, PassthroughEncoding}, error::MctpPacketResult, medium::{ MctpMedium, MctpMediumFrame, @@ -24,8 +25,12 @@ impl MctpMedium for SmbusEspiMedium { type Frame = SmbusEspiMediumFrame; type Error = &'static str; type ReplyContext = SmbusEspiReplyContext; + type Encoding = PassthroughEncoding; - fn deserialize<'buf>(&self, packet: &'buf [u8]) -> MctpPacketResult<(Self::Frame, &'buf [u8]), Self> { + fn deserialize<'buf>( + &self, + packet: &'buf [u8], + ) -> MctpPacketResult<(Self::Frame, EncodingDecoder<'buf, Self::Encoding>), Self> { // Check if packet has enough bytes for header if packet.len() < 4 { return Err(MctpPacketError::MediumError("Packet too short to parse smbus header")); @@ -46,9 +51,9 @@ impl MctpMedium for SmbusEspiMedium { )); } let pec = packet[header.byte_count as usize]; - // strip off the PEC byte - let packet = &packet[..header.byte_count as usize]; - Ok((SmbusEspiMediumFrame { header, pec }, packet)) + // strip off the PEC byte; the inner stuffed region is the body bytes + let inner = &packet[..header.byte_count as usize]; + Ok((SmbusEspiMediumFrame { header, pec }, EncodingDecoder::new(inner))) } fn serialize<'buf, F>( @@ -58,40 +63,43 @@ impl MctpMedium for SmbusEspiMedium { message_writer: F, ) -> MctpPacketResult<&'buf [u8], Self> where - F: for<'a> FnOnce(&'a mut [u8]) -> MctpPacketResult, + F: for<'a> FnOnce(&mut EncodingEncoder<'a, Self::Encoding>) -> MctpPacketResult<(), Self>, { // Reserve space for header (4 bytes) and PEC (1 byte) if buffer.len() < 5 { return Err(MctpPacketError::MediumError("Buffer too small for smbus frame")); } + let buffer_len = buffer.len(); + + // Write the body first via an encoder over the body region (reserve + // 4 leading header bytes and 1 trailing PEC byte). + let body_wire_len = { + let body_buf = &mut buffer[4..buffer_len - 1]; + let mut encoder = EncodingEncoder::::new(body_buf); + message_writer(&mut encoder)?; + encoder.wire_position() + }; - // split off a buffer where we will write the header, the rest is for body + PEC - let (header_slice, body) = buffer.split_at_mut(4); - - // Write the body first, but ensure we leave space for PEC - if body.is_empty() { - return Err(MctpPacketError::MediumError("No space for PEC byte")); - } - let available_body_len = body.len() - 1; // Reserve 1 byte for PEC - let body_len = message_writer(&mut body[..available_body_len])?; - - // with the body has been written, construct the header + // with the body has been written, construct the header. byte_count + // is the number of wire bytes that follow on the line per SMBus + // (PassthroughEncoding pairing means wire byte count == decoded + // byte count for SMBus today). let header = SmbusEspiMediumHeader { destination_slave_address: reply_context.source_slave_address, source_slave_address: reply_context.destination_slave_address, - byte_count: body_len as u8, + byte_count: body_wire_len as u8, command_code: SmbusCommandCode::Mctp, ..Default::default() }; let header_value = TryInto::::try_into(header).map_err(MctpPacketError::MediumError)?; - header_slice.copy_from_slice(&header_value.to_be_bytes()); + buffer[0..4].copy_from_slice(&header_value.to_be_bytes()); // with the header written, compute the PEC byte - let pec_value = smbus_pec::pec(&buffer[0..4 + body_len]); - buffer[4 + body_len] = pec_value; + let pec_value = smbus_pec::pec(&buffer[0..4 + body_wire_len]); + buffer[4 + body_wire_len] = pec_value; // add 4 for frame header, add 1 for PEC byte - Ok(&buffer[0..4 + body_len + 1]) + Ok(&buffer[0..4 + body_wire_len + 1]) } // TODO - this is a guess, need to find the actual value from spec @@ -170,7 +178,20 @@ impl MctpMediumFrame for SmbusEspiMediumFrame { #[cfg(test)] mod tests { extern crate std; + use std::vec::Vec; + use super::*; + use crate::buffer_encoding::DecodeError; + + /// Test-only helper: drain an `EncodingDecoder` to a `Vec` for + /// content assertions. Stops at the first error (e.g., `PrematureEnd`). + fn drain_to_vec(decoder: &mut EncodingDecoder<'_, PassthroughEncoding>) -> Vec { + let mut out = Vec::new(); + while let Ok(b) = decoder.read() { + out.push(b); + } + out + } #[test] fn test_deserialize_valid_packet() { @@ -200,14 +221,15 @@ mod tests { packet[8] = pec; let result = medium.deserialize(&packet).unwrap(); - let (frame, body) = result; + let (frame, mut decoder) = result; + let body = drain_to_vec(&mut decoder); assert_eq!(frame.header.destination_slave_address, 0x20); assert_eq!(frame.header.source_slave_address, 0x10); assert_eq!(frame.header.command_code, SmbusCommandCode::Mctp); assert_eq!(frame.header.byte_count, 4); assert_eq!(frame.pec, pec); - assert_eq!(body, &payload); + assert_eq!(body, payload); } #[test] @@ -215,10 +237,10 @@ mod tests { let medium = SmbusEspiMedium; let short_packet = [0x01, 0x02]; // Only 2 bytes, need at least 4 for header - let result = medium.deserialize(&short_packet); + let err = medium.deserialize(&short_packet).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError("Packet too short to parse smbus header")) + err, + MctpPacketError::MediumError("Packet too short to parse smbus header") ); } @@ -240,12 +262,10 @@ mod tests { packet[0..4].copy_from_slice(&header_bytes); packet[4..6].copy_from_slice(&short_payload); - let result = medium.deserialize(&packet); + let err = medium.deserialize(&packet).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError( - "Packet too short to parse smbus body and PEC" - )) + err, + MctpPacketError::MediumError("Packet too short to parse smbus body and PEC") ); } @@ -269,8 +289,8 @@ mod tests { packet[4..8].copy_from_slice(&payload); packet[8] = pec; - let result = medium.deserialize(&packet); - assert_eq!(result, Err(MctpPacketError::MediumError("Invalid smbus header"))); + let err = medium.deserialize(&packet).err().unwrap(); + assert_eq!(err, MctpPacketError::MediumError("Invalid smbus header")); } #[test] @@ -291,11 +311,11 @@ mod tests { packet[4] = pec; let result = medium.deserialize(&packet).unwrap(); - let (frame, body) = result; + let (frame, mut decoder) = result; assert_eq!(frame.header.byte_count, 0); assert_eq!(frame.pec, pec); - assert_eq!(body.len(), 0); + assert_eq!(decoder.read().unwrap_err(), DecodeError::PrematureEnd); } #[test] @@ -310,9 +330,10 @@ mod tests { let test_payload = [0xAA, 0xBB, 0xCC, 0xDD]; let result = medium - .serialize(reply_context, &mut buffer, |buf| { - buf[..test_payload.len()].copy_from_slice(&test_payload); - Ok(test_payload.len()) + .serialize(reply_context, &mut buffer, |encoder| { + encoder + .write_all(&test_payload) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }) .unwrap(); @@ -348,12 +369,12 @@ mod tests { let mut small_buffer = [0u8; 4]; // Only 4 bytes, need at least 5 (header + PEC) - let result = medium.serialize(reply_context, &mut small_buffer, |_| Ok(0)); + let err = medium + .serialize(reply_context, &mut small_buffer, |_| Ok(())) + .err() + .unwrap(); - assert_eq!( - result, - Err(MctpPacketError::MediumError("Buffer too small for smbus frame")) - ); + assert_eq!(err, MctpPacketError::MediumError("Buffer too small for smbus frame")); } #[test] @@ -370,7 +391,7 @@ mod tests { .serialize( reply_context, &mut minimal_buffer, - |_| Ok(0), // No payload data + |_| Ok(()), // No payload data ) .unwrap(); @@ -399,10 +420,10 @@ mod tests { let mut buffer = [0u8; 260]; // 4 + 255 + 1 = header + max payload + PEC let result = medium - .serialize(reply_context, &mut buffer, |buf| { - let copy_len = max_payload.len().min(buf.len()); - buf[..copy_len].copy_from_slice(&max_payload[..copy_len]); - Ok(copy_len) + .serialize(reply_context, &mut buffer, |encoder| { + encoder + .write_all(&max_payload) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }) .unwrap(); @@ -451,17 +472,19 @@ mod tests { // Serialize let serialized = medium - .serialize(original_context, &mut buffer, |buf| { - buf[..original_payload.len()].copy_from_slice(&original_payload); - Ok(original_payload.len()) + .serialize(original_context, &mut buffer, |encoder| { + encoder + .write_all(&original_payload) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }) .unwrap(); // Deserialize - let (frame, deserialized_payload) = medium.deserialize(serialized).unwrap(); + let (frame, mut decoder) = medium.deserialize(serialized).unwrap(); + let deserialized_payload = drain_to_vec(&mut decoder); // Verify roundtrip correctness - assert_eq!(deserialized_payload, &original_payload); + assert_eq!(deserialized_payload, original_payload); assert_eq!(frame.header.destination_slave_address, 0x24); // swapped assert_eq!(frame.header.source_slave_address, 0x42); // swapped assert_eq!(frame.header.command_code, SmbusCommandCode::Mctp); @@ -553,9 +576,10 @@ mod tests { let mut buffer = [0u8; 32]; let result = medium - .serialize(reply_context, &mut buffer, |buf| { - buf[..test_data.len()].copy_from_slice(&test_data); - Ok(test_data.len()) + .serialize(reply_context, &mut buffer, |encoder| { + encoder + .write_all(&test_data) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }) .unwrap(); @@ -581,7 +605,7 @@ mod tests { .serialize( reply_context, &mut buffer, - |_| Ok(0), // Empty payload + |_| Ok(()), // Empty payload ) .unwrap(); @@ -625,7 +649,7 @@ mod tests { let medium = SmbusEspiMedium; let mut buffer = [0u8; 16]; - let result = medium.serialize(reply_context, &mut buffer, |_| Ok(0)).unwrap(); + let result = medium.serialize(reply_context, &mut buffer, |_| Ok(())).unwrap(); let header_value = u32::from_be_bytes([result[0], result[1], result[2], result[3]]); let response_header = SmbusEspiMediumHeader::try_from(header_value).unwrap(); @@ -662,7 +686,8 @@ mod tests { let packet_slice = &packet[0..4 + byte_count as usize + 1]; let result = medium.deserialize(packet_slice).unwrap(); - let (frame, body) = result; + let (frame, mut decoder) = result; + let body = drain_to_vec(&mut decoder); assert_eq!(frame.header.byte_count, byte_count); assert_eq!(body.len(), byte_count as usize); @@ -689,12 +714,10 @@ mod tests { packet[4..6].copy_from_slice(&short_payload); packet[6] = 0x00; // PEC (doesn't matter for this test) - let result = medium.deserialize(&packet); + let err = medium.deserialize(&packet).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError( - "Packet too short to parse smbus body and PEC" - )) + err, + MctpPacketError::MediumError("Packet too short to parse smbus body and PEC") ); } @@ -709,14 +732,14 @@ mod tests { // Test with buffer smaller than minimum required (4 header + 1 PEC = 5 bytes) let mut tiny_buffer = [0u8; 4]; // Only 4 bytes, need at least 5 - let result = medium.serialize(reply_context, &mut tiny_buffer, |_| { - Ok(0) // No payload - }); + let err = medium + .serialize(reply_context, &mut tiny_buffer, |_| { + Ok(()) // No payload + }) + .err() + .unwrap(); - assert_eq!( - result, - Err(MctpPacketError::MediumError("Buffer too small for smbus frame")) - ); + assert_eq!(err, MctpPacketError::MediumError("Buffer too small for smbus frame")); } #[test] @@ -726,10 +749,10 @@ mod tests { // Test with packet shorter than header size (4 bytes) for packet_size in 0..4 { let short_packet = [0u8; 4]; - let result = medium.deserialize(&short_packet[..packet_size]); + let err = medium.deserialize(&short_packet[..packet_size]).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError("Packet too short to parse smbus header")) + err, + MctpPacketError::MediumError("Packet too short to parse smbus header") ); } } @@ -752,12 +775,10 @@ mod tests { packet[0..4].copy_from_slice(&header_bytes); packet[4..9].copy_from_slice(&payload); - let result = medium.deserialize(&packet); + let err = medium.deserialize(&packet).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError( - "Packet too short to parse smbus body and PEC" - )) + err, + MctpPacketError::MediumError("Packet too short to parse smbus body and PEC") ); } @@ -777,12 +798,10 @@ mod tests { let mut short_packet = [0u8; 4]; // Only header, no PEC short_packet.copy_from_slice(&header_bytes); - let result = medium.deserialize(&short_packet); + let err = medium.deserialize(&short_packet).err().unwrap(); assert_eq!( - result, - Err(MctpPacketError::MediumError( - "Packet too short to parse smbus body and PEC" - )) + err, + MctpPacketError::MediumError("Packet too short to parse smbus body and PEC") ); } @@ -799,28 +818,31 @@ mod tests { let max_payload = [0x55u8; 255]; let mut buffer = [0u8; 260]; // 4 + 255 + 1 = exactly enough - let result = medium.serialize(reply_context, &mut buffer, |buf| { - let copy_len = max_payload.len().min(buf.len()); - buf[..copy_len].copy_from_slice(&max_payload[..copy_len]); - Ok(copy_len) + let result = medium.serialize(reply_context, &mut buffer, |encoder| { + encoder + .write_all(&max_payload) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }); assert!(result.is_ok()); let serialized = result.unwrap(); assert_eq!(serialized.len(), 260); // Should use exactly all available space - // Test with buffer one byte too small for maximum payload + // Test with buffer one byte too small for maximum payload. + // The encoder will hit BufferFull when trying to write the + // 255th payload byte (only 254 fit after header reservation), + // so this serialize call now returns an error rather than + // silently truncating. let mut small_buffer = [0u8; 259]; // One byte short for max payload - let result_small = medium.serialize(reply_context, &mut small_buffer, |buf| { - // Try to write max payload but buffer is too small - let copy_len = max_payload.len().min(buf.len()); - buf[..copy_len].copy_from_slice(&max_payload[..copy_len]); - Ok(copy_len) + let result_small = medium.serialize(reply_context, &mut small_buffer, |encoder| { + encoder + .write_all(&max_payload) + .map_err(|_| MctpPacketError::SerializeError("encode error")) }); - // Should still work but with truncated payload (254 bytes payload + 4 header + 1 PEC = 259) - assert!(result_small.is_ok()); - let serialized_small = result_small.unwrap(); - assert_eq!(serialized_small.len(), 259); // Uses all available space + assert_eq!( + result_small.err().unwrap(), + MctpPacketError::SerializeError("encode error") + ); } } diff --git a/mctp-rs/src/serialize.rs b/mctp-rs/src/serialize.rs index 0f2d3861..69fac777 100644 --- a/mctp-rs/src/serialize.rs +++ b/mctp-rs/src/serialize.rs @@ -1,6 +1,10 @@ use crate::{ - MctpPacketError, error::MctpPacketResult, mctp_packet_context::MctpReplyContext, - mctp_transport_header::MctpTransportHeader, medium::MctpMedium, + MctpPacketError, + buffer_encoding::{BufferEncoding, EncodeError, EncodingEncoder}, + error::MctpPacketResult, + mctp_packet_context::MctpReplyContext, + mctp_transport_header::MctpTransportHeader, + medium::MctpMedium, }; #[derive(Debug, PartialEq, Eq)] @@ -25,15 +29,53 @@ impl<'buf, M: MctpMedium> SerializePacketState<'buf, M> { let packet = self.medium.serialize( self.reply_context.medium_context, self.assembly_buffer, - |buffer: &mut [u8]| { - let max_packet_size = self.medium.max_message_body_size().min(buffer.len()); - if max_packet_size < TRANSPORT_HEADER_SIZE { + |encoder: &mut EncodingEncoder<'_, M::Encoding>| { + let max_wire = self.medium.max_message_body_size().min(encoder.remaining_wire()); + + // Build the transport header first (with end_of_message + // tentatively 0) so we can measure its wire footprint + // under the medium's encoding before chunking the body. + let start_of_message = if self.current_packet_num == 0 { 1 } else { 0 }; + let packet_sequence_number = self.reply_context.packet_sequence_number.inc(); + let mut transport_header_value: u32 = MctpTransportHeader { + reserved: 0, + header_version: 1, + start_of_message, + end_of_message: 0, + packet_sequence_number, + tag_owner: 0, + message_tag: self.reply_context.message_tag, + source_endpoint_id: self.reply_context.destination_endpoint_id, + destination_endpoint_id: self.reply_context.source_endpoint_id, + } + .try_into() + .map_err(MctpPacketError::SerializeError)?; + let mut header_bytes = transport_header_value.to_be_bytes(); + let header_wire_cost = M::Encoding::wire_size_of(&header_bytes); + if header_wire_cost > max_wire { return Err(MctpPacketError::SerializeError( "assembly buffer too small for mctp transport header", )); } - let message_size = (max_packet_size - TRANSPORT_HEADER_SIZE).min(self.message_buffer.len()); + // Walk decoded body bytes one at a time, accumulating + // their per-byte wire footprint via + // `M::Encoding::wire_size_of`. Stop when adding the + // next byte would exceed the wire budget. Correct for + // both passthrough and stuffing encodings (both shipped + // encodings are byte-additive — `wire_size_of(a ++ b) + // == wire_size_of(a) + wire_size_of(b)`). + let body_wire_budget = max_wire - header_wire_cost; + let mut consumed_wire = 0usize; + let mut message_size = 0usize; + for &b in self.message_buffer.iter() { + let cost = M::Encoding::wire_size_of(&[b]); + if consumed_wire + cost > body_wire_budget { + break; + } + consumed_wire += cost; + message_size += 1; + } // if there is no room for any of the body, and the body is not empty, // then return an error, otherwise we infinate loop sending packets with headers and @@ -47,30 +89,43 @@ impl<'buf, M: MctpMedium> SerializePacketState<'buf, M> { let body = &self.message_buffer[..message_size]; self.message_buffer = &self.message_buffer[message_size..]; - let start_of_message = if self.current_packet_num == 0 { 1 } else { 0 }; + // Now that we know whether this is the final chunk, + // rebuild the transport header if `end_of_message` + // flips to 1. Re-measure the wire cost — none of the + // EOM-bit bytes hit 0x7E or 0x7D under either shipped + // encoding in practice, but do not assume. let end_of_message = if self.message_buffer.is_empty() { 1 } else { 0 }; - let packet_sequence_number = self.reply_context.packet_sequence_number.inc(); - let transport_header: u32 = MctpTransportHeader { - reserved: 0, - header_version: 1, - start_of_message, - end_of_message, - packet_sequence_number, - tag_owner: 0, - message_tag: self.reply_context.message_tag, - source_endpoint_id: self.reply_context.destination_endpoint_id, - destination_endpoint_id: self.reply_context.source_endpoint_id, + if end_of_message == 1 { + transport_header_value = MctpTransportHeader { + reserved: 0, + header_version: 1, + start_of_message, + end_of_message, + packet_sequence_number, + tag_owner: 0, + message_tag: self.reply_context.message_tag, + source_endpoint_id: self.reply_context.destination_endpoint_id, + destination_endpoint_id: self.reply_context.source_endpoint_id, + } + .try_into() + .map_err(MctpPacketError::SerializeError)?; + header_bytes = transport_header_value.to_be_bytes(); + let rebuilt_header_wire_cost = M::Encoding::wire_size_of(&header_bytes); + if rebuilt_header_wire_cost + consumed_wire > max_wire { + return Err(MctpPacketError::SerializeError( + "assembly buffer too small after EOM bit set", + )); + } } - .try_into() - .map_err(MctpPacketError::SerializeError)?; - // write the transport header and message body - let mut cursor = 0; - buffer[cursor..cursor + TRANSPORT_HEADER_SIZE].copy_from_slice(&transport_header.to_be_bytes()); - cursor += TRANSPORT_HEADER_SIZE; - // message body is the rest of the buffer, up to the packet size - buffer[cursor..cursor + body.len()].copy_from_slice(body); - Ok(cursor + body.len()) + // write the transport header and message body via the + // medium-supplied encoder. + let map_encode_err = |e: EncodeError| match e { + EncodeError::BufferFull => MctpPacketError::SerializeError("encoding: buffer full"), + }; + encoder.write_all(&header_bytes).map_err(map_encode_err)?; + encoder.write_all(body).map_err(map_encode_err)?; + Ok(()) }, ); diff --git a/mctp-rs/src/test_util.rs b/mctp-rs/src/test_util.rs index f367c670..e328a0c9 100644 --- a/mctp-rs/src/test_util.rs +++ b/mctp-rs/src/test_util.rs @@ -1,5 +1,6 @@ use crate::{ MctpPacketError, + buffer_encoding::{EncodingDecoder, EncodingEncoder, PassthroughEncoding}, error::MctpPacketResult, medium::{MctpMedium, MctpMediumFrame}, }; @@ -36,8 +37,12 @@ impl MctpMedium for TestMedium { type Frame = TestMediumFrame; type Error = &'static str; type ReplyContext = (); + type Encoding = PassthroughEncoding; - fn deserialize<'buf>(&self, packet: &'buf [u8]) -> MctpPacketResult<(Self::Frame, &'buf [u8]), Self> { + fn deserialize<'buf>( + &self, + packet: &'buf [u8], + ) -> MctpPacketResult<(Self::Frame, EncodingDecoder<'buf, Self::Encoding>), Self> { let packet_len = packet.len(); // check that header / trailer is present and correct @@ -51,8 +56,8 @@ impl MctpMedium for TestMedium { return Err(MctpPacketError::MediumError("trailer mismatch")); } - let packet = &packet[self.header.len()..packet_len - self.trailer.len()]; - Ok((TestMediumFrame(packet_len), packet)) + let inner = &packet[self.header.len()..packet_len - self.trailer.len()]; + Ok((TestMediumFrame(packet_len), EncodingDecoder::new(inner))) } fn max_message_body_size(&self) -> usize { self.mtu @@ -64,7 +69,7 @@ impl MctpMedium for TestMedium { message_writer: F, ) -> MctpPacketResult<&'buf [u8], Self> where - F: for<'a> FnOnce(&'a mut [u8]) -> MctpPacketResult, + F: for<'a> FnOnce(&mut EncodingEncoder<'a, Self::Encoding>) -> MctpPacketResult<(), Self>, { let header_len = self.header.len(); let trailer_len = self.trailer.len(); @@ -82,8 +87,15 @@ impl MctpMedium for TestMedium { let max_message_size = max_packet_size - header_len - trailer_len; buffer[0..header_len].copy_from_slice(self.header); - let size = message_writer(&mut buffer[header_len..header_len + max_message_size])?; - let len = header_len + size; + + let body_wire_len = { + let body_buf = &mut buffer[header_len..header_len + max_message_size]; + let mut encoder = EncodingEncoder::::new(body_buf); + message_writer(&mut encoder)?; + encoder.wire_position() + }; + + let len = header_len + body_wire_len; buffer[len..len + trailer_len].copy_from_slice(self.trailer); Ok(&buffer[..len + trailer_len]) } diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 52d64d20..659f35ae 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -22,6 +22,10 @@ audit-as-crates-io = false [policy.keyberon] audit-as-crates-io = false +[[exemptions.crc]] +version = "3.4.0" +criteria = "safe-to-deploy" + [[exemptions.diff]] version = "0.1.13" criteria = "safe-to-run"