Skip to content

Commit a6fe9e5

Browse files
committed
Utf8ValidatingReader detects encodings and strips BOMs automtically
In cases where the input is sufficiently short and doesn't contain invalid sequences, Utf8ValidatingReader was unable to detect the input as being not-UTF-8 We now call detect_encoding() during the first read() so that it can more effectively raise the appropriate errors. Doing this (and BOM stripping) upstream of the parser makes it possible to eliminate this responsibility from the parser, once it can be relied upon on all code paths.
1 parent 83089e1 commit a6fe9e5

3 files changed

Lines changed: 380 additions & 32 deletions

File tree

src/encoding.rs

Lines changed: 276 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ pub enum Utf8ValidationError {
3131
},
3232
/// Incomplete UTF-8 sequence at end of stream
3333
IncompleteSequence,
34+
/// Non-UTF-8 encoding detected at start of stream
35+
NonUtf8EncodingDetected(DetectedEncoding),
3436
}
3537

3638
impl From<Utf8Error> for Utf8ValidationError {
@@ -50,6 +52,13 @@ impl std::fmt::Display for Utf8ValidationError {
5052
Self::IncompleteSequence => {
5153
write!(f, "incomplete UTF-8 sequence at end of stream")
5254
}
55+
Self::NonUtf8EncodingDetected(detected) => {
56+
write!(
57+
f,
58+
"non-UTF-8 encoding detected at start of stream: {:?}",
59+
detected
60+
)
61+
}
5362
}
5463
}
5564
}
@@ -323,6 +332,7 @@ pub fn detect_encoding(bytes: &[u8]) -> Option<DetectedEncoding> {
323332
/// Possible scenarios for start-of-xml detection of encoding
324333
///
325334
/// See the documentation of [`detect_encoding`]
335+
#[derive(Clone, Debug, PartialEq, Eq)]
326336
pub enum DetectedEncoding {
327337
/// Matches UTF-8 or some other ascii-compatible encoding
328338
AsciiCompatible,
@@ -417,6 +427,10 @@ impl<R: io::Read> io::BufRead for Utf8BytesReader<R> {
417427
/// that only valid UTF-8 bytes are written to the output buffer. Incomplete UTF-8
418428
/// sequences at read boundaries are buffered and combined with subsequent reads.
419429
///
430+
/// Additionally, this reader checks the very beginning of the stream for encoding
431+
/// signatures (BOMs or XML declaration patterns) and rejects streams that appear to
432+
/// be encoded in UTF-16 or other non-UTF-8 encodings.
433+
///
420434
/// # Examples
421435
///
422436
/// ```
@@ -434,6 +448,8 @@ pub struct Utf8ValidatingReader<R> {
434448
inner: R,
435449
/// Buffer to hold incomplete UTF-8 sequences from previous reads (max 3 bytes)
436450
buffer: Vec<u8>,
451+
/// Whether we've checked for encoding at the start of the stream
452+
encoding_checked: bool,
437453
}
438454

439455
impl<R> Utf8ValidatingReader<R> {
@@ -442,6 +458,7 @@ impl<R> Utf8ValidatingReader<R> {
442458
Self {
443459
inner,
444460
buffer: Vec::with_capacity(4),
461+
encoding_checked: false,
445462
}
446463
}
447464

@@ -467,6 +484,49 @@ impl<R: Read> Read for Utf8ValidatingReader<R> {
467484
return Ok(0);
468485
}
469486

487+
// Check for encoding at the start of the stream
488+
if !self.encoding_checked {
489+
self.encoding_checked = true;
490+
491+
// Read initial data to detect encoding
492+
// Read enough for encoding detection (4 bytes) plus fill up to caller's buffer size
493+
let read_size = buf.len().max(64); // Read at least 64 bytes for efficiency
494+
let mut temp = vec![0u8; read_size];
495+
let n = self.inner.read(&mut temp)?;
496+
497+
if n > 0 {
498+
self.buffer.extend_from_slice(&temp[..n]);
499+
500+
// Try to detect encoding if we have at least 4 bytes
501+
if self.buffer.len() >= 4 {
502+
if let Some(detected) = detect_encoding(&self.buffer) {
503+
match detected {
504+
DetectedEncoding::Utf8Bom | DetectedEncoding::AsciiCompatible => {
505+
// Strip BOM if present
506+
let bom_len = detected.bom_len();
507+
if bom_len > 0 {
508+
self.buffer.drain(..bom_len);
509+
}
510+
}
511+
DetectedEncoding::Utf16Le
512+
| DetectedEncoding::Utf16LeBom
513+
| DetectedEncoding::Utf16Be
514+
| DetectedEncoding::Utf16BeBom => {
515+
// Reject UTF-16 encodings
516+
return Err(io::Error::new(
517+
io::ErrorKind::InvalidData,
518+
EncodingError::Utf8(
519+
Utf8ValidationError::NonUtf8EncodingDetected(detected),
520+
),
521+
));
522+
}
523+
}
524+
}
525+
}
526+
}
527+
// If we read 0 bytes or less than 4 bytes, assume UTF-8 and continue
528+
}
529+
470530
loop {
471531
// If we have buffered data, check if it's complete UTF-8
472532
if !self.buffer.is_empty() {
@@ -775,7 +835,21 @@ mod utf8_validating_reader_tests {
775835
// Second read should fail because incomplete sequence at EOF
776836
let result = reader.read(&mut buf);
777837
assert!(result.is_err());
778-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
838+
let err = result.unwrap_err();
839+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
840+
841+
// Verify the error can be downcast to EncodingError
842+
let encoding_err = err
843+
.get_ref()
844+
.unwrap()
845+
.downcast_ref::<EncodingError>()
846+
.expect("Error should downcast to EncodingError");
847+
848+
// Verify it's the IncompleteSequence error
849+
match encoding_err {
850+
EncodingError::Utf8(Utf8ValidationError::IncompleteSequence) => {}
851+
other => panic!("Expected IncompleteSequence error, got: {:?}", other),
852+
}
779853
}
780854

781855
#[test]
@@ -786,7 +860,23 @@ mod utf8_validating_reader_tests {
786860
let mut buf = [0u8; 10];
787861
let result = reader.read(&mut buf);
788862
assert!(result.is_err());
789-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
863+
let err = result.unwrap_err();
864+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
865+
866+
// Verify the error can be downcast to EncodingError
867+
let encoding_err = err
868+
.get_ref()
869+
.unwrap()
870+
.downcast_ref::<EncodingError>()
871+
.expect("Error should downcast to EncodingError");
872+
873+
// Verify it's an InvalidSequence error
874+
match encoding_err {
875+
EncodingError::Utf8(Utf8ValidationError::InvalidSequence { error_len }) => {
876+
assert_eq!(*error_len, 1, "Expected 1-byte invalid sequence (0xFF)");
877+
}
878+
other => panic!("Expected InvalidSequence error, got: {:?}", other),
879+
}
790880
}
791881

792882
#[test]
@@ -797,7 +887,23 @@ mod utf8_validating_reader_tests {
797887
let mut buf = [0u8; 10];
798888
let result = reader.read(&mut buf);
799889
assert!(result.is_err());
800-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
890+
let err = result.unwrap_err();
891+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
892+
893+
// Verify the error can be downcast to EncodingError
894+
let encoding_err = err
895+
.get_ref()
896+
.unwrap()
897+
.downcast_ref::<EncodingError>()
898+
.expect("Error should downcast to EncodingError");
899+
900+
// Verify it's an InvalidSequence error
901+
match encoding_err {
902+
EncodingError::Utf8(Utf8ValidationError::InvalidSequence { error_len }) => {
903+
assert_eq!(*error_len, 1, "Expected 1-byte invalid sequence");
904+
}
905+
other => panic!("Expected InvalidSequence error, got: {:?}", other),
906+
}
801907
}
802908

803909
#[test]
@@ -942,7 +1048,23 @@ mod utf8_validating_reader_tests {
9421048

9431049
let result = reader.read(&mut buf);
9441050
assert!(result.is_err());
945-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
1051+
let err = result.unwrap_err();
1052+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1053+
1054+
// Verify the error can be downcast to EncodingError
1055+
let encoding_err = err
1056+
.get_ref()
1057+
.unwrap()
1058+
.downcast_ref::<EncodingError>()
1059+
.expect("Error should downcast to EncodingError");
1060+
1061+
// Verify it's the expected error variant
1062+
match encoding_err {
1063+
EncodingError::Utf8(Utf8ValidationError::InvalidSequence { error_len }) => {
1064+
assert_eq!(*error_len, 1, "Expected 1-byte invalid sequence (0xFF)");
1065+
}
1066+
other => panic!("Expected InvalidSequence error, got: {:?}", other),
1067+
}
9461068
}
9471069

9481070
#[test]
@@ -1013,7 +1135,21 @@ mod utf8_validating_reader_tests {
10131135

10141136
let result = reader.read(&mut buf);
10151137
assert!(result.is_err());
1016-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
1138+
let err = result.unwrap_err();
1139+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1140+
1141+
// Verify the error can be downcast to EncodingError
1142+
let encoding_err = err
1143+
.get_ref()
1144+
.unwrap()
1145+
.downcast_ref::<EncodingError>()
1146+
.expect("Error should downcast to EncodingError");
1147+
1148+
// Verify it's the IncompleteSequence error
1149+
match encoding_err {
1150+
EncodingError::Utf8(Utf8ValidationError::IncompleteSequence) => {}
1151+
other => panic!("Expected IncompleteSequence error, got: {:?}", other),
1152+
}
10171153
}
10181154

10191155
#[test]
@@ -1028,7 +1164,21 @@ mod utf8_validating_reader_tests {
10281164

10291165
let result = reader.read(&mut buf);
10301166
assert!(result.is_err());
1031-
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
1167+
let err = result.unwrap_err();
1168+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1169+
1170+
// Verify the error can be downcast to EncodingError
1171+
let encoding_err = err
1172+
.get_ref()
1173+
.unwrap()
1174+
.downcast_ref::<EncodingError>()
1175+
.expect("Error should downcast to EncodingError");
1176+
1177+
// Verify it's the IncompleteSequence error
1178+
match encoding_err {
1179+
EncodingError::Utf8(Utf8ValidationError::IncompleteSequence) => {}
1180+
other => panic!("Expected IncompleteSequence error, got: {:?}", other),
1181+
}
10321182
}
10331183

10341184
#[test]
@@ -1042,6 +1192,26 @@ mod utf8_validating_reader_tests {
10421192

10431193
let result = reader.read(&mut buf);
10441194
assert!(result.is_err());
1195+
let err = result.unwrap_err();
1196+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1197+
1198+
// Verify the error can be downcast to EncodingError
1199+
let encoding_err = err
1200+
.get_ref()
1201+
.unwrap()
1202+
.downcast_ref::<EncodingError>()
1203+
.expect("Error should downcast to EncodingError");
1204+
1205+
// Verify it's an InvalidSequence error
1206+
match encoding_err {
1207+
EncodingError::Utf8(Utf8ValidationError::InvalidSequence { error_len }) => {
1208+
assert_eq!(
1209+
*error_len, 1,
1210+
"Expected 1-byte invalid sequence (0xC0 is invalid)"
1211+
);
1212+
}
1213+
other => panic!("Expected InvalidSequence error, got: {:?}", other),
1214+
}
10451215
}
10461216

10471217
#[test]
@@ -1054,6 +1224,26 @@ mod utf8_validating_reader_tests {
10541224

10551225
let result = reader.read(&mut buf);
10561226
assert!(result.is_err());
1227+
let err = result.unwrap_err();
1228+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1229+
1230+
// Verify the error can be downcast to EncodingError
1231+
let encoding_err = err
1232+
.get_ref()
1233+
.unwrap()
1234+
.downcast_ref::<EncodingError>()
1235+
.expect("Error should downcast to EncodingError");
1236+
1237+
// Verify it's an InvalidSequence error
1238+
match encoding_err {
1239+
EncodingError::Utf8(Utf8ValidationError::InvalidSequence { error_len }) => {
1240+
assert_eq!(
1241+
*error_len, 1,
1242+
"Expected 1-byte invalid sequence (0xED starts surrogate)"
1243+
);
1244+
}
1245+
other => panic!("Expected InvalidSequence error, got: {:?}", other),
1246+
}
10571247
}
10581248

10591249
#[test]
@@ -1078,4 +1268,84 @@ mod utf8_validating_reader_tests {
10781268
assert_eq!(n, 5);
10791269
assert_eq!(&buf[..n], data);
10801270
}
1271+
1272+
#[test]
1273+
fn test_utf8_bom_stripped() {
1274+
// UTF-8 BOM (0xEF 0xBB 0xBF) followed by "Hello"
1275+
let data = b"\xEF\xBB\xBFHello";
1276+
let mut reader = Utf8ValidatingReader::new(&data[..]);
1277+
let mut buf = [0u8; 20];
1278+
let n = reader.read(&mut buf).unwrap();
1279+
1280+
// BOM should be stripped, only "Hello" should be returned
1281+
assert_eq!(&buf[..n], b"Hello");
1282+
assert_eq!(std::str::from_utf8(&buf[..n]).unwrap(), "Hello");
1283+
}
1284+
1285+
#[test]
1286+
fn test_utf16le_bom_rejected() {
1287+
// UTF-16 LE BOM (0xFF 0xFE)
1288+
let data = b"\xFF\xFE<?xml";
1289+
let mut reader = Utf8ValidatingReader::new(&data[..]);
1290+
let mut buf = [0u8; 20];
1291+
1292+
let result = reader.read(&mut buf);
1293+
assert!(result.is_err());
1294+
let err = result.unwrap_err();
1295+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1296+
1297+
// Verify the error can be downcast to EncodingError
1298+
let encoding_err = err
1299+
.get_ref()
1300+
.unwrap()
1301+
.downcast_ref::<EncodingError>()
1302+
.expect("Error should downcast to EncodingError");
1303+
1304+
// Verify it's the NonUtf8EncodingDetected error with the correct encoding
1305+
match encoding_err {
1306+
EncodingError::Utf8(Utf8ValidationError::NonUtf8EncodingDetected(detected)) => {
1307+
assert_eq!(*detected, DetectedEncoding::Utf16LeBom);
1308+
}
1309+
other => panic!("Expected NonUtf8EncodingDetected error, got: {:?}", other),
1310+
}
1311+
}
1312+
1313+
#[test]
1314+
fn test_utf16be_bom_rejected() {
1315+
// UTF-16 BE BOM (0xFE 0xFF)
1316+
let data = b"\xFE\xFF\x00<\x00?";
1317+
let mut reader = Utf8ValidatingReader::new(&data[..]);
1318+
let mut buf = [0u8; 20];
1319+
1320+
let result = reader.read(&mut buf);
1321+
assert!(result.is_err());
1322+
let err = result.unwrap_err();
1323+
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1324+
1325+
// Verify the error can be downcast to EncodingError
1326+
let encoding_err = err
1327+
.get_ref()
1328+
.unwrap()
1329+
.downcast_ref::<EncodingError>()
1330+
.expect("Error should downcast to EncodingError");
1331+
1332+
// Verify it's the NonUtf8EncodingDetected error with the correct encoding
1333+
match encoding_err {
1334+
EncodingError::Utf8(Utf8ValidationError::NonUtf8EncodingDetected(detected)) => {
1335+
assert_eq!(*detected, DetectedEncoding::Utf16BeBom);
1336+
}
1337+
other => panic!("Expected NonUtf8EncodingDetected error, got: {:?}", other),
1338+
}
1339+
}
1340+
1341+
#[test]
1342+
fn test_ascii_compatible_encoding_allowed() {
1343+
// ASCII-compatible XML declaration (no BOM)
1344+
let data = b"<?xml version=\"1.0\"?><root/>";
1345+
let mut reader = Utf8ValidatingReader::new(&data[..]);
1346+
let mut buf = [0u8; 50];
1347+
1348+
let n = reader.read(&mut buf).unwrap();
1349+
assert_eq!(&buf[..n], data);
1350+
}
10811351
}

0 commit comments

Comments
 (0)