diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..ddff440 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustflags = ["-C", "target-cpu=native"] diff --git a/Cargo.lock b/Cargo.lock index c023ae2..3cc4040 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,14 +39,22 @@ dependencies = [ "rand", "rand_chacha", "rand_distr", + "wide", ] +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + [[package]] name = "capt" version = "3.0.2" dependencies = [ "aligned-vec", "rand", + "wide", ] [[package]] @@ -244,6 +252,15 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "safe_arch" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f7caad094bd561859bcd467734a720c3c1f5d1f338995351fefe2190c45efed" +dependencies = [ + "bytemuck", +] + [[package]] name = "shlex" version = "1.3.0" @@ -273,6 +290,16 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "wide" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac11b009ebeae802ed758530b6496784ebfee7a87b9abfbcaf3bbe25b814eb25" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "windows-link" version = "0.2.1" diff --git a/bench/Cargo.toml b/bench/Cargo.toml index 3a02ef7..3b2cb3b 100644 --- a/bench/Cargo.toml +++ b/bench/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] capt = { path = "../capt", features = ["simd"] } +wide = { version = "1.1.1", default-features = false } morton_filter = { path = "../morton_filter" } kiddo = { version = "5.2.2", features = ["simd"], default-features = false } rand = { version = "0.9.1", default-features = false } diff --git a/bench/src/bin/correctness.rs b/bench/src/bin/correctness.rs index 843165b..35b3709 100644 --- a/bench/src/bin/correctness.rs +++ b/bench/src/bin/correctness.rs @@ -1,8 +1,5 @@ -#![feature(portable_simd)] - -use std::simd::Simd; - use bench::{dist, kdt::PkdTree, parse_pointcloud_csv, parse_trace_csv, trace_r_range}; +use wide::f32x8; use capt::Capt; use kiddo::SquaredEuclidean; use rand::{seq::SliceRandom, Rng, SeedableRng}; @@ -65,19 +62,19 @@ fn main() -> Result<(), Box> { let exact_dist = dist(kdt.get_point(kdt.query1_exact(*center)), *center); assert_eq!(exact_dist, exact_kiddo_dist); - let simd_center: [Simd; 3] = [ - Simd::splat(center[0]), - Simd::splat(center[1]), - Simd::splat(center[2]), + let simd_center: [f32x8; 3] = [ + f32x8::splat(center[0]), + f32x8::splat(center[1]), + f32x8::splat(center[2]), ]; if exact_dist <= *r { println!("iter {i}: {:?} (collides)", (center, r)); assert!(aff_tree.collides(center, *r)); - assert!(aff_tree.collides_simd(&simd_center, Simd::splat(*r))) + assert!(aff_tree.collides_simd(&simd_center, f32x8::splat(*r))) } else { println!("iter {i}: {:?} (no collides)", (center, r)); assert!(!aff_tree.collides(center, *r)); - assert!(!aff_tree.collides_simd(&simd_center, Simd::splat(*r))) + assert!(!aff_tree.collides_simd(&simd_center, f32x8::splat(*r))) } } diff --git a/bench/src/bin/error.rs b/bench/src/bin/error.rs index 7a2b8ae..ff8855d 100644 --- a/bench/src/bin/error.rs +++ b/bench/src/bin/error.rs @@ -1,22 +1,19 @@ -#![feature(portable_simd)] - use bench::{dist, fuzz_pointcloud, get_points, kdt::PkdTree, make_needles}; use kiddo::SquaredEuclidean; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; const N: usize = 1 << 16; -const L: usize = 16; const D: usize = 3; fn main() { let mut rng = ChaCha20Rng::seed_from_u64(2707); let mut starting_points = get_points(N); fuzz_pointcloud(&mut starting_points, 0.001, &mut rng); - measure_error::(&starting_points, &mut rng, 1 << 16) + measure_error::(&starting_points, &mut rng, 1 << 16) } -pub fn measure_error( +pub fn measure_error( points: &[[f32; D]], rng: &mut impl Rng, n_trials: usize, @@ -27,7 +24,7 @@ pub fn measure_error( kiddo_kdt.add(pt, 0); } - let (seq_needles, _) = make_needles::(rng, n_trials); + let (seq_needles, _) = make_needles::(rng, n_trials); for seq_needle in seq_needles { let exact_kiddo_dist = kiddo_kdt diff --git a/bench/src/bin/forest_error.rs b/bench/src/bin/forest_error.rs index c5bb3b9..fffac97 100644 --- a/bench/src/bin/forest_error.rs +++ b/bench/src/bin/forest_error.rs @@ -30,7 +30,7 @@ fn err_forest(points: &[[f32; 3]], rng: &mut impl Rng) { kiddo_kdt.add(pt, 0); } - let (seq_needles, _) = make_needles::<3, 1>(rng, 10_000); + let (seq_needles, _) = make_needles::<3>(rng, 10_000); let mut total_err = 0.0; for &needle in &seq_needles { diff --git a/bench/src/bin/perf_plots.rs b/bench/src/bin/perf_plots.rs index d4174bf..f9a30f2 100644 --- a/bench/src/bin/perf_plots.rs +++ b/bench/src/bin/perf_plots.rs @@ -1,14 +1,10 @@ -#![feature(portable_simd)] - -use std::{ - cmp::min, env::args, error::Error, fs::File, hint::black_box, io::Write, simd::f32x8, - time::Duration, -}; +use std::{cmp::min, env::args, error::Error, fs::File, hint::black_box, io::Write, time::Duration}; use bench::{ forest::PkdForest, fuzz_pointcloud, kdt::PkdTree, parse_pointcloud_csv, parse_trace_csv, simd_trace_new, stopwatch, SimdTrace, Trace, }; +use wide::f32x8; use capt::Capt; #[allow(unused_imports)] use kiddo::SquaredEuclidean; @@ -17,13 +13,12 @@ use rand::{seq::SliceRandom, Rng}; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; const N_TRIALS: usize = 100_000; -const L: usize = 8; const QUERY_RADIUS: f32 = 0.05; struct Benchmark<'a> { seq: &'a Trace, - simd: &'a SimdTrace, + simd: &'a SimdTrace, f_query: File, } @@ -89,7 +84,7 @@ fn main() -> Result<(), Box> { println!("number of tests: {}", all_trace.len()); println!("radius range: {r_range:?}"); - let captree = Capt::<3>::new(&points, r_range, L); + let captree = Capt::<3>::new(&points, r_range, 8); let collide_trace: Box = all_trace .iter() @@ -156,7 +151,7 @@ fn do_row( let (pkdt, pkdt_time) = stopwatch(|| PkdTree::new(points)); - let (captree, captree_time) = stopwatch(|| Capt::<3, f32, u32>::new(points, r_range, L)); + let (captree, captree_time) = stopwatch(|| Capt::<3, f32, u32>::new(points, r_range, 8)); let (f1, f1_time) = stopwatch(|| PkdForest::<3, 1>::new(points)); let (f2, f2_time) = stopwatch(|| PkdForest::<3, 2>::new(points)); @@ -224,7 +219,7 @@ fn do_row( }); let (_, pkdt_total_simd_q_time) = stopwatch(|| { for (centers, radii) in simd_trace.iter() { - black_box(pkdt.might_collide_simd(centers, radii * radii)); + black_box(pkdt.might_collide_simd(centers, *radii * *radii)); } }); let (_, captree_total_seq_q_time) = stopwatch(|| { @@ -234,7 +229,7 @@ fn do_row( }); let (_, captree_total_simd_q_time) = stopwatch(|| { for (centers, radii) in simd_trace.iter() { - black_box(captree.collides_simd(centers, radii * radii)); + black_box(captree.collides_simd(centers, *radii * *radii)); } }); @@ -280,7 +275,7 @@ fn bench_forest( ) -> Duration { stopwatch(|| { for (centers, radii) in simd_trace { - black_box(forest.might_collide_simd(centers, radii * radii)); + black_box(forest.might_collide_simd(centers, *radii * *radii)); } }) .1 diff --git a/bench/src/forest.rs b/bench/src/forest.rs index 92dccb3..8d3ee6f 100644 --- a/bench/src/forest.rs +++ b/bench/src/forest.rs @@ -1,8 +1,7 @@ //! Power-of-two k-d forests. -use std::simd::{cmp::SimdPartialOrd, ptr::SimdConstPtr, Mask, Simd}; - use crate::{distsq, median_partition}; +use wide::{f32x8, i32x8, CmpGe}; #[derive(Clone, Debug)] struct RandomizedTree { @@ -53,26 +52,25 @@ impl PkdForest { } #[must_use] - pub fn might_collide_simd( - &self, - needles: &[Simd; K], - radii_squared: Simd, - ) -> bool { - let mut not_yet_collided = Mask::splat(true); + #[allow(clippy::cast_sign_loss)] + pub fn might_collide_simd(&self, needles: &[f32x8; K], radii_squared: f32x8) -> bool { + // all_true: f32x8 bitmask where all lanes are "not yet collided" + let all_true: f32x8 = + unsafe { core::mem::transmute::(i32x8::splat(-1_i32)) }; + let mut not_yet_collided = all_true; for tree in &self.test_seqs { - let indices = tree.mask_query(needles, not_yet_collided); - let mut dists_sq = Simd::splat(0.0); - let mut ptrs = Simd::splat(tree.points.as_ptr().cast()).wrapping_offset(indices); - for needle_set in needles { - let diffs = - unsafe { Simd::gather_select_ptr(ptrs, not_yet_collided, Simd::splat(0.0)) } - - needle_set; - dists_sq += diffs * diffs; - ptrs = ptrs.wrapping_add(Simd::splat(1)); + let indices = tree.forward_pass_wide(needles); + let idx_arr = indices.to_array(); + let mut dists_sq = f32x8::ZERO; + for (k, needle_values) in needles.iter().enumerate() { + let vals = f32x8::new(idx_arr.map(|i| tree.points[i as usize][k])); + let diffs = vals - needle_values; + dists_sq = dists_sq + diffs * diffs; } - not_yet_collided &= radii_squared.simd_lt(dists_sq).cast(); + // lanes where dists_sq >= radii_squared have not (yet) collided + not_yet_collided = not_yet_collided & dists_sq.simd_ge(radii_squared); if !not_yet_collided.all() { // at least one has collided - can return quickly @@ -145,37 +143,24 @@ impl RandomizedTree { test_idx - self.tests.len() } - #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] - /// Perform a masked SIMD query of this tree, only determining the location of the nearest - /// neighbors for points in `mask`. - fn mask_query( - &self, - needles: &[Simd; K], - mask: Mask, - ) -> Simd { - let mut test_idxs: Simd = Simd::splat(0); + #[allow(clippy::cast_sign_loss)] + fn forward_pass_wide(&self, needles: &[f32x8; K]) -> i32x8 { + let mut test_idxs = i32x8::splat(0_i32); let mut state = self.seed; - // Advance the tests forward for _ in 0..self.tests.len().trailing_ones() { - let relevant_tests: Simd = unsafe { - Simd::gather_select_ptr( - Simd::splat(self.tests.as_ptr().cast()).wrapping_offset(test_idxs), - mask, - Simd::splat(f32::NAN), - ) - }; + let idx_arr = test_idxs.to_array(); + let relevant_tests = + f32x8::new(idx_arr.map(|i| unsafe { *self.tests.get_unchecked(i as usize) })); let d = state as usize % K; - let cmp_results: Mask = (needles[d].simd_ge(relevant_tests)).into(); - - // TODO is there a faster way than using a conditional select? - test_idxs <<= Simd::splat(1); - test_idxs += Simd::splat(1); - test_idxs += cmp_results.to_simd() & Simd::splat(1); + let cmp_f = needles[d].simd_ge(relevant_tests); + let cmp_bit: i32x8 = + unsafe { core::mem::transmute::(cmp_f) } & i32x8::splat(1); + test_idxs = (test_idxs << 1_i32) + 1_i32 + cmp_bit; state = xorshift(state); } - test_idxs - Simd::splat(self.tests.len() as isize) + test_idxs - i32x8::splat(self.tests.len() as i32) } } diff --git a/bench/src/kdt.rs b/bench/src/kdt.rs index cb05fc1..3eed0e8 100644 --- a/bench/src/kdt.rs +++ b/bench/src/kdt.rs @@ -1,10 +1,8 @@ -use std::{ - mem::size_of, - simd::{cmp::SimdPartialOrd, num::SimdInt, ptr::SimdConstPtr, Mask, Simd}, -}; +use std::mem::size_of; use crate::{distsq, forward_pass, median_partition}; -use capt::{Aabb, AxisSimd, AxisSimdElement}; +use capt::Aabb; +use wide::{f32x8, i32x8, CmpGe, CmpLt}; #[derive(Clone, Debug, PartialEq)] /// A power-of-two KD-tree. @@ -81,27 +79,21 @@ impl PkdTree { } #[must_use] - #[allow(clippy::cast_possible_wrap)] - pub fn might_collide_simd( - &self, - needles: &[Simd; K], - radii_squared: Simd, - ) -> bool { - let indices = forward_pass_simd(&self.tests, needles); - let mut dists_squared = Simd::splat(0.0); - let mut ptrs = - Simd::splat(self.points.as_ptr().cast()).wrapping_add(indices * Simd::splat(K)); - for needle_values in needles { - let deltas = unsafe { Simd::gather_ptr(ptrs) } - needle_values; - dists_squared += deltas * deltas; - ptrs = ptrs.wrapping_add(Simd::splat(1)); + #[allow(clippy::cast_sign_loss)] + pub fn might_collide_simd(&self, needles: &[f32x8; K], radii_squared: f32x8) -> bool { + let indices = forward_pass_wide(&self.tests, needles); + let idx_arr = indices.to_array(); + let mut dists_squared = f32x8::ZERO; + for (k, needle_values) in needles.iter().enumerate() { + let vals = f32x8::new(idx_arr.map(|i| self.points[i as usize][k])); + let deltas = vals - needle_values; + dists_squared = dists_squared + deltas * deltas; } dists_squared.simd_lt(radii_squared).any() } #[must_use] - #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] - /// Query for one point in this tree, returning an exact answer. + #[allow(clippy::cast_possible_wrap, clippy::missing_panics_doc)] pub fn query1_exact(&self, needle: [f32; K]) -> usize { let mut id = usize::MAX; let mut best_distsq = f32::INFINITY; @@ -191,40 +183,33 @@ impl PkdTree { #[must_use] #[allow(clippy::missing_panics_doc)] - pub const fn get_point(&self, id: usize) -> [f32; K] { + pub fn get_point(&self, id: usize) -> [f32; K] { self.points[id] } #[must_use] /// Return the total memory used (stack + heap) by this structure. - pub const fn memory_used(&self) -> usize { + pub fn memory_used(&self) -> usize { size_of::() + (self.points.len() * K + self.tests.len()) * size_of::() } } #[inline] -#[allow(clippy::cast_possible_wrap)] -fn forward_pass_simd( - tests: &[A], - centers: &[Simd; K], -) -> Simd -where - Simd: AxisSimd, - A: AxisSimdElement, -{ - let mut i: Simd = Simd::splat(0); +#[allow(clippy::cast_sign_loss)] +fn forward_pass_wide(tests: &[f32], centers: &[f32x8; K]) -> i32x8 { + let mut test_idxs = i32x8::splat(0_i32); let mut k = 0; for _ in 0..tests.len().trailing_ones() { - let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_add(i); - let relevant_tests = unsafe { Simd::gather_ptr(test_ptrs) }; - let cmp: Mask = Simd::::cast_mask(centers[k].simd_ge(relevant_tests)); - - let one = Simd::splat(1); - i = (i << one) + one + (cmp.to_simd().cast() & one); + let idx_arr = test_idxs.to_array(); + let relevant_tests = + f32x8::new(idx_arr.map(|i| unsafe { *tests.get_unchecked(i as usize) })); + let cmp_f = centers[k % K].simd_ge(relevant_tests); + let cmp_bit: i32x8 = + unsafe { core::mem::transmute::(cmp_f) } & i32x8::splat(1); + test_idxs = (test_idxs << 1_i32) + 1_i32 + cmp_bit; k = (k + 1) % K; } - - i - Simd::splat(tests.len()) + test_idxs - i32x8::splat(tests.len() as i32) } #[cfg(test)] @@ -274,11 +259,13 @@ mod tests { ]; let kdt = PkdTree::new(&points); - let needles = [Simd::from_array([-1.0, 2.0]), Simd::from_array([-1.0, 2.0])]; - assert_eq!( - forward_pass_simd(&kdt.tests, &needles), - Simd::from_array([0, points.len() - 1]) - ); + let needles = [ + f32x8::new([-1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + f32x8::new([-1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + ]; + let result = forward_pass_wide(&kdt.tests, &needles).to_array(); + assert_eq!(result[0], 0); + assert_eq!(result[1], (points.len() - 1) as i32); } #[test] diff --git a/bench/src/lib.rs b/bench/src/lib.rs index 72ca390..b445903 100644 --- a/bench/src/lib.rs +++ b/bench/src/lib.rs @@ -1,16 +1,14 @@ -#![feature(portable_simd)] - use std::{ env, error::Error, path::Path, - simd::Simd, time::{Duration, Instant}, }; use capt::Axis; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; +use wide::f32x8; use rand_distr::{Distribution, Normal}; @@ -44,34 +42,33 @@ pub fn get_points(n_points_if_no_cloud: usize) -> Box<[[f32; 3]]> { /// # Generic parameters /// /// - `D`: the dimension of the space -/// - `L`: the number of SIMD lanes /// /// # Returns /// /// Returns a pair `(seq_needles, simd_needles)`, where `seq_needles` is correctly shaped for -/// sequential querying and `simd_needles` is correctly shaped for SIMD querying. -pub fn make_needles( +/// sequential querying and `simd_needles` is correctly shaped for SIMD querying (8 lanes). +pub fn make_needles( rng: &mut impl Rng, n_trials: usize, -) -> (Vec<[f32; D]>, Vec<[Simd; D]>) { +) -> (Vec<[f32; D]>, Vec<[f32x8; D]>) { let mut seq_needles = Vec::new(); let mut simd_needles = Vec::new(); - for _ in 0..n_trials / L { - let mut simd_pts = [Simd::splat(0.0); D]; - for l in 0..L { + for _ in 0..n_trials / 8 { + let mut simd_pts = [[0.0_f32; 8]; D]; + for l in 0..8 { let mut seq_needle = [0.0; D]; - for d in 0..3 { + for d in 0..D { let value = rng.random_range::(0.0..1.0); seq_needle[d] = value; - simd_pts[d].as_mut_array()[l] = value; + simd_pts[d][l] = value; } seq_needles.push(seq_needle); } - simd_needles.push(simd_pts); + simd_needles.push(simd_pts.map(f32x8::new)); } - assert_eq!(seq_needles.len(), simd_needles.len() * L); + assert_eq!(seq_needles.len(), simd_needles.len() * 8); (seq_needles, simd_needles) } @@ -82,39 +79,38 @@ pub fn make_needles( /// # Generic parameters /// /// - `D`: the dimension of the space -/// - `L`: the number of SIMD lanes /// /// # Returns /// /// Returns a pair `(seq_needles, simd_needles)`, where `seq_needles` is correctly shaped for -/// sequential querying and `simd_needles` is correctly shaped for SIMD querying. +/// sequential querying and `simd_needles` is correctly shaped for SIMD querying (8 lanes). /// Additionally, each element of each element of `simd_needles` will be relatively close in space. -pub fn make_correlated_needles( +pub fn make_correlated_needles( rng: &mut impl Rng, n_trials: usize, -) -> (Vec<[f32; D]>, Vec<[Simd; D]>) { +) -> (Vec<[f32; D]>, Vec<[f32x8; D]>) { let mut seq_needles = Vec::new(); let mut simd_needles = Vec::new(); - for _ in 0..n_trials / L { + for _ in 0..n_trials / 8 { let mut start_pt = [0.0; D]; for v in start_pt.iter_mut() { *v = rng.random_range::(0.0..1.0); } - let mut simd_pts = [Simd::splat(0.0); D]; - for l in 0..L { + let mut simd_pts = [[0.0_f32; 8]; D]; + for l in 0..8 { let mut seq_needle = [0.0; D]; for d in 0..D { let value = start_pt[d] + rng.random_range::(-0.02..0.02); seq_needle[d] = value; - simd_pts[d].as_mut_array()[l] = value; + simd_pts[d][l] = value; } seq_needles.push(seq_needle); } - simd_needles.push(simd_pts); + simd_needles.push(simd_pts.map(f32x8::new)); } - assert_eq!(seq_needles.len(), simd_needles.len() * L); + assert_eq!(seq_needles.len(), simd_needles.len() * 8); (seq_needles, simd_needles) } @@ -168,21 +164,21 @@ pub fn parse_trace_csv(p: impl AsRef) -> Result, Box = [([Simd; 3], Simd)]; +pub type SimdTrace = [([f32x8; 3], f32x8)]; -pub fn simd_trace_new(trace: &Trace) -> Box> { +pub fn simd_trace_new(trace: &Trace) -> Box { trace - .chunks(L) + .chunks(8) .map(|w| { - let mut centers = [[0.0; L]; 3]; - let mut radii = [0.0; L]; + let mut centers = [[0.0_f32; 8]; 3]; + let mut radii = [0.0_f32; 8]; for (l, ([x, y, z], r)) in w.iter().copied().enumerate() { centers[0][l] = x; centers[1][l] = y; centers[2][l] = z; radii[l] = r; } - (centers.map(Simd::from_array), Simd::from_array(radii)) + (centers.map(f32x8::new), f32x8::new(radii)) }) .collect() } diff --git a/capt/Cargo.toml b/capt/Cargo.toml index 7fc666d..92e0b8c 100644 --- a/capt/Cargo.toml +++ b/capt/Cargo.toml @@ -19,9 +19,9 @@ categories = ["no-std", "science", "science::robotics", "algorithms"] rand = { version = "0.9.1", default-features = false, features = ["small_rng"] } [features] -simd = [] +simd = ["wide"] [dependencies] aligned-vec = { version = "0.6.4", default-features = false } +wide = { version = "1.1.1", default-features = false, optional = true} -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/capt/src/lib.rs b/capt/src/lib.rs index 555dcbd..4a722be 100644 --- a/capt/src/lib.rs +++ b/capt/src/lib.rs @@ -60,14 +60,12 @@ //! ## Optional features //! //! This crate exposes one feature, `simd`, which enables a SIMD-parallel interface for querying -//! `Capt`s. The `simd` feature requires nightly Rust and therefore should be considered unstable. -//! This enables the function `Capt::collides_simd`, a parallel collision checker for batches of -//! search queries. +//! `Capt`s. This enables the function `Capt::collides_simd`, a parallel collision checker for +//! batches of 8 f32 search queries using the `wide` crate. //! //! ## License //! //! This work is licensed to you under the Apache 2.0 license. -#![cfg_attr(feature = "simd", feature(portable_simd))] #![cfg_attr(not(test), no_std)] #![warn(clippy::pedantic, clippy::cargo, clippy::nursery, missing_docs)] @@ -85,14 +83,7 @@ use core::{ }; #[cfg(feature = "simd")] -use core::{ - ops::{AddAssign, Mul, SubAssign}, - simd::{ - Mask, Simd, SimdElement, - cmp::{SimdPartialEq, SimdPartialOrd}, - ptr::SimdConstPtr, - }, -}; +use wide::{f32x8, i32x8, CmpGe, CmpLe}; /// A generic trait representing values which may be used as an "axis;" that is, elements of a /// vector representing a point. @@ -192,30 +183,6 @@ pub trait Axis: PartialOrd + Copy + Sub + Add { fn square(self) -> Self; } -#[cfg(feature = "simd")] -/// A trait used for SIMD elements. -pub trait AxisSimdElement: SimdElement + Default + Axis {} - -#[cfg(feature = "simd")] -/// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s. -/// -/// The interface for this trait should be considered unstable since the standard SIMD API may -/// change with Rust versions. -pub trait AxisSimd: - Sized - + SimdPartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Mul -{ - /// Cast a mask for a SIMD vector into a mask of `isize`s. - fn cast_mask(mask: ::Mask) -> Mask; - /// Determine whether a mask contains any true elements. - fn mask_any(mask: ::Mask) -> bool; -} - /// An index type used for lookups into and out of arrays. /// /// This is implemented so that [`Capt`]s can use smaller index sizes (such as [`u32`] or [`u16`]) @@ -225,23 +192,6 @@ pub trait Index: TryFrom + TryInto + Copy { const ZERO: Self; } -#[cfg(feature = "simd")] -/// A SIMD parallel version of [`Index`]. -/// -/// This is used for implementing SIMD lookups in a [`Capt`]. -/// The interface for this trait should be considered unstable since the standard SIMD API may -/// change with Rust versions. -pub trait IndexSimd: SimdElement + Default { - #[must_use] - /// Convert a SIMD array of `Self` to a SIMD array of `usize`, without checking that each - /// element is valid. - /// - /// # Safety - /// - /// This function is only safe if all values of `x` are valid when converted to a `usize`. - unsafe fn to_simd_usize_unchecked(x: Simd) -> Simd; -} - macro_rules! impl_axis { ($t: ty, $tm: ty) => { impl Axis for $t { @@ -261,18 +211,6 @@ macro_rules! impl_axis { } } - #[cfg(feature = "simd")] - impl AxisSimdElement for $t {} - - #[cfg(feature = "simd")] - impl AxisSimd for Simd<$t, L> { - fn cast_mask(mask: ::Mask) -> Mask { - mask.into() - } - fn mask_any(mask: ::Mask) -> bool { - mask.any() - } - } }; } @@ -281,13 +219,6 @@ macro_rules! impl_idx { impl Index for $t { const ZERO: Self = 0; } - - #[cfg(feature = "simd")] - impl IndexSimd for $t { - unsafe fn to_simd_usize_unchecked(x: Simd) -> Simd { - unsafe { x.to_array().map(|a| a.try_into().unwrap_unchecked()).into() } - } - } }; } @@ -312,30 +243,22 @@ fn clamp(x: A, min: A, max: A) -> A { } #[inline] -#[allow(clippy::cast_possible_wrap)] +#[allow(clippy::cast_sign_loss)] #[cfg(feature = "simd")] -fn forward_pass_simd( - tests: &[A], - centers: &[Simd; K], -) -> Simd -where - Simd: AxisSimd, - A: AxisSimdElement, -{ - let mut test_idxs: Simd = Simd::splat(0); +fn forward_pass_wide(tests: &[f32], centers: &[f32x8; K]) -> i32x8 { + let mut test_idxs = i32x8::splat(0_i32); let mut k = 0; for _ in 0..tests.len().trailing_ones() { - let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_offset(test_idxs); - let relevant_tests: Simd = unsafe { Simd::gather_ptr(test_ptrs) }; - let cmp_results: Mask = - Simd::::cast_mask(centers[k % K].simd_ge(relevant_tests)); - - let one = Simd::splat(1); - test_idxs = (test_idxs << one) + one + (cmp_results.to_simd() & Simd::splat(1)); + let idx_arr = test_idxs.to_array(); + let relevant_tests = + f32x8::new(idx_arr.map(|i| unsafe { *tests.get_unchecked(i as usize) })); + let cmp_f = centers[k % K].simd_ge(relevant_tests); + let cmp_bit: i32x8 = + unsafe { core::mem::transmute::(cmp_f) } & i32x8::splat(1); + test_idxs = (test_idxs << 1_i32) + 1_i32 + cmp_bit; k = (k + 1) % K; } - - test_idxs - Simd::splat(tests.len() as isize) + test_idxs - i32x8::splat(tests.len() as i32) } #[derive(Clone, Debug, PartialEq, Eq)] @@ -871,116 +794,112 @@ where } } -#[allow(clippy::mismatching_type_param_order)] #[cfg(feature = "simd")] -impl Capt +impl Capt where - I: IndexSimd, - A: Mul, + I: Index, { #[must_use] /// Determine whether any sphere in the list of provided spheres intersects a point in this /// tree. /// + /// Each element of `centers` is an [`f32x8`] holding the coordinate for that dimension across + /// 8 parallel query spheres. `radii` holds the radius for each of the 8 queries. + /// /// # Panics /// - /// This function will panic if the `Capt` was not constructed with a large enough lane count - /// for the query. + /// This function will panic if the `Capt` was not constructed with a lane count of at least 8. /// /// # Examples /// /// ``` - /// #![feature(portable_simd)] - /// use std::simd::Simd; + /// use wide::f32x8; /// /// let points = [[1.0, 2.0], [1.1, 1.1]]; /// /// let centers = [ - /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // x-positions - /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // y-positions + /// f32x8::new([1.0, 1.1, 1.2, 1.3, 0.0, 0.0, 0.0, 0.0]), // x-positions + /// f32x8::new([1.0, 1.1, 1.2, 1.3, 0.0, 0.0, 0.0, 0.0]), // y-positions /// ]; - /// let radii = Simd::splat(0.05); + /// let radii = f32x8::splat(0.05); /// - /// let tree = capt::Capt::<2, f32, u32>::new(&points, (0.0, 0.1), 4); + /// let tree = capt::Capt::<2, f32, u32>::new(&points, (0.0, 0.1), 8); /// /// println!("{tree:?}"); /// /// assert!(tree.collides_simd(¢ers, radii)); /// ``` - pub fn collides_simd( - &self, - centers: &[Simd; K], - mut radii: Simd, - ) -> bool - where - Simd: AxisSimd, - A: AxisSimdElement, - { - assert!(L.is_power_of_two(), "lane count must be power of two"); + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + pub fn collides_simd(&self, centers: &[f32x8; K], mut radii: f32x8) -> bool { assert!( - 1 << self.lanes_log2 >= L, - "lane count of query must be lower than lane count of CAPT" + 1 << self.lanes_log2 >= 8, + "CAPT must be constructed with at least 8 lanes for f32x8 queries" ); - radii += Simd::splat(self.r_point); - let zs = forward_pass_simd(&self.tests, centers); - - let mut inbounds = Mask::splat(true); - - let mut aabb_ptrs = Simd::splat(self.aabbs.as_ptr()).wrapping_offset(zs).cast(); + radii = radii + f32x8::splat(self.r_point); + let zs = forward_pass_wide(&self.tests, centers); + let zs_arr = zs.to_array(); - unsafe { - for center in centers { - inbounds &= Simd::::cast_mask( - (Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY)) - - radii) - .simd_le(*center), - ); - aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1)); - } - for center in centers { - inbounds &= Simd::::cast_mask( - Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY)) - .simd_ge(*center - radii), - ); - aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1)); - } + // AABB inbounds check: flat f32 view of aabb array (lo[0..K] then hi[0..K] per entry) + let aabb_f32 = unsafe { + core::slice::from_raw_parts(self.aabbs.as_ptr().cast::(), self.aabbs.len() * 2 * K) + }; + let all_true: f32x8 = + unsafe { core::mem::transmute::(i32x8::splat(-1_i32)) }; + let mut inbounds = all_true; + for k in 0..K { + let lo_vals = f32x8::new(core::array::from_fn(|j| unsafe { + *aabb_f32.get_unchecked(zs_arr[j] as usize * (2 * K) + k) + })); + inbounds = inbounds & (lo_vals - radii).simd_le(centers[k]); + } + for k in 0..K { + let hi_vals = f32x8::new(core::array::from_fn(|j| unsafe { + *aabb_f32.get_unchecked(zs_arr[j] as usize * (2 * K) + K + k) + })); + inbounds = inbounds & hi_vals.simd_ge(centers[k] - radii); } if !inbounds.any() { return false; } - // retrieve start/end pointers for the affordance buffer - let start_ptrs = Simd::splat(self.starts.as_ptr()).wrapping_offset(zs); - let starts = unsafe { I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs)) }.to_array(); - let ends = unsafe { - I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs.wrapping_add(Simd::splat(1)))) - } - .to_array(); - - starts - .into_iter() - .zip(ends) - .zip(inbounds.to_array()) - .enumerate() - .filter_map(|(j, r)| r.1.then_some((j, r.0))) - .any(|(j, (start, end))| { - let mut n_center = [Simd::splat(A::ZERO); K]; + let inbounds_arr: [i32; 8] = + unsafe { core::mem::transmute::(inbounds) }.to_array(); + let starts: [usize; 8] = core::array::from_fn(|j| unsafe { + self.starts[zs_arr[j] as usize].try_into().unwrap_unchecked() + }); + let ends: [usize; 8] = core::array::from_fn(|j| unsafe { + self.starts[zs_arr[j] as usize + 1].try_into().unwrap_unchecked() + }); + + let centers_arr: [[f32; 8]; K] = core::array::from_fn(|k| centers[k].to_array()); + let radii_arr = radii.to_array(); + + for j in 0..8 { + if inbounds_arr[j] == 0 { + continue; + } + let start = starts[j]; + let end = ends[j]; + let n_center: [f32x8; K] = + core::array::from_fn(|k| f32x8::splat(centers_arr[k][j])); + let rs = f32x8::splat(radii_arr[j]); + let rs_sq = rs * rs; + let mut i = start; + while i < end { + let mut dists_sq = f32x8::ZERO; + #[allow(clippy::needless_range_loop)] for k in 0..K { - n_center[k] = Simd::splat(centers[k][j]); + let vals: f32x8 = unsafe { *self.afforded[k].as_ptr().add(i).cast() }; + let diff = vals - n_center[k]; + dists_sq = dists_sq + diff * diff; } - let rs = Simd::splat(radii[j]); - let rs_sq = rs * rs; - (start..end).step_by(L).any(|i| { - let mut dists_sq = Simd::splat(A::ZERO); - #[allow(clippy::needless_range_loop)] - for k in 0..K { - let vals: Simd = unsafe { *self.afforded[k].as_ptr().add(i).cast() }; - let diff = vals - n_center[k]; - dists_sq += diff * diff; - } - Simd::::mask_any(dists_sq.simd_le(rs_sq)) - }) - }) + if dists_sq.simd_le(rs_sq).any() { + return true; + } + i += 8; + } + } + false } } @@ -1074,6 +993,22 @@ unsafe fn median_partition(points: &mut [[A; K]], k: us } } +#[cfg(all(test, feature = "simd"))] +#[test] +fn simd_instruction_set() { + if cfg!(target_feature = "avx2") { + println!("SIMD: AVX2"); + } else if cfg!(target_feature = "avx") { + println!("SIMD: AVX"); + } else if cfg!(target_feature = "sse4.1") { + println!("SIMD: SSE4.1"); + } else if cfg!(target_feature = "sse2") { + println!("SIMD: SSE2"); + } else { + println!("SIMD: scalar fallback (no SSE/AVX detected)"); + } +} + #[cfg(test)] mod tests { use rand::{Rng, SeedableRng, rngs::SmallRng}; diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 271800c..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "nightly" \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml index c886414..0b84359 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -2,4 +2,4 @@ newline_style = "Unix" wrap_comments = true comment_width = 100 format_code_in_doc_comments = true -imports_granularity = "crate" \ No newline at end of file +imports_granularity = "Crate" \ No newline at end of file