diff --git a/zerompk/src/read.rs b/zerompk/src/read.rs index c25c2a8..4cc2981 100644 --- a/zerompk/src/read.rs +++ b/zerompk/src/read.rs @@ -1,3 +1,5 @@ +use core::hint::cold_path; + #[cfg(feature = "std")] use alloc::vec; @@ -38,7 +40,7 @@ pub trait Read<'de> { /// /// impl<'de> FromMessagePack<'de> for Outer { /// fn read>(reader: &mut R) -> Result { - /// reader.increment_depth()?; + /// reader.increment_depth()?; /// let inner = Inner::read(reader)?; /// reader.decrement_depth(); /// Ok(Self { inner }) @@ -138,6 +140,7 @@ pub trait Read<'de> { if actual == expected { Ok(()) } else { + cold_path(); Err(Error::ArrayLengthMismatch { expected, actual }) } } @@ -149,6 +152,7 @@ pub trait Read<'de> { if actual == expected { Ok(()) } else { + cold_path(); Err(Error::MapLengthMismatch { expected, actual }) } } @@ -174,6 +178,7 @@ impl<'de> SliceReader<'de> { if self.pos < self.data.len() { Ok(self.data[self.pos]) } else { + cold_path(); Err(Error::BufferTooSmall) } } @@ -183,6 +188,7 @@ impl<'de> SliceReader<'de> { if self.pos + len <= self.data.len() { unsafe { Ok(self.data.get_unchecked(self.pos..(self.pos + len))) } } else { + cold_path(); Err(Error::BufferTooSmall) } } @@ -213,6 +219,7 @@ impl<'de> Read<'de> for SliceReader<'de> { #[inline(always)] fn increment_depth(&mut self) -> Result<()> { if self.depth >= MAX_DEPTH { + cold_path(); Err(Error::DepthLimitExceeded { max: MAX_DEPTH }) } else { self.depth += 1; @@ -224,6 +231,8 @@ impl<'de> Read<'de> for SliceReader<'de> { fn decrement_depth(&mut self) { if self.depth > 0 { self.depth -= 1; + } else { + cold_path(); } } @@ -234,267 +243,242 @@ impl<'de> Read<'de> for SliceReader<'de> { self.pos += 1; Ok(()) } else { + cold_path(); Err(Error::InvalidMarker(byte)) } } #[inline(always)] fn read_boolean(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { - FALSE_MARKER => { - self.pos += 1; - Ok(false) - } - TRUE_MARKER => { - self.pos += 1; - Ok(true) + FALSE_MARKER => Ok(false), + TRUE_MARKER => Ok(true), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) } - _ => Err(Error::InvalidMarker(byte)), } } #[inline(always)] fn read_u8(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte), // uint 8 UINT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_u16(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as u16) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as u16), // uint 8 UINT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as u16) } // uint 16 UINT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(u16::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_u32(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as u32) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as u32), // uint 8 UINT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as u32) } // uint 16 UINT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(u16::from_be_bytes(*bytes) as u32) } // uint 32 UINT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(u32::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_u64(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as u64) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as u64), // uint 8 UINT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as u64) } // uint 16 UINT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(u16::from_be_bytes(*bytes) as u64) } // uint 32 UINT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(u32::from_be_bytes(*bytes) as u64) } // uint 64 UINT64_MARKER => { - self.pos += 1; let bytes = self.take_array::<8>()?; Ok(u64::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_i8(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as i8) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as i8), // Negative FixInt - NEG_FIXINT_START..=NEG_FIXINT_END => { - self.pos += 1; - Ok(byte as i8) - } + NEG_FIXINT_START..=NEG_FIXINT_END => Ok(byte as i8), // int 8 INT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as i8) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_i16(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as i16) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as i16), // Negative FixInt - NEG_FIXINT_START..=NEG_FIXINT_END => { - self.pos += 1; - Ok((byte as i8) as i16) - } + NEG_FIXINT_START..=NEG_FIXINT_END => Ok((byte as i8) as i16), // int 8 INT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as i8 as i16) } // int 16 INT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(i16::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_i32(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as i32) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as i32), // Negative FixInt - NEG_FIXINT_START..=NEG_FIXINT_END => { - self.pos += 1; - Ok((byte as i8) as i32) - } + NEG_FIXINT_START..=NEG_FIXINT_END => Ok((byte as i8) as i32), // int 8 INT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as i8 as i32) } // int 16 INT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(i16::from_be_bytes(*bytes) as i32) } // int 32 INT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(i32::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_i64(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // Positive FixInt - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(byte as i64) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(byte as i64), // Negative FixInt - NEG_FIXINT_START..=NEG_FIXINT_END => { - self.pos += 1; - Ok((byte as i8) as i64) - } + NEG_FIXINT_START..=NEG_FIXINT_END => Ok((byte as i8) as i64), // int 8 INT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(byte as i8 as i64) } // int 16 INT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(i16::from_be_bytes(*bytes) as i64) } // int 32 INT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(i32::from_be_bytes(*bytes) as i64) } // int 64 INT64_MARKER => { - self.pos += 1; let bytes = self.take_array::<8>()?; Ok(i64::from_be_bytes(*bytes)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } @@ -508,7 +492,10 @@ impl<'de> Read<'de> for SliceReader<'de> { let bytes = self.take_array::<4>()?; Ok(f32::from_bits(u32::from_be_bytes(*bytes))) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + Err(Error::InvalidMarker(byte)) + } } } @@ -522,38 +509,39 @@ impl<'de> Read<'de> for SliceReader<'de> { let bytes = self.take_array::<8>()?; Ok(f64::from_bits(u64::from_be_bytes(*bytes))) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_string(&mut self) -> Result> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; let len = match byte { // fixstr - FIXSTR_START..=FIXSTR_END => { - self.pos += 1; - (byte - FIXSTR_START) as usize - } + FIXSTR_START..=FIXSTR_END => (byte - FIXSTR_START) as usize, // str 8 STR8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; byte as usize } // str 16 STR16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; u16::from_be_bytes(*bytes) as usize } // str 32 STR32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; u32::from_be_bytes(*bytes) as usize } - _ => return Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + return Err(Error::InvalidMarker(byte)); + } }; let bytes = self.take_slice(len)?; match core::str::from_utf8(bytes) { @@ -564,32 +552,30 @@ impl<'de> Read<'de> for SliceReader<'de> { #[inline(always)] fn read_string_bytes(&mut self) -> Result> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; let len = match byte { // fixstr - FIXSTR_START..=FIXSTR_END => { - self.pos += 1; - (byte - FIXSTR_START) as usize - } + FIXSTR_START..=FIXSTR_END => (byte - FIXSTR_START) as usize, // str 8 STR8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; byte as usize } // str 16 STR16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; u16::from_be_bytes(*bytes) as usize } // str 32 STR32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; u32::from_be_bytes(*bytes) as usize } - _ => return Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + return Err(Error::InvalidMarker(byte)); + } }; let bytes = self.take_slice(len)?; Ok(alloc::borrow::Cow::Borrowed(bytes)) @@ -597,27 +583,28 @@ impl<'de> Read<'de> for SliceReader<'de> { #[inline(always)] fn read_binary(&mut self) -> Result> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; let len = match byte { // bin 8 BIN8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; byte as usize } // bin 16 BIN16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; u16::from_be_bytes(*bytes) as usize } // bin 32 BIN32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; u32::from_be_bytes(*bytes) as usize } - _ => return Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + return Err(Error::InvalidMarker(byte)); + } }; let bytes = self.take_slice(len)?; Ok(alloc::borrow::Cow::Borrowed(bytes)) @@ -625,11 +612,10 @@ impl<'de> Read<'de> for SliceReader<'de> { #[inline(always)] fn read_timestamp(&mut self) -> Result<(i64, u32)> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // fixext 4 with type -1 TIMESTAMP32_MARKER => { - self.pos += 1; let ext_info = self.take_array::<5>()?; let [ext, tail @ ..] = *ext_info; if ext as i8 != TIMESTAMP_EXT_TYPE { @@ -641,7 +627,6 @@ impl<'de> Read<'de> for SliceReader<'de> { } // fixext 8 with type -1 TIMESTAMP64_MARKER => { - self.pos += 1; let ext_info = self.take_array::<9>()?; let [ext, tail @ ..] = *ext_info; if ext as i8 != TIMESTAMP_EXT_TYPE { @@ -658,7 +643,6 @@ impl<'de> Read<'de> for SliceReader<'de> { } // ext8(12) with type -1 TIMESTAMP96_MARKER => { - self.pos += 1; let len = self.take_byte()? as usize; if len != 12 { return Err(Error::InvalidMarker(len as u8)); @@ -679,108 +663,96 @@ impl<'de> Read<'de> for SliceReader<'de> { } Ok((seconds, nanoseconds)) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_array_len(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // fixarray - FIXARRAY_START..=FIXARRAY_END => { - self.pos += 1; - Ok((byte - FIXARRAY_START) as usize) - } + FIXARRAY_START..=FIXARRAY_END => Ok((byte - FIXARRAY_START) as usize), // array 16 ARRAY16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(u16::from_be_bytes(*bytes) as usize) } // array 32 ARRAY32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(u32::from_be_bytes(*bytes) as usize) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_map_len(&mut self) -> Result { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { // fixmap - FIXMAP_START..=FIXMAP_END => { - self.pos += 1; - Ok((byte - FIXMAP_START) as usize) - } + FIXMAP_START..=FIXMAP_END => Ok((byte - FIXMAP_START) as usize), // map 16 MAP16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(u16::from_be_bytes(*bytes) as usize) } // map 32 MAP32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(u32::from_be_bytes(*bytes) as usize) } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } #[inline(always)] fn read_ext_len(&mut self) -> Result<(i8, usize)> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; let len = match byte { // fixext 1 - FIXEXT1_MARKER => { - self.pos += 1; - 1 - } + FIXEXT1_MARKER => 1, // fixext 2 - FIXEXT2_MARKER => { - self.pos += 1; - 2 - } + FIXEXT2_MARKER => 2, // fixext 4 - FIXEXT4_MARKER => { - self.pos += 1; - 4 - } + FIXEXT4_MARKER => 4, // fixext 8 - FIXEXT8_MARKER => { - self.pos += 1; - 8 - } + FIXEXT8_MARKER => 8, // fixext 16 - FIXEXT16_MARKER => { - self.pos += 1; - 16 - } + FIXEXT16_MARKER => 16, // ext 8 EXT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; byte as usize } // ext 16 EXT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; u16::from_be_bytes(*bytes) as usize } // ext 32 EXT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; u32::from_be_bytes(*bytes) as usize } - _ => return Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + return Err(Error::InvalidMarker(byte)); + } }; let ext_type = self.take_byte()? as i8; Ok((ext_type, len)) @@ -798,6 +770,7 @@ impl<'de> Read<'de> for SliceReader<'de> { // This is intended to prevent pre-allocation of memory, // which can be used in attacks that exploit abnormal sizes. if self.data.len() - self.pos < len { + cold_path(); return Err(Error::BufferTooSmall); } @@ -830,34 +803,26 @@ impl<'de> Read<'de> for SliceReader<'de> { #[inline(always)] fn read_tag(&mut self) -> Result> { - let byte = self.peek_byte()?; + let byte = self.take_byte()?; match byte { - POS_FIXINT_START..=POS_FIXINT_END => { - self.pos += 1; - Ok(Tag::Int(byte as u64)) - } + POS_FIXINT_START..=POS_FIXINT_END => Ok(Tag::Int(byte as u64)), UINT8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; Ok(Tag::Int(byte as u64)) } UINT16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; Ok(Tag::Int(u16::from_be_bytes(*bytes) as u64)) } UINT32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; Ok(Tag::Int(u32::from_be_bytes(*bytes) as u64)) } UINT64_MARKER => { - self.pos += 1; let bytes = self.take_array::<8>()?; Ok(Tag::Int(u64::from_be_bytes(*bytes))) } FIXSTR_START..=FIXSTR_END => { - self.pos += 1; let len = (byte - FIXSTR_START) as usize; let bytes = self.take_slice(len)?; match core::str::from_utf8(bytes) { @@ -866,7 +831,6 @@ impl<'de> Read<'de> for SliceReader<'de> { } } STR8_MARKER => { - self.pos += 1; let byte = self.take_byte()?; let len = byte as usize; let bytes = self.take_slice(len)?; @@ -876,7 +840,6 @@ impl<'de> Read<'de> for SliceReader<'de> { } } STR16_MARKER => { - self.pos += 1; let bytes = self.take_array::<2>()?; let len = u16::from_be_bytes(*bytes) as usize; @@ -887,7 +850,6 @@ impl<'de> Read<'de> for SliceReader<'de> { } } STR32_MARKER => { - self.pos += 1; let bytes = self.take_array::<4>()?; let len = u32::from_be_bytes(*bytes) as usize; @@ -897,7 +859,11 @@ impl<'de> Read<'de> for SliceReader<'de> { Err(err) => Err(Error::InvalidUtf8(err)), } } - _ => Err(Error::InvalidMarker(byte)), + _ => { + cold_path(); + self.pos -= 1; + Err(Error::InvalidMarker(byte)) + } } } } @@ -979,6 +945,7 @@ impl<'de, R: std::io::Read> Read<'de> for IOReader { #[inline(always)] fn increment_depth(&mut self) -> Result<()> { if self.depth >= MAX_DEPTH { + cold_path(); Err(Error::DepthLimitExceeded { max: MAX_DEPTH }) } else { self.depth += 1; @@ -988,7 +955,11 @@ impl<'de, R: std::io::Read> Read<'de> for IOReader { #[inline(always)] fn decrement_depth(&mut self) { - self.depth -= 1; + if self.depth > 0 { + self.depth -= 1; + } else { + cold_path(); + } } #[inline(always)] diff --git a/zerompk/src/write.rs b/zerompk/src/write.rs index c6de0e2..97eabaa 100644 --- a/zerompk/src/write.rs +++ b/zerompk/src/write.rs @@ -1,3 +1,5 @@ +use core::hint::cold_path; + use alloc::vec::Vec; use crate::{Error, Result, consts::*}; @@ -17,7 +19,7 @@ use crate::{Error, Result, consts::*}; /// impl ToMessagePack for Point { /// fn write(&self, writer: &mut W) -> Result<()> { /// writer.write_array_len(2)?; -/// writer.write_i32(self.x)?; +/// writer.write_i32(self.x)?; /// writer.write_i32(self.y)?; /// Ok(()) /// } @@ -92,6 +94,7 @@ impl<'a> SliceWriter<'a> { #[inline(always)] fn take_array(&mut self) -> Result<&mut [u8; N]> { if self.pos + N > self.buffer.len() { + cold_path(); return Err(Error::BufferTooSmall); } let array: &mut [u8; N] = @@ -103,6 +106,7 @@ impl<'a> SliceWriter<'a> { #[inline(always)] fn take_slice(&mut self, len: usize) -> Result<&mut [u8]> { if self.pos + len > self.buffer.len() { + cold_path(); return Err(Error::BufferTooSmall); }