@@ -31,6 +31,8 @@ pub enum Utf8ValidationError {
3131 } ,
3232 /// Incomplete UTF-8 sequence at end of stream
3333 IncompleteSequence ,
34+ /// Non-UTF-8 encoding detected at start of stream
35+ NonUtf8EncodingDetected ( DetectedEncoding ) ,
3436}
3537
3638impl From < Utf8Error > for Utf8ValidationError {
@@ -50,6 +52,13 @@ impl std::fmt::Display for Utf8ValidationError {
5052 Self :: IncompleteSequence => {
5153 write ! ( f, "incomplete UTF-8 sequence at end of stream" )
5254 }
55+ Self :: NonUtf8EncodingDetected ( detected) => {
56+ write ! (
57+ f,
58+ "non-UTF-8 encoding detected at start of stream: {:?}" ,
59+ detected
60+ )
61+ }
5362 }
5463 }
5564}
@@ -323,6 +332,7 @@ pub fn detect_encoding(bytes: &[u8]) -> Option<DetectedEncoding> {
323332/// Possible scenarios for start-of-xml detection of encoding
324333///
325334/// See the documentation of [`detect_encoding`]
335+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
326336pub enum DetectedEncoding {
327337 /// Matches UTF-8 or some other ascii-compatible encoding
328338 AsciiCompatible ,
@@ -417,6 +427,10 @@ impl<R: io::Read> io::BufRead for Utf8BytesReader<R> {
417427/// that only valid UTF-8 bytes are written to the output buffer. Incomplete UTF-8
418428/// sequences at read boundaries are buffered and combined with subsequent reads.
419429///
430+ /// Additionally, this reader checks the very beginning of the stream for encoding
431+ /// signatures (BOMs or XML declaration patterns) and rejects streams that appear to
432+ /// be encoded in UTF-16 or other non-UTF-8 encodings.
433+ ///
420434/// # Examples
421435///
422436/// ```
@@ -434,6 +448,8 @@ pub struct Utf8ValidatingReader<R> {
434448 inner : R ,
435449 /// Buffer to hold incomplete UTF-8 sequences from previous reads (max 3 bytes)
436450 buffer : Vec < u8 > ,
451+ /// Whether we've checked for encoding at the start of the stream
452+ encoding_checked : bool ,
437453}
438454
439455impl < R > Utf8ValidatingReader < R > {
@@ -442,6 +458,7 @@ impl<R> Utf8ValidatingReader<R> {
442458 Self {
443459 inner,
444460 buffer : Vec :: with_capacity ( 4 ) ,
461+ encoding_checked : false ,
445462 }
446463 }
447464
@@ -467,6 +484,49 @@ impl<R: Read> Read for Utf8ValidatingReader<R> {
467484 return Ok ( 0 ) ;
468485 }
469486
487+ // Check for encoding at the start of the stream
488+ if !self . encoding_checked {
489+ self . encoding_checked = true ;
490+
491+ // Read initial data to detect encoding
492+ // Read enough for encoding detection (4 bytes) plus fill up to caller's buffer size
493+ let read_size = buf. len ( ) . max ( 64 ) ; // Read at least 64 bytes for efficiency
494+ let mut temp = vec ! [ 0u8 ; read_size] ;
495+ let n = self . inner . read ( & mut temp) ?;
496+
497+ if n > 0 {
498+ self . buffer . extend_from_slice ( & temp[ ..n] ) ;
499+
500+ // Try to detect encoding if we have at least 4 bytes
501+ if self . buffer . len ( ) >= 4 {
502+ if let Some ( detected) = detect_encoding ( & self . buffer ) {
503+ match detected {
504+ DetectedEncoding :: Utf8Bom | DetectedEncoding :: AsciiCompatible => {
505+ // Strip BOM if present
506+ let bom_len = detected. bom_len ( ) ;
507+ if bom_len > 0 {
508+ self . buffer . drain ( ..bom_len) ;
509+ }
510+ }
511+ DetectedEncoding :: Utf16Le
512+ | DetectedEncoding :: Utf16LeBom
513+ | DetectedEncoding :: Utf16Be
514+ | DetectedEncoding :: Utf16BeBom => {
515+ // Reject UTF-16 encodings
516+ return Err ( io:: Error :: new (
517+ io:: ErrorKind :: InvalidData ,
518+ EncodingError :: Utf8 (
519+ Utf8ValidationError :: NonUtf8EncodingDetected ( detected) ,
520+ ) ,
521+ ) ) ;
522+ }
523+ }
524+ }
525+ }
526+ }
527+ // If we read 0 bytes or less than 4 bytes, assume UTF-8 and continue
528+ }
529+
470530 loop {
471531 // If we have buffered data, check if it's complete UTF-8
472532 if !self . buffer . is_empty ( ) {
@@ -775,7 +835,21 @@ mod utf8_validating_reader_tests {
775835 // Second read should fail because incomplete sequence at EOF
776836 let result = reader. read ( & mut buf) ;
777837 assert ! ( result. is_err( ) ) ;
778- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
838+ let err = result. unwrap_err ( ) ;
839+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
840+
841+ // Verify the error can be downcast to EncodingError
842+ let encoding_err = err
843+ . get_ref ( )
844+ . unwrap ( )
845+ . downcast_ref :: < EncodingError > ( )
846+ . expect ( "Error should downcast to EncodingError" ) ;
847+
848+ // Verify it's the IncompleteSequence error
849+ match encoding_err {
850+ EncodingError :: Utf8 ( Utf8ValidationError :: IncompleteSequence ) => { }
851+ other => panic ! ( "Expected IncompleteSequence error, got: {:?}" , other) ,
852+ }
779853 }
780854
781855 #[ test]
@@ -786,7 +860,23 @@ mod utf8_validating_reader_tests {
786860 let mut buf = [ 0u8 ; 10 ] ;
787861 let result = reader. read ( & mut buf) ;
788862 assert ! ( result. is_err( ) ) ;
789- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
863+ let err = result. unwrap_err ( ) ;
864+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
865+
866+ // Verify the error can be downcast to EncodingError
867+ let encoding_err = err
868+ . get_ref ( )
869+ . unwrap ( )
870+ . downcast_ref :: < EncodingError > ( )
871+ . expect ( "Error should downcast to EncodingError" ) ;
872+
873+ // Verify it's an InvalidSequence error
874+ match encoding_err {
875+ EncodingError :: Utf8 ( Utf8ValidationError :: InvalidSequence { error_len } ) => {
876+ assert_eq ! ( * error_len, 1 , "Expected 1-byte invalid sequence (0xFF)" ) ;
877+ }
878+ other => panic ! ( "Expected InvalidSequence error, got: {:?}" , other) ,
879+ }
790880 }
791881
792882 #[ test]
@@ -797,7 +887,23 @@ mod utf8_validating_reader_tests {
797887 let mut buf = [ 0u8 ; 10 ] ;
798888 let result = reader. read ( & mut buf) ;
799889 assert ! ( result. is_err( ) ) ;
800- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
890+ let err = result. unwrap_err ( ) ;
891+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
892+
893+ // Verify the error can be downcast to EncodingError
894+ let encoding_err = err
895+ . get_ref ( )
896+ . unwrap ( )
897+ . downcast_ref :: < EncodingError > ( )
898+ . expect ( "Error should downcast to EncodingError" ) ;
899+
900+ // Verify it's an InvalidSequence error
901+ match encoding_err {
902+ EncodingError :: Utf8 ( Utf8ValidationError :: InvalidSequence { error_len } ) => {
903+ assert_eq ! ( * error_len, 1 , "Expected 1-byte invalid sequence" ) ;
904+ }
905+ other => panic ! ( "Expected InvalidSequence error, got: {:?}" , other) ,
906+ }
801907 }
802908
803909 #[ test]
@@ -942,7 +1048,23 @@ mod utf8_validating_reader_tests {
9421048
9431049 let result = reader. read ( & mut buf) ;
9441050 assert ! ( result. is_err( ) ) ;
945- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
1051+ let err = result. unwrap_err ( ) ;
1052+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1053+
1054+ // Verify the error can be downcast to EncodingError
1055+ let encoding_err = err
1056+ . get_ref ( )
1057+ . unwrap ( )
1058+ . downcast_ref :: < EncodingError > ( )
1059+ . expect ( "Error should downcast to EncodingError" ) ;
1060+
1061+ // Verify it's the expected error variant
1062+ match encoding_err {
1063+ EncodingError :: Utf8 ( Utf8ValidationError :: InvalidSequence { error_len } ) => {
1064+ assert_eq ! ( * error_len, 1 , "Expected 1-byte invalid sequence (0xFF)" ) ;
1065+ }
1066+ other => panic ! ( "Expected InvalidSequence error, got: {:?}" , other) ,
1067+ }
9461068 }
9471069
9481070 #[ test]
@@ -1013,7 +1135,21 @@ mod utf8_validating_reader_tests {
10131135
10141136 let result = reader. read ( & mut buf) ;
10151137 assert ! ( result. is_err( ) ) ;
1016- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
1138+ let err = result. unwrap_err ( ) ;
1139+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1140+
1141+ // Verify the error can be downcast to EncodingError
1142+ let encoding_err = err
1143+ . get_ref ( )
1144+ . unwrap ( )
1145+ . downcast_ref :: < EncodingError > ( )
1146+ . expect ( "Error should downcast to EncodingError" ) ;
1147+
1148+ // Verify it's the IncompleteSequence error
1149+ match encoding_err {
1150+ EncodingError :: Utf8 ( Utf8ValidationError :: IncompleteSequence ) => { }
1151+ other => panic ! ( "Expected IncompleteSequence error, got: {:?}" , other) ,
1152+ }
10171153 }
10181154
10191155 #[ test]
@@ -1028,7 +1164,21 @@ mod utf8_validating_reader_tests {
10281164
10291165 let result = reader. read ( & mut buf) ;
10301166 assert ! ( result. is_err( ) ) ;
1031- assert_eq ! ( result. unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidData ) ;
1167+ let err = result. unwrap_err ( ) ;
1168+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1169+
1170+ // Verify the error can be downcast to EncodingError
1171+ let encoding_err = err
1172+ . get_ref ( )
1173+ . unwrap ( )
1174+ . downcast_ref :: < EncodingError > ( )
1175+ . expect ( "Error should downcast to EncodingError" ) ;
1176+
1177+ // Verify it's the IncompleteSequence error
1178+ match encoding_err {
1179+ EncodingError :: Utf8 ( Utf8ValidationError :: IncompleteSequence ) => { }
1180+ other => panic ! ( "Expected IncompleteSequence error, got: {:?}" , other) ,
1181+ }
10321182 }
10331183
10341184 #[ test]
@@ -1042,6 +1192,26 @@ mod utf8_validating_reader_tests {
10421192
10431193 let result = reader. read ( & mut buf) ;
10441194 assert ! ( result. is_err( ) ) ;
1195+ let err = result. unwrap_err ( ) ;
1196+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1197+
1198+ // Verify the error can be downcast to EncodingError
1199+ let encoding_err = err
1200+ . get_ref ( )
1201+ . unwrap ( )
1202+ . downcast_ref :: < EncodingError > ( )
1203+ . expect ( "Error should downcast to EncodingError" ) ;
1204+
1205+ // Verify it's an InvalidSequence error
1206+ match encoding_err {
1207+ EncodingError :: Utf8 ( Utf8ValidationError :: InvalidSequence { error_len } ) => {
1208+ assert_eq ! (
1209+ * error_len, 1 ,
1210+ "Expected 1-byte invalid sequence (0xC0 is invalid)"
1211+ ) ;
1212+ }
1213+ other => panic ! ( "Expected InvalidSequence error, got: {:?}" , other) ,
1214+ }
10451215 }
10461216
10471217 #[ test]
@@ -1054,6 +1224,26 @@ mod utf8_validating_reader_tests {
10541224
10551225 let result = reader. read ( & mut buf) ;
10561226 assert ! ( result. is_err( ) ) ;
1227+ let err = result. unwrap_err ( ) ;
1228+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1229+
1230+ // Verify the error can be downcast to EncodingError
1231+ let encoding_err = err
1232+ . get_ref ( )
1233+ . unwrap ( )
1234+ . downcast_ref :: < EncodingError > ( )
1235+ . expect ( "Error should downcast to EncodingError" ) ;
1236+
1237+ // Verify it's an InvalidSequence error
1238+ match encoding_err {
1239+ EncodingError :: Utf8 ( Utf8ValidationError :: InvalidSequence { error_len } ) => {
1240+ assert_eq ! (
1241+ * error_len, 1 ,
1242+ "Expected 1-byte invalid sequence (0xED starts surrogate)"
1243+ ) ;
1244+ }
1245+ other => panic ! ( "Expected InvalidSequence error, got: {:?}" , other) ,
1246+ }
10571247 }
10581248
10591249 #[ test]
@@ -1078,4 +1268,84 @@ mod utf8_validating_reader_tests {
10781268 assert_eq ! ( n, 5 ) ;
10791269 assert_eq ! ( & buf[ ..n] , data) ;
10801270 }
1271+
1272+ #[ test]
1273+ fn test_utf8_bom_stripped ( ) {
1274+ // UTF-8 BOM (0xEF 0xBB 0xBF) followed by "Hello"
1275+ let data = b"\xEF \xBB \xBF Hello" ;
1276+ let mut reader = Utf8ValidatingReader :: new ( & data[ ..] ) ;
1277+ let mut buf = [ 0u8 ; 20 ] ;
1278+ let n = reader. read ( & mut buf) . unwrap ( ) ;
1279+
1280+ // BOM should be stripped, only "Hello" should be returned
1281+ assert_eq ! ( & buf[ ..n] , b"Hello" ) ;
1282+ assert_eq ! ( std:: str :: from_utf8( & buf[ ..n] ) . unwrap( ) , "Hello" ) ;
1283+ }
1284+
1285+ #[ test]
1286+ fn test_utf16le_bom_rejected ( ) {
1287+ // UTF-16 LE BOM (0xFF 0xFE)
1288+ let data = b"\xFF \xFE <?xml" ;
1289+ let mut reader = Utf8ValidatingReader :: new ( & data[ ..] ) ;
1290+ let mut buf = [ 0u8 ; 20 ] ;
1291+
1292+ let result = reader. read ( & mut buf) ;
1293+ assert ! ( result. is_err( ) ) ;
1294+ let err = result. unwrap_err ( ) ;
1295+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1296+
1297+ // Verify the error can be downcast to EncodingError
1298+ let encoding_err = err
1299+ . get_ref ( )
1300+ . unwrap ( )
1301+ . downcast_ref :: < EncodingError > ( )
1302+ . expect ( "Error should downcast to EncodingError" ) ;
1303+
1304+ // Verify it's the NonUtf8EncodingDetected error with the correct encoding
1305+ match encoding_err {
1306+ EncodingError :: Utf8 ( Utf8ValidationError :: NonUtf8EncodingDetected ( detected) ) => {
1307+ assert_eq ! ( * detected, DetectedEncoding :: Utf16LeBom ) ;
1308+ }
1309+ other => panic ! ( "Expected NonUtf8EncodingDetected error, got: {:?}" , other) ,
1310+ }
1311+ }
1312+
1313+ #[ test]
1314+ fn test_utf16be_bom_rejected ( ) {
1315+ // UTF-16 BE BOM (0xFE 0xFF)
1316+ let data = b"\xFE \xFF \x00 <\x00 ?" ;
1317+ let mut reader = Utf8ValidatingReader :: new ( & data[ ..] ) ;
1318+ let mut buf = [ 0u8 ; 20 ] ;
1319+
1320+ let result = reader. read ( & mut buf) ;
1321+ assert ! ( result. is_err( ) ) ;
1322+ let err = result. unwrap_err ( ) ;
1323+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: InvalidData ) ;
1324+
1325+ // Verify the error can be downcast to EncodingError
1326+ let encoding_err = err
1327+ . get_ref ( )
1328+ . unwrap ( )
1329+ . downcast_ref :: < EncodingError > ( )
1330+ . expect ( "Error should downcast to EncodingError" ) ;
1331+
1332+ // Verify it's the NonUtf8EncodingDetected error with the correct encoding
1333+ match encoding_err {
1334+ EncodingError :: Utf8 ( Utf8ValidationError :: NonUtf8EncodingDetected ( detected) ) => {
1335+ assert_eq ! ( * detected, DetectedEncoding :: Utf16BeBom ) ;
1336+ }
1337+ other => panic ! ( "Expected NonUtf8EncodingDetected error, got: {:?}" , other) ,
1338+ }
1339+ }
1340+
1341+ #[ test]
1342+ fn test_ascii_compatible_encoding_allowed ( ) {
1343+ // ASCII-compatible XML declaration (no BOM)
1344+ let data = b"<?xml version=\" 1.0\" ?><root/>" ;
1345+ let mut reader = Utf8ValidatingReader :: new ( & data[ ..] ) ;
1346+ let mut buf = [ 0u8 ; 50 ] ;
1347+
1348+ let n = reader. read ( & mut buf) . unwrap ( ) ;
1349+ assert_eq ! ( & buf[ ..n] , data) ;
1350+ }
10811351}
0 commit comments