diff --git a/Cargo.toml b/Cargo.toml index ccf0be9..88e059a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ default = ["std"] std = [] serde = ["serde/alloc"] bytes = [] +simd = [] [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } @@ -30,4 +31,8 @@ criterion = { version = "0.5", features = ["html_reports"] } [[bench]] name = "cheetah" +harness = false + +[[bench]] +name = "simd" harness = false \ No newline at end of file diff --git a/README.md b/README.md index fc541cd..ceb6958 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ CheetahString is a versatile string type that goes beyond the standard library's - **⚑ Performance Focused** - Optimized for common string operations - Reduced memory allocations via intelligent internal representation + - Optional SIMD acceleration for string matching operations (x86_64 SSE2) - Benchmarked against standard library types - **πŸ›‘οΈ Safe & Correct** @@ -44,20 +45,21 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -cheetah-string = "0.1" +cheetah-string = "1.0.0" ``` ### Optional Features ```toml [dependencies] -cheetah-string = { version = "0.1", features = ["bytes", "serde"] } +cheetah-string = { version = "1.0.0", features = ["bytes", "serde", "simd"] } ``` Available features: - `std` (default): Enable standard library support - `bytes`: Integration with the `bytes` crate - `serde`: Serialization support via serde +- `simd`: SIMD-accelerated string operations (x86_64 SSE2) ## πŸš€ Quick Start @@ -74,8 +76,10 @@ let small = CheetahString::from("short"); // Stored inline! // String operations let s = CheetahString::from("Hello, World!"); -assert!(s.starts_with("Hello")); +assert!(s.starts_with("Hello")); // Supports &str +assert!(s.starts_with('H')); // Also supports char assert!(s.contains("World")); +assert!(s.contains('W')); assert_eq!(s.to_lowercase(), "hello, world!"); // Concatenation @@ -101,10 +105,14 @@ CheetahString is designed with performance in mind: - **Small String Optimization (SSO)**: Strings up to 23 bytes are stored inline without heap allocation - **Efficient Sharing**: Large strings use `Arc` for cheap cloning - **Optimized Operations**: Common operations like concatenation have fast-path implementations +- **SIMD Acceleration** (with `simd` feature): String matching operations (`starts_with`, `ends_with`, `contains`, `find`, equality comparisons) are accelerated using SSE2 SIMD instructions on x86_64 platforms. The implementation automatically falls back to scalar code for small inputs or when SIMD is not available. Run benchmarks: ```bash cargo bench + +# With SIMD feature +cargo bench --features simd ``` ## πŸ” Internal Representation @@ -131,7 +139,7 @@ CheetahString intelligently chooses the most efficient storage: ### Query Methods - `len()`, `is_empty()`, `as_str()`, `as_bytes()` -- `starts_with()`, `ends_with()`, `contains()` +- `starts_with()`, `ends_with()`, `contains()` - Support both `&str` and `char` patterns - `find()`, `rfind()` ### Transformation @@ -141,8 +149,8 @@ CheetahString intelligently chooses the most efficient storage: - `substring()`, `repeat()` ### Iteration -- `chars()` - Iterate over characters -- `split()` - Split by pattern +- `chars()` - Iterate over characters (double-ended iterator) +- `split()` - Split by pattern (supports `&str` and `char`) - `lines()` - Iterate over lines ### Mutation diff --git a/benches/simd.rs b/benches/simd.rs new file mode 100644 index 0000000..453c1f2 --- /dev/null +++ b/benches/simd.rs @@ -0,0 +1,164 @@ +use cheetah_string::CheetahString; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; + +fn bench_equality(c: &mut Criterion) { + let mut group = c.benchmark_group("equality"); + + for size in [16, 32, 64, 128, 256, 512, 1024, 4096] { + let s1 = CheetahString::from("a".repeat(size)); + let s2 = CheetahString::from("a".repeat(size)); + let s3 = CheetahString::from(format!("{}b", "a".repeat(size - 1))); + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("equal", size), &size, |b, _| { + b.iter(|| black_box(&s1) == black_box(&s2)) + }); + + group.bench_with_input(BenchmarkId::new("not_equal", size), &size, |b, _| { + b.iter(|| black_box(&s1) == black_box(&s3)) + }); + } + + group.finish(); +} + +fn bench_starts_with(c: &mut Criterion) { + let mut group = c.benchmark_group("starts_with"); + + for size in [16, 32, 64, 128, 256, 512, 1024, 4096] { + let haystack = CheetahString::from("a".repeat(size)); + let needle_match = "a".repeat(size / 2); + let needle_no_match = "b".repeat(size / 2); + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).starts_with(black_box(&needle_match))) + }); + + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).starts_with(black_box(&needle_no_match))) + }); + } + + group.finish(); +} + +fn bench_ends_with(c: &mut Criterion) { + let mut group = c.benchmark_group("ends_with"); + + for size in [16, 32, 64, 128, 256, 512, 1024, 4096] { + let haystack = CheetahString::from("a".repeat(size)); + let needle_match = "a".repeat(size / 2); + let needle_no_match = "b".repeat(size / 2); + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).ends_with(black_box(&needle_match))) + }); + + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).ends_with(black_box(&needle_no_match))) + }); + } + + group.finish(); +} + +fn bench_contains(c: &mut Criterion) { + let mut group = c.benchmark_group("contains"); + + for size in [16, 32, 64, 128, 256, 512, 1024, 4096] { + let haystack = + CheetahString::from(format!("{}x{}", "a".repeat(size / 2), "a".repeat(size / 2))); + let needle_match = "x"; + let needle_no_match = "z"; + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).contains(black_box(needle_match))) + }); + + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).contains(black_box(needle_no_match))) + }); + } + + group.finish(); +} + +fn bench_find(c: &mut Criterion) { + let mut group = c.benchmark_group("find"); + + for size in [16, 32, 64, 128, 256, 512, 1024, 4096] { + let haystack = + CheetahString::from(format!("{}x{}", "a".repeat(size / 2), "a".repeat(size / 2))); + let needle_match = "x"; + let needle_no_match = "z"; + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).find(black_box(needle_match))) + }); + + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + b.iter(|| black_box(&haystack).find(black_box(needle_no_match))) + }); + } + + group.finish(); +} + +fn bench_realistic_workload(c: &mut Criterion) { + let mut group = c.benchmark_group("realistic"); + + // Simulate URL parsing + let url = CheetahString::from("https://api.example.com/v1/users/12345?filter=active&sort=name"); + + group.bench_function("url_parsing", |b| { + b.iter(|| { + black_box(&url).starts_with("https://") + && black_box(&url).contains("api") + && black_box(&url).contains("users") + }) + }); + + // Simulate log filtering + let log = + CheetahString::from("[2024-01-01 12:00:00] INFO: Processing request for user_id=12345"); + + group.bench_function("log_filtering", |b| { + b.iter(|| { + black_box(&log).starts_with("[2024") + && black_box(&log).contains("INFO") + && black_box(&log).contains("user_id") + }) + }); + + // Simulate content type checking + let content_type = CheetahString::from("application/json; charset=utf-8"); + + group.bench_function("content_type_check", |b| { + b.iter(|| { + black_box(&content_type).starts_with("application/") + && black_box(&content_type).contains("json") + }) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_equality, + bench_starts_with, + bench_ends_with, + bench_contains, + bench_find, + bench_realistic_workload +); +criterion_main!(benches); diff --git a/examples/simd_demo.rs b/examples/simd_demo.rs new file mode 100644 index 0000000..c2f2bc2 --- /dev/null +++ b/examples/simd_demo.rs @@ -0,0 +1,102 @@ +// Example demonstrating SIMD-accelerated string operations in CheetahString +// Run with: cargo run --example simd_demo --features simd + +use cheetah_string::CheetahString; + +fn main() { + println!("CheetahString SIMD Demo"); + println!("=======================\n"); + + // Example 1: Equality comparison + println!("1. Equality Comparison:"); + let s1 = CheetahString::from("Hello, World! This is a SIMD-accelerated string comparison."); + let s2 = CheetahString::from("Hello, World! This is a SIMD-accelerated string comparison."); + let s3 = CheetahString::from("Hello, World! This is a different string."); + + println!(" s1 == s2: {}", s1 == s2); // true (uses SIMD for long strings) + println!(" s1 == s3: {}\n", s1 == s3); // false + + // Example 2: starts_with + println!("2. String Prefix Matching:"); + let url = CheetahString::from("https://api.example.com/v1/users/12345?filter=active&sort=name"); + println!(" URL: {}", url); + println!(" Starts with 'https://': {}", url.starts_with("https://")); + println!(" Starts with 'http://': {}\n", url.starts_with("http://")); + + // Example 3: ends_with + println!("3. String Suffix Matching:"); + let filename = CheetahString::from("document.pdf"); + println!(" Filename: {}", filename); + println!(" Ends with '.pdf': {}", filename.ends_with(".pdf")); + println!(" Ends with '.txt': {}\n", filename.ends_with(".txt")); + + // Example 4: contains + println!("4. Substring Search:"); + let log = CheetahString::from( + "[2024-01-01 12:00:00] INFO: Processing request for user_id=12345 from ip=192.168.1.100", + ); + println!(" Log entry: {}", log); + println!(" Contains 'INFO': {}", log.contains("INFO")); + println!(" Contains 'ERROR': {}", log.contains("ERROR")); + println!(" Contains 'user_id': {}\n", log.contains("user_id")); + + // Example 5: find + println!("5. Pattern Finding:"); + let text = CheetahString::from("The quick brown fox jumps over the lazy dog"); + println!(" Text: {}", text); + if let Some(pos) = text.find("fox") { + println!(" Found 'fox' at position: {}", pos); + } + if let Some(pos) = text.find("lazy") { + println!(" Found 'lazy' at position: {}", pos); + } + if text.find("cat").is_none() { + println!(" 'cat' not found\n"); + } + + // Example 6: Real-world use case - URL validation + println!("6. Real-world Use Case - URL Validation:"); + let urls = vec![ + "https://secure.example.com/api/v1/data", + "http://insecure.example.com/page", + "ftp://files.example.com/download", + ]; + + for url in urls { + let url_str = CheetahString::from(url); + let is_secure = url_str.starts_with("https://"); + let is_api = url_str.contains("/api/"); + println!( + " URL: {} - Secure: {}, API endpoint: {}", + url, is_secure, is_api + ); + } + println!(); + + // Example 7: Performance-sensitive pattern matching + println!("7. Log Processing Example:"); + let logs = vec![ + "[2024-01-01 10:00:00] ERROR: Database connection failed", + "[2024-01-01 10:01:00] INFO: Retrying connection...", + "[2024-01-01 10:02:00] INFO: Connection established", + "[2024-01-01 10:03:00] WARN: High memory usage detected", + ]; + + let mut errors = 0; + let mut warnings = 0; + + for log in logs { + let log_str = CheetahString::from(log); + if log_str.contains("ERROR") { + errors += 1; + println!(" Error found: {}", log); + } else if log_str.contains("WARN") { + warnings += 1; + println!(" Warning found: {}", log); + } + } + + println!("\n Summary: {} errors, {} warnings", errors, warnings); + println!("\nNote: When compiled with --features simd, these operations use SSE2 SIMD"); + println!(" instructions for improved performance on longer strings (>= 16 bytes)."); +} diff --git a/src/cheetah_string.rs b/src/cheetah_string.rs index a0d5a2a..dc07874 100644 --- a/src/cheetah_string.rs +++ b/src/cheetah_string.rs @@ -454,6 +454,9 @@ impl CheetahString { /// Returns `true` if the string starts with the given pattern. /// + /// When the `simd` feature is enabled, this method uses SIMD instructions + /// for improved performance on longer patterns. + /// /// # Examples /// /// ``` @@ -468,7 +471,16 @@ impl CheetahString { pub fn starts_with(&self, pat: P) -> bool { match pat.as_str_pattern() { StrPatternImpl::Char(c) => self.as_str().starts_with(c), - StrPatternImpl::Str(s) => self.as_str().starts_with(s), + StrPatternImpl::Str(s) => { + #[cfg(feature = "simd")] + { + crate::simd::starts_with_bytes(self.as_bytes(), s.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str().starts_with(s) + } + } } } @@ -490,6 +502,9 @@ impl CheetahString { /// Returns `true` if the string ends with the given pattern. /// + /// When the `simd` feature is enabled, this method uses SIMD instructions + /// for improved performance on longer patterns. + /// /// # Examples /// /// ``` @@ -504,7 +519,16 @@ impl CheetahString { pub fn ends_with(&self, pat: P) -> bool { match pat.as_str_pattern() { StrPatternImpl::Char(c) => self.as_str().ends_with(c), - StrPatternImpl::Str(s) => self.as_str().ends_with(s), + StrPatternImpl::Str(s) => { + #[cfg(feature = "simd")] + { + crate::simd::ends_with_bytes(self.as_bytes(), s.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str().ends_with(s) + } + } } } @@ -526,6 +550,9 @@ impl CheetahString { /// Returns `true` if the string contains the given pattern. /// + /// When the `simd` feature is enabled, this method uses SIMD instructions + /// for improved performance on longer patterns. + /// /// # Examples /// /// ``` @@ -540,7 +567,16 @@ impl CheetahString { pub fn contains(&self, pat: P) -> bool { match pat.as_str_pattern() { StrPatternImpl::Char(c) => self.as_str().contains(c), - StrPatternImpl::Str(s) => self.as_str().contains(s), + StrPatternImpl::Str(s) => { + #[cfg(feature = "simd")] + { + crate::simd::find_bytes(self.as_bytes(), s.as_bytes()).is_some() + } + #[cfg(not(feature = "simd"))] + { + self.as_str().contains(s) + } + } } } @@ -562,6 +598,9 @@ impl CheetahString { /// Returns the byte index of the first occurrence of the pattern, or `None` if not found. /// + /// When the `simd` feature is enabled, this method uses SIMD instructions + /// for improved performance on longer patterns. + /// /// # Examples /// /// ``` @@ -573,7 +612,15 @@ impl CheetahString { /// ``` #[inline] pub fn find>(&self, pat: P) -> Option { - self.as_str().find(pat.as_ref()) + let pat = pat.as_ref(); + #[cfg(feature = "simd")] + { + crate::simd::find_bytes(self.as_bytes(), pat.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str().find(pat) + } } /// Returns the byte index of the last occurrence of the pattern, or `None` if not found. @@ -903,21 +950,42 @@ impl CheetahString { impl PartialEq for CheetahString { #[inline] fn eq(&self, other: &Self) -> bool { - self.as_str() == other.as_str() + #[cfg(feature = "simd")] + { + crate::simd::eq_bytes(self.as_bytes(), other.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str() == other.as_str() + } } } impl PartialEq for CheetahString { #[inline] fn eq(&self, other: &str) -> bool { - self.as_str() == other + #[cfg(feature = "simd")] + { + crate::simd::eq_bytes(self.as_bytes(), other.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str() == other + } } } impl PartialEq for CheetahString { #[inline] fn eq(&self, other: &String) -> bool { - self.as_str() == other.as_str() + #[cfg(feature = "simd")] + { + crate::simd::eq_bytes(self.as_bytes(), other.as_bytes()) + } + #[cfg(not(feature = "simd"))] + { + self.as_str() == other.as_str() + } } } diff --git a/src/lib.rs b/src/lib.rs index 28e4f89..5a5e6de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,29 @@ //! It is usable in both `std` and `no_std` environments. Additionally, CheetahString supports serde for serialization and deserialization. //! CheetahString also supports the `bytes` feature, allowing conversion to the `bytes::Bytes` type. //! This reduces memory allocations during cloning, enhancing performance. -//! example: +//! +//! # SIMD Acceleration +//! +//! When compiled with the `simd` feature flag, CheetahString uses SIMD (Single Instruction, Multiple Data) +//! instructions to accelerate string matching operations on x86_64 platforms with SSE2 support. +//! SIMD acceleration is applied to: +//! - `starts_with()` - Pattern prefix matching +//! - `ends_with()` - Pattern suffix matching +//! - `contains()` / `find()` - Substring search +//! - Equality comparisons (`==`, `!=`) +//! +//! The implementation automatically uses SIMD for strings >= 16 bytes and falls back to scalar operations +//! for smaller inputs or when SIMD is not available. +//! +//! To enable SIMD acceleration: +//! ```toml +//! [dependencies] +//! cheetah-string = { version = "1.0.0", features = ["simd"] } +//! ``` +//! +//! # Examples +//! +//! Basic usage: //! ```rust //! use cheetah_string::CheetahString; //! @@ -18,11 +40,30 @@ //! //! ``` //! +//! Using SIMD-accelerated operations (when `simd` feature is enabled): +//! ```rust +//! use cheetah_string::CheetahString; +//! +//! let url = CheetahString::from("https://api.example.com/v1/users"); +//! +//! // These operations use SIMD when the pattern is >= 16 bytes +//! if url.starts_with("https://") { +//! println!("Secure connection"); +//! } +//! +//! if url.contains("api") { +//! println!("API endpoint"); +//! } +//! ``` +//! mod cheetah_string; mod error; #[cfg(feature = "serde")] mod serde; +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +mod simd; + pub use cheetah_string::{CheetahString, SplitPattern, SplitStr, SplitWrapper, StrPattern}; pub use error::{Error, Result}; diff --git a/src/simd.rs b/src/simd.rs new file mode 100644 index 0000000..3381b78 --- /dev/null +++ b/src/simd.rs @@ -0,0 +1,278 @@ +//! SIMD-accelerated string operations +//! +//! This module provides SIMD implementations for common string operations +//! when the `simd` feature is enabled. It automatically falls back to +//! scalar implementations when SIMD is not available or for small inputs. + +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +use core::arch::x86_64::*; + +/// Minimum length threshold for using SIMD operations +const SIMD_THRESHOLD: usize = 16; + +/// Compare two byte slices for equality using SIMD when available +#[inline] +pub(crate) fn eq_bytes(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + + #[cfg(all(feature = "simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse2") && a.len() >= SIMD_THRESHOLD { + return unsafe { eq_bytes_sse2(a, b) }; + } + } + + // Fallback to standard comparison + a == b +} + +/// Check if haystack starts with needle using SIMD when available +#[inline] +pub(crate) fn starts_with_bytes(haystack: &[u8], needle: &[u8]) -> bool { + if needle.len() > haystack.len() { + return false; + } + + if needle.is_empty() { + return true; + } + + #[cfg(all(feature = "simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse2") && needle.len() >= SIMD_THRESHOLD { + return unsafe { eq_bytes_sse2(&haystack[..needle.len()], needle) }; + } + } + + // Fallback to standard comparison + haystack.starts_with(needle) +} + +/// Check if haystack ends with needle using SIMD when available +#[inline] +pub(crate) fn ends_with_bytes(haystack: &[u8], needle: &[u8]) -> bool { + if needle.len() > haystack.len() { + return false; + } + + if needle.is_empty() { + return true; + } + + #[cfg(all(feature = "simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse2") && needle.len() >= SIMD_THRESHOLD { + let start = haystack.len() - needle.len(); + return unsafe { eq_bytes_sse2(&haystack[start..], needle) }; + } + } + + // Fallback to standard comparison + haystack.ends_with(needle) +} + +/// Find the first occurrence of needle in haystack using SIMD when available +#[inline] +pub(crate) fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option { + if needle.is_empty() { + return Some(0); + } + + if needle.len() > haystack.len() { + return None; + } + + #[cfg(all(feature = "simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse2") + && needle.len() >= SIMD_THRESHOLD + && haystack.len() >= SIMD_THRESHOLD + { + return unsafe { find_bytes_sse2(haystack, needle) }; + } + } + + // Fallback to standard search + haystack + .windows(needle.len()) + .position(|window| window == needle) +} + +// SIMD implementations for x86_64 with SSE2 + +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +#[target_feature(enable = "sse2")] +#[inline] +unsafe fn eq_bytes_sse2(a: &[u8], b: &[u8]) -> bool { + debug_assert_eq!(a.len(), b.len()); + + let len = a.len(); + let mut offset = 0; + + // Process 16 bytes at a time + while offset + 16 <= len { + let a_vec = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i); + let b_vec = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i); + let cmp = _mm_cmpeq_epi8(a_vec, b_vec); + let mask = _mm_movemask_epi8(cmp); + + if mask != 0xFFFF { + return false; + } + + offset += 16; + } + + // Handle remaining bytes + for i in offset..len { + if a[i] != b[i] { + return false; + } + } + + true +} + +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +#[target_feature(enable = "sse2")] +#[inline] +unsafe fn find_bytes_sse2(haystack: &[u8], needle: &[u8]) -> Option { + let haystack_len = haystack.len(); + let needle_len = needle.len(); + + if needle_len > haystack_len { + return None; + } + + // For small needles, use a simple SIMD approach + if needle_len == 1 { + return find_byte_sse2(haystack, needle[0]); + } + + // For larger needles, use a hybrid approach + // First, search for the first byte of the needle + let first_byte = needle[0]; + let mut pos = 0; + + while pos + needle_len <= haystack_len { + // Find the next occurrence of the first byte + if let Some(offset) = find_byte_sse2(&haystack[pos..], first_byte) { + let candidate_pos = pos + offset; + + // Check if the rest matches + if candidate_pos + needle_len <= haystack_len { + if eq_bytes_sse2(&haystack[candidate_pos..candidate_pos + needle_len], needle) { + return Some(candidate_pos); + } + pos = candidate_pos + 1; + } else { + return None; + } + } else { + return None; + } + } + + None +} + +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +#[target_feature(enable = "sse2")] +#[inline] +unsafe fn find_byte_sse2(haystack: &[u8], needle: u8) -> Option { + let len = haystack.len(); + let mut offset = 0; + + // Broadcast the needle byte to all positions in the vector + let needle_vec = _mm_set1_epi8(needle as i8); + + // Process 16 bytes at a time + while offset + 16 <= len { + let haystack_vec = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i); + let cmp = _mm_cmpeq_epi8(haystack_vec, needle_vec); + let mask = _mm_movemask_epi8(cmp); + + if mask != 0 { + // Found at least one match + let bit_pos = mask.trailing_zeros() as usize; + return Some(offset + bit_pos); + } + + offset += 16; + } + + // Handle remaining bytes + haystack[offset..len] + .iter() + .position(|&b| b == needle) + .map(|pos| offset + pos) +} + +#[cfg(all(test, feature = "simd"))] +mod tests { + use super::*; + + #[test] + fn test_eq_bytes() { + let a = b"hello world, this is a test"; + let b = b"hello world, this is a test"; + let c = b"hello world, this is b test"; + + assert!(eq_bytes(a, b)); + assert!(!eq_bytes(a, c)); + assert!(!eq_bytes(&a[..10], a)); + } + + #[test] + fn test_starts_with_bytes() { + let haystack = b"hello world, this is a test"; + assert!(starts_with_bytes(haystack, b"hello")); + assert!(starts_with_bytes(haystack, b"hello world")); + assert!(!starts_with_bytes(haystack, b"world")); + assert!(starts_with_bytes(haystack, b"")); + } + + #[test] + fn test_ends_with_bytes() { + let haystack = b"hello world, this is a test"; + assert!(ends_with_bytes(haystack, b"test")); + assert!(ends_with_bytes(haystack, b"a test")); + assert!(!ends_with_bytes(haystack, b"hello")); + assert!(ends_with_bytes(haystack, b"")); + } + + #[test] + fn test_find_bytes() { + let haystack = b"hello world, this is a test"; + assert_eq!(find_bytes(haystack, b"world"), Some(6)); + assert_eq!(find_bytes(haystack, b"test"), Some(23)); + assert_eq!(find_bytes(haystack, b"xyz"), None); + assert_eq!(find_bytes(haystack, b""), Some(0)); + } + + #[test] + fn test_find_byte() { + let haystack = b"hello world"; + unsafe { + assert_eq!(find_byte_sse2(haystack, b'w'), Some(6)); + assert_eq!(find_byte_sse2(haystack, b'h'), Some(0)); + assert_eq!(find_byte_sse2(haystack, b'd'), Some(10)); + assert_eq!(find_byte_sse2(haystack, b'x'), None); + } + } + + #[test] + fn test_simd_threshold() { + // Test with strings below SIMD threshold + let small_a = b"hello"; + let small_b = b"hello"; + assert!(eq_bytes(small_a, small_b)); + + // Test with strings above SIMD threshold + let large_a = b"this is a longer string that exceeds the SIMD threshold"; + let large_b = b"this is a longer string that exceeds the SIMD threshold"; + assert!(eq_bytes(large_a, large_b)); + } +} diff --git a/tests/simd.rs b/tests/simd.rs new file mode 100644 index 0000000..f915636 --- /dev/null +++ b/tests/simd.rs @@ -0,0 +1,202 @@ +#![cfg(all(feature = "simd", target_arch = "x86_64"))] + +use cheetah_string::CheetahString; + +#[test] +fn test_simd_equality() { + // Short strings (below SIMD threshold) + let s1 = CheetahString::from("hello"); + let s2 = CheetahString::from("hello"); + let s3 = CheetahString::from("world"); + assert_eq!(s1, s2); + assert_ne!(s1, s3); + + // Long strings (above SIMD threshold) + let long1 = CheetahString::from("a".repeat(1024)); + let long2 = CheetahString::from("a".repeat(1024)); + let long3 = CheetahString::from(format!("{}b", "a".repeat(1023))); + assert_eq!(long1, long2); + assert_ne!(long1, long3); + + // Equality with str + assert_eq!(s1, "hello"); + assert_ne!(s1, "world"); + + // Equality with String + assert_eq!(s1, String::from("hello")); + assert_ne!(s1, String::from("world")); +} + +#[test] +fn test_simd_starts_with() { + let s = CheetahString::from("hello world, this is a test"); + + // Short patterns + assert!(s.starts_with("hello")); + assert!(s.starts_with("hello world")); + assert!(!s.starts_with("world")); + + // Long patterns (above SIMD threshold) + let long = CheetahString::from("a".repeat(1024)); + assert!(long.starts_with(&"a".repeat(100))); + assert!(long.starts_with(&"a".repeat(500))); + assert!(!long.starts_with(&"b".repeat(100))); + + // Edge cases + assert!(s.starts_with("")); + let empty = CheetahString::from(""); + assert!(empty.starts_with("")); + assert!(!empty.starts_with("a")); +} + +#[test] +fn test_simd_ends_with() { + let s = CheetahString::from("hello world, this is a test"); + + // Short patterns + assert!(s.ends_with("test")); + assert!(s.ends_with("a test")); + assert!(!s.ends_with("hello")); + + // Long patterns (above SIMD threshold) + let long = CheetahString::from("a".repeat(1024)); + assert!(long.ends_with(&"a".repeat(100))); + assert!(long.ends_with(&"a".repeat(500))); + assert!(!long.ends_with(&"b".repeat(100))); + + // Edge cases + assert!(s.ends_with("")); + let empty = CheetahString::from(""); + assert!(empty.ends_with("")); + assert!(!empty.ends_with("a")); +} + +#[test] +fn test_simd_contains() { + let s = CheetahString::from("hello world, this is a test"); + + // Short patterns + assert!(s.contains("world")); + assert!(s.contains("this")); + assert!(!s.contains("xyz")); + + // Long strings + let long = CheetahString::from(format!("{}needle{}", "a".repeat(500), "a".repeat(500))); + assert!(long.contains("needle")); + assert!(!long.contains("haystack")); + + // Edge cases + assert!(s.contains("")); + let empty = CheetahString::from(""); + assert!(empty.contains("")); + assert!(!empty.contains("a")); +} + +#[test] +fn test_simd_find() { + let s = CheetahString::from("hello world, this is a test"); + + // Short patterns + assert_eq!(s.find("world"), Some(6)); + assert_eq!(s.find("this"), Some(13)); + assert_eq!(s.find("test"), Some(23)); + assert_eq!(s.find("xyz"), None); + + // Long strings + let long = CheetahString::from(format!("{}needle{}", "a".repeat(500), "a".repeat(500))); + assert_eq!(long.find("needle"), Some(500)); + assert_eq!(long.find("haystack"), None); + + // Edge cases + assert_eq!(s.find(""), Some(0)); + let empty = CheetahString::from(""); + assert_eq!(empty.find(""), Some(0)); + assert_eq!(empty.find("a"), None); + + // First character + assert_eq!(s.find("h"), Some(0)); + // First 't' character appears in "this" + assert_eq!(s.find("t"), Some(13)); +} + +#[test] +fn test_simd_unicode() { + // Test with unicode strings + let s = CheetahString::from("Hello δΈ–η•Œ! 🌍"); + + assert!(s.starts_with("Hello")); + assert!(s.contains("δΈ–η•Œ")); + assert!(s.ends_with("🌍")); + assert_eq!(s.find("δΈ–η•Œ"), Some(6)); + + // Test equality with unicode + let s1 = CheetahString::from("δΈ–η•Œ"); + let s2 = CheetahString::from("δΈ–η•Œ"); + assert_eq!(s1, s2); +} + +#[test] +fn test_simd_aligned_and_unaligned() { + // Test with different alignment scenarios + for offset in 0..16 { + let prefix = "x".repeat(offset); + let content = "a".repeat(100); + let s = CheetahString::from(format!("{}{}", prefix, content)); + + assert!(s.starts_with(&prefix)); + assert!(s.contains(&content)); + assert!(s.ends_with(&content)); + } +} + +#[test] +fn test_simd_boundary_conditions() { + // Test strings of exactly SIMD_THRESHOLD length (16 bytes) + let s16 = CheetahString::from("0123456789abcdef"); // exactly 16 bytes + assert!(s16.starts_with("0123456789abcdef")); + assert!(s16.ends_with("0123456789abcdef")); + assert!(s16.contains("0123456789abcdef")); + assert_eq!(s16, "0123456789abcdef"); + + // Test strings just below and above threshold + let s15 = CheetahString::from("0123456789abcde"); // 15 bytes + let s17 = CheetahString::from("0123456789abcdefg"); // 17 bytes + + assert!(s15.starts_with("0123456789abcde")); + assert!(s17.starts_with("0123456789abcdefg")); +} + +#[test] +fn test_simd_pattern_at_end() { + // Test finding pattern at the very end + let s = CheetahString::from("aaaaaaaaaaaaaaab"); + assert_eq!(s.find("b"), Some(15)); + assert!(s.ends_with("b")); + + // Test with longer strings + let long = CheetahString::from(format!("{}end", "a".repeat(1000))); + assert_eq!(long.find("end"), Some(1000)); + assert!(long.ends_with("end")); +} + +#[test] +fn test_simd_multiple_occurrences() { + // Test that find returns the first occurrence + let s = CheetahString::from("abcabcabc"); + assert_eq!(s.find("abc"), Some(0)); + assert_eq!(s.find("bc"), Some(1)); + assert_eq!(s.find("ca"), Some(2)); +} + +#[test] +fn test_simd_inline_storage() { + // Test with inline-stored strings (≀ 23 bytes) + let inline = CheetahString::from("short string"); + assert!(inline.starts_with("short")); + assert!(inline.contains("string")); + assert!(inline.ends_with("string")); + assert_eq!(inline.find("string"), Some(6)); + + let inline2 = CheetahString::from("short string"); + assert_eq!(inline, inline2); +}