From 2a6e0cab43d28a623fd11f7137614ea7162f6ff7 Mon Sep 17 00:00:00 2001 From: Ionizing Date: Sun, 23 Nov 2025 18:04:05 +0800 Subject: [PATCH] Add Accelerate as device --- Cargo.toml | 3 + crates-device/rstsr-accelerate/Cargo.toml | 35 + crates-device/rstsr-accelerate/build.rs | 12 + crates-device/rstsr-accelerate/readme.md | 40 + .../rstsr-accelerate/src/conversion.rs | 79 ++ .../rstsr-accelerate/src/creation.rs | 134 +++ crates-device/rstsr-accelerate/src/device.rs | 133 +++ .../src/driver_impl/cblas/blas3/gemm.rs | 88 ++ .../src/driver_impl/cblas/blas3/mod.rs | 3 + .../src/driver_impl/cblas/blas3/syhemm.rs | 86 ++ .../src/driver_impl/cblas/blas3/trsm.rs | 80 ++ .../src/driver_impl/cblas/mod.rs | 1 + .../src/driver_impl/lapack/eigh/mod.rs | 4 + .../src/driver_impl/lapack/eigh/syev.rs | 192 +++++ .../src/driver_impl/lapack/eigh/syevd.rs | 231 ++++++ .../src/driver_impl/lapack/eigh/sygv.rs | 246 ++++++ .../src/driver_impl/lapack/eigh/sygvd.rs | 275 +++++++ .../src/driver_impl/lapack/mod.rs | 3 + .../src/driver_impl/lapack/solve/gesv.rs | 141 ++++ .../src/driver_impl/lapack/solve/getrf.rs | 101 +++ .../src/driver_impl/lapack/solve/getri.rs | 124 +++ .../src/driver_impl/lapack/solve/mod.rs | 5 + .../src/driver_impl/lapack/solve/potrf.rs | 87 ++ .../src/driver_impl/lapack/solve/sysv.rs | 230 ++++++ .../src/driver_impl/lapack/svd/gesdd.rs | 359 ++++++++ .../src/driver_impl/lapack/svd/gesvd.rs | 363 +++++++++ .../src/driver_impl/lapack/svd/mod.rs | 2 + .../rstsr-accelerate/src/driver_impl/mod.rs | 16 + crates-device/rstsr-accelerate/src/lib.rs | 30 + .../src/linalg_auto_impl/cholesky.rs | 90 +++ .../src/linalg_auto_impl/det.rs | 47 ++ .../src/linalg_auto_impl/eigh.rs | 216 +++++ .../src/linalg_auto_impl/eigvalsh.rs | 214 +++++ .../src/linalg_auto_impl/inv.rs | 46 ++ .../src/linalg_auto_impl/mod.rs | 12 + .../src/linalg_auto_impl/pinv.rs | 85 ++ .../src/linalg_auto_impl/slogdet.rs | 46 ++ .../src/linalg_auto_impl/solve_general.rs | 134 +++ .../src/linalg_auto_impl/solve_symmetric.rs | 160 ++++ .../src/linalg_auto_impl/solve_triangular.rs | 158 ++++ .../src/linalg_auto_impl/svd.rs | 109 +++ .../src/linalg_auto_impl/svdvals.rs | 80 ++ crates-device/rstsr-accelerate/src/matmul.rs | 479 +++++++++++ .../rstsr-accelerate/src/matmul_impl.rs | 583 +++++++++++++ .../rstsr-accelerate/src/prelude_dev.rs | 4 + .../src/rayon_auto_impl/adv_indexing.rs | 21 + .../src/rayon_auto_impl/assignment.rs | 49 ++ .../src/rayon_auto_impl/mod.rs | 9 + .../rayon_auto_impl/op_binary_arithmetic.rs | 119 +++ .../src/rayon_auto_impl/op_binary_common.rs | 240 ++++++ .../rayon_auto_impl/op_ternary_arithmetic.rs | 56 ++ .../src/rayon_auto_impl/op_ternary_common.rs | 190 +++++ .../src/rayon_auto_impl/op_tri.rs | 55 ++ .../src/rayon_auto_impl/op_with_func.rs | 117 +++ .../src/rayon_auto_impl/reduction.rs | 763 ++++++++++++++++++ .../src/sci_auto_impl/distance_auto_impl.rs | 120 +++ .../src/sci_auto_impl/integrate_auto_impl.rs | 13 + .../rstsr-accelerate/src/sci_auto_impl/mod.rs | 2 + .../rstsr-accelerate/src/threading.rs | 65 ++ .../rstsr-accelerate/tests/issues/issue_45.rs | 14 + .../rstsr-accelerate/tests/issues/mod.rs | 1 + crates-device/rstsr-accelerate/tests/mod.rs | 5 + .../test_driver_impl/driver_validation_f64.py | 141 ++++ .../tests/test_driver_impl/lapack_eigh_f64.rs | 148 ++++ .../test_driver_impl/lapack_solve_f64.rs | 95 +++ .../tests/test_driver_impl/lapack_svd_f64.rs | 80 ++ .../tests/test_driver_impl/mod.rs | 3 + .../tests/test_linalg_func/func_c64.rs | 319 ++++++++ .../tests/test_linalg_func/func_f64.rs | 260 ++++++ .../test_linalg_func/func_validation_c64.py | 235 ++++++ .../test_linalg_func/func_validation_f64.py | 173 ++++ .../tests/test_linalg_func/mod.rs | 2 + .../rstsr-accelerate/tests/test_workable.rs | 35 + rstsr/Cargo.toml | 12 +- 74 files changed, 8903 insertions(+), 5 deletions(-) create mode 100644 crates-device/rstsr-accelerate/Cargo.toml create mode 100644 crates-device/rstsr-accelerate/build.rs create mode 100644 crates-device/rstsr-accelerate/readme.md create mode 100644 crates-device/rstsr-accelerate/src/conversion.rs create mode 100644 crates-device/rstsr-accelerate/src/creation.rs create mode 100644 crates-device/rstsr-accelerate/src/device.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/gemm.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/syhemm.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/trsm.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/cblas/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syev.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syevd.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygv.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygvd.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/gesv.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getrf.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getri.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/potrf.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/sysv.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesdd.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesvd.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/driver_impl/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/lib.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/cholesky.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/det.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/eigh.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/eigvalsh.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/inv.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/pinv.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/slogdet.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_general.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_symmetric.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_triangular.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/svd.rs create mode 100644 crates-device/rstsr-accelerate/src/linalg_auto_impl/svdvals.rs create mode 100644 crates-device/rstsr-accelerate/src/matmul.rs create mode 100644 crates-device/rstsr-accelerate/src/matmul_impl.rs create mode 100644 crates-device/rstsr-accelerate/src/prelude_dev.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/adv_indexing.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/assignment.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_arithmetic.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_common.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_arithmetic.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_common.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_tri.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/op_with_func.rs create mode 100644 crates-device/rstsr-accelerate/src/rayon_auto_impl/reduction.rs create mode 100644 crates-device/rstsr-accelerate/src/sci_auto_impl/distance_auto_impl.rs create mode 100644 crates-device/rstsr-accelerate/src/sci_auto_impl/integrate_auto_impl.rs create mode 100644 crates-device/rstsr-accelerate/src/sci_auto_impl/mod.rs create mode 100644 crates-device/rstsr-accelerate/src/threading.rs create mode 100644 crates-device/rstsr-accelerate/tests/issues/issue_45.rs create mode 100644 crates-device/rstsr-accelerate/tests/issues/mod.rs create mode 100644 crates-device/rstsr-accelerate/tests/mod.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_driver_impl/driver_validation_f64.py create mode 100644 crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_eigh_f64.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_solve_f64.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_svd_f64.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_driver_impl/mod.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_linalg_func/func_c64.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_linalg_func/func_f64.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_c64.py create mode 100644 crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_f64.py create mode 100644 crates-device/rstsr-accelerate/tests/test_linalg_func/mod.rs create mode 100644 crates-device/rstsr-accelerate/tests/test_workable.rs diff --git a/Cargo.toml b/Cargo.toml index 00872e46..9dc1fbb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "rstsr-native-impl", "rstsr-sci-traits", "crates-device/rstsr-openblas", + "crates-device/rstsr-accelerate", "crates-device/rstsr-mkl", "crates-device/rstsr-blis", "crates-device/rstsr-aocl", @@ -39,6 +40,7 @@ rstsr-linalg-traits = { path = "./rstsr-linalg-traits", default-features = false rstsr-sci-traits = { path = "./rstsr-sci-traits", default-features = false, version = "0.6.0" } # members (device) rstsr-openblas = { path = "./crates-device/rstsr-openblas", default-features = false, version = "0.6.0" } +rstsr-accelerate = { path = "./crates-device/rstsr-accelerate", default-features = false, version = "0.6.0" } rstsr-mkl = { path = "./crates-device/rstsr-mkl", default-features = false, version = "0.6.0" } rstsr-blis = { path = "./crates-device/rstsr-blis", default-features = false, version = "0.6.0" } rstsr-aocl = { path = "./crates-device/rstsr-aocl", default-features = false, version = "0.6.0" } @@ -50,6 +52,7 @@ rstsr-test-manifest = { path = "./rstsr-test-manifest", default-features = false # ffi dependencies rstsr-cblas-base = { version = "0.1" } rstsr-openblas-ffi = { version = "0.5", default-features = false, features = ["blas", "cblas", "lapack"] } +rstsr-lapack-ffi = { version = "0.5", default-features = false, features = ["blas", "cblas", "lapack"] } rstsr-mkl-ffi = { version = "0.2", default-features = false, features = ["blas", "cblas", "lapack"] } rstsr-blis-ffi = { version = "0.2", default-features = false, features = ["lapack"] } rstsr-aocl-ffi = { version = "0.2", default-features = false, features = ["blis", "lapack"] } diff --git a/crates-device/rstsr-accelerate/Cargo.toml b/crates-device/rstsr-accelerate/Cargo.toml new file mode 100644 index 00000000..c5e6edc3 --- /dev/null +++ b/crates-device/rstsr-accelerate/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "rstsr-accelerate" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "readme.md" + +[dependencies] +rayon = { workspace = true } +num = { workspace = true } +duplicate = { workspace = true } +rstsr-lapack-ffi = { workspace = true } +rstsr-native-impl = { workspace = true, features = ["rayon"] } +rstsr-core = { workspace = true, features = ["rayon"] } +rstsr-common = { workspace = true, features = ["rayon"] } +rstsr-dtype-traits = { workspace = true, features = ["half"] } +rstsr-blas-traits = { workspace = true } +rstsr-linalg-traits = { workspace = true, optional = true } +rstsr-sci-traits = { workspace = true, optional = true } + +[dev-dependencies] +rstsr = { path = "../../rstsr", default-features = false, features = ["linalg"] } +rstsr-test-manifest = { workspace = true } + +[features] +default = ["linalg"] +dynamic_loading = ["rstsr-lapack-ffi/dynamic_loading"] +faer = ["rstsr-core/faer"] +ilp64 = ["rstsr-lapack-ffi/ilp64", "rstsr-blas-traits/ilp64"] +linalg = ["dep:rstsr-linalg-traits"] +sci = ["dep:rstsr-sci-traits"] diff --git a/crates-device/rstsr-accelerate/build.rs b/crates-device/rstsr-accelerate/build.rs new file mode 100644 index 00000000..3f7ee3a8 --- /dev/null +++ b/crates-device/rstsr-accelerate/build.rs @@ -0,0 +1,12 @@ +fn main() { + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-lib=framework=Accelerate"); + } + + + #[cfg(not(target_os = "macos"))] + { + panic!("'accelerate' feature is only available for macOS target."); + } +} diff --git a/crates-device/rstsr-accelerate/readme.md b/crates-device/rstsr-accelerate/readme.md new file mode 100644 index 00000000..4556a2c6 --- /dev/null +++ b/crates-device/rstsr-accelerate/readme.md @@ -0,0 +1,40 @@ +# RSTSR OpenBLAS device + +This crate enables OpenBLAS device. + +For more information of OpenBLAS and its usage, we refer to [document of rstsr-openblas-ffi](https://docs.rs/rstsr-openblas-ffi/). + +## Usage + +```rust +use rstsr_core::prelude::*; +use rstsr_openblas::DeviceOpenBLAS; + +// specify the number of threads of 16 +let device = DeviceOpenBLAS::new(16); +// if you want to use the default number of threads, use the following line +// let device = DeviceOpenBLAS::default(); + +let a = rt::linspace((0.0, 1.0, 1048576, &device)).into_shape([16, 256, 256]); +let b = rt::linspace((1.0, 2.0, 1048576, &device)).into_shape([16, 256, 256]); + +// by optimized BLAS, the following operation is very fast +let c = &a % &b; + +// mean of all elements is also performed in parallel +let c_mean = c.mean_all(); + +println!("{:?}", c_mean); +assert!((c_mean - 213.2503660477036) < 1e-6); +``` + +## Important Notes + +- We do not provide automatic linkage: + - Please add `-l openblas` in `RUSTFLAGS`, or `cargo:rustc-link-lib=openblas` in build.rs, or something similar, to your project. + We do not use external FFI crates `blas` or `blas-sys`, and do not automatically search OpenBLAS library for linking. + - If feature `openmp` activated, please add `-l gomp` or `-l omp` in `RUSTFLAGS`, or `cargo:rustc-link-lib=gomp` or `cargo:rustc-link-lib=omp` in build.rs, or something similar, to your project. + We do not use external FFI crate `openmp-sys`, and do not automatically search for OpenMP library for linking. + +- If your OpenBLAS is compiled with OpenMP, please add `openmp` feature to either this crate or `rstsr-openblas-ffi`. + - In our testing, OpenBLAS with OpenMP is probably more efficient than pthreads. However, we currently decided not make `openmp` as default feature. \ No newline at end of file diff --git a/crates-device/rstsr-accelerate/src/conversion.rs b/crates-device/rstsr-accelerate/src/conversion.rs new file mode 100644 index 00000000..b495361a --- /dev/null +++ b/crates-device/rstsr-accelerate/src/conversion.rs @@ -0,0 +1,79 @@ +use crate::prelude_dev::*; + +macro_rules! impl_change_device { + ($DevA: ty, $DevB: ty) => { + impl<'a, R, T, D> DeviceChangeAPI<'a, $DevB, R, T, D> for $DevA + where + T: Clone + Send + Sync + 'a, + D: DimAPI, + R: DataCloneAPI>, + { + type Repr = R; + type ReprTo = DataRef<'a, Vec>; + + fn change_device( + tensor: TensorAny, + device: &$DevB, + ) -> Result> { + let (storage, layout) = tensor.into_raw_parts(); + let (data, _) = storage.into_raw_parts(); + let storage = Storage::new(data, device.clone()); + let tensor = TensorAny::new(storage, layout); + Ok(tensor) + } + + fn into_device( + tensor: TensorAny, + device: &$DevB, + ) -> Result>, T, $DevB, D>> { + let tensor = tensor.into_owned(); + DeviceChangeAPI::change_device(tensor, device) + } + + fn to_device(tensor: &'a TensorAny, device: &$DevB) -> Result> { + let view = tensor.view(); + DeviceChangeAPI::change_device(view, device) + } + } + }; +} + +impl_change_device!(DeviceCpuSerial, DeviceBLAS); +impl_change_device!(DeviceBLAS, DeviceCpuSerial); +impl_change_device!(DeviceBLAS, DeviceBLAS); +#[cfg(feature = "faer")] +impl_change_device!(DeviceFaer, DeviceBLAS); +#[cfg(feature = "faer")] +impl_change_device!(DeviceBLAS, DeviceFaer); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_device_conversion_cpu_serial() { + let device_serial = DeviceCpuSerial::default(); + let device = DeviceBLAS::new(0); + let a = linspace((1.0, 5.0, 5, &device)); + let b = a.to_device(&device_serial); + println!("{b:?}"); + let a = linspace((1.0, 5.0, 5, &device_serial)); + let a_view = a.view(); + let b = a_view.to_device(&device); + println!("{b:?}"); + } + + #[test] + #[cfg(feature = "faer")] + fn test_device_conversion_faer() { + let device_faer = DeviceFaer::new(0); + let device = DeviceBLAS::new(0); + let a = linspace((1.0, 5.0, 5, &device)); + let b = a.to_device(&device_faer); + println!("{b:?}"); + let a = linspace((1.0, 5.0, 5, &device_faer)); + let a_view = a.view(); + let b = a_view.to_device(&device); + println!("{b:?}"); + } +} diff --git a/crates-device/rstsr-accelerate/src/creation.rs b/crates-device/rstsr-accelerate/src/creation.rs new file mode 100644 index 00000000..9feebf14 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/creation.rs @@ -0,0 +1,134 @@ +use crate::prelude_dev::*; +use num::{complex::ComplexFloat, Num}; + +// for creation, we use most of the functions from DeviceCpuSerial +impl DeviceCreationAnyAPI for DeviceBLAS +where + Self: DeviceRawAPI> + DeviceRawAPI, Raw = Vec>>, +{ + unsafe fn empty_impl(&self, len: usize) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().empty_impl(len)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } + + fn full_impl(&self, len: usize, fill: T) -> Result>, T, Self>> + where + T: Clone, + { + let storage = DeviceCpuSerial::default().full_impl(len, fill)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } + + fn outof_cpu_vec(&self, vec: Vec) -> Result>, T, Self>> { + Ok(Storage::new(DataOwned::from(vec), self.clone())) + } + + fn from_cpu_vec(&self, vec: &[T]) -> Result>, T, Self>> + where + T: Clone, + { + let raw = vec.to_vec(); + Ok(Storage::new(DataOwned::from(raw), self.clone())) + } + + fn uninit_impl(&self, len: usize) -> Result>>, MaybeUninit, Self>> { + let raw = unsafe { uninitialized_vec(len) }?; + Ok(Storage::new(raw.into(), self.clone())) + } + + unsafe fn assume_init_impl( + storage: Storage>>, MaybeUninit, Self>, + ) -> Result>, T, Self>> + where + Self: DeviceRawAPI>, + { + let (data, device) = storage.into_raw_parts(); + let vec = data.into_raw(); + // transmute `Vec>` to `Vec` + let vec = core::mem::transmute::>, Vec>(vec); + let data = vec.into(); + Ok(Storage::new(data, device)) + } +} + +impl DeviceCreationNumAPI for DeviceBLAS +where + T: Num + Clone, + Self: DeviceRawAPI>, +{ + fn zeros_impl(&self, len: usize) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().zeros_impl(len)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } + + fn ones_impl(&self, len: usize) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().ones_impl(len)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } + + fn arange_int_impl(&self, len: usize) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().arange_int_impl(len)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } +} + +impl DeviceCreationPartialOrdNumAPI for DeviceBLAS +where + T: Num + PartialOrd + Clone, + Self: DeviceRawAPI>, +{ + fn arange_impl(&self, start: T, end: T, step: T) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().arange_impl(start, end, step)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } +} + +impl DeviceCreationComplexFloatAPI for DeviceBLAS +where + T: ComplexFloat + Clone + Send + Sync, + Self: DeviceRawAPI>, +{ + fn linspace_impl(&self, start: T, end: T, n: usize, endpoint: bool) -> Result>, T, Self>> { + let storage = DeviceCpuSerial::default().linspace_impl(start, end, n, endpoint)?; + let (data, _) = storage.into_raw_parts(); + Ok(Storage::new(data, self.clone())) + } +} + +impl DeviceCreationTriAPI for DeviceBLAS +where + T: Num + Clone, + Self: DeviceRawAPI>, +{ + fn tril_impl(&self, raw: &mut Self::Raw, layout: &Layout, k: isize) -> Result<()> + where + D: DimAPI, + { + DeviceCpuSerial::default().tril_impl(raw, layout, k) + } + + fn triu_impl(&self, raw: &mut Self::Raw, layout: &Layout, k: isize) -> Result<()> + where + D: DimAPI, + { + DeviceCpuSerial::default().triu_impl(raw, layout, k) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_linspace() { + let device = DeviceBLAS::default(); + let a = linspace((1.0, 5.0, 5, &device)); + assert_eq!(a.raw(), &vec![1., 2., 3., 4., 5.]); + } +} diff --git a/crates-device/rstsr-accelerate/src/device.rs b/crates-device/rstsr-accelerate/src/device.rs new file mode 100644 index 00000000..ff8cec76 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/device.rs @@ -0,0 +1,133 @@ +use crate::prelude_dev::*; +use num::{complex::ComplexFloat, Num}; +use rstsr_dtype_traits::DTypeIntoFloatAPI; + +impl DeviceBLAS { + pub fn new(num_threads: usize) -> Self { + DeviceBLAS { base: DeviceCpuRayon::new(num_threads) } + } +} + +impl DeviceRayonAPI for DeviceBLAS { + #[inline] + fn set_num_threads(&mut self, num_threads: usize) { + self.base.set_num_threads(num_threads); + } + + #[inline] + fn get_num_threads(&self) -> usize { + self.base.get_num_threads() + } + + #[inline] + fn get_pool(&self) -> &ThreadPool { + self.base.get_pool() + } + + #[inline] + fn get_current_pool(&self) -> Option<&ThreadPool> { + self.base.get_current_pool() + } +} + +impl Default for DeviceBLAS { + fn default() -> Self { + DeviceBLAS::new(0) + } +} + +impl DeviceBaseAPI for DeviceBLAS { + fn same_device(&self, other: &Self) -> bool { + let same_num_threads = self.get_num_threads() == other.get_num_threads(); + let same_default_order = self.default_order() == other.default_order(); + same_num_threads && same_default_order + } + + fn default_order(&self) -> FlagOrder { + self.base.default_order() + } + + fn set_default_order(&mut self, order: FlagOrder) { + self.base.set_default_order(order); + } +} + +impl DeviceRawAPI for DeviceBLAS { + type Raw = Vec; +} + +impl DeviceStorageAPI for DeviceBLAS { + fn len(storage: &Storage) -> usize + where + R: DataAPI, + { + storage.raw().len() + } + + fn to_cpu_vec(storage: &Storage) -> Result> + where + Self::Raw: Clone, + R: DataAPI, + { + Ok(storage.raw().clone()) + } + + fn into_cpu_vec(storage: Storage) -> Result> + where + Self::Raw: Clone, + R: DataCloneAPI, + { + let (raw, _) = storage.into_raw_parts(); + Ok(raw.into_owned().into_raw()) + } + + #[inline] + fn get_index(storage: &Storage, index: usize) -> T + where + T: Clone, + R: DataAPI, + { + storage.raw()[index].clone() + } + + #[inline] + fn get_index_ptr(storage: &Storage, index: usize) -> *const T + where + R: DataAPI, + { + &storage.raw()[index] as *const T + } + + #[inline] + fn get_index_mut_ptr(storage: &mut Storage, index: usize) -> *mut T + where + R: DataMutAPI, + { + storage.raw_mut().get_mut(index).unwrap() as *mut T + } + + #[inline] + fn set_index(storage: &mut Storage, index: usize, value: T) + where + R: DataMutAPI, + { + storage.raw_mut()[index] = value; + } +} + +impl DeviceAPI for DeviceBLAS {} + +impl DeviceComplexFloatAPI for DeviceBLAS +where + T: ComplexFloat + DTypeIntoFloatAPI + Send + Sync, + T::Real: DTypeIntoFloatAPI + Send + Sync, + D: DimAPI, +{ +} + +impl DeviceNumAPI for DeviceBLAS +where + T: Clone + Num + Send + Sync, + D: DimAPI, +{ +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/gemm.rs b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/gemm.rs new file mode 100644 index 00000000..f99cba77 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/gemm.rs @@ -0,0 +1,88 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use duplicate::duplicate_item; +use num::Complex; +use rstsr_blas_traits::blas3::gemm::*; +use rstsr_common::prelude::*; + +#[duplicate_item( + T cblas_func ; + [f32] [cblas_sgemm]; + [f64] [cblas_dgemm]; +)] +impl GEMMDriverAPI for DeviceBLAS { + unsafe fn driver_gemm( + order: FlagOrder, + transa: FlagTrans, + transb: FlagTrans, + m: usize, + n: usize, + k: usize, + alpha: T, + a: *const T, + lda: usize, + b: *const T, + ldb: usize, + beta: T, + c: *mut T, + ldc: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + transa.into(), + transb.into(), + m as _, + n as _, + k as _, + alpha, + a, + lda as _, + b, + ldb as _, + beta, + c, + ldc as _, + ); + } +} + +#[duplicate_item( + T cblas_func ; + [Complex] [cblas_cgemm]; + [Complex] [cblas_zgemm]; +)] +impl GEMMDriverAPI for DeviceBLAS { + unsafe fn driver_gemm( + order: FlagOrder, + transa: FlagTrans, + transb: FlagTrans, + m: usize, + n: usize, + k: usize, + alpha: T, + a: *const T, + lda: usize, + b: *const T, + ldb: usize, + beta: T, + c: *mut T, + ldc: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + transa.into(), + transb.into(), + m as _, + n as _, + k as _, + &alpha as *const _ as *const _, + a as *const _, + lda as _, + b as *const _, + ldb as _, + &beta as *const _ as *const _, + c as *mut _, + ldc as _, + ); + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/mod.rs new file mode 100644 index 00000000..460bea1b --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/mod.rs @@ -0,0 +1,3 @@ +pub mod gemm; +pub mod syhemm; +pub mod trsm; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/syhemm.rs b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/syhemm.rs new file mode 100644 index 00000000..a59729e4 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/syhemm.rs @@ -0,0 +1,86 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use duplicate::duplicate_item; +use num::Complex; +use rstsr_blas_traits::blas3::syhemm::*; +use rstsr_common::prelude::*; + +#[duplicate_item( + T cblas_func ; + [f32] [cblas_ssymm]; + [f64] [cblas_dsymm]; +)] +impl SYHEMMDriverAPI for DeviceBLAS { + unsafe fn driver_syhemm( + order: FlagOrder, + side: FlagSide, + uplo: FlagUpLo, + m: usize, + n: usize, + alpha: T, + a: *const T, + lda: usize, + b: *const T, + ldb: usize, + beta: T, + c: *mut T, + ldc: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + side.into(), + uplo.into(), + m as _, + n as _, + alpha, + a, + lda as _, + b, + ldb as _, + beta, + c, + ldc as _, + ); + } +} + +#[duplicate_item( + T cblas_func HERMI ; + [Complex] [cblas_csymm] [false]; + [Complex] [cblas_chemm] [true ]; + [Complex] [cblas_zsymm] [false]; + [Complex] [cblas_zhemm] [true ]; +)] +impl SYHEMMDriverAPI for DeviceBLAS { + unsafe fn driver_syhemm( + order: FlagOrder, + side: FlagSide, + uplo: FlagUpLo, + m: usize, + n: usize, + alpha: T, + a: *const T, + lda: usize, + b: *const T, + ldb: usize, + beta: T, + c: *mut T, + ldc: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + side.into(), + uplo.into(), + m as _, + n as _, + &alpha as *const _ as *const _, + a as *const _ as *const _, + lda as _, + b as *const _ as *const _, + ldb as _, + &beta as *const _ as *const _, + c as *mut _ as *mut _, + ldc as _, + ); + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/trsm.rs b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/trsm.rs new file mode 100644 index 00000000..1e86d48f --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/cblas/blas3/trsm.rs @@ -0,0 +1,80 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use duplicate::duplicate_item; +use num::Complex; +use rstsr_blas_traits::blas3::trsm::*; +use rstsr_common::prelude::*; + +#[duplicate_item( + T cblas_func ; + [f32] [cblas_strsm]; + [f64] [cblas_dtrsm]; +)] +impl TRSMDriverAPI for DeviceBLAS { + unsafe fn driver_trsm( + order: FlagOrder, + side: FlagSide, + uplo: FlagUpLo, + transa: FlagTrans, + diag: FlagDiag, + m: usize, + n: usize, + alpha: T, + a: *const T, + lda: usize, + b: *mut T, + ldb: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + side.into(), + uplo.into(), + transa.into(), + diag.into(), + m as _, + n as _, + alpha, + a, + lda as _, + b, + ldb as _, + ); + } +} + +#[duplicate_item( + T cblas_func ; + [Complex] [cblas_ctrsm]; + [Complex] [cblas_ztrsm]; +)] +impl TRSMDriverAPI for DeviceBLAS { + unsafe fn driver_trsm( + order: FlagOrder, + side: FlagSide, + uplo: FlagUpLo, + transa: FlagTrans, + diag: FlagDiag, + m: usize, + n: usize, + alpha: T, + a: *const T, + lda: usize, + b: *mut T, + ldb: usize, + ) { + lapack_ffi::cblas::cblas_func( + order.into(), + side.into(), + uplo.into(), + transa.into(), + diag.into(), + m as _, + n as _, + &alpha as *const _ as *const _, + a as *const _, + lda as _, + b as *mut _, + ldb as _, + ); + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/cblas/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/cblas/mod.rs new file mode 100644 index 00000000..a76b4220 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/cblas/mod.rs @@ -0,0 +1 @@ +pub mod blas3; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/mod.rs new file mode 100644 index 00000000..4ee9d5af --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/mod.rs @@ -0,0 +1,4 @@ +pub mod syev; +pub mod syevd; +pub mod sygv; +pub mod sygvd; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syev.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syev.rs new file mode 100644 index 00000000..4d40cf5b --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syev.rs @@ -0,0 +1,192 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [ssyev_]; + [f64] [dsyev_]; +)] +impl SYEVDriverAPI for DeviceBLAS { + unsafe fn driver_syev( + order: FlagOrder, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + w: *mut T, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_(&(jobz as _), &uplo.into(), &(n as _), a, &(lda as _), w, &mut work_query, &lwork, &mut info); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + w, + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + w, + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cheev_]; + [Complex] [zheev_]; +)] +impl SYEVDriverAPI for DeviceBLAS { + unsafe fn driver_syev( + order: FlagOrder, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + w: *mut ::Real, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Allocate memory for working array(s) + let rwork_len = (3 * n - 2).max(1); + let mut rwork: Vec<::Real> = match uninitialized_vec(rwork_len) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + w as *mut _, + &mut work_query as *mut _ as *mut _, + &lwork, + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syevd.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syevd.rs new file mode 100644 index 00000000..8801f88a --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/syevd.rs @@ -0,0 +1,231 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [ssyevd_]; + [f64] [dsyevd_]; +)] +impl SYEVDDriverAPI for DeviceBLAS { + unsafe fn driver_syevd( + order: FlagOrder, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + w: *mut T, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let liwork = -1; + let mut work_query = 0.0; + let mut iwork_query = 0; + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + w, + &mut work_query, + &lwork, + &mut iwork_query, + &liwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + let liwork = iwork_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + w, + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + w, + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cheevd_]; + [Complex] [zheevd_]; +)] +impl SYEVDDriverAPI for DeviceBLAS { + unsafe fn driver_syevd( + order: FlagOrder, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + w: *mut ::Real, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let lrwork = -1; + let liwork = -1; + let mut work_query = 0.0; + let mut rwork_query = 0.0; + let mut iwork_query = 0; + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + w as *mut _, + &mut work_query as *mut _ as *mut _, + &lwork, + &mut rwork_query as *mut _ as *mut _, + &lrwork, + &mut iwork_query, + &liwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + let lrwork = rwork_query as usize; + let liwork = iwork_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + let mut rwork: Vec<::Real> = match uninitialized_vec(lrwork) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &(lrwork as _), + iwork.as_mut_ptr() as *mut _, + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &(lrwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygv.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygv.rs new file mode 100644 index 00000000..c565098b --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygv.rs @@ -0,0 +1,246 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [ssygv_]; + [f64] [dsygv_]; +)] +impl SYGVDriverAPI for DeviceBLAS { + unsafe fn driver_sygv( + order: FlagOrder, + itype: blas_int, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + b: *mut T, + ldb: usize, + w: *mut T, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + b, + &(ldb as _), + w, + &mut work_query, + &lwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + b, + &(ldb as _), + w, + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * n) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + b_t.as_mut_ptr(), + &(ldb_t as _), + w, + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [chegv_]; + [Complex] [zhegv_]; +)] +impl SYGVDriverAPI for DeviceBLAS { + unsafe fn driver_sygv( + order: FlagOrder, + itype: blas_int, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + b: *mut T, + ldb: usize, + w: *mut ::Real, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Allocate memory for working array(s) + let rwork_len = (3 * n - 2).max(1); + let mut rwork: Vec<::Real> = match uninitialized_vec(rwork_len) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + b as *mut _, + &(ldb as _), + w as *mut _, + &mut work_query as *mut _ as *mut _, + &lwork, + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + b as *mut _, + &(ldb as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * n) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + b_t.as_mut_ptr() as *mut _, + &(ldb_t as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygvd.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygvd.rs new file mode 100644 index 00000000..c33b793d --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/eigh/sygvd.rs @@ -0,0 +1,275 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [ssygvd_]; + [f64] [dsygvd_]; +)] +impl SYGVDDriverAPI for DeviceBLAS { + unsafe fn driver_sygvd( + order: FlagOrder, + itype: blas_int, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + b: *mut T, + ldb: usize, + w: *mut T, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let liwork = -1; + let mut work_query = 0.0; + let mut iwork_query = 0; + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + b, + &(ldb as _), + w, + &mut work_query, + &lwork, + &mut iwork_query, + &liwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + let liwork = iwork_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a, + &(lda as _), + b, + &(ldb as _), + w, + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * n) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + b_t.as_mut_ptr(), + &(ldb_t as _), + w, + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [chegvd_]; + [Complex] [zhegvd_]; +)] +impl SYGVDDriverAPI for DeviceBLAS { + unsafe fn driver_sygvd( + order: FlagOrder, + itype: blas_int, + jobz: char, + uplo: FlagUpLo, + n: usize, + a: *mut T, + lda: usize, + b: *mut T, + ldb: usize, + w: *mut ::Real, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let lrwork = -1; + let liwork = -1; + let mut work_query = 0.0; + let mut rwork_query = 0.0; + let mut iwork_query = 0; + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + b as *mut _, + &(ldb as _), + w as *mut _, + &mut work_query as *mut _ as *mut _, + &lwork, + &mut rwork_query as *mut _ as *mut _, + &lrwork, + &mut iwork_query, + &liwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + let lrwork = rwork_query as usize; + let liwork = iwork_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + let mut rwork: Vec<::Real> = match uninitialized_vec(lrwork) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a as *mut _, + &(lda as _), + b as *mut _, + &(ldb as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &(lrwork as _), + iwork.as_mut_ptr() as *mut _, + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * n) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &itype, + &(jobz as _), + &uplo.into(), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + b_t.as_mut_ptr() as *mut _, + &(ldb_t as _), + w as *mut _, + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &(lrwork as _), + iwork.as_mut_ptr(), + &(liwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/mod.rs new file mode 100644 index 00000000..6eac0b5f --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/mod.rs @@ -0,0 +1,3 @@ +pub mod eigh; +pub mod solve; +pub mod svd; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/gesv.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/gesv.rs new file mode 100644 index 00000000..05d876d8 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/gesv.rs @@ -0,0 +1,141 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [sgesv_]; + [f64] [dgesv_]; +)] +impl GESVDriverAPI for DeviceBLAS { + unsafe fn driver_gesv( + order: FlagOrder, + n: usize, + nrhs: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + b: *mut T, + ldb: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + func_(&(n as _), &(nrhs as _), a, &(lda as _), ipiv, b, &(ldb as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * nrhs) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function + func_( + &(n as _), + &(nrhs as _), + a_t.as_mut_ptr(), + &(lda_t as _), + ipiv, + b_t.as_mut_ptr(), + &(ldb_t as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cgesv_]; + [Complex] [zgesv_]; +)] +impl GESVDriverAPI for DeviceBLAS { + unsafe fn driver_gesv( + order: FlagOrder, + n: usize, + nrhs: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + b: *mut T, + ldb: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + func_(&(n as _), &(nrhs as _), a as *mut _, &(lda as _), ipiv, b as *mut _, &(ldb as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * nrhs) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function + func_( + &(n as _), + &(nrhs as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + ipiv, + b_t.as_mut_ptr() as *mut _, + &(ldb_t as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getrf.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getrf.rs new file mode 100644 index 00000000..3b5683fc --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getrf.rs @@ -0,0 +1,101 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [sgetrf_]; + [f64] [dgetrf_]; +)] +impl GETRFDriverAPI for DeviceBLAS { + unsafe fn driver_getrf( + order: FlagOrder, + m: usize, + n: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&(m as _), &(n as _), a, &(lda as _), ipiv, &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_(&(m as _), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, &mut info); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cgetrf_]; + [Complex] [zgetrf_]; +)] +impl GETRFDriverAPI for DeviceBLAS { + unsafe fn driver_getrf( + order: FlagOrder, + m: usize, + n: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&(m as _), &(n as _), a as *mut _, &(lda as _), ipiv, &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_(&(m as _), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), ipiv, &mut info); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getri.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getri.rs new file mode 100644 index 00000000..a854e664 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/getri.rs @@ -0,0 +1,124 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; + +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [sgetri_]; + [f64] [dgetri_]; +)] +impl GETRIDriverAPI for DeviceBLAS { + unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_(&(n as _), a, &(lda as _), ipiv, &mut work_query, &lwork, &mut info); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&(n as _), a, &(lda as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_(&(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cgetri_]; + [Complex] [zgetri_]; +)] +impl GETRIDriverAPI for DeviceBLAS { + unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query: T = num::zero(); + func_(&(n as _), a as *mut _, &(lda as _), ipiv, &mut work_query as *mut _ as *mut _, &lwork, &mut info); + if info != 0 { + return info; + } + let lwork = work_query.re as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&(n as _), a as *mut _, &(lda as _), ipiv, work.as_mut_ptr() as *mut _, &(lwork as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_( + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + ipiv, + work.as_mut_ptr() as *mut _, + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/mod.rs new file mode 100644 index 00000000..4c83e6b9 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/mod.rs @@ -0,0 +1,5 @@ +pub mod gesv; +pub mod getrf; +pub mod getri; +pub mod potrf; +pub mod sysv; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/potrf.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/potrf.rs new file mode 100644 index 00000000..9c31b212 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/potrf.rs @@ -0,0 +1,87 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [spotrf_]; + [f64] [dpotrf_]; +)] +impl POTRFDriverAPI for DeviceBLAS { + unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&uplo.into(), &(n as _), a, &(lda as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_(&uplo.into(), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), &mut info); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cpotrf_]; + [Complex] [zpotrf_]; +)] +impl POTRFDriverAPI for DeviceBLAS { + unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int { + use lapack_ffi::lapack::func_; + + let mut info = 0; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_(&uplo.into(), &(n as _), a as *mut _, &(lda as _), &mut info); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + // Call LAPACK function and adjust info + func_(&uplo.into(), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), &mut info); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/sysv.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/sysv.rs new file mode 100644 index 00000000..5092db47 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/solve/sysv.rs @@ -0,0 +1,230 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; + +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [ssysv_]; + [f64] [dsysv_]; +)] +impl SYSVDriverAPI for DeviceBLAS { + unsafe fn driver_sysv( + order: FlagOrder, + uplo: FlagUpLo, + n: usize, + nrhs: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + b: *mut T, + ldb: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a, + &(n as _), + ipiv, + b, + &(n as _), + &mut work_query, + &lwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a, + &(lda as _), + ipiv, + b, + &(ldb as _), + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * nrhs) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a_t.as_mut_ptr(), + &(lda_t as _), + ipiv, + b_t.as_mut_ptr(), + &(ldb_t as _), + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} + +#[duplicate_item( + T func_ HERMI ; + [Complex] [csysv_] [false]; + [Complex] [chesv_] [true ]; + [Complex] [zsysv_] [false]; + [Complex] [zhesv_] [true ]; +)] +impl SYSVDriverAPI for DeviceBLAS { + unsafe fn driver_sysv( + order: FlagOrder, + uplo: FlagUpLo, + n: usize, + nrhs: usize, + a: *mut T, + lda: usize, + ipiv: *mut blas_int, + b: *mut T, + ldb: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a as *mut _, + &(n as _), + ipiv, + b as *mut _, + &(n as _), + &mut work_query as *mut _ as *mut _, + &lwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work arrays + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a as *mut _, + &(lda as _), + ipiv, + b as *mut _, + &(ldb as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = n.max(1); + let ldb_t = n.max(1); + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(n * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let mut b_t: Vec = match uninitialized_vec(n * nrhs) { + Ok(b_t) => b_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, n * lda); + let b_slice = from_raw_parts_mut(b, n * ldb); + let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0); + let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0); + let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap(); + // Call LAPACK function and adjust info + func_( + &uplo.into(), + &(n as _), + &(nrhs as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + ipiv, + b_t.as_mut_ptr() as *mut _, + &(ldb_t as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap(); + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesdd.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesdd.rs new file mode 100644 index 00000000..4c7d66e9 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesdd.rs @@ -0,0 +1,359 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::Complex; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; + +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [sgesdd_]; + [f64] [dgesdd_]; +)] +impl GESDDDriverAPI for DeviceBLAS { + unsafe fn driver_gesdd( + order: FlagOrder, + jobz: char, + m: usize, + n: usize, + a: *mut T, + lda: usize, + s: *mut T, + u: *mut T, + ldu: usize, + vt: *mut T, + ldvt: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Allocate memory for temporary array(s) + let liwork = 8 * m.min(n); + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &(jobz as _), + &(m as _), + &(n as _), + a, + &(m.max(n) as _), + s, + u, + &(m.max(n) as _), + vt, + &(m.max(n) as _), + &mut work_query, + &lwork, + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &(m as _), + &(n as _), + a, + &(lda as _), + s, + u, + &(ldu as _), + vt, + &(ldvt as _), + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 }; + let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) { + m + } else if jobz == 'S' { + m.min(n) + } else { + 1 + }; + let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) { + n + } else if jobz == 'S' { + m.min(n) + } else { + 1 + }; + let ldu_t = nrows_u.max(1); + let ldvt_t = nrows_vt.max(1); + + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + + let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { + match uninitialized_vec(nrows_u * ncols_u) { + Ok(u_t) => Some(u_t), + Err(_) => return -1011, + } + } else { + None + }; + + let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) { + match uninitialized_vec(nrows_vt * n) { + Ok(vt_t) => Some(vt_t), + Err(_) => return -1011, + } + } else { + None + }; + + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &(m as _), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + s, + u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()), + &(ldu_t as _), + vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()), + &(ldvt_t as _), + work.as_mut_ptr(), + &(lwork as _), + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + + if let Some(u_t) = u_t { + let u_slice = from_raw_parts_mut(u, nrows_u * ldu); + let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0); + let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0); + orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap(); + } + + if let Some(vt_t) = vt_t { + let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt); + let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0); + let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0); + orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap(); + } + } + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cgesdd_]; + [Complex] [zgesdd_]; +)] +impl GESDDDriverAPI for DeviceBLAS { + unsafe fn driver_gesdd( + order: FlagOrder, + jobz: char, + m: usize, + n: usize, + a: *mut T, + lda: usize, + s: *mut ::Real, + u: *mut T, + ldu: usize, + vt: *mut T, + ldvt: usize, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Allocate memory for temporary array(s) + let liwork = 8 * m.min(n); + let mut iwork: Vec = match uninitialized_vec(liwork) { + Ok(iwork) => iwork, + Err(_) => return -1010, + }; + + // Query optimal working array(s) size + let mut info = 0; + let lwork = -1; + let lrwork = + if jobz == 'N' { 7 * m.min(n) } else { m.min(n) * (5 * m.min(n) + 7).max(2 * m.max(n) + 2 * m.min(n) + 1) }; + let mut work_query = Complex::new(0.0, 0.0); + let mut rwork_query = 0.0; + func_( + &(jobz as _), + &(m as _), + &(n as _), + a as *mut _, + &(m.max(n) as _), + s as *mut _, + u as *mut _, + &(m.max(n) as _), + vt as *mut _, + &(m.max(n) as _), + &mut work_query as *mut _ as *mut _, + &lwork, + &mut rwork_query as *mut _ as *mut _, + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query.re as usize; + + // Allocate memory for temporary array(s) + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + let mut rwork: Vec<::Real> = match uninitialized_vec(lrwork) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &(m as _), + &(n as _), + a as *mut _, + &(lda as _), + s as *mut _, + u as *mut _, + &(ldu as _), + vt as *mut _, + &(ldvt as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 }; + let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) { + m + } else if jobz == 'S' { + m.min(n) + } else { + 1 + }; + let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) { + n + } else if jobz == 'S' { + m.min(n) + } else { + 1 + }; + let ldu_t = nrows_u.max(1); + let ldvt_t = nrows_vt.max(1); + + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + + let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { + match uninitialized_vec(nrows_u * ncols_u) { + Ok(u_t) => Some(u_t), + Err(_) => return -1011, + } + } else { + None + }; + + let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) { + match uninitialized_vec(nrows_vt * n) { + Ok(vt_t) => Some(vt_t), + Err(_) => return -1011, + } + } else { + None + }; + + // Call LAPACK function and adjust info + func_( + &(jobz as _), + &(m as _), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + s as *mut _, + u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _, + &(ldu_t as _), + vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _, + &(ldvt_t as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + iwork.as_mut_ptr(), + &mut info, + ); + if info != 0 { + return info; + } + + // Transpose output matrices + orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + + if let Some(u_t) = u_t { + let u_slice = from_raw_parts_mut(u, nrows_u * ldu); + let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0); + let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0); + orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap(); + } + + if let Some(vt_t) = vt_t { + let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt); + let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0); + let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0); + orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap(); + } + } + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesvd.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesvd.rs new file mode 100644 index 00000000..c7274128 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/gesvd.rs @@ -0,0 +1,363 @@ +use crate::lapack_ffi; +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use num::{Complex, Zero}; +use rstsr_blas_traits::prelude::*; +use rstsr_common::prelude_dev::*; + +use rstsr_native_impl::prelude_dev::*; +use std::slice::from_raw_parts_mut; + +#[duplicate_item( + T func_ ; + [f32] [sgesvd_]; + [f64] [dgesvd_]; +)] +impl GESVDDriverAPI for DeviceBLAS { + unsafe fn driver_gesvd( + order: FlagOrder, + jobu: char, + jobvt: char, + m: usize, + n: usize, + a: *mut T, + lda: usize, + s: *mut T, + u: *mut T, + ldu: usize, + vt: *mut T, + ldvt: usize, + superb: *mut T, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Query optimal working array size + let mut info = 0; + let lwork = -1; + let mut work_query = 0.0; + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a, + &(lda as _), + s, + u, + &(ldu as _), + vt, + &(ldvt as _), + &mut work_query, + &lwork, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query as usize; + + // Allocate memory for work array + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a, + &(lda as _), + s, + u, + &(ldu as _), + vt, + &(ldvt as _), + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 }; + let ncols_u = if jobu == 'A' { + m + } else if jobu == 'S' { + m.min(n) + } else { + 1 + }; + let nrows_vt = if jobvt == 'A' { + n + } else if jobvt == 'S' { + m.min(n) + } else { + 1 + }; + let ldu_t = nrows_u.max(1); + let ldvt_t = nrows_vt.max(1); + + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + + let mut u_t = if jobu == 'A' || jobu == 'S' { + match uninitialized_vec(nrows_u * ncols_u) { + Ok(u_t) => u_t, + Err(_) => return -1011, + } + } else { + Vec::new() + }; + + let mut vt_t = if jobvt == 'A' || jobvt == 'S' { + match uninitialized_vec(nrows_vt * n) { + Ok(vt_t) => vt_t, + Err(_) => return -1011, + } + } else { + Vec::new() + }; + + // Call LAPACK function + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a_t.as_mut_ptr(), + &(lda_t as _), + s, + if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() } else { u }, + &(ldu_t as _), + if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() } else { vt }, + &(ldvt_t as _), + work.as_mut_ptr(), + &(lwork as _), + &mut info, + ); + if info != 0 { + return info; + } + + // Transpose output matrices + orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + + if jobu == 'A' || jobu == 'S' { + let u_slice = from_raw_parts_mut(u, nrows_u * ldu); + let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0); + let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap(); + } + + if jobvt == 'A' || jobvt == 'S' { + let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt); + let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0); + let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap(); + } + } + + // Backup superb data + let min_mn = m.min(n); + for i in 0..min_mn - 1 { + superb.add(i).write(work[i + 1]); + } + + return info; + } +} + +#[duplicate_item( + T func_ ; + [Complex] [cgesvd_]; + [Complex] [zgesvd_]; +)] +impl GESVDDriverAPI for DeviceBLAS { + unsafe fn driver_gesvd( + order: FlagOrder, + jobu: char, + jobvt: char, + m: usize, + n: usize, + a: *mut T, + lda: usize, + s: *mut ::Real, + u: *mut T, + ldu: usize, + vt: *mut T, + ldvt: usize, + superb: *mut ::Real, + ) -> blas_int { + use lapack_ffi::lapack::func_; + + // Allocate rwork + let min_mn = m.min(n); + let mut rwork: Vec<::Real> = match uninitialized_vec(5 * min_mn) { + Ok(rwork) => rwork, + Err(_) => return -1010, + }; + + // Query optimal working array size + let mut info = 0; + let lwork = -1; + let mut work_query = ::zero(); + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a as *mut _, + &(lda as _), + s as *mut _, + u as *mut _, + &(ldu as _), + vt as *mut _, + &(ldvt as _), + &mut work_query as *mut _ as *mut _, + &lwork, + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + let lwork = work_query.re() as usize; + + // Allocate memory for work array + let mut work: Vec = match uninitialized_vec(lwork) { + Ok(work) => work, + Err(_) => return -1010, + }; + + if order == ColMajor { + // Call LAPACK function + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a as *mut _, + &(lda as _), + s as *mut _, + u as *mut _, + &(ldu as _), + vt as *mut _, + &(ldvt as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + } else { + let lda_t = m.max(1); + let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 }; + let ncols_u = if jobu == 'A' { + m + } else if jobu == 'S' { + m.min(n) + } else { + 1 + }; + let nrows_vt = if jobvt == 'A' { + n + } else if jobvt == 'S' { + m.min(n) + } else { + 1 + }; + let ldu_t = nrows_u.max(1); + let ldvt_t = nrows_vt.max(1); + + // Transpose input matrices + let mut a_t: Vec = match uninitialized_vec(m * n) { + Ok(a_t) => a_t, + Err(_) => return -1011, + }; + let a_slice = from_raw_parts_mut(a, m * lda); + let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0); + let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap(); + + let mut u_t = if jobu == 'A' || jobu == 'S' { + match uninitialized_vec(nrows_u * ncols_u) { + Ok(u_t) => u_t, + Err(_) => return -1011, + } + } else { + Vec::new() + }; + + let mut vt_t = if jobvt == 'A' || jobvt == 'S' { + match uninitialized_vec(nrows_vt * n) { + Ok(vt_t) => vt_t, + Err(_) => return -1011, + } + } else { + Vec::new() + }; + + // Call LAPACK function + func_( + &(jobu as _), + &(jobvt as _), + &(m as _), + &(n as _), + a_t.as_mut_ptr() as *mut _, + &(lda_t as _), + s as *mut _, + if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() as *mut _ } else { u as *mut _ }, + &(ldu_t as _), + if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() as *mut _ } else { vt as *mut _ }, + &(ldvt_t as _), + work.as_mut_ptr() as *mut _, + &(lwork as _), + rwork.as_mut_ptr() as *mut _, + &mut info, + ); + if info != 0 { + return info; + } + + // Transpose output matrices + orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap(); + + if jobu == 'A' || jobu == 'S' { + let u_slice = from_raw_parts_mut(u, nrows_u * ldu); + let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0); + let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap(); + } + + if jobvt == 'A' || jobvt == 'S' { + let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt); + let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0); + let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0); + orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap(); + } + } + + // Backup superb data + #[allow(clippy::needless_range_loop)] + for i in 0..min_mn - 1 { + superb.add(i).write(rwork[i]); + } + + return info; + } +} diff --git a/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/mod.rs new file mode 100644 index 00000000..49f55705 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/lapack/svd/mod.rs @@ -0,0 +1,2 @@ +pub mod gesdd; +pub mod gesvd; diff --git a/crates-device/rstsr-accelerate/src/driver_impl/mod.rs b/crates-device/rstsr-accelerate/src/driver_impl/mod.rs new file mode 100644 index 00000000..dfdfae86 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/driver_impl/mod.rs @@ -0,0 +1,16 @@ +pub mod cblas; + +pub mod lapack; + +use crate::DeviceBLAS; +use duplicate::duplicate_item; +use num::Complex; +use rstsr_blas_traits::prelude::*; + +impl BlasDriverBaseAPI for DeviceBLAS where T: BlasFloat {} + +#[duplicate_item(T; [f32]; [f64]; [Complex]; [Complex])] +impl BlasDriverAPI for DeviceBLAS {} + +#[duplicate_item(T; [f32]; [f64]; [Complex]; [Complex])] +impl LapackDriverAPI for DeviceBLAS {} diff --git a/crates-device/rstsr-accelerate/src/lib.rs b/crates-device/rstsr-accelerate/src/lib.rs new file mode 100644 index 00000000..2473ce20 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/lib.rs @@ -0,0 +1,30 @@ +#![allow(clippy::needless_return)] +#![allow(non_camel_case_types)] +#![doc = include_str!("../readme.md")] + +pub mod conversion; +pub mod creation; +pub mod device; +pub mod matmul; +pub mod matmul_impl; +pub mod prelude_dev; +pub mod rayon_auto_impl; +pub mod threading; + +pub mod driver_impl; +#[cfg(feature = "linalg")] +pub mod linalg_auto_impl; + +#[cfg(feature = "sci")] +pub mod sci_auto_impl; + +use rstsr_core::prelude_dev::DeviceCpuRayon; + +#[derive(Clone, Debug)] +pub struct DeviceAccelerate { + base: DeviceCpuRayon, +} + +pub(crate) use rstsr_lapack_ffi as lapack_ffi; +pub(crate) use DeviceAccelerate as DeviceBLAS; +pub(crate) use DeviceAccelerate as DeviceRayonAutoImpl; diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/cholesky.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/cholesky.rs new file mode 100644 index 00000000..62e22eb5 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/cholesky.rs @@ -0,0 +1,90 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region full-args */ + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl CholeskyAPI for (Tr, Option) +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn cholesky_f(self) -> Result { + let (a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = a.view().into_dim::(); + let result = ref_impl_cholesky_f(a.view().into(), uplo)?.into_owned(); + Ok(result.into_dim::().into_dim::()) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl CholeskyAPI for (Tr, Option) +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tr; + fn cholesky_f(self) -> Result { + let (mut a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a_ix2 = a.view_mut().into_dim::(); + let result = ref_impl_cholesky_f(a_ix2.into(), uplo)?; + result.clone_to_mut(); + Ok(a) + } +} + +/* #endregion */ + +/* #region sub-args */ + +#[duplicate_item( + ImplStruct args_tuple internal_tuple ; + [(Tr, FlagUpLo)] [(a, uplo)] [(a, Some(uplo))]; +)] +impl CholeskyAPI for ImplStruct +where + (Tr, Option): CholeskyAPI, +{ + type Out = <(Tr, Option) as CholeskyAPI>::Out; + fn cholesky_f(self) -> Result { + let args_tuple = self; + CholeskyAPI::::cholesky_f(internal_tuple) + } +} + +#[duplicate_item( + ImplType Tr; + ['a, T, D, R: DataAPI>] [&'a TensorAny]; + ['a, T, D, ] [TensorView<'a, T, DeviceBLAS, D> ]; + [ T, D ] [Tensor ]; + ['a, T, D ] [TensorMut<'a, T, DeviceBLAS, D> ]; +)] +impl CholeskyAPI for Tr +where + T: BlasFloat, + D: DimAPI, + (Tr, Option): CholeskyAPI, +{ + type Out = <(Tr, Option) as CholeskyAPI>::Out; + fn cholesky_f(self) -> Result { + let a = self; + CholeskyAPI::::cholesky_f((a, None)) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/det.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/det.rs new file mode 100644 index 00000000..1313709e --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/det.rs @@ -0,0 +1,47 @@ +use crate::DeviceBLAS; +use num::complex::ComplexFloat; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl DetAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = T; + fn det_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = self; + let a_view = a.view().into_dim::(); + let (sign, logabsdet) = ref_impl_slogdet_f(a_view.into())?; + Ok(sign * logabsdet.exp()) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl DetAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = T; + fn det_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let mut a = self; + let a_view = a.view_mut().into_dim::(); + let (sign, logabsdet) = ref_impl_slogdet_f(a_view.into())?; + Ok(sign * logabsdet.exp()) + } +} diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigh.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigh.rs new file mode 100644 index 00000000..228c6671 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigh.rs @@ -0,0 +1,216 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region simple eigh */ + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EighAPI for (Tr, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tensor>; + fn eigh_f(self) -> Result { + let (a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a_view = a.view().into_dim::(); + let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?; + let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + let vecs = vecs.unwrap().into_owned().into_dim::().into_dim::(); + return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs }); + } +} + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EighAPI for Tr +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tensor>; + fn eigh_f(self) -> Result { + let a = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EighAPI::::eigh_f((a, uplo)) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl EighAPI for (Tr, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tr>; + fn eigh_f(self) -> Result { + let (mut a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a_view = a.view_mut().into_dim::(); + let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?; + let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + vecs.unwrap().clone_to_mut(); + return Ok(EighResult { eigenvalues: vals, eigenvectors: a }); + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl EighAPI for Tr +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tr>; + fn eigh_f(self) -> Result { + let a = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EighAPI::::eigh_f((a, uplo)) + } +} + +/* #endregion */ + +/* #region general eigh */ + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EighAPI for (TrA, TrB, FlagUpLo, i32) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tensor>; + fn eigh_f(self) -> Result { + let (a, b, uplo, eig_type) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?; + let a_view = a.view().into_dim::(); + let b_view = b.view().into_dim::(); + let eigh_args = EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).build()?; + let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + let vecs = vecs.unwrap().into_owned().into_dim::().into_dim::(); + return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs }); + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EighAPI for (TrA, TrB, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tensor>; + fn eigh_f(self) -> Result { + let (a, b, uplo) = self; + EighAPI::::eigh_f((a, b, uplo, 1)) + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EighAPI for (TrA, TrB) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, Tensor>; + fn eigh_f(self) -> Result { + let (a, b) = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EighAPI::::eigh_f((a, b, uplo, 1)) + } +} + +/* #endregion */ + +/* #region EighArgs implementation */ + +impl<'a, 'b, T> EighAPI for EighArgs<'a, 'b, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, TensorMutable<'a, T, DeviceBLAS, Ix2>>; + fn eigh_f(self) -> Result { + let args = self.build()?; + rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?; + let (vals, vecs) = ref_impl_eigh_simple_f(args)?; + Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() }) + } +} + +impl<'a, 'b, T> EighAPI for EighArgs_<'a, 'b, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = EighResult, TensorMutable<'a, T, DeviceBLAS, Ix2>>; + fn eigh_f(self) -> Result { + let args = self; + rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?; + let (vals, vecs) = ref_impl_eigh_simple_f(args)?; + Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() }) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigvalsh.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigvalsh.rs new file mode 100644 index 00000000..ecfc776a --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/eigvalsh.rs @@ -0,0 +1,214 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region simple eigh */ + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EigvalshAPI for (Tr, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let (a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a_view = a.view().into_dim::(); + let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?; + let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + return Ok(vals); + } +} + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EigvalshAPI for Tr +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let a = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EigvalshAPI::::eigvalsh_f((a, uplo)) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl EigvalshAPI for (Tr, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let (mut a, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a_view = a.view_mut().into_dim::(); + let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?; + let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + return Ok(vals); + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl EigvalshAPI for Tr +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let a = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EigvalshAPI::::eigvalsh_f((a, uplo)) + } +} + +/* #endregion */ + +/* #region general eigh */ + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EigvalshAPI for (TrA, TrB, FlagUpLo, i32) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let (a, b, uplo, eig_type) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?; + let a_view = a.view().into_dim::(); + let b_view = b.view().into_dim::(); + let eigh_args = + EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).eigvals_only(true).build()?; + let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?; + let vals = vals.into_dim::().into_dim::(); + return Ok(vals); + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EigvalshAPI for (TrA, TrB, FlagUpLo) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let (a, b, uplo) = self; + EigvalshAPI::::eigvalsh_f((a, b, uplo, 1)) + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, D, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, D, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, D>]; + [T, D, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny ]; + [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl EigvalshAPI for (TrA, TrB) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let (a, b) = self; + let uplo = match a.device().default_order() { + RowMajor => Lower, + ColMajor => Upper, + }; + EigvalshAPI::::eigvalsh_f((a, b, uplo, 1)) + } +} + +/* #endregion */ + +/* #region EighArgs implementation */ + +impl<'a, 'b, T> EigvalshAPI for EighArgs<'a, 'b, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let args = self.build()?; + rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?; + let (vals, _) = ref_impl_eigh_simple_f(args)?; + Ok(vals) + } +} + +impl<'a, 'b, T> EigvalshAPI for EighArgs_<'a, 'b, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn eigvalsh_f(self) -> Result { + let args = self; + rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?; + let (vals, _) = ref_impl_eigh_simple_f(args)?; + Ok(vals) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/inv.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/inv.rs new file mode 100644 index 00000000..cb551384 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/inv.rs @@ -0,0 +1,46 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl InvAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn inv_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = self.view().into_dim::(); + let result = ref_impl_inv_f(a.into())?.into_owned(); + Ok(result.into_dim::().into_dim::()) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl InvAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tr; + fn inv_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let mut a = self; + let a_view = a.view_mut().into_dim::(); + let result = ref_impl_inv_f(a_view.into())?; + result.clone_to_mut(); + Ok(a) + } +} diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/mod.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/mod.rs new file mode 100644 index 00000000..0d2e0e77 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/mod.rs @@ -0,0 +1,12 @@ +pub mod cholesky; +pub mod det; +pub mod eigh; +pub mod eigvalsh; +pub mod inv; +pub mod pinv; +pub mod slogdet; +pub mod solve_general; +pub mod solve_symmetric; +pub mod solve_triangular; +pub mod svd; +pub mod svdvals; diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/pinv.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/pinv.rs new file mode 100644 index 00000000..1b369f42 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/pinv.rs @@ -0,0 +1,85 @@ +use crate::DeviceBLAS; +use num::FromPrimitive; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +impl PinvAPI for (&TensorAny, T::Real, T::Real) +where + R: DataAPI>, + T: BlasFloat, + T::Real: FromPrimitive, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = PinvResult>; + fn pinv_f(self) -> Result { + let (a, atol, rtol) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = a.view().into_dim::(); + let (b, rank) = ref_impl_pinv_f(a, Some(atol), Some(rtol))?.into(); + let b = b.into_dim::().into_dim::(); + return Ok(PinvResult { pinv: b, rank }); + } +} + +impl PinvAPI for &TensorAny +where + R: DataAPI>, + T: BlasFloat, + T::Real: FromPrimitive, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = PinvResult>; + fn pinv_f(self) -> Result { + let a = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = a.view().into_dim::(); + let (pinv, rank) = ref_impl_pinv_f(a, None, None)?.into(); + let pinv = pinv.into_dim::().into_dim::(); + return Ok(PinvResult { pinv, rank }); + } +} + +#[duplicate_item( + Tr ; + [Tensor ]; + [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl PinvAPI for (Tr, T::Real, T::Real) +where + T: BlasFloat, + T::Real: FromPrimitive, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = PinvResult>; + fn pinv_f(self) -> Result { + let (a, atol, rtol) = self; + PinvAPI::::pinv_f((&a, atol, rtol)) + } +} + +#[duplicate_item( + Tr ; + [Tensor ]; + [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl PinvAPI for Tr +where + T: BlasFloat, + T::Real: FromPrimitive, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = PinvResult>; + fn pinv_f(self) -> Result { + let a = self; + PinvAPI::::pinv_f(&a) + } +} diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/slogdet.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/slogdet.rs new file mode 100644 index 00000000..931453cc --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/slogdet.rs @@ -0,0 +1,46 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +#[duplicate_item( + ImplType Tr ; + [T, D, R: DataAPI>] [&TensorAny ]; + [T, D ] [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl SLogDetAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = SLogDetResult; + fn slogdet_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = self; + let a_view = a.view().into_dim::(); + let (sign, logabsdet) = ref_impl_slogdet_f(a_view.into())?; + Ok(SLogDetResult { sign, logabsdet }) + } +} + +#[duplicate_item( + ImplType Tr ; + ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>]; + [ T, D] [Tensor ]; +)] +impl SLogDetAPI for Tr +where + T: BlasFloat, + D: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = SLogDetResult; + fn slogdet_f(self) -> Result { + rstsr_assert_eq!(self.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let mut a = self; + let a_view = a.view_mut().into_dim::(); + let (sign, logabsdet) = ref_impl_slogdet_f(a_view.into())?; + Ok(SLogDetResult { sign, logabsdet }) + } +} diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_general.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_general.rs new file mode 100644 index 00000000..8653da91 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_general.rs @@ -0,0 +1,134 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, DA, DB, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveGeneralAPI for (TrA, TrB) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_general_f(self) -> Result { + let (a, b) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB, R: DataAPI>] [&TensorAny ] [TensorMut<'b, T, DeviceBLAS, DB>]; + ['b, T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB, R: DataAPI>] [&TensorAny ] [Tensor ]; + [ T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor ]; +)] +impl SolveGeneralAPI for (TrA, TrB) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_general_f(self) -> Result { + let (a, mut b) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?; + result.clone_to_mut(); + Ok(b) + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, R: DataAPI>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI>] [Tensor ] [&TensorAny ]; + [T, DA, DB, ] [Tensor ] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveGeneralAPI for (TrA, TrB) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_general_f(self) -> Result { + let (mut a, b) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor ]; + ['b, T, DA, DB] [Tensor ] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [Tensor ] [Tensor ]; +)] +impl SolveGeneralAPI for (TrA, TrB) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_general_f(self) -> Result { + let (mut a, mut b) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?; + result.clone_to_mut(); + Ok(b) + } +} diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_symmetric.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_symmetric.rs new file mode 100644 index 00000000..d61db21e --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_symmetric.rs @@ -0,0 +1,160 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region full-args */ + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, DA, DB, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveSymmetricAPI for (TrA, TrB, bool, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_symmetric_f(self) -> Result { + let (a, b, hermi, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB, R: DataAPI>] [&TensorAny ] [TensorMut<'b, T, DeviceBLAS, DB>]; + ['b, T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB, R: DataAPI>] [&TensorAny ] [Tensor ]; + [ T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor ]; +)] +impl SolveSymmetricAPI for (TrA, TrB, bool, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_symmetric_f(self) -> Result { + let (a, mut b, hermi, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?; + result.clone_to_mut(); + Ok(b) + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, R: DataAPI>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI>] [Tensor ] [&TensorAny ]; + [T, DA, DB, ] [Tensor ] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveSymmetricAPI for (TrA, TrB, bool, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_symmetric_f(self) -> Result { + let (mut a, b, hermi, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor ]; + ['b, T, DA, DB] [Tensor ] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [Tensor ] [Tensor ]; +)] +impl SolveSymmetricAPI for (TrA, TrB, bool, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_symmetric_f(self) -> Result { + let (mut a, mut b, hermi, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?; + result.clone_to_mut(); + Ok(b) + } +} + +/* #endregion */ + +/* #region sub-args */ + +#[duplicate_item( + ImplStruct args_tuple internal_tuple ; + [(TrA, TrB, bool, FlagUpLo)] [(a, b, hermi, uplo)] [(a, b, hermi, Some(uplo))]; + [(TrA, TrB, bool, )] [(a, b, hermi, )] [(a, b, hermi, None )]; + [(TrA, TrB, FlagUpLo)] [(a, b, uplo)] [(a, b, true , Some(uplo))]; + [(TrA, TrB, )] [(a, b, )] [(a, b, true , None )]; +)] +impl SolveSymmetricAPI for ImplStruct +where + (TrA, TrB, bool, Option): SolveSymmetricAPI, +{ + type Out = <(TrA, TrB, bool, Option) as SolveSymmetricAPI>::Out; + fn solve_symmetric_f(self) -> Result { + let args_tuple = self; + SolveSymmetricAPI::::solve_symmetric_f(internal_tuple) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_triangular.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_triangular.rs new file mode 100644 index 00000000..bbb27cc9 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/solve_triangular.rs @@ -0,0 +1,158 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region full-args */ + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, Ra: DataAPI>, Rb: DataAPI>] [&TensorAny] [&TensorAny]; + [T, DA, DB, R: DataAPI> ] [&TensorAny ] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI> ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveTriangularAPI for (TrA, TrB, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_triangular_f(self) -> Result { + let (a, b, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_triangular_f(a_view.into(), b_view.into(), uplo)?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB, R: DataAPI>] [&TensorAny ] [TensorMut<'b, T, DeviceBLAS, DB>]; + ['b, T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB, R: DataAPI>] [&TensorAny ] [Tensor ]; + [ T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor ]; +)] +impl SolveTriangularAPI for (TrA, TrB, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_triangular_f(self) -> Result { + let (a, mut b, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_triangular_f(a_view.into(), b_view.into(), uplo)?; + result.clone_to_mut(); + Ok(b) + } +} + +#[duplicate_item( + ImplType TrA TrB ; + [T, DA, DB, R: DataAPI>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny ]; + [T, DA, DB, ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>]; + [T, DA, DB, R: DataAPI>] [Tensor ] [&TensorAny ]; + [T, DA, DB, ] [Tensor ] [TensorView<'_, T, DeviceBLAS, DB>]; +)] +impl SolveTriangularAPI for (TrA, TrB, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn solve_triangular_f(self) -> Result { + let (mut a, b, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i((.., None)).into_dim::(), + false => b.view().into_dim::(), + }; + let result = ref_impl_solve_triangular_f(a_view.into(), b_view.into(), uplo)?; + let result = result.into_owned().into_dim::(); + match is_b_vec { + true => Ok(result.into_shape(-1).into_dim::()), + false => Ok(result.into_dim::()), + } + } +} + +#[duplicate_item( + ImplType TrA TrB ; + ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor ]; + ['b, T, DA, DB] [Tensor ] [TensorMut<'b, T, DeviceBLAS, DB>]; + [ T, DA, DB] [Tensor ] [Tensor ]; +)] +impl SolveTriangularAPI for (TrA, TrB, Option) +where + T: BlasFloat, + DA: DimAPI, + DB: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = TrB; + fn solve_triangular_f(self) -> Result { + let (mut a, mut b, uplo) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?; + let is_b_vec = b.ndim() == 1; + let a_view = a.view_mut().into_dim::(); + let b_view = match is_b_vec { + true => b.i_mut((.., None)).into_dim::(), + false => b.view_mut().into_dim::(), + }; + let result = ref_impl_solve_triangular_f(a_view.into(), b_view.into(), uplo)?; + result.clone_to_mut(); + Ok(b) + } +} + +/* #endregion */ + +/* #region sub-args */ + +#[duplicate_item( + ImplStruct args_tuple internal_tuple ; + [(TrA, TrB, FlagUpLo)] [(a, b, uplo)] [(a, b, Some(uplo))]; + [(TrA, TrB, )] [(a, b, )] [(a, b, None )]; +)] +impl SolveTriangularAPI for ImplStruct +where + (TrA, TrB, Option): SolveTriangularAPI, +{ + type Out = <(TrA, TrB, Option) as SolveTriangularAPI>::Out; + fn solve_triangular_f(self) -> Result { + let args_tuple = self; + SolveTriangularAPI::::solve_triangular_f(internal_tuple) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/svd.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/svd.rs new file mode 100644 index 00000000..9e10e916 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/svd.rs @@ -0,0 +1,109 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region full-args */ + +impl SVDAPI for (&TensorAny, bool) +where + R: DataAPI>, + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = + SVDResult, Tensor, Tensor>; + fn svd_f(self) -> Result { + let (a, full_matrices) = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = a.view().into_dim::(); + let svd_args = SVDArgs::default().a(a).full_matrices(full_matrices).build()?; + let (u, s, vt) = ref_impl_svd_simple_f(svd_args)?; + // convert dimensions + let u = u.unwrap().into_dim::().into_dim::(); + let vt = vt.unwrap().into_dim::().into_dim::(); + let s = s.into_dim::().into_dim::(); + Ok(SVDResult { u, s, vt }) + } +} + +#[duplicate_item( + Tr; [Tensor]; [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl SVDAPI for (Tr, bool) +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = + SVDResult, Tensor, Tensor>; + fn svd_f(self) -> Result { + let (a, full_matrices) = self; + SVDAPI::::svd_f((&a, full_matrices)) + } +} + +/* #endregion */ + +/* #region sub-args */ + +#[duplicate_item( + ImplType Tr; + ['a, T, D, R: DataAPI>] [&'a TensorAny]; + ['a, T, D, ] [TensorView<'a, T, DeviceBLAS, D> ]; + [ T, D ] [Tensor ]; +)] +impl SVDAPI for Tr +where + T: BlasFloat, + D: DimAPI, + (Tr, bool): SVDAPI, +{ + type Out = <(Tr, bool) as SVDAPI>::Out; + fn svd_f(self) -> Result { + let a = self; + SVDAPI::::svd_f((a, true)) + } +} + +/* #endregion */ + +/* #region SVDArgs implementation */ + +impl<'a, T> SVDAPI for SVDArgs<'a, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = SVDResult, Tensor, Tensor>; + fn svd_f(self) -> Result { + SVDAPI::::svd_f(self.build()?) + } +} + +impl<'a, T> SVDAPI for SVDArgs_<'a, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = SVDResult, Tensor, Tensor>; + fn svd_f(self) -> Result { + let args = self; + rstsr_assert!( + args.full_matrices.is_some(), + InvalidValue, + "`svd` must compute UV. Refer to `svdvals` if UV is not required." + )?; + let (u, s, vt) = ref_impl_svd_simple_f(args)?; + let u = u.unwrap().into_dim::().into_dim::(); + let vt = vt.unwrap().into_dim::().into_dim::(); + let s = s.into_dim::().into_dim::(); + Ok(SVDResult { u, s, vt }) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/linalg_auto_impl/svdvals.rs b/crates-device/rstsr-accelerate/src/linalg_auto_impl/svdvals.rs new file mode 100644 index 00000000..3843b176 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/linalg_auto_impl/svdvals.rs @@ -0,0 +1,80 @@ +use crate::DeviceBLAS; +use rstsr_blas_traits::prelude::*; +use rstsr_core::prelude_dev::*; +use rstsr_linalg_traits::prelude_dev::*; + +/* #region full-args */ + +impl SVDvalsAPI for &TensorAny +where + R: DataAPI>, + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn svdvals_f(self) -> Result { + let a = self; + rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?; + let a = a.view().into_dim::(); + let svd_args = SVDArgs::default().a(a).full_matrices(None).build()?; + let (_, s, _) = ref_impl_svd_simple_f(svd_args)?; + // convert dimensions + let s = s.into_dim::().into_dim::(); + Ok(s) + } +} + +#[duplicate_item( + Tr; [Tensor]; [TensorView<'_, T, DeviceBLAS, D>]; +)] +impl SVDvalsAPI for Tr +where + T: BlasFloat, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn svdvals_f(self) -> Result { + let a = self; + SVDvalsAPI::::svdvals_f(&a) + } +} + +/* #endregion */ + +/* #region SVDArgs implementation */ + +impl<'a, T> SVDvalsAPI for SVDArgs<'a, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn svdvals_f(self) -> Result { + SVDvalsAPI::::svdvals_f(self.build()?) + } +} + +impl<'a, T> SVDvalsAPI for SVDArgs_<'a, DeviceBLAS, T> +where + T: BlasFloat, + DeviceBLAS: LapackDriverAPI, +{ + type Out = Tensor; + fn svdvals_f(self) -> Result { + let args = self; + rstsr_assert!( + args.full_matrices.is_none(), + InvalidValue, + "`svdvals` must not compute UV. Refer to `svd` if UV is required." + )?; + let (_, s, _) = ref_impl_svd_simple_f(args)?; + let s = s.into_dim::().into_dim::(); + Ok(s) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/matmul.rs b/crates-device/rstsr-accelerate/src/matmul.rs new file mode 100644 index 00000000..03d4042a --- /dev/null +++ b/crates-device/rstsr-accelerate/src/matmul.rs @@ -0,0 +1,479 @@ +use crate::matmul_impl::*; +use crate::prelude_dev::*; +use crate::threading::with_num_threads; +use core::any::TypeId; +use core::ops::{Add, Mul}; +use core::slice::{from_raw_parts, from_raw_parts_mut}; +use num::{Complex, Zero}; +use rayon::prelude::*; + +// code from ndarray +fn same_type() -> bool { + TypeId::of::() == TypeId::of::() +} + +#[allow(clippy::too_many_arguments)] +pub fn gemm_blas_ix2_no_conj_dispatch( + c: &mut [TC], + lc: &Layout, + a: &[TA], + la: &Layout, + b: &[TB], + lb: &Layout, + alpha: TC, + beta: TC, + pool: Option<&ThreadPool>, +) -> Result<()> +where + TA: Clone + Send + Sync + 'static, + TB: Clone + Send + Sync + 'static, + TC: Clone + Send + Sync + 'static, + TA: Mul, + TC: Mul + Add + Zero + PartialEq, +{ + // check if syrk could be applicable + let able_syrk = beta == TC::zero() + && same_type::() + && same_type::() + && unsafe { + let a_ptr = a.as_ptr().add(la.offset()) as *const TC; + let b_ptr = b.as_ptr().add(lb.offset()) as *const TC; + let equal_ptr = core::ptr::eq(a_ptr, b_ptr); + let equal_shape = la.shape() == lb.reverse_axes().shape(); + let equal_stride = la.stride() == lb.reverse_axes().stride(); + equal_ptr && equal_shape && equal_stride + }; + + // type check and dispatch + macro_rules! impl_gemm_dispatch { + ($ty: ty, $fn_gemm_name: ident, $fn_syrk_name: ident) => { + if (same_type::() && same_type::() && same_type::()) { + let a_slice = unsafe { from_raw_parts(a.as_ptr() as *const $ty, a.len()) }; + let b_slice = unsafe { from_raw_parts(b.as_ptr() as *const $ty, b.len()) }; + let c_slice = unsafe { from_raw_parts_mut(c.as_mut_ptr() as *mut $ty, c.len()) }; + let alpha = unsafe { *(&alpha as *const TC as *const $ty) }; + let beta = unsafe { *(&beta as *const TC as *const $ty) }; + if able_syrk { + $fn_syrk_name(c_slice, lc, a_slice, la, alpha, pool)?; + } else { + $fn_gemm_name(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool)?; + } + return Ok(()); + } + }; + } + + impl_gemm_dispatch!(f32, gemm_blas_no_conj_f32, syrk_blas_no_conj_f32); + impl_gemm_dispatch!(f64, gemm_blas_no_conj_f64, syrk_blas_no_conj_f64); + impl_gemm_dispatch!(Complex, gemm_blas_no_conj_c32, syrk_blas_no_conj_c32); + impl_gemm_dispatch!(Complex, gemm_blas_no_conj_c64, syrk_blas_no_conj_c64); + + // not able to be accelarated by blas_no_conj + // fallback to naive implementation + let c_slice = c; + let a_slice = a; + let b_slice = b; + return gemm_ix2_naive_cpu_rayon(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool); +} + +#[allow(clippy::too_many_arguments)] +pub fn matmul_row_major_blas( + c: &mut [TC], + lc: &Layout, + a: &[TA], + la: &Layout, + b: &[TB], + lb: &Layout, + alpha: TC, + beta: TC, + pool: Option<&ThreadPool>, +) -> Result<()> +where + TA: Clone + Send + Sync + 'static, + TB: Clone + Send + Sync + 'static, + TC: Clone + Send + Sync + 'static, + DA: DimAPI, + DB: DimAPI, + DC: DimAPI, + TA: Mul, + TC: Mul + Add + Zero + PartialEq, +{ + // NOTE: this only works for row-major layout + // for column-major layout, we need to transpose the input: + // C = A * B => C^T = B^T * A^T + + // quick return for empty matrix + // in this case, we do not check the shape of a, b, c + if lc.size() == 0 { + return Ok(()); + } + + let nthreads = match pool { + Some(pool) => pool.current_num_threads(), + None => 1, + }; + + // handle special cases + match (la.ndim(), lb.ndim(), lc.ndim()) { + (1, 1, 0) => { + // rule 1: vector inner dot + let la = &la.clone().into_dim::().unwrap(); + let lb = &lb.clone().into_dim::().unwrap(); + let lc = &lc.clone().into_dim::().unwrap(); + let c_num = &mut c[lc.offset()]; + return with_num_threads(nthreads, || inner_dot_naive_cpu_rayon(c_num, a, la, b, lb, alpha, beta, pool)); + }, + (2, 2, 2) => { + // rule 2: matrix multiplication + let la = &la.clone().into_dim::().unwrap(); + let lb = &lb.clone().into_dim::().unwrap(); + let lc = &lc.clone().into_dim::().unwrap(); + return with_num_threads(nthreads, || { + gemm_blas_ix2_no_conj_dispatch(c, lc, a, la, b, lb, alpha, beta, pool) + }); + }, + _ => (), + }; + + // handle broadcasted cases + // temporary variables + let la_matmul; + let lb_matmul; + let lc_matmul; + let la_rest; + let lb_rest; + let lc_rest; + + match (la.ndim(), lb.ndim(), lc.ndim()) { + // we have already handled these cases + (1, 1, 0) | (2, 2, 2) => unreachable!(), + (1, 2.., _) => { + // rule 3: | ` K` | `..., K, N` | ` ..., N` | + rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?; + let (la_r, la_m) = la.dim_split_at(-1)?; + let (lb_r, lb_m) = lb.dim_split_at(-2)?; + let (lc_r, lc_m) = lc.dim_split_at(-1)?; + la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1; + lb_rest = lb_r; + lc_rest = lc_r; + la_matmul = la_m.dim_insert(0)?.into_dim::()?; + lb_matmul = lb_m.into_dim::()?; + lc_matmul = lc_m.dim_insert(0)?.into_dim::()?; + }, + (2.., 1, _) => { + // rule 4: | `..., M, K` | ` K` | ` ..., M` | + rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?; + let (la_r, la_m) = la.dim_split_at(-2)?; + let (lb_r, lb_m) = lb.dim_split_at(-1)?; + let (lc_r, lc_m) = lc.dim_split_at(-1)?; + la_rest = la_r; + lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1; + lc_rest = lc_r; + la_matmul = la_m.into_dim::()?; + lb_matmul = lb_m.dim_insert(1)?.into_dim::()?; + lc_matmul = lc_m.dim_insert(1)?.into_dim::()?; + }, + (2, 3.., _) => { + // rule 5: | ` M, K` | `..., K, N` | `..., M, N` | + rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?; + let (la_r, la_m) = la.dim_split_at(-2)?; + let (lb_r, lb_m) = lb.dim_split_at(-2)?; + let (lc_r, lc_m) = lc.dim_split_at(-2)?; + la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1; + lb_rest = lb_r; + lc_rest = lc_r; + la_matmul = la_m.into_dim::()?; + lb_matmul = lb_m.into_dim::()?; + lc_matmul = lc_m.into_dim::()?; + }, + (3.., 2, _) => { + // rule 6: | `..., M, K` | ` K, N` | `..., M, N` | + rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?; + let (la_r, la_m) = la.dim_split_at(-2)?; + let (lb_r, lb_m) = lb.dim_split_at(-2)?; + let (lc_r, lc_m) = lc.dim_split_at(-2)?; + la_rest = la_r; + lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1; + lc_rest = lc_r; + la_matmul = la_m.into_dim::()?; + lb_matmul = lb_m.into_dim::()?; + lc_matmul = lc_m.into_dim::()?; + }, + (3.., 3.., _) => { + // rule 7: | `..., M, K` | `..., K, N` | `..., M, N` | + rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?; + rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?; + let (la_r, la_m) = la.dim_split_at(-2)?; + let (lb_r, lb_m) = lb.dim_split_at(-2)?; + let (lc_r, lc_m) = lc.dim_split_at(-2)?; + la_rest = la_r; + lb_rest = lb_r; + lc_rest = lc_r; + la_matmul = la_m.into_dim::()?; + lb_matmul = lb_m.into_dim::()?; + lc_matmul = lc_m.into_dim::()?; + }, + _ => { + rstsr_raise!(InvalidLayout, "This is not valid layout for matmul broadcasting.")?; + unreachable!() + }, + } + // now, lx_rest should have the same shape, while lx_matmul + // should be matmulable + // only parallel matmul when lx_rest is small (larger than + // 2*nthreads), otherwise parallel matmul anyway + rstsr_assert_eq!(la_rest.shape(), lb_rest.shape(), InvalidLayout)?; + rstsr_assert_eq!(lb_rest.shape(), lc_rest.shape(), InvalidLayout)?; + let n_task = la_rest.size(); + let ita_rest = IterLayoutColMajor::new(&la_rest)?; + let itb_rest = IterLayoutColMajor::new(&lb_rest)?; + let itc_rest = IterLayoutColMajor::new(&lc_rest)?; + if n_task >= 4 * nthreads { + // parallel outer, sequential matmul + with_num_threads(1, || { + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } + }) + } else { + // sequential outer, parallel matmul + with_num_threads(nthreads, || -> Result<()> { + izip!(ita_rest, itb_rest, itc_rest).try_for_each(|(ia_rest, ib_rest, ic_rest)| { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, pool) + }) + }) + } +} + +#[allow(clippy::too_many_arguments)] +impl DeviceMatMulAPI for DeviceBLAS +where + TA: Clone + Send + Sync + 'static, + TB: Clone + Send + Sync + 'static, + TC: Clone + Send + Sync + 'static, + DA: DimAPI, + DB: DimAPI, + DC: DimAPI, + TA: Mul, + TB: Mul, + TC: Mul + Add + Zero + PartialEq, +{ + fn matmul( + &self, + c: &mut Vec, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + alpha: TC, + beta: TC, + ) -> Result<()> { + let default_order = self.default_order(); + let pool = self.get_current_pool(); + match default_order { + RowMajor => matmul_row_major_blas(c, lc, a, la, b, lb, alpha, beta, pool), + ColMajor => { + let la = la.reverse_axes(); + let lb = lb.reverse_axes(); + let lc = lc.reverse_axes(); + matmul_row_major_blas(c, &lc, b, &lb, a, &la, alpha, beta, pool) + }, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_matmul() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]); + let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); + println!("{:}", &a % &b); + + let a = linspace((0.0, 14.0, 15, &device)); + let b = linspace((0.0, 14.0, 15, &device)); + println!("{:}", &a % &b); + + let a = linspace((0.0, 2.0, 3, &device)); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + println!("{:}", &a % &b); + + let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let b = linspace((0.0, 4.0, 5, &device)); + println!("{:}", &a % &b); + + let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + println!("{:}", &a % &b); + + let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); + println!("{:}", &a % &b); + } + + #[test] + #[ignore] + fn parallel_test_full() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]); + let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]); + for _ in 0..10 { + let start = std::time::Instant::now(); + let _ = &a % &b; + println!("time: {:?}", start.elapsed()); + } + } + + #[test] + #[ignore] + fn parallel_test_full_512() { + let device = DeviceBLAS::new(1); + let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + for _ in 0..1000 { + let start = std::time::Instant::now(); + let c = &a % &b; + println!("{:?}", c.device()); + println!("time: {:?}", start.elapsed()); + } + } + + #[test] + #[ignore] + fn parallel_test_par_rule7() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]); + let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]); + for i in 0..10 { + let start = std::time::Instant::now(); + let c = &a % &b; + println!("{:?}", c.layout()); + println!("time: {:?}", start.elapsed()); + if i == 0 { + println!("{c:?}"); + } + } + } + + #[test] + #[ignore] + fn parallel_test_par_rule6() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]); + let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + for i in 0..10 { + let start = std::time::Instant::now(); + let c = &a % &b; + println!("{:?}", c.layout()); + println!("time: {:?}", start.elapsed()); + if i == 0 { + println!("{c:?}"); + } + } + } + + #[test] + #[ignore] + fn parallel_test_par_rule6_fprefer() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([512, 512, 256]).into_reverse_axes(); + let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + for i in 0..10 { + let start = std::time::Instant::now(); + let c = &a % &b; + println!("{:?}", c.layout()); + println!("time: {:?}", start.elapsed()); + if i == 0 { + println!("{c:?}"); + } + } + } + + #[test] + fn syrk_correctness() { + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]); + let c = &a % &a.t(); + let d = &a % &b.t(); + assert!(allclose_f64(&c, &d)); + + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]); + let b = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]); + let c = &a % &a.swapaxes(-1, -2); + let d = &a % &b.swapaxes(-1, -2); + assert!(allclose_f64(&c, &d)); + } + + #[test] + #[ignore] + fn syrk_efficiency() { + use std::hint::black_box; + let device = DeviceBLAS::default(); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]); + let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]); + for _ in 0..10 { + let start = std::time::Instant::now(); + black_box(&a % &a.swapaxes(-1, -2)); + println!("syrk time: {:?}", start.elapsed()); + let start = std::time::Instant::now(); + black_box(&a % &b.swapaxes(-1, -2)); + println!("gemm time: {:?}", start.elapsed()); + } + + println!("---------------------"); + let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]); + let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]); + for _ in 0..10 { + let start = std::time::Instant::now(); + black_box(&a % &a.swapaxes(-1, -2)); + println!("syrk time: {:?}", start.elapsed()); + let start = std::time::Instant::now(); + black_box(&a % &b.swapaxes(-1, -2)); + println!("gemm time: {:?}", start.elapsed()); + } + } +} diff --git a/crates-device/rstsr-accelerate/src/matmul_impl.rs b/crates-device/rstsr-accelerate/src/matmul_impl.rs new file mode 100644 index 00000000..aaeae1bd --- /dev/null +++ b/crates-device/rstsr-accelerate/src/matmul_impl.rs @@ -0,0 +1,583 @@ +#![allow(non_camel_case_types)] + +use crate::prelude_dev::*; +use lapack_ffi::cblas; +use num::complex::Complex; +use num::traits::ConstZero; +use rayon::prelude::*; +use rstsr_core::prelude_dev::uninitialized_vec; +use std::ffi::c_void; + +type c32 = Complex; +type c64 = Complex; + +use cblas::CBLAS_LAYOUT::CblasColMajor as ColMajor; +use cblas::CBLAS_TRANSPOSE::CblasNoTrans as NoTrans; +use cblas::CBLAS_TRANSPOSE::CblasTrans as Trans; +use cblas::CBLAS_UPLO::CblasUpper as Upper; + +/* #region gemm */ + +#[duplicate_item( + ty fn_name cblas_wrap ; + [f32] [gemm_blas_no_conj_f32] [cblas_sgemm_wrap]; + [f64] [gemm_blas_no_conj_f64] [cblas_dgemm_wrap]; + [c32] [gemm_blas_no_conj_c32] [cblas_cgemm_wrap]; + [c64] [gemm_blas_no_conj_c64] [cblas_zgemm_wrap]; +)] +#[allow(clippy::too_many_arguments)] +pub fn fn_name( + c: &mut [ty], + lc: &Layout, + a: &[ty], + la: &Layout, + b: &[ty], + lb: &Layout, + alpha: ty, + beta: ty, + pool: Option<&ThreadPool>, +) -> Result<()> { + // nthreads is only used for `assign_cpu_rayon`. + // the threading should be handled outside this function. + + // check layout of output + if !lc.f_prefer() { + // change to f-contig anyway + // we do not handle conj, so this can be done easily + if lc.c_prefer() { + // c-prefer, transpose and run + return fn_name(c, &lc.reverse_axes(), b, &lb.reverse_axes(), a, &la.reverse_axes(), alpha, beta, pool); + } else { + // not c-prefer, allocate new buffer and copy back + let lc_new = lc.shape().new_f_contig(None); + let mut c_new = unsafe { uninitialized_vec(lc_new.size())? }; + if beta == ::ZERO { + fill_cpu_rayon(&mut c_new, &lc_new, ::ZERO, pool)?; + } else { + assign_cpu_rayon(&mut c_new, &lc_new, c, lc, pool)?; + } + fn_name(&mut c_new, &lc_new, a, la, b, lb, alpha, ::ZERO, pool)?; + assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?; + return Ok(()); + } + } + + // we assume that the layout is correct + let sc = lc.shape(); + let sa = la.shape(); + let sb = lb.shape(); + rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?; + rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?; + rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?; + + let m = sc[0]; + let n = sc[1]; + let k = sa[1]; + + // handle the special case that k is zero-dimensional + if k == 0 { + // if k is zero, the result is a zero matrix + return fill_cpu_rayon(c, lc, ::ZERO, pool); + } + + // handle the special case that n/m is zero-dimensional + if n == 0 || m == 0 { + // if n or m is zero, the result matrix size is zero, and nothing to do + return Ok(()); + } + + // determine trans/layout and clone data if necessary + let mut a_data: Option> = None; + let mut b_data: Option> = None; + let (a_trans, la) = if la.f_prefer() { + (NoTrans, la.clone()) + } else if la.c_prefer() { + (Trans, la.reverse_axes()) + } else { + let len = la.size(); + a_data = unsafe { Some(uninitialized_vec(len)?) }; + let la_data = la.shape().new_f_contig(None); + assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?; + (NoTrans, la_data) + }; + let (b_trans, lb) = if lb.f_prefer() { + (NoTrans, lb.clone()) + } else if lb.c_prefer() { + (Trans, lb.reverse_axes()) + } else { + let len = lb.size(); + b_data = unsafe { Some(uninitialized_vec(len)?) }; + let lb_data = lb.shape().new_f_contig(None); + assign_cpu_rayon(b_data.as_mut().unwrap(), &lb_data, b, lb, pool)?; + (NoTrans, lb_data) + }; + + // final configuration + // shape may be broadcasted for one-dimension case, so make this check + let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize }; + let ldb = if lb.shape()[1] != 1 { lb.stride()[1] } else { lb.shape()[0] as isize }; + let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize }; + + let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) }; + let ptr_a = + if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } }; + let ptr_b = + if let Some(b_data) = b_data.as_ref() { b_data.as_ptr() } else { unsafe { b.as_ptr().add(lb.offset()) } }; + + // actual computation + unsafe { + cblas_wrap(ColMajor, a_trans, b_trans, m, n, k, alpha, ptr_a, lda, ptr_b, ldb, beta, ptr_c, ldc); + } + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_sgemm_wrap( + order: cblas::CBLAS_LAYOUT, + a_trans: cblas::CBLAS_TRANSPOSE, + b_trans: cblas::CBLAS_TRANSPOSE, + m: usize, + n: usize, + k: usize, + alpha: f32, + ptr_a: *const f32, + lda: isize, + ptr_b: *const f32, + ldb: isize, + beta: f32, + ptr_c: *mut f32, + ldc: isize, +) { + unsafe { + cblas::cblas_sgemm( + order as cblas::CBLAS_LAYOUT, + a_trans as cblas::CBLAS_TRANSPOSE, + b_trans as cblas::CBLAS_TRANSPOSE, + m as cblas::blas_int, + n as cblas::blas_int, + k as cblas::blas_int, + alpha, + ptr_a, + lda as cblas::blas_int, + ptr_b, + ldb as cblas::blas_int, + beta, + ptr_c, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_dgemm_wrap( + order: cblas::CBLAS_LAYOUT, + a_trans: cblas::CBLAS_TRANSPOSE, + b_trans: cblas::CBLAS_TRANSPOSE, + m: usize, + n: usize, + k: usize, + alpha: f64, + ptr_a: *const f64, + lda: isize, + ptr_b: *const f64, + ldb: isize, + beta: f64, + ptr_c: *mut f64, + ldc: isize, +) { + unsafe { + cblas::cblas_dgemm( + order as cblas::CBLAS_LAYOUT, + a_trans as cblas::CBLAS_TRANSPOSE, + b_trans as cblas::CBLAS_TRANSPOSE, + m as cblas::blas_int, + n as cblas::blas_int, + k as cblas::blas_int, + alpha, + ptr_a, + lda as cblas::blas_int, + ptr_b, + ldb as cblas::blas_int, + beta, + ptr_c, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_cgemm_wrap( + order: cblas::CBLAS_LAYOUT, + a_trans: cblas::CBLAS_TRANSPOSE, + b_trans: cblas::CBLAS_TRANSPOSE, + m: usize, + n: usize, + k: usize, + alpha: c32, + ptr_a: *const c32, + lda: isize, + ptr_b: *const c32, + ldb: isize, + beta: c32, + ptr_c: *mut c32, + ldc: isize, +) { + unsafe { + cblas::cblas_cgemm( + order as cblas::CBLAS_LAYOUT, + a_trans as cblas::CBLAS_TRANSPOSE, + b_trans as cblas::CBLAS_TRANSPOSE, + m as cblas::blas_int, + n as cblas::blas_int, + k as cblas::blas_int, + &alpha as *const _ as *const c_void, + ptr_a as *const c_void, + lda as cblas::blas_int, + ptr_b as *const c_void, + ldb as cblas::blas_int, + &beta as *const _ as *const c_void, + ptr_c as *mut c_void, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_zgemm_wrap( + order: cblas::CBLAS_LAYOUT, + a_trans: cblas::CBLAS_TRANSPOSE, + b_trans: cblas::CBLAS_TRANSPOSE, + m: usize, + n: usize, + k: usize, + alpha: c64, + ptr_a: *const c64, + lda: isize, + ptr_b: *const c64, + ldb: isize, + beta: c64, + ptr_c: *mut c64, + ldc: isize, +) { + unsafe { + cblas::cblas_zgemm( + order as cblas::CBLAS_LAYOUT, + a_trans as cblas::CBLAS_TRANSPOSE, + b_trans as cblas::CBLAS_TRANSPOSE, + m as cblas::blas_int, + n as cblas::blas_int, + k as cblas::blas_int, + &alpha as *const _ as *const c_void, + ptr_a as *const c_void, + lda as cblas::blas_int, + ptr_b as *const c_void, + ldb as cblas::blas_int, + &beta as *const _ as *const c_void, + ptr_c as *mut c_void, + ldc as cblas::blas_int, + ); + } +} + +/* #endregion */ + +/* #region syrk */ + +#[duplicate_item( + ty fn_name cblas_wrap ; + [f32] [syrk_blas_no_conj_f32] [cblas_ssyrk_wrap]; + [f64] [syrk_blas_no_conj_f64] [cblas_dsyrk_wrap]; + [c32] [syrk_blas_no_conj_c32] [cblas_csyrk_wrap]; + [c64] [syrk_blas_no_conj_c64] [cblas_zsyrk_wrap]; +)] +pub fn fn_name( + c: &mut [ty], + lc: &Layout, + a: &[ty], + la: &Layout, + alpha: ty, + pool: Option<&ThreadPool>, +) -> Result<()> { + // beta is assumed to be zero, and not passed as argument. + + // check layout of output + if !lc.f_prefer() { + // change to f-contig anyway + // we do not handle conj, so this can be done easily + if lc.c_prefer() { + // c-prefer, transpose and run + return fn_name(c, &lc.reverse_axes(), a, la, alpha, pool); + } else { + // not c-prefer, allocate new buffer and copy back + let lc_new = lc.shape().new_f_contig(None); + let mut c_new = unsafe { uninitialized_vec(lc_new.size())? }; + fill_cpu_rayon(&mut c_new, &lc_new, ::ZERO, pool)?; + fn_name(&mut c_new, &lc_new, a, la, alpha, pool)?; + assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?; + return Ok(()); + } + } + + // we assume that the layout is correct + let sc = lc.shape(); + let sa = la.shape(); + rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?; + rstsr_assert_eq!(sc[1], sc[0], InvalidLayout)?; + + let n = sc[0]; + let k = sa[1]; + + // handle the special case that k is zero-dimensional + if k == 0 { + // if k is zero, the result is a zero matrix + return fill_cpu_rayon(c, lc, ::ZERO, pool); + } + + // handle the special case that n is zero-dimensional + if n == 0 { + // if n is zero, the result matrix size is zero, and nothing to do + return Ok(()); + } + + // determine trans/layout and clone data if necessary + let mut a_data: Option> = None; + let (a_trans, la) = if la.f_prefer() { + (NoTrans, la.clone()) + } else if la.c_prefer() { + (Trans, la.reverse_axes()) + } else { + let len = la.size(); + a_data = unsafe { Some(uninitialized_vec(len)?) }; + let la_data = la.shape().new_f_contig(None); + assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?; + (NoTrans, la_data) + }; + + // final configuration + // shape may be broadcasted for one-dimension case, so make this check + let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize }; + let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize }; + + let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) }; + let ptr_a = + if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } }; + + // actual computation + unsafe { + cblas_wrap(ColMajor, Upper, a_trans, n, k, alpha, ptr_a, lda, ::ZERO, ptr_c, ldc); + } + + // write back to lower triangle + let n = sc[0]; + let ldc = lc.stride()[1]; + let offset = lc.offset() as isize; + let task = || { + (0..(n as isize)).into_par_iter().for_each(|j| { + ((j + 1)..(n as isize)).for_each(|i| unsafe { + let idx_ij = (offset + j * ldc + i) as usize; + let idx_ji = (offset + i * ldc + j) as usize; + let c_ptr_ij = c.as_ptr().add(idx_ij) as *mut ty; + *c_ptr_ij = c[idx_ji]; + }); + }); + }; + pool.map_or_else(task, |pool| pool.install(task)); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_ssyrk_wrap( + order: cblas::CBLAS_LAYOUT, + uplo: cblas::CBLAS_UPLO, + a_trans: cblas::CBLAS_TRANSPOSE, + n: usize, + k: usize, + alpha: f32, + ptr_a: *const f32, + lda: isize, + beta: f32, + ptr_c: *mut f32, + ldc: isize, +) { + unsafe { + cblas::cblas_ssyrk( + order as cblas::CBLAS_LAYOUT, + uplo as cblas::CBLAS_UPLO, + a_trans as cblas::CBLAS_TRANSPOSE, + n as cblas::blas_int, + k as cblas::blas_int, + alpha, + ptr_a, + lda as cblas::blas_int, + beta, + ptr_c, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_dsyrk_wrap( + order: cblas::CBLAS_LAYOUT, + uplo: cblas::CBLAS_UPLO, + a_trans: cblas::CBLAS_TRANSPOSE, + n: usize, + k: usize, + alpha: f64, + ptr_a: *const f64, + lda: isize, + beta: f64, + ptr_c: *mut f64, + ldc: isize, +) { + unsafe { + cblas::cblas_dsyrk( + order as cblas::CBLAS_LAYOUT, + uplo as cblas::CBLAS_UPLO, + a_trans as cblas::CBLAS_TRANSPOSE, + n as cblas::blas_int, + k as cblas::blas_int, + alpha, + ptr_a, + lda as cblas::blas_int, + beta, + ptr_c, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_csyrk_wrap( + order: cblas::CBLAS_LAYOUT, + uplo: cblas::CBLAS_UPLO, + a_trans: cblas::CBLAS_TRANSPOSE, + n: usize, + k: usize, + alpha: c32, + ptr_a: *const c32, + lda: isize, + beta: c32, + ptr_c: *mut c32, + ldc: isize, +) { + unsafe { + cblas::cblas_csyrk( + order as cblas::CBLAS_LAYOUT, + uplo as cblas::CBLAS_UPLO, + a_trans as cblas::CBLAS_TRANSPOSE, + n as cblas::blas_int, + k as cblas::blas_int, + &alpha as *const _ as *const c_void, + ptr_a as *const c_void, + lda as cblas::blas_int, + &beta as *const _ as *const c_void, + ptr_c as *mut c_void, + ldc as cblas::blas_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn cblas_zsyrk_wrap( + order: cblas::CBLAS_LAYOUT, + uplo: cblas::CBLAS_UPLO, + a_trans: cblas::CBLAS_TRANSPOSE, + n: usize, + k: usize, + alpha: c64, + ptr_a: *const c64, + lda: isize, + beta: c64, + ptr_c: *mut c64, + ldc: isize, +) { + unsafe { + cblas::cblas_zsyrk( + order as cblas::CBLAS_LAYOUT, + uplo as cblas::CBLAS_UPLO, + a_trans as cblas::CBLAS_TRANSPOSE, + n as cblas::blas_int, + k as cblas::blas_int, + &alpha as *const _ as *const c_void, + ptr_a as *const c_void, + lda as cblas::blas_int, + &beta as *const _ as *const c_void, + ptr_c as *mut c_void, + ldc as cblas::blas_int, + ); + } +} + +/* #endregion */ + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_f32() { + let a = vec![1., 2., 3., 4., 5., 6.]; + let b = vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]; + let mut c = vec![0.0; 16]; + + let la = [2, 3].c(); + let lb = [3, 4].c(); + let lc = [2, 4].c(); + let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap(); + let pool = Some(&pool); + gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap(); + let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc); + println!("{c_tsr:}"); + println!("{:}", c_tsr.reshape([8])); + let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]); + assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref)); + + let la = [2, 3].c(); + let lb = [3, 4].c(); + let lc = [2, 4].f(); + gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap(); + let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc); + println!("{c_tsr:}"); + println!("{:}", c_tsr.reshape([8])); + let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]); + assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref)); + + let la = [2, 3].f(); + let lb = [3, 4].c(); + let lc = [2, 4].c(); + gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap(); + let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc); + println!("{c_tsr:}"); + println!("{:}", c_tsr.reshape([8])); + let c_ref = asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]); + assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref)); + + let la = [2, 3].f(); + let lb = [3, 4].c(); + let lc = [2, 4].f(); + gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 2.0, 0.0, pool).unwrap(); + let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc); + println!("{c_tsr:}"); + println!("{:}", c_tsr.reshape([8])); + let c_ref = 2 * asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]); + assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref)); + } + + #[test] + fn test_c32() { + let a = linspace((c32::new(1., 1.), c32::new(6., 6.), 6)).into_vec(); + let b = linspace((c32::new(1., 1.), c32::new(12., 12.), 12)).into_vec(); + let mut c = vec![c32::ZERO; 16]; + + let la = [2, 3].c(); + let lb = [3, 4].c(); + let lc = [2, 4].c(); + let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap(); + let pool = Some(&pool); + gemm_blas_no_conj_c32(&mut c, &lc, &a, &la, &b, &lb, c32::ONE, c32::ZERO, pool).unwrap(); + let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc); + println!("{c_tsr:}"); + println!("{:}", c_tsr.reshape([8])); + } +} diff --git a/crates-device/rstsr-accelerate/src/prelude_dev.rs b/crates-device/rstsr-accelerate/src/prelude_dev.rs new file mode 100644 index 00000000..8d23e755 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/prelude_dev.rs @@ -0,0 +1,4 @@ +pub(crate) use crate::DeviceBLAS; +pub(crate) use crate::DeviceRayonAutoImpl; +pub use rstsr_core::prelude_dev::*; +pub(crate) use rstsr_lapack_ffi as lapack_ffi; diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/adv_indexing.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/adv_indexing.rs new file mode 100644 index 00000000..e30b6e1b --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/adv_indexing.rs @@ -0,0 +1,21 @@ +use crate::prelude_dev::*; + +impl DeviceIndexSelectAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync, + D: DimAPI + DimSmallerOneAPI, + D::SmallerOne: DimAPI, +{ + fn index_select( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + axis: usize, + indices: &[usize], + ) -> Result<()> { + let pool = self.get_current_pool(); + index_select_cpu_rayon(c, lc, a, la, axis, indices, pool) + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/assignment.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/assignment.rs new file mode 100644 index 00000000..c8303cfc --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/assignment.rs @@ -0,0 +1,49 @@ +use crate::prelude_dev::*; + +impl OpAssignArbitaryAPI for DeviceRayonAutoImpl +where + TC: Clone + Send + Sync, + TA: Clone + Send + Sync + DTypeCastAPI, + DC: DimAPI, + DA: DimAPI, +{ + fn assign_arbitary(&self, c: &mut Vec, lc: &Layout, a: &Vec, la: &Layout) -> Result<()> { + let pool = self.get_current_pool(); + let default_order = self.default_order(); + assign_arbitary_promote_cpu_rayon(c, lc, a, la, default_order, pool) + } + + fn assign_arbitary_uninit( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + ) -> Result<()> { + let pool = self.get_current_pool(); + let default_order = self.default_order(); + return assign_arbitary_uninit_promote_cpu_rayon(c, lc, a, la, default_order, pool); + } +} + +impl OpAssignAPI for DeviceRayonAutoImpl +where + TC: Clone + Send + Sync, + TA: Clone + Send + Sync + DTypeCastAPI, + D: DimAPI, +{ + fn assign(&self, c: &mut Vec, lc: &Layout, a: &Vec, la: &Layout) -> Result<()> { + let pool = self.get_current_pool(); + assign_promote_cpu_rayon(c, lc, a, la, pool) + } + + fn assign_uninit(&self, c: &mut Vec>, lc: &Layout, a: &Vec, la: &Layout) -> Result<()> { + let pool = self.get_current_pool(); + return assign_uninit_promote_cpu_rayon(c, lc, a, la, pool); + } + + fn fill(&self, c: &mut Vec, lc: &Layout, fill: TA) -> Result<()> { + let pool = self.get_current_pool(); + fill_promote_cpu_rayon(c, lc, fill, pool) + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/mod.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/mod.rs new file mode 100644 index 00000000..6435e5d8 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/mod.rs @@ -0,0 +1,9 @@ +pub mod adv_indexing; +pub mod assignment; +pub mod op_binary_arithmetic; +pub mod op_binary_common; +pub mod op_ternary_arithmetic; +pub mod op_ternary_common; +pub mod op_tri; +pub mod op_with_func; +pub mod reduction; diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_arithmetic.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_arithmetic.rs new file mode 100644 index 00000000..b4bed011 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_arithmetic.rs @@ -0,0 +1,119 @@ +use crate::prelude_dev::*; +use core::mem::transmute; + +#[duplicate_item( + DeviceOpAPI Op func ; + [DeviceAddAssignAPI ] [AddAssign ] [|a, b| unsafe { *a.assume_init_mut() += b.clone() }]; + [DeviceSubAssignAPI ] [SubAssign ] [|a, b| unsafe { *a.assume_init_mut() -= b.clone() }]; + [DeviceMulAssignAPI ] [MulAssign ] [|a, b| unsafe { *a.assume_init_mut() *= b.clone() }]; + [DeviceDivAssignAPI ] [DivAssign ] [|a, b| unsafe { *a.assume_init_mut() /= b.clone() }]; + [DeviceRemAssignAPI ] [RemAssign ] [|a, b| unsafe { *a.assume_init_mut() %= b.clone() }]; + [DeviceBitOrAssignAPI ] [BitOrAssign ] [|a, b| unsafe { *a.assume_init_mut() |= b.clone() }]; + [DeviceBitAndAssignAPI] [BitAndAssign] [|a, b| unsafe { *a.assume_init_mut() &= b.clone() }]; + [DeviceBitXorAssignAPI] [BitXorAssign] [|a, b| unsafe { *a.assume_init_mut() ^= b.clone() }]; + [DeviceShlAssignAPI ] [ShlAssign ] [|a, b| unsafe { *a.assume_init_mut() <<= b.clone() }]; + [DeviceShrAssignAPI ] [ShrAssign ] [|a, b| unsafe { *a.assume_init_mut() >>= b.clone() }]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + Op, + TB: Clone + Send + Sync, + D: DimAPI, +{ + fn op_muta_refb(&self, a: &mut Vec, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> { + let a = unsafe { transmute::<&mut Vec, &mut Vec>>(a) }; + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta_numb(&self, a: &mut Vec, la: &Layout, b: TB) -> Result<()> { + let a = unsafe { transmute::<&mut Vec, &mut Vec>>(a) }; + self.op_muta_numb_func(a, la, b, &mut func) + } +} + +#[duplicate_item( + DeviceOpAPI Op func ; + [DeviceLConsumeAddAPI ] [Add ] [|a, b| unsafe { a.write(a.assume_init_read() + b.clone()); }]; + [DeviceLConsumeSubAPI ] [Sub ] [|a, b| unsafe { a.write(a.assume_init_read() - b.clone()); }]; + [DeviceLConsumeMulAPI ] [Mul ] [|a, b| unsafe { a.write(a.assume_init_read() * b.clone()); }]; + [DeviceLConsumeDivAPI ] [Div ] [|a, b| unsafe { a.write(a.assume_init_read() / b.clone()); }]; + [DeviceLConsumeRemAPI ] [Rem ] [|a, b| unsafe { a.write(a.assume_init_read() % b.clone()); }]; + [DeviceLConsumeBitOrAPI ] [BitOr ] [|a, b| unsafe { a.write(a.assume_init_read() | b.clone()); }]; + [DeviceLConsumeBitAndAPI] [BitAnd] [|a, b| unsafe { a.write(a.assume_init_read() & b.clone()); }]; + [DeviceLConsumeBitXorAPI] [BitXor] [|a, b| unsafe { a.write(a.assume_init_read() ^ b.clone()); }]; + [DeviceLConsumeShlAPI ] [Shl ] [|a, b| unsafe { a.write(a.assume_init_read() << b.clone()); }]; + [DeviceLConsumeShrAPI ] [Shr ] [|a, b| unsafe { a.write(a.assume_init_read() >> b.clone()); }]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + Op, + TB: Clone + Send + Sync, + D: DimAPI, +{ + fn op_muta_refb(&self, a: &mut Vec, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> { + let a = unsafe { transmute::<&mut Vec, &mut Vec>>(a) }; + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta_numb(&self, a: &mut Vec, la: &Layout, b: TB) -> Result<()> { + let a = unsafe { transmute::<&mut Vec, &mut Vec>>(a) }; + self.op_muta_numb_func(a, la, b, &mut func) + } +} + +#[duplicate_item( + DeviceOpAPI Op func ; + [DeviceRConsumeAddAPI ] [Add ] [|a, b| unsafe { a.write(b.clone() + a.assume_init_read()); }]; + [DeviceRConsumeSubAPI ] [Sub ] [|a, b| unsafe { a.write(b.clone() - a.assume_init_read()); }]; + [DeviceRConsumeMulAPI ] [Mul ] [|a, b| unsafe { a.write(b.clone() * a.assume_init_read()); }]; + [DeviceRConsumeDivAPI ] [Div ] [|a, b| unsafe { a.write(b.clone() / a.assume_init_read()); }]; + [DeviceRConsumeRemAPI ] [Rem ] [|a, b| unsafe { a.write(b.clone() % a.assume_init_read()); }]; + [DeviceRConsumeBitOrAPI ] [BitOr ] [|a, b| unsafe { a.write(b.clone() | a.assume_init_read()); }]; + [DeviceRConsumeBitAndAPI] [BitAnd] [|a, b| unsafe { a.write(b.clone() & a.assume_init_read()); }]; + [DeviceRConsumeBitXorAPI] [BitXor] [|a, b| unsafe { a.write(b.clone() ^ a.assume_init_read()); }]; + [DeviceRConsumeShlAPI ] [Shl ] [|a, b| unsafe { a.write(b.clone() << a.assume_init_read()); }]; + [DeviceRConsumeShrAPI ] [Shr ] [|a, b| unsafe { a.write(b.clone() >> a.assume_init_read()); }]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + Op, + TB: Clone + Send + Sync, + D: DimAPI, +{ + fn op_muta_refb(&self, b: &mut Vec, lb: &Layout, a: &Vec, la: &Layout) -> Result<()> { + let b = unsafe { transmute::<&mut Vec, &mut Vec>>(b) }; + self.op_muta_refb_func(b, lb, a, la, &mut func) + } + + fn op_muta_numb(&self, b: &mut Vec, lb: &Layout, a: TA) -> Result<()> { + let b = unsafe { transmute::<&mut Vec, &mut Vec>>(b) }; + self.op_muta_numb_func(b, lb, a, &mut func) + } +} + +#[duplicate_item( + DeviceOpAPI Op func func_inplace ; + [DeviceNegAPI] [Neg] [|a, b| { a.write(-b.clone()); }] [|a| unsafe { a.write(-a.assume_init_read()); }]; + [DeviceNotAPI] [Not] [|a, b| { a.write(!b.clone()); }] [|a| unsafe { a.write(!a.assume_init_read()); }]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + D: DimAPI, +{ + fn op_muta_refb(&self, a: &mut Vec>, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> + where + TB: Op, + { + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta(&self, a: &mut Vec, la: &Layout) -> Result<()> + where + TA: Op, + { + let a = unsafe { transmute::<&mut Vec, &mut Vec>>(a) }; + self.op_muta_func(a, la, &mut func_inplace) + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_common.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_common.rs new file mode 100644 index 00000000..d293734f --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_binary_common.rs @@ -0,0 +1,240 @@ +use crate::prelude_dev::*; +use core::ops::Div; +use num::complex::ComplexFloat; +use num::{Float, Signed}; +use rstsr_dtype_traits::{DTypeIntoFloatAPI, ExtNum}; + +// TODO: log1p + +/* #region same type */ + +#[duplicate_item( + DeviceOpAPI NumTrait func_inner; + [DeviceAcosAPI ] [ComplexFloat] [b.acos() ]; + [DeviceAcoshAPI ] [ComplexFloat] [b.acosh() ]; + [DeviceAsinAPI ] [ComplexFloat] [b.asin() ]; + [DeviceAsinhAPI ] [ComplexFloat] [b.asinh() ]; + [DeviceAtanAPI ] [ComplexFloat] [b.atan() ]; + [DeviceAtanhAPI ] [ComplexFloat] [b.atanh() ]; + [DeviceCeilAPI ] [Float ] [b.ceil() ]; + [DeviceConjAPI ] [ComplexFloat] [b.conj() ]; + [DeviceCosAPI ] [ComplexFloat] [b.cos() ]; + [DeviceCoshAPI ] [ComplexFloat] [b.cosh() ]; + [DeviceExpAPI ] [ComplexFloat] [b.exp() ]; + [DeviceExpm1API ] [Float ] [b.exp_m1()]; + [DeviceFloorAPI ] [Float ] [b.floor() ]; + [DeviceInvAPI ] [ComplexFloat] [b.recip() ]; + [DeviceLogAPI ] [ComplexFloat] [b.ln() ]; + [DeviceLog2API ] [ComplexFloat] [b.log2() ]; + [DeviceLog10API ] [ComplexFloat] [b.log10() ]; + [DeviceReciprocalAPI] [ComplexFloat] [b.recip() ]; + [DeviceRoundAPI ] [Float ] [b.round() ]; + [DeviceSinAPI ] [ComplexFloat] [b.sin() ]; + [DeviceSinhAPI ] [ComplexFloat] [b.sinh() ]; + [DeviceSqrtAPI ] [ComplexFloat] [b.sqrt() ]; + [DeviceTanAPI ] [ComplexFloat] [b.tan() ]; + [DeviceTanhAPI ] [ComplexFloat] [b.tanh() ]; + [DeviceTruncAPI ] [Float ] [b.trunc() ]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + DTypeIntoFloatAPI, + D: DimAPI, +{ + type TOut = T::FloatType; + + fn op_muta_refb( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + let mut func = |a: &mut MaybeUninit, b: &T| { + let b = b.clone().into_float(); + a.write(func_inner); + }; + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta(&self, a: &mut Vec>, la: &Layout) -> Result<()> { + let mut func = |a: &mut MaybeUninit| { + let b = unsafe { a.assume_init_read() }; + a.write(func_inner); + }; + self.op_muta_func(a, la, &mut func) + } +} + +impl DeviceSquareAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + Mul, + D: DimAPI, +{ + type TOut = T; + + fn op_muta_refb(&self, a: &mut Vec>, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> { + let mut func = |a: &mut MaybeUninit, b: &T| { + a.write(b.clone() * b.clone()); + }; + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta(&self, a: &mut Vec>, la: &Layout) -> Result<()> { + let mut func = |a: &mut MaybeUninit| { + let b = unsafe { a.assume_init_read() }; + a.write(b.clone() * b); + }; + self.op_muta_func(a, la, &mut func) + } +} + +/* #endregion */ + +/* #region boolean output */ + +#[duplicate_item( + DeviceOpAPI NumTrait func ; + [DeviceSignBitAPI ] [Signed ] [|a, b| { a.write(b.is_positive()); } ]; + [DeviceIsFiniteAPI] [ComplexFloat] [|a, b| { a.write(b.is_finite() ); } ]; + [DeviceIsInfAPI ] [ComplexFloat] [|a, b| { a.write(b.is_infinite()); } ]; + [DeviceIsNanAPI ] [ComplexFloat] [|a, b| { a.write(b.is_nan() ); } ]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + T: Clone + NumTrait + Send + Sync, + D: DimAPI, +{ + type TOut = bool; + + fn op_muta_refb(&self, a: &mut Vec>, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> { + self.op_muta_refb_func(a, la, b, lb, &mut func) + } + + fn op_muta(&self, _a: &mut Vec>, _la: &Layout) -> Result<()> { + let type_b = core::any::type_name::(); + unreachable!("{:?} is not supported in this function.", type_b); + } +} + +/* #endregion */ + +/* #region complex specific implementation */ + +impl DeviceAbsAPI for DeviceRayonAutoImpl +where + T: ExtNum + Send + Sync, + T::AbsOut: Send + Sync, + D: DimAPI, +{ + type TOut = T::AbsOut; + + fn op_muta_refb( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_muta_refb_func(a, la, b, lb, &mut |a, b| { + a.write(b.clone().ext_abs()); + }) + } + + fn op_muta(&self, a: &mut Vec>, la: &Layout) -> Result<()> { + if T::ABS_UNCHANGED { + return Ok(()); + } else if T::ABS_SAME_TYPE { + return self.op_muta_func(a, la, &mut |a| unsafe { + a.write(a.assume_init_read().ext_abs()); + }); + } else { + let type_b = core::any::type_name::(); + unreachable!("{:?} is not supported in this function.", type_b); + } + } +} + +impl DeviceImagAPI for DeviceRayonAutoImpl +where + T: ExtNum + Send + Sync, + T::AbsOut: Send + Sync, + D: DimAPI, +{ + type TOut = T::AbsOut; + + fn op_muta_refb( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_muta_refb_func(a, la, b, lb, &mut |a, b| { + a.write(b.clone().ext_imag()); + }) + } + + fn op_muta(&self, a: &mut Vec>, la: &Layout) -> Result<()> { + if T::ABS_SAME_TYPE { + return self.op_muta_func(a, la, &mut |a| unsafe { + a.write(a.assume_init_read().ext_imag()); + }); + } else { + let type_b = core::any::type_name::(); + unreachable!("{:?} is not supported in this function.", type_b); + } + } +} + +impl DeviceRealAPI for DeviceRayonAutoImpl +where + T: ExtNum + Send + Sync, + T::AbsOut: Send + Sync, + D: DimAPI, +{ + type TOut = T::AbsOut; + + fn op_muta_refb( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_muta_refb_func(a, la, b, lb, &mut |a, b| { + a.write(b.clone().ext_real()); + }) + } + + fn op_muta(&self, _a: &mut Vec>, _la: &Layout) -> Result<()> { + if T::ABS_SAME_TYPE { + return Ok(()); + } else { + let type_b = core::any::type_name::(); + unreachable!("{:?} is not supported in this function.", type_b); + } + } +} + +impl DeviceSignAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + ComplexFloat + Div, + D: DimAPI, +{ + type TOut = T; + + fn op_muta_refb(&self, a: &mut Vec>, la: &Layout, b: &Vec, lb: &Layout) -> Result<()> { + self.op_muta_refb_func(a, la, b, lb, &mut |a, b| { + a.write(*b / b.abs()); + }) + } + + fn op_muta(&self, a: &mut Vec>, la: &Layout) -> Result<()> { + self.op_muta_func(a, la, &mut |a| unsafe { + a.write(a.assume_init_read() / a.assume_init_read().abs()); + }) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_arithmetic.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_arithmetic.rs new file mode 100644 index 00000000..9494cca7 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_arithmetic.rs @@ -0,0 +1,56 @@ +use crate::prelude_dev::*; + +#[duplicate_item( + DeviceOpAPI Op func ; + [DeviceAddAPI ] [Add ] [|c, a, b| { c.write(a.clone() + b.clone()); }]; + [DeviceSubAPI ] [Sub ] [|c, a, b| { c.write(a.clone() - b.clone()); }]; + [DeviceMulAPI ] [Mul ] [|c, a, b| { c.write(a.clone() * b.clone()); }]; + [DeviceDivAPI ] [Div ] [|c, a, b| { c.write(a.clone() / b.clone()); }]; + [DeviceRemAPI ] [Rem ] [|c, a, b| { c.write(a.clone() % b.clone()); }]; + [DeviceBitOrAPI ] [BitOr ] [|c, a, b| { c.write(a.clone() | b.clone()); }]; + [DeviceBitAndAPI] [BitAnd] [|c, a, b| { c.write(a.clone() & b.clone()); }]; + [DeviceBitXorAPI] [BitXor] [|c, a, b| { c.write(a.clone() ^ b.clone()); }]; + [DeviceShlAPI ] [Shl ] [|c, a, b| { c.write(a.clone() << b.clone()); }]; + [DeviceShrAPI ] [Shr ] [|c, a, b| { c.write(a.clone() >> b.clone()); }]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + Op, + TB: Clone + Send + Sync, + TC: Clone + Send + Sync, + D: DimAPI, +{ + fn op_mutc_refa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut func) + } + + fn op_mutc_refa_numb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: TB, + ) -> Result<()> { + self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut func) + } + + fn op_mutc_numa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: TA, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut func) + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_common.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_common.rs new file mode 100644 index 00000000..204efca7 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_ternary_common.rs @@ -0,0 +1,190 @@ +use crate::prelude_dev::*; +use num::complex::ComplexFloat; +use num::{pow::Pow, Float}; +use rstsr_dtype_traits::{DTypeIntoFloatAPI, DTypePromoteAPI, ExtFloat, ExtReal}; + +// output with special promotion +#[duplicate_item( + DeviceOpAPI TraitT func_inner; + [DeviceATan2API ] [Float ] [Float::atan2(a, b) ]; + [DeviceCopySignAPI ] [Float ] [Float::copysign(a, b) ]; + [DeviceHypotAPI ] [Float ] [Float::hypot(a, b) ]; + [DeviceNextAfterAPI ] [ExtFloat ] [ExtFloat::ext_nextafter(a, b) ]; + [DeviceLogAddExpAPI ] [ComplexFloat ] [(a.exp() + b.exp()).ln() ]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + DTypePromoteAPI>, + TB: Clone + Send + Sync, + D: DimAPI, +{ + type TOut = ::FloatType; + + fn op_mutc_refa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + let (a, b) = (a.into_float(), b.into_float()); + c.write(func_inner); + }; + self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut func) + } + + fn op_mutc_refa_numb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: TB, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + let (a, b) = (a.into_float(), b.into_float()); + c.write(func_inner); + }; + self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut func) + } + + fn op_mutc_numa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: TA, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + let (a, b) = (a.into_float(), b.into_float()); + c.write(func_inner); + }; + self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut func) + } +} + +// general promotion +#[duplicate_item( + DeviceOpAPI TO TraitT func_inner; + [DeviceMaximumAPI ] [TA::Res] [ExtReal ] [ExtReal::ext_max(a, b) ]; + [DeviceMinimumAPI ] [TA::Res] [ExtReal ] [ExtReal::ext_min(a, b) ]; + [DeviceFloorDivideAPI ] [TA::Res] [ExtReal ] [ExtReal::ext_floor_divide(a, b)]; + [DeviceEqualAPI ] [bool ] [PartialEq ] [a == b ]; + [DeviceGreaterAPI ] [bool ] [PartialOrd ] [a > b ]; + [DeviceGreaterEqualAPI] [bool ] [PartialOrd ] [a >= b ]; + [DeviceLessAPI ] [bool ] [PartialOrd ] [a < b ]; + [DeviceLessEqualAPI ] [bool ] [PartialOrd ] [a <= b ]; +)] +impl DeviceOpAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + DTypePromoteAPI, + TB: Clone + Send + Sync, + D: DimAPI, +{ + type TOut = TO; + + fn op_mutc_refa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + c.write(func_inner); + }; + self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut func) + } + + fn op_mutc_refa_numb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: TB, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + c.write(func_inner); + }; + self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut func) + } + + fn op_mutc_numa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: TA, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + let mut func = |c: &mut MaybeUninit, a: &TA, b: &TB| { + let (a, b) = TA::promote_pair(a.clone(), b.clone()); + c.write(func_inner); + }; + self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut func) + } +} + +// Special case for pow +impl DevicePowAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + TA: Pow, + TA::Output: Clone + Send + Sync, + D: DimAPI, +{ + type TOut = TA::Output; + + fn op_mutc_refa_refb( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + ) -> Result<()> { + self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut |c, a, b| { + c.write(a.clone().pow(b.clone())); + }) + } + + fn op_mutc_refa_numb( + &self, + c: &mut >>::Raw, + lc: &Layout, + a: &>::Raw, + la: &Layout, + b: TB, + ) -> Result<()> { + self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut |c, a, b| { + c.write(a.clone().pow(b.clone())); + }) + } + + fn op_mutc_numa_refb( + &self, + c: &mut >>::Raw, + lc: &Layout, + a: TA, + b: &>::Raw, + lb: &Layout, + ) -> Result<()> { + self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut |c, a, b| { + c.write(a.clone().pow(b.clone())); + }) + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_tri.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_tri.rs new file mode 100644 index 00000000..71b99a23 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_tri.rs @@ -0,0 +1,55 @@ +use crate::prelude_dev::*; +use num::complex::ComplexFloat; + +impl DeviceOpPackTriAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync, +{ + fn pack_tri( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + uplo: FlagUpLo, + ) -> Result<()> { + let pool = self.get_current_pool(); + let default_order = self.default_order(); + match default_order { + RowMajor => pack_tri_cpu_rayon(a, la, b, lb, uplo, pool), + ColMajor => { + let la = la.reverse_axes(); + let lb = lb.reverse_axes(); + let uplo = uplo.flip()?; + pack_tri_cpu_rayon(a, &la, b, &lb, uplo, pool) + }, + } + } +} + +impl DeviceOpUnpackTriAPI for DeviceRayonAutoImpl +where + T: ComplexFloat + Send + Sync, +{ + fn unpack_tri( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + uplo: FlagUpLo, + symm: FlagSymm, + ) -> Result<()> { + let pool = self.get_current_pool(); + let default_order = self.default_order(); + match default_order { + RowMajor => unpack_tri_cpu_rayon(a, la, b, lb, uplo, symm, pool), + ColMajor => { + let la = la.reverse_axes(); + let lb = lb.reverse_axes(); + let uplo = uplo.flip()?; + unpack_tri_cpu_rayon(a, &la, b, &lb, uplo, symm, pool) + }, + } + } +} diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_with_func.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_with_func.rs new file mode 100644 index 00000000..0c8df24d --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/op_with_func.rs @@ -0,0 +1,117 @@ +use crate::prelude_dev::*; + +/* #region impl op_func for DeviceRayonAutoImpl */ + +impl DeviceOp_MutC_RefA_RefB_API for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + TC: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit, &TA, &TB) + ?Sized + Send + Sync, +{ + fn op_mutc_refa_refb_func( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: &Vec, + lb: &Layout, + f: &mut F, + ) -> Result<()> { + let pool = self.get_current_pool(); + op_mutc_refa_refb_func_cpu_rayon(c, lc, a, la, b, lb, f, pool) + } +} + +impl DeviceOp_MutC_RefA_NumB_API for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + TC: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit, &TA, &TB) + ?Sized + Send + Sync, +{ + fn op_mutc_refa_numb_func( + &self, + c: &mut Vec>, + lc: &Layout, + a: &Vec, + la: &Layout, + b: TB, + f: &mut F, + ) -> Result<()> { + let pool = self.get_current_pool(); + op_mutc_refa_numb_func_cpu_rayon(c, lc, a, la, b, f, pool) + } +} + +impl DeviceOp_MutC_NumA_RefB_API for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + TC: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit, &TA, &TB) + ?Sized + Send + Sync, +{ + fn op_mutc_numa_refb_func( + &self, + c: &mut Vec>, + lc: &Layout, + a: TA, + b: &Vec, + lb: &Layout, + f: &mut F, + ) -> Result<()> { + let pool = self.get_current_pool(); + op_mutc_numa_refb_func_cpu_rayon(c, lc, a, b, lb, f, pool) + } +} + +impl DeviceOp_MutA_RefB_API for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit, &TB) + ?Sized + Send + Sync, +{ + fn op_muta_refb_func( + &self, + a: &mut Vec>, + la: &Layout, + b: &Vec, + lb: &Layout, + f: &mut F, + ) -> Result<()> { + let pool = self.get_current_pool(); + op_muta_refb_func_cpu_rayon(a, la, b, lb, f, pool) + } +} + +impl DeviceOp_MutA_NumB_API for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync, + TB: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit, &TB) + ?Sized + Send + Sync, +{ + fn op_muta_numb_func(&self, a: &mut Vec>, la: &Layout, b: TB, f: &mut F) -> Result<()> { + let pool = self.get_current_pool(); + op_muta_numb_func_cpu_rayon(a, la, b, f, pool) + } +} + +impl DeviceOp_MutA_API for DeviceRayonAutoImpl +where + T: Clone + Send + Sync, + D: DimAPI, + F: Fn(&mut MaybeUninit) + ?Sized + Send + Sync, +{ + fn op_muta_func(&self, a: &mut Vec>, la: &Layout, f: &mut F) -> Result<()> { + let pool = self.get_current_pool(); + op_muta_func_cpu_rayon(a, la, f, pool) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/src/rayon_auto_impl/reduction.rs b/crates-device/rstsr-accelerate/src/rayon_auto_impl/reduction.rs new file mode 100644 index 00000000..5d17d8f7 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/rayon_auto_impl/reduction.rs @@ -0,0 +1,763 @@ +use crate::prelude_dev::*; +use core::ops::{Add, Mul}; +use num::complex::ComplexFloat; +use num::{FromPrimitive, One, Zero}; +use rstsr_dtype_traits::ExtReal; + +impl OpSumAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + Zero + Add, + D: DimAPI, +{ + type TOut = T; + + fn sum_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = T::zero; + let f = |acc, x| acc + x; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn sum_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = T::zero; + let f = |acc, x| acc + x; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpMinAPI for DeviceRayonAutoImpl +where + T: ExtReal + Send + Sync, + D: DimAPI, +{ + type TOut = T; + + fn min_all(&self, a: &Vec, la: &Layout) -> Result { + if la.size() == 0 { + rstsr_raise!(InvalidValue, "zero-size array is not supported for min")?; + } + + let pool = self.get_current_pool(); + + let f_init = T::ext_max_value; + let f = |acc: T, x: T| acc.ext_min(x); + let f_sum = |acc1: T, acc2: T| acc1.ext_min(acc2); + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn min_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T, Self>, Layout)> { + if la.size() == 0 { + rstsr_raise!(InvalidValue, "zero-size array is not supported for min")?; + } + + let pool = self.get_current_pool(); + + let f_init = T::ext_max_value; + let f = |acc: T, x: T| acc.ext_min(x); + let f_sum = |acc1: T, acc2: T| acc1.ext_min(acc2); + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpMaxAPI for DeviceRayonAutoImpl +where + T: ExtReal + Send + Sync, + D: DimAPI, +{ + type TOut = T; + + fn max_all(&self, a: &Vec, la: &Layout) -> Result { + if la.size() == 0 { + rstsr_raise!(InvalidValue, "zero-size array is not supported for max")?; + } + + let pool = self.get_current_pool(); + + let f_init = T::ext_min_value; + let f = |acc: T, x: T| acc.ext_max(x); + let f_sum = |acc1: T, acc2: T| acc1.ext_max(acc2); + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn max_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T, Self>, Layout)> { + if la.size() == 0 { + rstsr_raise!(InvalidValue, "zero-size array is not supported for max")?; + } + + let pool = self.get_current_pool(); + + let f_init = T::ext_min_value; + let f = |acc: T, x: T| acc.ext_max(x); + let f_sum = |acc1: T, acc2: T| acc1.ext_max(acc2); + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpProdAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + One + Mul, + D: DimAPI, +{ + type TOut = T; + + fn prod_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = T::one; + let f = |acc, x| acc * x; + let f_sum = |acc1, acc2| acc1 * acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn prod_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = T::one; + let f = |acc, x| acc * x; + let f_sum = |acc1, acc2| acc1 * acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpMeanAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + ComplexFloat + FromPrimitive, + D: DimAPI, +{ + type TOut = T; + + fn mean_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let size = la.size(); + let f_init = T::zero; + let f = |acc, x| acc + x; + let f_sum = |acc, x| acc + x; + let f_out = |acc| acc / T::from_usize(size).unwrap(); + + let sum = reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool)?; + Ok(sum) + } + + fn mean_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T, Self>, Layout)> { + let pool = self.get_current_pool(); + + let (layout_axes, _) = la.dim_split_axes(axes)?; + let size = layout_axes.size(); + let f_init = T::zero; + let f = |acc, x| acc + x; + let f_sum = |acc, x| acc + x; + let f_out = |acc| acc / T::from_usize(size).unwrap(); + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpVarAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + ComplexFloat + FromPrimitive, + T::Real: Clone + Send + Sync + ComplexFloat + FromPrimitive, + D: DimAPI, +{ + type TOut = T::Real; + + fn var_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let size = la.size(); + + let f_init = || (T::zero(), T::Real::zero()); + let f = |(acc_1, acc_2): (T, T::Real), x: T| (acc_1 + x, acc_2 + (x * x.conj()).re()); + let f_sum = |(acc_1, acc_2): (T, T::Real), (x_1, x_2)| (acc_1 + x_1, acc_2 + x_2); + let f_out = |(acc_1, acc_2): (T, T::Real)| { + let size_1 = T::from_usize(size).unwrap(); + let size_2 = T::Real::from_usize(size).unwrap(); + let mean = acc_1 / size_1; + acc_2 / size_2 - (mean * mean.conj()).re() + }; + + let result = reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool)?; + Ok(result) + } + + fn var_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T::Real, Self>, Layout)> { + let pool = self.get_current_pool(); + + let (layout_axes, _) = la.dim_split_axes(axes)?; + let size = layout_axes.size(); + + let f_init = || (T::zero(), T::Real::zero()); + let f = |(acc_1, acc_2): (T, T::Real), x: T| (acc_1 + x, acc_2 + (x * x.conj()).re()); + let f_sum = |(acc_1, acc_2): (T, T::Real), (x_1, x_2)| (acc_1 + x_1, acc_2 + x_2); + let f_out = |(acc_1, acc_2): (T, T::Real)| { + let size_1 = T::from_usize(size).unwrap(); + let size_2 = T::Real::from_usize(size).unwrap(); + let mean = acc_1 / size_1; + acc_2 / size_2 - (mean * mean.conj()).re() + }; + + let (out, layout_out) = reduce_axes_difftype_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpStdAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + ComplexFloat + FromPrimitive, + T::Real: Clone + Send + Sync + ComplexFloat + FromPrimitive, + D: DimAPI, +{ + type TOut = T::Real; + + fn std_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let size = la.size(); + + let f_init = || (T::zero(), T::Real::zero()); + let f = |(acc_1, acc_2): (T, T::Real), x: T| (acc_1 + x, acc_2 + (x * x.conj()).re()); + let f_sum = |(acc_1, acc_2): (T, T::Real), (x_1, x_2)| (acc_1 + x_1, acc_2 + x_2); + let f_out = |(acc_1, acc_2): (T, T::Real)| { + let size_1 = T::from_usize(size).unwrap(); + let size_2 = T::Real::from_usize(size).unwrap(); + let mean = acc_1 / size_1; + let var = acc_2 / size_2 - (mean * mean.conj()).re(); + var.sqrt() + }; + + let result = reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool)?; + Ok(result) + } + + fn std_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T::Real, Self>, Layout)> { + let pool = self.get_current_pool(); + + let (layout_axes, _) = la.dim_split_axes(axes)?; + let size = layout_axes.size(); + + let f_init = || (T::zero(), T::Real::zero()); + let f = |(acc_1, acc_2): (T, T::Real), x: T| (acc_1 + x, acc_2 + (x * x.conj()).re()); + let f_sum = |(acc_1, acc_2): (T, T::Real), (x_1, x_2)| (acc_1 + x_1, acc_2 + x_2); + let f_out = |(acc_1, acc_2): (T, T::Real)| { + let size_1 = T::from_usize(size).unwrap(); + let size_2 = T::Real::from_usize(size).unwrap(); + let mean = acc_1 / size_1; + let var = acc_2 / size_2 - (mean * mean.conj()).re(); + var.sqrt() + }; + + let (out, layout_out) = reduce_axes_difftype_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpL2NormAPI for DeviceRayonAutoImpl +where + T: Clone + Send + Sync + ComplexFloat + FromPrimitive, + T::Real: Clone + Send + Sync + ComplexFloat + FromPrimitive, + D: DimAPI, +{ + type TOut = T::Real; + + fn l2_norm_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = || T::Real::zero(); + let f = |acc: T::Real, x: T| acc + (x * x.conj()).re(); + let f_sum = |acc: T::Real, x: T::Real| acc + x; + let f_out = |acc: T::Real| acc.sqrt(); + + let result = reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool)?; + Ok(result) + } + + fn l2_norm_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, T::Real, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = || T::Real::zero(); + let f = |acc: T::Real, x: T| acc + (x * x.conj()).re(); + let f_sum = |acc: T::Real, x: T::Real| acc + x; + let f_out = |acc: T::Real| acc.sqrt(); + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpArgMinAPI for DeviceRayonAutoImpl +where + T: Clone + PartialOrd + Send + Sync, + D: DimAPI, +{ + type TOut = usize; + + fn argmin_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, Self::TOut, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y < x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let (out, layout_out) = reduce_axes_arg_cpu_rayon(a, la, axes, f_comp, f_eq, RowMajor, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } + + fn argmin_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y < x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let result = reduce_all_arg_cpu_rayon(a, la, f_comp, f_eq, RowMajor, pool)?; + Ok(result) + } +} + +impl OpArgMaxAPI for DeviceRayonAutoImpl +where + T: Clone + PartialOrd + Send + Sync, + D: DimAPI, +{ + type TOut = usize; + + fn argmax_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, Self::TOut, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y > x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let (out, layout_out) = reduce_axes_arg_cpu_rayon(a, la, axes, f_comp, f_eq, RowMajor, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } + + fn argmax_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y > x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let result = reduce_all_arg_cpu_rayon(a, la, f_comp, f_eq, RowMajor, pool)?; + Ok(result) + } +} + +impl OpAllAPI for DeviceRayonAutoImpl +where + D: DimAPI, +{ + type TOut = bool; + + fn all_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = || true; + let f = |acc, x| acc && x; + let f_sum = |acc1, acc2| acc1 && acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn all_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, bool, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = || true; + let f = |acc, x| acc && x; + let f_sum = |acc1, acc2| acc1 && acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpAnyAPI for DeviceRayonAutoImpl +where + D: DimAPI, +{ + type TOut = bool; + + fn any_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = || false; + let f = |acc, x| acc || x; + let f_sum = |acc1, acc2| acc1 || acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn any_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, bool, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = || false; + let f = |acc, x| acc || x; + let f_sum = |acc1, acc2| acc1 || acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpCountNonZeroAPI for DeviceRayonAutoImpl +where + T: Clone + PartialEq + Zero + Send + Sync, + D: DimAPI, +{ + type TOut = usize; + + fn count_nonzero_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = || 0; + let f = |acc, x| if x != T::zero() { acc + 1 } else { acc }; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn count_nonzero_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, usize, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = || 0; + let f = |acc, x| if x != T::zero() { acc + 1 } else { acc }; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpUnraveledArgMinAPI for DeviceRayonAutoImpl +where + T: Clone + PartialOrd + Send + Sync, + D: DimAPI, +{ + fn unraveled_argmin_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, IxD, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y < x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let (out, layout_out) = reduce_axes_unraveled_arg_cpu_rayon(a, la, axes, f_comp, f_eq, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } + + fn unraveled_argmin_all(&self, a: &>::Raw, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y < x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let result = reduce_all_unraveled_arg_cpu_rayon(a, la, f_comp, f_eq, pool)?; + Ok(result) + } +} + +impl OpUnraveledArgMaxAPI for DeviceRayonAutoImpl +where + T: Clone + PartialOrd + Send + Sync, + D: DimAPI, +{ + fn unraveled_argmax_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, IxD, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y > x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let (out, layout_out) = reduce_axes_unraveled_arg_cpu_rayon(a, la, axes, f_comp, f_eq, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } + + fn unraveled_argmax_all(&self, a: &>::Raw, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_comp = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y > x) + } else { + Some(true) + } + }; + let f_eq = |x: Option, y: T| -> Option { + if let Some(x) = x { + Some(y == x) + } else { + Some(false) + } + }; + let result = reduce_all_unraveled_arg_cpu_rayon(a, la, f_comp, f_eq, pool)?; + Ok(result) + } +} + +impl OpSumBoolAPI for DeviceRayonAutoImpl +where + D: DimAPI, +{ + fn sum_all(&self, a: &Vec, la: &Layout) -> Result { + let pool = self.get_current_pool(); + + let f_init = || 0; + let f = |acc, x| match x { + true => acc + 1, + false => acc, + }; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + reduce_all_cpu_rayon(a, la, f_init, f, f_sum, f_out, pool) + } + + fn sum_axes( + &self, + a: &Vec, + la: &Layout, + axes: &[isize], + ) -> Result<(Storage>, usize, Self>, Layout)> { + let pool = self.get_current_pool(); + + let f_init = || 0; + let f = |acc, x| match x { + true => acc + 1, + false => acc, + }; + let f_sum = |acc1, acc2| acc1 + acc2; + let f_out = |acc| acc; + + let (out, layout_out) = reduce_axes_cpu_rayon(a, &la.to_dim()?, axes, f_init, f, f_sum, f_out, pool)?; + Ok((Storage::new(out.into(), self.clone()), layout_out)) + } +} + +impl OpAllCloseAPI for DeviceRayonAutoImpl +where + TA: Clone + Send + Sync + DTypePromoteAPI, + TB: Clone + Send + Sync, + >::Res: ExtNum>, + TE: ExtFloat + Add + Mul + PartialOrd + Clone + Send + Sync, + D: DimAPI, +{ + fn allclose_all( + &self, + a: &>::Raw, + la: &Layout, + b: &>::Raw, + lb: &Layout, + isclose_args: &IsCloseArgs, + ) -> Result { + use rstsr_dtype_traits::isclose; + + let pool = self.get_current_pool(); + + if la.size() == 0 || lb.size() == 0 { + rstsr_raise!(InvalidValue, "zero-size array is not supported for allclose")?; + } + + let f_init = || true; + let f = |acc: bool, (a_elem, b_elem): (TA, TB)| { + let result = isclose(&a_elem, &b_elem, isclose_args); + acc && result + }; + let f_sum = |acc1: bool, acc2: bool| acc1 && acc2; + let f_out = |acc: bool| acc; + + reduce_all_binary_cpu_rayon(a, la, b, lb, f_init, f, f_sum, f_out, pool) + } + + fn allclose_axes( + &self, + _a: &>::Raw, + _la: &Layout, + _b: &>::Raw, + _lb: &Layout, + _axes: &[isize], + _isclose_args: &IsCloseArgs, + ) -> Result<(Storage>::Raw>, bool, Self>, Layout)> { + unimplemented!("This function (`allclose_axes`) is not planned to be implemented yet."); + } +} diff --git a/crates-device/rstsr-accelerate/src/sci_auto_impl/distance_auto_impl.rs b/crates-device/rstsr-accelerate/src/sci_auto_impl/distance_auto_impl.rs new file mode 100644 index 00000000..2cee23b7 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/sci_auto_impl/distance_auto_impl.rs @@ -0,0 +1,120 @@ +use crate::prelude_dev::*; + +use num::Float; +use rstsr_sci_traits::distance::metric::{MetricDistAPI, MetricDistWeightedAPI, MetricEuclidean}; +use rstsr_sci_traits::distance::native_impl::{cdist_rayon, cdist_weighted_rayon}; +use rstsr_sci_traits::distance::traits::CDistAPI; + +impl CDistAPI + for ( + TensorView<'_, T, DeviceRayonAutoImpl, D>, + TensorView<'_, T, DeviceRayonAutoImpl, D>, + M, + TensorView<'_, TW, DeviceRayonAutoImpl, DW>, + ) +where + M: MetricDistWeightedAPI, Weight = Vec, Out = TW> + Send + Sync, + T: Send + Sync, + TW: Float + Send + Sync, + M::Out: Send + Sync, + DeviceRayonAutoImpl: DeviceAPI> + + DeviceAPI + + DeviceAPI> + + DeviceCreationAnyAPI + + DeviceCreationAnyAPI + + OpAssignArbitaryAPI + + OpAssignAPI, + D: DimAPI + DimIntoAPI, + DW: DimAPI + DimIntoAPI, +{ + type Out = Tensor; + + fn cdist_f(self) -> Result { + let (xa, xb, kernel, weight) = self; + rstsr_assert_eq!(xa.ndim(), 2, InvalidLayout, "xa must be a 2D tensor")?; + rstsr_assert_eq!(xb.ndim(), 2, InvalidLayout, "xb must be a 2D tensor")?; + rstsr_assert_eq!(weight.ndim(), 1, InvalidLayout, "weight must be a 1D tensor")?; + rstsr_assert!(xa.device().same_device(xb.device()), DeviceMismatch)?; + rstsr_assert!(xa.device().same_device(weight.device()), DeviceMismatch)?; + let la = xa.layout().to_dim::()?; + let lb = xb.layout().to_dim::()?; + let device = xa.device().clone(); + let order = device.default_order(); + let weight = weight.into_contig_f(RowMajor)?; + let pool = device.get_current_pool(); + let dist = cdist_weighted_rayon(xa.raw(), xb.raw(), &la, &lb, weight.raw(), kernel, order, pool)?; + + let m = la.shape()[0]; + let n = lb.shape()[0]; + asarray_f((dist, [m, n], &device))?.into_dim_f::() + } +} + +impl CDistAPI + for (TensorView<'_, T, DeviceRayonAutoImpl, D>, TensorView<'_, T, DeviceRayonAutoImpl, D>, M) +where + M: MetricDistAPI> + Send + Sync, + T: Send + Sync, + M::Out: Send + Sync, + DeviceRayonAutoImpl: + DeviceAPI> + DeviceAPI> + DeviceCreationAnyAPI, + D: DimAPI + DimIntoAPI, +{ + type Out = Tensor; + + fn cdist_f(self) -> Result { + let (xa, xb, kernel) = self; + rstsr_assert_eq!(xa.ndim(), 2, InvalidLayout, "xa must be a 2D tensor")?; + rstsr_assert_eq!(xb.ndim(), 2, InvalidLayout, "xb must be a 2D tensor")?; + rstsr_assert!(xa.device().same_device(xb.device()), DeviceMismatch)?; + let la = xa.layout().to_dim::()?; + let lb = xb.layout().to_dim::()?; + let device = xa.device().clone(); + let order = device.default_order(); + let pool = device.get_current_pool(); + let dist = cdist_rayon(xa.raw(), xb.raw(), &la, &lb, kernel, order, pool)?; + + let m = la.shape()[0]; + let n = lb.shape()[0]; + asarray_f((dist, [m, n], &device))?.into_dim_f::() + } +} + +impl CDistAPI + for (TensorView<'_, T, DeviceRayonAutoImpl, D>, TensorView<'_, T, DeviceRayonAutoImpl, D>) +where + T: Float + Send + Sync, + DeviceRayonAutoImpl: DeviceAPI> + DeviceCreationAnyAPI, + D: DimAPI + DimIntoAPI, +{ + type Out = Tensor; + + fn cdist_f(self) -> Result { + let (xa, xb) = self; + CDistAPI::::cdist_f((xa, xb, MetricEuclidean)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use rstsr_sci_traits::distance::metric::MetricEuclidean; + use rstsr_sci_traits::distance::traits::cdist; + + #[test] + fn playground() { + let device = DeviceRayonAutoImpl::default(); + let a = linspace((0., 1., 6400, &device)).into_shape((1600, 4)); + let b = linspace((0., 1., 8000, &device)).into_shape((2000, 4)).into_flip(-1); + + let d = cdist((a.view(), b.view(), MetricEuclidean)); + println!("{d:16.8?}"); + + let d = cdist((a.view(), b.view())); + println!("{d:16.8?}"); + + let w = asarray((vec![1.5, 1.2, 0.7, 1.3], &device)); + let d_w = cdist((a.view(), b.view(), MetricEuclidean, w.view())); + println!("{d_w:16.8?}"); + } +} diff --git a/crates-device/rstsr-accelerate/src/sci_auto_impl/integrate_auto_impl.rs b/crates-device/rstsr-accelerate/src/sci_auto_impl/integrate_auto_impl.rs new file mode 100644 index 00000000..31b9456a --- /dev/null +++ b/crates-device/rstsr-accelerate/src/sci_auto_impl/integrate_auto_impl.rs @@ -0,0 +1,13 @@ +use crate::prelude_dev::*; +use rstsr_sci_traits::integrate::lebedev::*; + +impl LebedevRuleAPI for DeviceRayonAutoImpl { + fn lebedev_rule_f(&self, n: usize) -> Result> { + let degree = lebedev_order_to_degree(n).map_err(|_| rstsr_error!(InvalidValue, "Invalid Lebedev order {n}"))?; + let (quads, weights) = lebedev_make_angular_grid(degree)?; + let ngrids = weights.len(); + let quads = asarray((quads, [ngrids, 3].c(), self)); + let weights = asarray((weights, [ngrids].c(), self)); + Ok(LebedevQuad { quads, weights }) + } +} diff --git a/crates-device/rstsr-accelerate/src/sci_auto_impl/mod.rs b/crates-device/rstsr-accelerate/src/sci_auto_impl/mod.rs new file mode 100644 index 00000000..74e245ea --- /dev/null +++ b/crates-device/rstsr-accelerate/src/sci_auto_impl/mod.rs @@ -0,0 +1,2 @@ +pub mod distance_auto_impl; +pub mod integrate_auto_impl; diff --git a/crates-device/rstsr-accelerate/src/threading.rs b/crates-device/rstsr-accelerate/src/threading.rs new file mode 100644 index 00000000..7258f7d1 --- /dev/null +++ b/crates-device/rstsr-accelerate/src/threading.rs @@ -0,0 +1,65 @@ +//! Apple Accelerate threading, now only single-threading supported. + +use crate::prelude_dev::*; +use rstsr_blas_traits::prelude_dev::*; + +/* #region threading number control */ + +struct AccelerateConfig; + +impl AccelerateConfig { + #[allow(dead_code)] + fn set_num_threads(&mut self, _n: usize) { + // Direct multi-thread control only available after macOS 15 + } + + fn get_num_threads(&mut self) -> usize { + // Direct multi-thread control only available after macOS 15 + return 1; + } +} + +/// Set number of threads for Apple Accelerate. +/// +/// This function should be safe to call from multiple threads. +pub fn set_num_threads(_n: usize) { + // Direct multi-thread control only available after macOS 15 +} + +/// Get the number of threads currently set for Apple Accelerate. +/// +/// This function should be safe to call from multiple threads. +pub fn get_num_threads() -> usize { + AccelerateConfig.get_num_threads() +} + +pub fn with_num_threads(nthreads: usize, f: F) -> R +where + F: FnOnce() -> R, +{ + let n = get_num_threads(); + set_num_threads(nthreads); + let r = f(); + set_num_threads(n); + return r; +} + +/* #endregion */ + +/* #region trait impl */ + +impl BlasThreadAPI for DeviceBLAS { + fn get_blas_num_threads(&self) -> usize { + crate::threading::get_num_threads() + } + + fn set_blas_num_threads(&self, nthreads: usize) { + crate::threading::set_num_threads(nthreads); + } + + fn with_blas_num_threads(&self, nthreads: usize, f: impl FnOnce() -> T) -> T { + crate::threading::with_num_threads(nthreads, f) + } +} + +/* #endregion */ diff --git a/crates-device/rstsr-accelerate/tests/issues/issue_45.rs b/crates-device/rstsr-accelerate/tests/issues/issue_45.rs new file mode 100644 index 00000000..b2bdb8fa --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/issues/issue_45.rs @@ -0,0 +1,14 @@ +#[test] +fn issue_45() { + use rstsr::prelude::*; + use rstsr_accelerate::DeviceAccelerate; + let device = DeviceAccelerate::default(); + let a: Tensor = rt::asarray((vec![], [1024, 0], &device)); + let b: Tensor = rt::asarray((vec![], [1000, 0], &device)); + let c = &a % b.t(); + println!("{:?}", c.shape()); + assert!(c.abs().sum() < 1e-10); + + let c = &a % a.t(); + assert!(c.abs().sum() < 1e-10); +} diff --git a/crates-device/rstsr-accelerate/tests/issues/mod.rs b/crates-device/rstsr-accelerate/tests/issues/mod.rs new file mode 100644 index 00000000..0decdedb --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/issues/mod.rs @@ -0,0 +1 @@ +mod issue_45; diff --git a/crates-device/rstsr-accelerate/tests/mod.rs b/crates-device/rstsr-accelerate/tests/mod.rs new file mode 100644 index 00000000..d47b531f --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/mod.rs @@ -0,0 +1,5 @@ +mod test_driver_impl; +#[cfg(feature = "linalg")] +mod test_linalg_func; + +mod issues; diff --git a/crates-device/rstsr-accelerate/tests/test_driver_impl/driver_validation_f64.py b/crates-device/rstsr-accelerate/tests/test_driver_impl/driver_validation_f64.py new file mode 100644 index 00000000..6cc747a3 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_driver_impl/driver_validation_f64.py @@ -0,0 +1,141 @@ +# # Driver tests in Python + +import numpy as np +import scipy + + +def fingerprint(a): + return np.dot(np.cos(np.arange(a.size)), np.asarray(a, order="C").ravel()) + + +# Path of npy files in rstsr-test-manifest + +root = "../../../../rstsr-test-manifest/resources/" + +# ## make sure of random generation + +a_raw = np.load(f"{root}/a-f64.npy") +b_raw = np.load(f"{root}/b-f64.npy") + +assert np.isclose(fingerprint(a_raw), 191.28900005103065) +assert np.isclose(fingerprint(b_raw), -51.11100342180723) + +# ## eigh driver tests + +# ### dsyev* + +a = a_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsyevd(a, lower=True) +assert np.isclose(fingerprint(w), -71.4747209499407) +assert np.isclose(fingerprint(np.abs(v)), -9.903934930318247) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsyevd(a, lower=False) +assert np.isclose(fingerprint(w), -71.4902453763506) +assert np.isclose(fingerprint(np.abs(v)), 6.973792268793419) +fingerprint(w), fingerprint(np.abs(v)) + +# ### dsygv* + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsygvd(a, b, uplo='L') +assert np.isclose(fingerprint(w), -89.60433120129908) +assert np.isclose(fingerprint(np.abs(v)), -5.243112559130817) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsygvd(a, b, uplo='U') +assert np.isclose(fingerprint(w), -65.27252612342873) +assert np.isclose(fingerprint(np.abs(v)), -7.0849504857534535) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsygvd(a, b, uplo='L', itype=2) +assert np.isclose(fingerprint(w), -2437.094304861363) +assert np.isclose(fingerprint(np.abs(v)), -4.108281604767547) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v, _ = scipy.linalg.lapack.dsygvd(a, b, uplo='L', itype=3) +assert np.isclose(fingerprint(w), -2437.094304861363) +assert np.isclose(fingerprint(np.abs(v)), 30.756098926747757) +fingerprint(w), fingerprint(np.abs(v)) + +# ## solve driver tests + +# ### gesv + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy()[:1024*512].reshape(1024, 512) +lu, piv, x, _ = scipy.linalg.lapack.dgesv(a, b) +assert np.isclose(fingerprint(lu), 5397.198541468395) +assert np.isclose(fingerprint(piv), -14.694714160751573) +assert np.isclose(fingerprint(x), -1951.253447757597) +fingerprint(lu), fingerprint(piv), fingerprint(x) + +# ### getrf, getri + +a = a_raw.copy().reshape(1024, 1024) +lu, piv, _ = scipy.linalg.lapack.dgetrf(a) +assert np.isclose(fingerprint(lu), 5397.198541468395) +assert np.isclose(fingerprint(piv), -14.694714160751573) +fingerprint(lu), fingerprint(piv) + +inv_a, _ = scipy.linalg.lapack.dgetri(lu, piv) +assert np.isclose(fingerprint(inv_a), 143.3900557703788) +fingerprint(inv_a) + +# ### potrf + +# Please note that driver implementation in rust does not clean upper/lower triangular. + +b = b_raw.copy().reshape(1024, 1024) +c, _ = scipy.linalg.lapack.dpotrf(b, lower=True, clean=0) +assert np.isclose(fingerprint(c), 35.17266259472725) +fingerprint(c) + +b = b_raw.copy().reshape(1024, 1024) +c, _ = scipy.linalg.lapack.dpotrf(b, lower=False, clean=0) +assert np.isclose(fingerprint(c), -53.53353704132017) +fingerprint(c) + +# ### sysv + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy()[:1024*512].reshape(1024, 512) +udut, piv, x, _ = scipy.linalg.lapack.dsysv(a, b, lower=True) +assert np.isclose(fingerprint(udut), -1201.6472395568974) +assert np.isclose(fingerprint(piv), -16668.7094872639) +assert np.isclose(fingerprint(x), -397.12032355166446) +fingerprint(udut), fingerprint(piv), fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy()[:1024*512].reshape(1024, 512) +udut, piv, x, _ = scipy.linalg.lapack.dsysv(a, b, lower=False) +assert np.isclose(fingerprint(udut), 1182.7836118324408) +assert np.isclose(fingerprint(piv), 11905.503011559245) +assert np.isclose(fingerprint(x), -314.4502289190444) +fingerprint(udut), fingerprint(piv), fingerprint(x) + +# ## svd driver tests + +a = a_raw.copy()[:1024*512].reshape(1024, 512) +u, s, vt, _ = scipy.linalg.lapack.dgesvd(a) +assert np.isclose(fingerprint(np.abs(u)), -1.9368850983570982) +assert np.isclose(fingerprint(s), 33.969339071043095) +assert np.isclose(fingerprint(np.abs(vt)), 13.465522484136157) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +a = a_raw.copy()[:1024*512].reshape(1024, 512) +u, s, vt, _ = scipy.linalg.lapack.dgesvd(a, full_matrices=False) +assert np.isclose(fingerprint(np.abs(u)), -9.144981428076894) +assert np.isclose(fingerprint(s), 33.969339071043095) +assert np.isclose(fingerprint(np.abs(vt)), 13.465522484136157) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + + diff --git a/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_eigh_f64.rs b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_eigh_f64.rs new file mode 100644 index 00000000..92ffc77a --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_eigh_f64.rs @@ -0,0 +1,148 @@ +use rstsr_blas_traits::lapack_eigh::*; +use rstsr_core::prelude::*; +use rstsr_core::prelude_dev::fingerprint; +use rstsr_accelerate::DeviceAccelerate as DeviceBLAS; +use rstsr_test_manifest::get_vec; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_dsyevd() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DSYEVD::default().a(a.view()).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -9.903934930318247).abs() < 1e-8); + + // upper for c-contiguous + let driver = DSYEVD::default().a(a.view()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4902453763506).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 6.973792268793419).abs() < 1e-8); + + // transpose upper for c-contiguous + let driver = DSYEVD::default().a(a.t()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -9.903934930318247).abs() < 1e-8); + } + + #[test] + fn test_dsyev() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DSYEV::default().a(a.view()).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -9.903934930318247).abs() < 1e-8); + + // upper for c-contiguous + let driver = DSYEV::default().a(a.view()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4902453763506).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 6.973792268793419).abs() < 1e-8); + + // transpose upper for c-contiguous + let driver = DSYEV::default().a(a.t()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -9.903934930318247).abs() < 1e-8); + } + + #[test] + fn test_dsygvd() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DSYGVD::default().a(a.view()).b(b.view()).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -5.243112559130817).abs() < 1e-8); + + // upper for c-contiguous + let driver = DSYGVD::default().a(a.view()).b(b.view()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -65.27252612342873).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -7.0849504857534535).abs() < 1e-8); + + // transpose upper for c-contiguous + let driver = DSYGVD::default().a(a.t()).b(b.t()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -5.243112559130817).abs() < 1e-8); + + // itype 2 + let driver = DSYGVD::default().a(a.view()).b(b.view()).itype(2).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -2437.094304861363).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -4.108281604767547).abs() < 1e-8); + + // itype 3 + let driver = DSYGVD::default().a(a.view()).b(b.view()).itype(3).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -2437.094304861363).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 30.756098926747757).abs() < 1e-8); + } + + #[test] + fn test_dsygv() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DSYGV::default().a(a.view()).b(b.view()).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -5.243112559130817).abs() < 1e-8); + + // upper for c-contiguous + let driver = DSYGV::default().a(a.view()).b(b.view()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -65.27252612342873).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -7.0849504857534535).abs() < 1e-8); + + // transpose upper for c-contiguous + let driver = DSYGV::default().a(a.t()).b(b.t()).uplo(Upper).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -5.243112559130817).abs() < 1e-8); + + // itype 2 + let driver = DSYGV::default().a(a.view()).b(b.view()).itype(2).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -2437.094304861363).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -4.108281604767547).abs() < 1e-8); + + // itype 3 + let driver = DSYGV::default().a(a.view()).b(b.view()).itype(3).build().unwrap(); + let (w, v) = driver.run().unwrap(); + let v = v.into_owned(); + assert!((fingerprint(&w) - -2437.094304861363).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 30.756098926747757).abs() < 1e-8); + } +} diff --git a/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_solve_f64.rs b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_solve_f64.rs new file mode 100644 index 00000000..2d2a0667 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_solve_f64.rs @@ -0,0 +1,95 @@ +use rstsr_blas_traits::lapack_solve::*; +use rstsr_core::prelude::*; +use rstsr_core::prelude_dev::fingerprint; +use rstsr_accelerate::DeviceAccelerate as DeviceBLAS; +use rstsr_test_manifest::get_vec; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_dgesv() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = &get_vec::('b')[..1024 * 512]; + let b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let driver = DGESV::default().a(a.view()).b(b.view()).build().unwrap(); + let (lu, piv, x) = driver.run().unwrap(); + let lu = lu.into_owned(); + let x = x.into_owned(); + let fpiv = piv.map(|&v| v as f64); + assert!((fingerprint(&lu) - 5397.198541468395).abs() < 1e-8); + assert!((fingerprint(&fpiv) - -14.694714160751573).abs() < 1e-8); + assert!((fingerprint(&x) - -1951.253447757597).abs() < 1e-8); + } + + #[test] + fn test_dgetrf_dgetri() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DGETRF::default().a(a.view()).build().unwrap(); + let (lu, piv) = driver.run().unwrap(); + let lu = lu.into_owned(); + let fpiv = piv.map(|&v| v as f64); + assert!((fingerprint(&lu) - 5397.198541468395).abs() < 1e-8); + assert!((fingerprint(&fpiv) - -14.694714160751573).abs() < 1e-8); + + let driver = DGETRI::default().a(lu.view()).ipiv(piv.view()).build().unwrap(); + let inv_a = driver.run().unwrap(); + let inv_a = inv_a.into_owned(); + assert!((fingerprint(&inv_a) - 143.3900557703788).abs() < 1e-8); + } + + #[test] + fn test_dpotrf() { + let device = DeviceBLAS::default(); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let driver = DPOTRF::default().a(b.view()).build().unwrap(); + let c = driver.run().unwrap(); + let c = c.into_owned(); + println!("fingerprint {:?}", fingerprint(&c)); + assert!((fingerprint(&c) - 35.17266259472725).abs() < 1e-8); + + // upper + let driver = DPOTRF::default().a(b.view()).uplo(Upper).build().unwrap(); + let c = driver.run().unwrap(); + let c = c.into_owned(); + println!("fingerprint {:?}", fingerprint(&c)); + assert!((fingerprint(&c) - -53.53353704132017).abs() < 1e-8); + } + + #[test] + fn test_dsysv() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = &get_vec::('b')[..1024 * 512]; + let b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let driver = DSYSV::default().a(a.view()).b(b.view()).build().unwrap(); + let (udut, piv, x) = driver.run().unwrap(); + let udut = udut.into_owned(); + let x = x.into_owned(); + let fpiv = piv.map(|&v| v as f64); + assert!((fingerprint(&udut) - -1201.6472395568974).abs() < 1e-8); + assert!((fingerprint(&fpiv) - -16668.7094872639).abs() < 1e-8); + assert!((fingerprint(&x) - -397.12032355166446).abs() < 1e-8); + + // upper + let driver = DSYSV::default().a(a.view()).b(b.view()).uplo(Upper).build().unwrap(); + let (udut, piv, x) = driver.run().unwrap(); + let udut = udut.into_owned(); + let x = x.into_owned(); + let fpiv = piv.map(|&v| v as f64); + assert!((fingerprint(&udut) - 1182.7836118324408).abs() < 1e-8); + assert!((fingerprint(&fpiv) - 11905.503011559245).abs() < 1e-8); + assert!((fingerprint(&x) - -314.4502289190444).abs() < 1e-8); + } +} diff --git a/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_svd_f64.rs b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_svd_f64.rs new file mode 100644 index 00000000..0a4e5bc7 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_driver_impl/lapack_svd_f64.rs @@ -0,0 +1,80 @@ +use rstsr_blas_traits::lapack_svd::*; +use rstsr_core::prelude::*; +use rstsr_core::prelude_dev::fingerprint; +use rstsr_accelerate::DeviceAccelerate as DeviceBLAS; +use rstsr_test_manifest::get_vec; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_dgesvd() { + let device = DeviceBLAS::default(); + let a_vec = &get_vec::('a')[..1024 * 512]; + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let driver = DGESVD::default().a(a.view()).build().unwrap(); + if let (s, Some(u), Some(vt), _) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -1.9368850983570982).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + } else { + panic!("DGESVD did not return expected output"); + } + + // full_matrices = false + let driver = DGESVD::default().a(a.view()).full_matrices(false).build().unwrap(); + if let (s, Some(u), Some(vt), _) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -9.144981428076894).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + } else { + panic!("DGESVD did not return expected output"); + } + + // full_matrices = false, compute_uv = false + let driver = DGESVD::default().a(a.view()).full_matrices(false).compute_uv(false).build().unwrap(); + if let (s, None, None, _) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + } else { + panic!("DGESVD did not return expected output"); + } + } + + #[test] + fn test_dgesdd() { + let device = DeviceBLAS::default(); + let a_vec = &get_vec::('a')[..1024 * 512]; + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let driver = DGESDD::default().a(a.view()).build().unwrap(); + if let (s, Some(u), Some(vt)) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -1.9368850983570982).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + } else { + panic!("DGESDD did not return expected output"); + } + + // full_matrices = false + let driver = DGESDD::default().a(a.view()).full_matrices(false).build().unwrap(); + if let (s, Some(u), Some(vt)) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -9.144981428076894).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + } else { + panic!("DGESDD did not return expected output"); + } + + // full_matrices = false, compute_uv = false + let driver = DGESDD::default().a(a.view()).full_matrices(false).compute_uv(false).build().unwrap(); + if let (s, None, None) = driver.run().unwrap() { + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + } else { + panic!("DGESDD did not return expected output"); + } + } +} diff --git a/crates-device/rstsr-accelerate/tests/test_driver_impl/mod.rs b/crates-device/rstsr-accelerate/tests/test_driver_impl/mod.rs new file mode 100644 index 00000000..76ddc142 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_driver_impl/mod.rs @@ -0,0 +1,3 @@ +mod lapack_eigh_f64; +mod lapack_solve_f64; +mod lapack_svd_f64; diff --git a/crates-device/rstsr-accelerate/tests/test_linalg_func/func_c64.rs b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_c64.rs new file mode 100644 index 00000000..65bc2248 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_c64.rs @@ -0,0 +1,319 @@ +use rstsr::prelude::*; +use rstsr_core::prelude_dev::fingerprint; +use rstsr_accelerate::DeviceAccelerate as DeviceBLAS; +use rstsr_test_manifest::get_vec; + +#[allow(non_camel_case_types)] +type c64 = num::Complex; + +macro_rules! c64 { + ($real:expr, $imag:expr) => { + c64::new($real, $imag) + }; + ($real:expr) => { + c64::new($real, 0.0) + }; +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cholesky() { + let device = DeviceBLAS::default(); + let mut b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default + let c = rt::linalg::cholesky(b.view()); + assert!((fingerprint(&c) - c64!(62.89494065393874, -73.47055443374522)).norm() < 1e-8); + + // upper + let c = rt::linalg::cholesky((b.view(), Upper)); + assert!((fingerprint(&c) - c64!(13.720509103165073, -1.8066465348490963)).norm() < 1e-8); + + // mutable changes itself + rt::linalg::cholesky((b.view_mut(), Upper)); + assert!((fingerprint(&b) - c64!(13.720509103165073, -1.8066465348490963)).norm() < 1e-8); + } + + #[test] + fn test_det() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..5 * 5].to_vec(); + let mut a = rt::asarray((a_vec, [5, 5].c(), &device)); + + let det = rt::linalg::det(a.view_mut()); + assert!((det - c64!(-24.808965756481086, 11.800248863799464)).norm() < 1e-8); + } + + #[test] + fn test_eigh() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default, a + let (w, v) = rt::linalg::eigh(a.view()).into(); + assert!((fingerprint(&w) - -100.79793355894122).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -7.450761195788254).abs() < 1e-8); + + // upper, a + let (w, v) = rt::linalg::eigh((a.view(), Upper)).into(); + assert!((fingerprint(&w) - -103.99103522434956).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -12.184946930165328).abs() < 1e-8); + + // default, a b + let (w, v) = rt::linalg::eigh((a.view(), b.view())).into(); + assert!((fingerprint(&w) - -97.43376763322635).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -4.3181177983574255).abs() < 1e-8); + + // upper, a b, itype=3 + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Upper, 3)).into(); + assert!((fingerprint(&w) - -4656.824753078057).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -0.15861903557045487).abs() < 1e-8); + + // mutable changes a + let (w, _) = rt::linalg::eigh(a.view_mut()).into(); + assert!((fingerprint(&w) - -100.79793355894122).abs() < 1e-8); + assert!((fingerprint(&a.abs()) - -7.450761195788254).abs() < 1e-8); + } + + #[test] + fn test_eigvalsh() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default, a + let w = rt::linalg::eigvalsh(a.view()); + assert!((fingerprint(&w) - -100.79793355894122).abs() < 1e-8); + + // upper, a + let w = rt::linalg::eigvalsh((a.view(), Upper)); + assert!((fingerprint(&w) - -103.99103522434956).abs() < 1e-8); + + // default, a b + let w = rt::linalg::eigvalsh((a.view(), b.view())); + assert!((fingerprint(&w) - -97.43376763322635).abs() < 1e-8); + + // upper, a b, itype=3 + let w = rt::linalg::eigvalsh((a.view(), b.view(), Upper, 3)); + assert!((fingerprint(&w) - -4656.824753078057).abs() < 1e-8); + + // mutable changes a + let w = rt::linalg::eigvalsh(a.view_mut()); + assert!((fingerprint(&w) - -100.79793355894122).abs() < 1e-8); + } + + #[test] + fn test_inv() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + + // immutable + let a_inv = rt::linalg::inv(a.view()); + assert!((fingerprint(&a_inv) - c64!(-11.836382515156183, 8.250167298349842)).norm() < 1e-8); + + // mutable + rt::linalg::inv(a.view_mut()); + assert!((fingerprint(&a) - c64!(-11.836382515156183, 8.250167298349842)).norm() < 1e-8); + } + + #[test] + fn test_pinv() { + let device = DeviceBLAS::default(); + + // 1024 x 512 + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + let (a_pinv, rank) = rt::linalg::pinv((a.view(), 20.0, 0.3)).into(); + assert!((fingerprint(&a_pinv) - c64!(-0.03454885412959018, -0.023651876085623254)).norm() < 1e-8); + assert_eq!(rank, 240); + + // 512 x 1024 + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + + let (a_pinv, rank) = rt::linalg::pinv((a.view(), 20.0, 0.3)).into(); + assert!((fingerprint(&a_pinv) - c64!(-0.2814806469687325, -0.15198888300458474)).norm() < 1e-8); + assert_eq!(rank, 240); + } + + #[test] + fn test_slogdet() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + + let (sign, logabsdet) = rt::linalg::slogdet(a.view_mut()).into(); + assert!((sign - c64!(-0.44606842323663365, 0.8949988613351316)).norm() < 1e-8); + assert!(logabsdet - 3393.6720579594585 < 1e-8); + } + + #[test] + fn test_solve_general() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024 * 512].to_vec(); + let mut b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_general((a.view(), b.view())); + assert!((fingerprint(&x) - c64!(404.1900761036138, -258.5602505551204)).norm() < 1e-8); + + // mutable changes itself + rt::linalg::solve_general((a.view_mut(), b.view_mut())); + assert!((fingerprint(&b) - c64!(404.1900761036138, -258.5602505551204)).norm() < 1e-8); + } + + #[test] + fn test_solve_general_for_vec() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024].to_vec(); + let mut b = rt::asarray((b_vec, [1024].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_general((a.view(), b.view())); + assert!((fingerprint(&x) - c64!(-15.070310793269726, -1.987917054716041)).norm() < 1e-8); + + // mutable changes itself + rt::linalg::solve_general((a.view_mut(), b.view_mut())); + assert!((fingerprint(&b) - c64!(-15.070310793269726, -1.987917054716041)).norm() < 1e-8); + } + + #[test] + fn test_solve_symmetric() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024 * 512].to_vec(); + let mut b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default (hermi) + let x = rt::linalg::solve_symmetric((a.view(), b.view())); + assert!((fingerprint(&x) - c64!(-1053.7242100144504, -559.2846004618166)).norm() < 1e-8); + + // upper (hermi) + let x = rt::linalg::solve_symmetric((a.view(), b.view(), Upper)); + assert!((fingerprint(&x) - c64!(674.2725854112028, -68.55236080351166)).norm() < 1e-8); + + // default (symm) + let x = rt::linalg::solve_symmetric((a.view(), b.view(), false)); + assert!((fingerprint(&x) - c64!(401.05642312535775, -805.8028453625365)).norm() < 1e-8); + + // upper, mutable changes b (symm) + rt::linalg::solve_symmetric((a.view(), b.view_mut(), false, Upper)); + assert!((fingerprint(&b) - c64!(141.70122084637046, -829.609691493499)).norm() < 1e-8); + } + + #[test] + fn test_solve_triangular() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let mut a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_triangular((b.view(), a.view())); + assert!((fingerprint(&x) - c64!(-8.433708003916948, 20.578272827017052)).norm() < 1e-8); + + // upper, mutable changes a + rt::linalg::solve_triangular((b.view(), a.view_mut(), Upper)); + assert!((fingerprint(&a) - c64!(0.1778922244846507, 11.42463765128442)).norm() < 1e-8); + } + + #[test] + fn test_svd() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let (u, s, vt) = rt::linalg::svd(a.view()).into(); + assert!((fingerprint(&s) - 46.60343405921802).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -15.44133470545584).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 2.1605324161714172).abs() < 1e-8); + + // full_matrices = false + let (u, s, vt) = rt::linalg::svd((a.view(), false)).into(); + assert!((fingerprint(&s) - 46.60343405921802).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -1.9516528722381659).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 2.1605324161714172).abs() < 1e-8); + + // m < n, full_matrices = false + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + let (u, s, vt) = rt::linalg::svd((a.view(), false)).into(); + assert!((fingerprint(&s) - 47.599274835886646).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - 4.636614351700778).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 1.4497879458575658).abs() < 1e-8); + } + + #[test] + fn test_svdvals() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let s = rt::linalg::svdvals(a.view()); + assert!((fingerprint(&s) - 46.60343405921802).abs() < 1e-8); + + // m < n, full_matrices = false + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + let s = rt::linalg::svdvals(a.view()); + assert!((fingerprint(&s) - 47.599274835886646).abs() < 1e-8); + } +} + +#[cfg(test)] +mod test_generalized_eigh { + use super::*; + + #[test] + fn test_generalized_eigh() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 1024].to_vec(); + let b_vec = get_vec::('b')[..1024 * 1024].to_vec(); + let a = rt::asarray((a_vec, [1024, 1024].c(), &device)).into_dim::(); + let b = rt::asarray((b_vec, [1024, 1024].c(), &device)).into_dim::(); + + // 1, lower + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Lower, 1)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -97.43376763322635).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -4.3181177983574255).abs() < 1e-8); + + // 1, upper + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Upper, 1)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -54.81859256480441).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -1.4841788446757156).abs() < 1e-8); + + // 2, lower + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Lower, 2)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -4967.627482507203).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 5.541034627252399).abs() < 1e-8); + + // 2, upper + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Upper, 2)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -4656.824753078057).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 1.0609263552377188).abs() < 1e-8); + + // 3, lower + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Lower, 3)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -4967.627482507203).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 118.76501084045631).abs() < 1e-8); + + // 3, upper + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Upper, 3)).into(); + println!("w: {:?}, v: {:?}", fingerprint(&w), fingerprint(&v.abs())); + assert!((fingerprint(&w) - -4656.824753078057).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -0.15861903557045487).abs() < 1e-8); + } +} diff --git a/crates-device/rstsr-accelerate/tests/test_linalg_func/func_f64.rs b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_f64.rs new file mode 100644 index 00000000..5a8567c2 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_f64.rs @@ -0,0 +1,260 @@ +use rstsr::prelude::*; +use rstsr_core::prelude_dev::fingerprint; +use rstsr_accelerate::DeviceAccelerate as DeviceBLAS; +use rstsr_test_manifest::get_vec; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cholesky() { + let device = DeviceBLAS::default(); + let mut b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default + let c = rt::linalg::cholesky(b.view()); + assert!((fingerprint(&c) - 43.21904478556176).abs() < 1e-8); + + // upper + let c = rt::linalg::cholesky((b.view(), Upper)); + assert!((fingerprint(&c) - -25.925655124816647).abs() < 1e-8); + + // mutable changes itself + rt::linalg::cholesky((b.view_mut(), Upper)); + assert!((fingerprint(&b) - -25.925655124816647).abs() < 1e-8); + } + + #[test] + fn test_cholesky_submatrix() { + let device = DeviceBLAS::default(); + let vec_b: Vec = vec![0.0, 1.0, 2.0, 1.0, 5.0, 1.5, 2.0, 1.5, 8.0]; + let b = rt::asarray((vec_b, [3, 3].c(), &device)); + + let b_view = b.i((1..3, 1..3)); + let c = rt::linalg::cholesky(b_view); + assert!((fingerprint(&c) - -0.7633202592326889).abs() < 1e-8); + } + + #[test] + fn test_det() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..5 * 5].to_vec(); + let mut a = rt::asarray((a_vec, [5, 5].c(), &device)); + + let det = rt::linalg::det(a.view_mut()); + assert!((det - 3.9699917597338046).abs() < 1e-8); + } + + #[test] + fn test_eigh() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default, a + let (w, v) = rt::linalg::eigh(a.view()).into(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -9.903934930318247).abs() < 1e-8); + + // upper, a + let (w, v) = rt::linalg::eigh((a.view(), Upper)).into(); + assert!((fingerprint(&w) - -71.4902453763506).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 6.973792268793419).abs() < 1e-8); + + // default, a b + let (w, v) = rt::linalg::eigh((a.view(), b.view())).into(); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - -5.243112559130817).abs() < 1e-8); + + // upper, a b, itype=3 + let (w, v) = rt::linalg::eigh((a.view(), b.view(), Upper, 3)).into(); + assert!((fingerprint(&w) - -2503.84161931662).abs() < 1e-8); + assert!((fingerprint(&v.abs()) - 152.17700520642055).abs() < 1e-8); + + // mutable changes a + let (w, _) = rt::linalg::eigh(a.view_mut()).into(); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + assert!((fingerprint(&a.abs()) - -9.903934930318247).abs() < 1e-8); + } + + #[test] + fn test_eigvalsh() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)); + + // default, a + let w = rt::linalg::eigvalsh(a.view()); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + + // upper, a + let w = rt::linalg::eigvalsh((a.view(), Upper)); + assert!((fingerprint(&w) - -71.4902453763506).abs() < 1e-8); + + // default, a b + let w = rt::linalg::eigvalsh((a.view(), b.view())); + assert!((fingerprint(&w) - -89.60433120129908).abs() < 1e-8); + + // upper, a b, itype=3 + let w = rt::linalg::eigvalsh((a.view(), b.view(), Upper, 3)); + assert!((fingerprint(&w) - -2503.84161931662).abs() < 1e-8); + + // mutable changes a + let w = rt::linalg::eigvalsh(a.view_mut()); + assert!((fingerprint(&w) - -71.4747209499407).abs() < 1e-8); + } + + #[test] + fn test_inv() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + + // immutable + let a_inv = rt::linalg::inv(a.view()); + assert!((fingerprint(&a_inv) - 143.39005577037764).abs() < 1e-8); + + // mutable + rt::linalg::inv(a.view_mut()); + assert!((fingerprint(&a) - 143.39005577037764).abs() < 1e-8); + } + + #[test] + fn test_pinv() { + let device = DeviceBLAS::default(); + + // 1024 x 512 + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + let (a_pinv, rank) = rt::linalg::pinv((a.view(), 20.0, 0.3)).into(); + assert!((fingerprint(&a_pinv) - 0.0878262837784408).abs() < 1e-8); + assert_eq!(rank, 163); + + // 512 x 1024 + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + + let (a_pinv, rank) = rt::linalg::pinv((a.view(), 20.0, 0.3)).into(); + assert!((fingerprint(&a_pinv) - -0.3244041253699862).abs() < 1e-8); + assert_eq!(rank, 161); + } + + #[test] + fn test_slogdet() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)); + + let (sign, logabsdet) = rt::linalg::slogdet(a.view_mut()).into(); + assert!(sign - -1.0 < 1e-8); + assert!(logabsdet - 3031.1259211802403 < 1e-8); + } + + #[test] + fn test_solve_general() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024 * 512].to_vec(); + let mut b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_general((a.view(), b.view())); + assert!((fingerprint(&x) - -1951.253447757597).abs() < 1e-8); + + // mutable changes itself + rt::linalg::solve_general((a.view_mut(), b.view_mut())); + assert!((fingerprint(&b) - -1951.253447757597).abs() < 1e-8); + } + + #[test] + fn test_solve_general_for_vec() { + let device = DeviceBLAS::default(); + let mut a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024].to_vec(); + let mut b = rt::asarray((b_vec, [1024].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_general((a.view(), b.view())); + assert!((fingerprint(&x) - -9.120066438800688).abs() < 1e-8); + + // mutable changes itself + rt::linalg::solve_general((a.view_mut(), b.view_mut())); + assert!((fingerprint(&b) - -9.120066438800688).abs() < 1e-8); + } + + #[test] + fn test_solve_symmetric() { + let device = DeviceBLAS::default(); + let a = rt::asarray((get_vec::('a'), [1024, 1024].c(), &device)).into_dim::(); + let b_vec = get_vec::('b')[..1024 * 512].to_vec(); + let mut b = rt::asarray((b_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_symmetric((a.view(), b.view())); + assert!((fingerprint(&x) - -397.1203235513806).abs() < 1e-8); + + // upper, mutable changes b + rt::linalg::solve_symmetric((a.view(), b.view_mut(), Upper)); + assert!((fingerprint(&b) - -314.45022891879034).abs() < 1e-8); + } + + #[test] + fn test_solve_triangular() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let mut a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + let b = rt::asarray((get_vec::('b'), [1024, 1024].c(), &device)).into_dim::(); + + // default + let x = rt::linalg::solve_triangular((b.view(), a.view())); + assert!((fingerprint(&x) - -2.6133848012216587).abs() < 1e-8); + + // upper, mutable changes a + rt::linalg::solve_triangular((b.view(), a.view_mut(), Upper)); + assert!((fingerprint(&a) - 5.112256818100785).abs() < 1e-8); + } + + #[test] + fn test_svd() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let (u, s, vt) = rt::linalg::svd(a.view()).into(); + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -1.9368850983570982).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + + // full_matrices = false + let (u, s, vt) = rt::linalg::svd((a.view(), false)).into(); + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -9.144981428076894).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - 13.465522484136157).abs() < 1e-8); + + // m < n, full_matrices = false + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + let (u, s, vt) = rt::linalg::svd((a.view(), false)).into(); + assert!((fingerprint(&s) - 32.27742168207757).abs() < 1e-8); + assert!((fingerprint(&u.abs()) - -3.716931052161584).abs() < 1e-8); + assert!((fingerprint(&vt.abs()) - -0.32301437281530243).abs() < 1e-8); + } + + #[test] + fn test_svdvals() { + let device = DeviceBLAS::default(); + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [1024, 512].c(), &device)).into_dim::(); + + // default + let s = rt::linalg::svdvals(a.view()); + assert!((fingerprint(&s) - 33.969339071043095).abs() < 1e-8); + + // m < n, full_matrices = false + let a_vec = get_vec::('a')[..1024 * 512].to_vec(); + let a = rt::asarray((a_vec, [512, 1024].c(), &device)).into_dim::(); + let s = rt::linalg::svdvals(a.view()); + assert!((fingerprint(&s) - 32.27742168207757).abs() < 1e-8); + } +} diff --git a/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_c64.py b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_c64.py new file mode 100644 index 00000000..5ce42e7d --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_c64.py @@ -0,0 +1,235 @@ +# # Driver tests in Python + +import numpy as np +import scipy + + +def fingerprint(a): + return np.dot(np.cos(np.arange(a.size)), np.asarray(a, order="C").ravel()) + + +# Path of npy files in rstsr-test-manifest + +root = "../../../../rstsr-test-manifest/resources/" + +# ## make sure of random generation + +a_raw = np.load(f"{root}/a-c64.npy") +b_raw = np.load(f"{root}/b-c64.npy") + +assert np.isclose(fingerprint(a_raw), 191.28900005102915+217.50386287824938j) +assert np.isclose(fingerprint(b_raw), 267.6279081341384-641.4397224458443j) + +# ## tests + +# ### cholesky + +b = b_raw.copy().reshape(1024, 1024) +c = np.linalg.cholesky(b) +assert np.isclose(fingerprint(c), 62.89494065393874-73.47055443374522j) +fingerprint(c) + +b = b_raw.copy().reshape(1024, 1024) +c = scipy.linalg.cholesky(b, lower=False) +assert np.isclose(fingerprint(c), 13.720509103165073-1.8066465348490963j) +fingerprint(c) + +# ### eigh + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = np.linalg.eigh(a) +assert np.isclose(fingerprint(w), -100.79793355894122) +assert np.isclose(fingerprint(np.abs(v)), -7.450761195788254) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = np.linalg.eigh(a, UPLO="U") +assert np.isclose(fingerprint(w), -103.99103522434956) +assert np.isclose(fingerprint(np.abs(v)), -12.184946930165328) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, lower=True) +assert np.isclose(fingerprint(w), -97.43376763322635) +assert np.isclose(fingerprint(np.abs(v)), -4.3181177983574255) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, lower=False, type=3) +assert np.isclose(fingerprint(w), -4656.824753078057) +assert np.isclose(fingerprint(np.abs(v)), -0.15861903557045487) +fingerprint(w), fingerprint(np.abs(v)) + +# ### Tests of eigh (itype, lower) + +# 1, lower +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=1, lower=True) +assert np.isclose(fingerprint(w), -97.43376763322635) +assert np.isclose(fingerprint(np.abs(v)), -4.3181177983574255) +fingerprint(w), fingerprint(np.abs(v)) + +# 1, upper +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=1, lower=False) +assert np.isclose(fingerprint(w), -54.81859256480441) +assert np.isclose(fingerprint(np.abs(v)), -1.4841788446757156) +fingerprint(w), fingerprint(np.abs(v)) + +# 2, lower +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=2, lower=True) +assert np.isclose(fingerprint(w), -4967.627482507203) +assert np.isclose(fingerprint(np.abs(v)), 5.541034627252399) +fingerprint(w), fingerprint(np.abs(v)) + +# 2, upper +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=2, lower=False) +assert np.isclose(fingerprint(w), -4656.824753078057) +assert np.isclose(fingerprint(np.abs(v)), 1.0609263552377188) +fingerprint(w), fingerprint(np.abs(v)) + +# 3, lower +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=3, lower=True) +assert np.isclose(fingerprint(w), -4967.627482507203) +assert np.isclose(fingerprint(np.abs(v)), 118.76501084045631) +fingerprint(w), fingerprint(np.abs(v)) + +# 3, upper +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, type=3, lower=False) +assert np.isclose(fingerprint(w), -4656.824753078057) +assert np.isclose(fingerprint(np.abs(v)), -0.15861903557045487) +fingerprint(w), fingerprint(np.abs(v)) + +# ### inv + +a = a_raw.copy().reshape(1024, 1024) +a_inv = np.linalg.inv(a) +assert np.isclose(fingerprint(a_inv), -11.836382515156183+8.250167298349842j) +fingerprint(a_inv) + +# ### solve_general + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = np.linalg.solve(a, b) +assert np.isclose(fingerprint(x), 404.1900761036138-258.5602505551204j) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024].copy().reshape(1024) +x = np.linalg.solve(a, b) +assert np.isclose(fingerprint(x), -15.070310793269726-1.987917054716041j) +fingerprint(x) + +# ### sovle_symmetric + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="sym", lower=True) +assert np.isclose(fingerprint(x), 401.05642312535775-805.8028453625365j) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="sym", lower=False) +assert np.isclose(fingerprint(x), 141.70122084637046-829.609691493499j) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="her", lower=True) +assert np.isclose(fingerprint(x), -1053.7242100144504-559.2846004618166j) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="her", lower=False) +assert np.isclose(fingerprint(x), 674.2725854112028-68.55236080351166j) +fingerprint(x) + +# ### sovle_triangular + +a = a_raw[:1024*512].copy().reshape(1024, 512) +b = b_raw.copy().reshape(1024, 1024) +x = scipy.linalg.solve(b, a, assume_a="lower triangular") +assert np.isclose(fingerprint(x), -8.433708003916948+20.578272827017052j) +fingerprint(x) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +b = b_raw.copy().reshape(1024, 1024) +x = scipy.linalg.solve(b, a, assume_a="upper triangular") +assert np.isclose(fingerprint(x), 0.1778922244846507+11.42463765128442j) +fingerprint(x) + +# ### slogdot + +a = a_raw.copy().reshape(1024, 1024) +sgn, logabsdet = np.linalg.slogdet(a) +assert np.isclose(sgn, -0.44606842323663365+0.8949988613351316j) +assert np.isclose(logabsdet, 3393.6720579594585) +sgn, logabsdet + +# ### det + +a = a_raw[:25].copy().reshape(5, 5) +det = np.linalg.det(a) +assert np.isclose(det, -24.808965756481086+11.800248863799464j) +det + +# ### svd + +a = a_raw[:1024*512].copy().reshape(1024, 512) +(u, s, vt) = scipy.linalg.svd(a) +assert np.isclose(fingerprint(np.abs(u)), -15.44133470545584) +assert np.isclose(fingerprint(s), 46.60343405921802) +assert np.isclose(fingerprint(np.abs(vt)), 2.1605324161714172) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +(u, s, vt) = scipy.linalg.svd(a, full_matrices=False) +assert np.isclose(fingerprint(np.abs(u)), -1.9516528722381659) +assert np.isclose(fingerprint(s), 46.60343405921802) +assert np.isclose(fingerprint(np.abs(vt)), 2.1605324161714172) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +s = scipy.linalg.svd(a, compute_uv=False) +assert np.isclose(fingerprint(s), 46.60343405921802) +fingerprint(s) + +a = a_raw[:1024*512].copy().reshape(512, 1024) +(u, s, vt) = scipy.linalg.svd(a, full_matrices=False) +assert np.isclose(fingerprint(np.abs(u)), 4.636614351700778) +assert np.isclose(fingerprint(s), 47.599274835886646) +assert np.isclose(fingerprint(np.abs(vt)), 1.4497879458575658) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +# ### pinv + +a = a_raw[:1024*512].copy().reshape(1024, 512) +a_pinv, rank = scipy.linalg.pinv(a, return_rank=True, atol=20, rtol=0.3) +assert np.isclose(fingerprint(a_pinv), -0.03454885412959018-0.023651876085623254j) +assert rank == 240 +fingerprint(a_pinv), rank + +a = a_raw[:1024*512].copy().reshape(512, 1024) +a_pinv, rank = scipy.linalg.pinv(a, return_rank=True, atol=20, rtol=0.3) +assert np.isclose(fingerprint(a_pinv), -0.2814806469687325-0.15198888300458474j) +assert rank == 240 +fingerprint(a_pinv), rank + + diff --git a/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_f64.py b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_f64.py new file mode 100644 index 00000000..79d88ef4 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_linalg_func/func_validation_f64.py @@ -0,0 +1,173 @@ +# # Driver tests in Python + +import numpy as np +import scipy + + +def fingerprint(a): + return np.dot(np.cos(np.arange(a.size)), np.asarray(a, order="C").ravel()) + + +# Path of npy files in rstsr-test-manifest + +root = "../../../../rstsr-test-manifest/resources/" + +# ## make sure of random generation + +a_raw = np.load(f"{root}/a-f64.npy") +b_raw = np.load(f"{root}/b-f64.npy") + +assert np.isclose(fingerprint(a_raw), 191.28900005103065) +assert np.isclose(fingerprint(b_raw), -51.11100342180723) + +# ## tests + +# ### cholesky + +b = b_raw.copy().reshape(1024, 1024) +c = np.linalg.cholesky(b) +assert np.isclose(fingerprint(c), 43.21904478556176) +fingerprint(c) + +b = b_raw.copy().reshape(1024, 1024) +c = scipy.linalg.cholesky(b, lower=False) +assert np.isclose(fingerprint(c), -25.925655124816647) +fingerprint(c) + +# ### eigh + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = np.linalg.eigh(a) +assert np.isclose(fingerprint(w), -71.4747209499407) +assert np.isclose(fingerprint(np.abs(v)), -9.903934930318247) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = np.linalg.eigh(a, UPLO="U") +assert np.isclose(fingerprint(w), -71.4902453763506) +assert np.isclose(fingerprint(np.abs(v)), 6.973792268793419) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, lower=True) +assert np.isclose(fingerprint(w), -89.60433120129908) +assert np.isclose(fingerprint(np.abs(v)), -5.243112559130817) +fingerprint(w), fingerprint(np.abs(v)) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw.copy().reshape(1024, 1024) +w, v = scipy.linalg.eigh(a, b, lower=False, type=3) +assert np.isclose(fingerprint(w), -2503.84161931662) +assert np.isclose(fingerprint(np.abs(v)), 152.17700520642055) +fingerprint(w), fingerprint(np.abs(v)) + +# ### inv + +a = a_raw.copy().reshape(1024, 1024) +a_inv = np.linalg.inv(a) +assert np.isclose(fingerprint(a_inv), 143.39005577037764) +fingerprint(a_inv) + +# ### solve_general + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = np.linalg.solve(a, b) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024].copy().reshape(1024) +x = np.linalg.solve(a, b) +fingerprint(x) + +# ### sovle_symmetric + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="sym", lower=True) +assert np.isclose(fingerprint(x), -397.1203235513806) +fingerprint(x) + +a = a_raw.copy().reshape(1024, 1024) +b = b_raw[:1024*512].copy().reshape(1024, 512) +x = scipy.linalg.solve(a, b, assume_a="sym", lower=False) +assert np.isclose(fingerprint(x), -314.45022891879034) +fingerprint(x) + +# ### sovle_triangular + +a = a_raw[:1024*512].copy().reshape(1024, 512) +b = b_raw.copy().reshape(1024, 1024) +x = scipy.linalg.solve(b, a, assume_a="lower triangular") +assert np.isclose(fingerprint(x), -2.6133848012216587) +fingerprint(x) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +b = b_raw.copy().reshape(1024, 1024) +x = scipy.linalg.solve(b, a, assume_a="upper triangular") +assert np.isclose(fingerprint(x), 5.112256818100785) +fingerprint(x) + +# ### slogdot + +a = a_raw.copy().reshape(1024, 1024) +sgn, logabsdet = np.linalg.slogdet(a) +assert np.isclose(sgn, -1) +assert np.isclose(logabsdet, 3031.1259211802403) +sgn, logabsdet + +# ### det + +a = a_raw[:25].copy().reshape(5, 5) +det = np.linalg.det(a) +assert np.isclose(det, 3.9699917597338046) +det + +np.linalg.slogdet(a) + +# ### svd + +a = a_raw[:1024*512].copy().reshape(1024, 512) +(u, s, vt) = scipy.linalg.svd(a) +assert np.isclose(fingerprint(np.abs(u)), -1.9368850983570982) +assert np.isclose(fingerprint(s), 33.969339071043095) +assert np.isclose(fingerprint(np.abs(vt)), 13.465522484136157) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +(u, s, vt) = scipy.linalg.svd(a, full_matrices=False) +assert np.isclose(fingerprint(np.abs(u)), -9.144981428076894) +assert np.isclose(fingerprint(s), 33.969339071043095) +assert np.isclose(fingerprint(np.abs(vt)), 13.465522484136157) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +a = a_raw[:1024*512].copy().reshape(1024, 512) +s = scipy.linalg.svd(a, compute_uv=False) +assert np.isclose(fingerprint(s), 33.969339071043095) +fingerprint(s) + +a = a_raw[:1024*512].copy().reshape(512, 1024) +(u, s, vt) = scipy.linalg.svd(a, full_matrices=False) +assert np.isclose(fingerprint(np.abs(u)), -3.716931052161584) +assert np.isclose(fingerprint(s), 32.27742168207757) +assert np.isclose(fingerprint(np.abs(vt)), -0.32301437281530243) +fingerprint(np.abs(u)), fingerprint(s), fingerprint(np.abs(vt)) + +# ### pinv + +a = a_raw[:1024*512].copy().reshape(1024, 512) +a_pinv, rank = scipy.linalg.pinv(a, return_rank=True, atol=20, rtol=0.3) +assert np.isclose(fingerprint(a_pinv), 0.0878262837784408) +assert rank == 163 +fingerprint(a_pinv), rank + +a = a_raw[:1024*512].copy().reshape(512, 1024) +a_pinv, rank = scipy.linalg.pinv(a, return_rank=True, atol=20, rtol=0.3) +assert np.isclose(fingerprint(a_pinv), -0.3244041253699862) +assert rank == 161 +fingerprint(a_pinv), rank + + diff --git a/crates-device/rstsr-accelerate/tests/test_linalg_func/mod.rs b/crates-device/rstsr-accelerate/tests/test_linalg_func/mod.rs new file mode 100644 index 00000000..b8f3f0f7 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_linalg_func/mod.rs @@ -0,0 +1,2 @@ +mod func_c64; +mod func_f64; diff --git a/crates-device/rstsr-accelerate/tests/test_workable.rs b/crates-device/rstsr-accelerate/tests/test_workable.rs new file mode 100644 index 00000000..21a6ee54 --- /dev/null +++ b/crates-device/rstsr-accelerate/tests/test_workable.rs @@ -0,0 +1,35 @@ +#[cfg(test)] +mod test { + + #[test] + fn workable() { + use rstsr_core::prelude::*; + use rstsr_accelerate::DeviceAccelerate; + + // specify the number of threads of 16 + let device = DeviceAccelerate::new(16); + // if you want to use the default number of threads, use the following line + // let device = DeviceAccelerate::default(); + + let a = rt::linspace((0.0, 1.0, 1048576, &device)).into_shape([16, 256, 256]); + let b = rt::linspace((1.0, 2.0, 1048576, &device)).into_shape([16, 256, 256]); + + // by optimized BLAS, the following operation is very fast + let c = &a % &b; + + // mean of all elements is also performed in parallel + let c_mean = c.mean_all(); + println!("{c_mean:?}"); + assert!((c_mean - 213.2503660477036) < 1e-6); + + let c_std = c.std_all(); + println!("{c_std:?}"); + assert!((c_std - 148.88523481701804) < 1e-6); + + let c_std_1 = c.std_axes((0, 1)); + println!("{c_std_1}"); + + let c_std_2 = c.std_axes((1, 2)); + println!("{c_std_2}"); + } +} diff --git a/rstsr/Cargo.toml b/rstsr/Cargo.toml index ae07b4ed..284cb525 100644 --- a/rstsr/Cargo.toml +++ b/rstsr/Cargo.toml @@ -16,6 +16,7 @@ rstsr-linalg-traits = { workspace = true, optional = true } rstsr-sci-traits = { workspace = true, optional = true } # device dependencies rstsr-openblas = { workspace = true, optional = true } +rstsr-accelerate = { workspace = true, optional = true } rstsr-mkl = { workspace = true, optional = true } rstsr-blis = { workspace = true, optional = true } rstsr-aocl = { workspace = true, optional = true } @@ -32,7 +33,7 @@ default = ["std", "backtrace", "rstsr-core/default", "faer", "faer_as_default"] std = ["rstsr-core/std"] backtrace = ["rstsr-core/backtrace"] rayon = ["rstsr-core/rayon"] -faer = ["rstsr-core/faer", "rstsr-linalg-traits?/faer", "rstsr-sci-traits?/faer", "rstsr-openblas?/faer", "rstsr-mkl?/faer", "rstsr-blis?/faer", "rstsr-aocl?/faer", "rstsr-kml?/faer"] +faer = ["rstsr-core/faer", "rstsr-linalg-traits?/faer", "rstsr-sci-traits?/faer", "rstsr-openblas?/faer", "rstsr-accelerate?/faer", "rstsr-mkl?/faer", "rstsr-blis?/faer", "rstsr-aocl?/faer", "rstsr-kml?/faer"] faer_as_default = ["rstsr-core/faer_as_default", "faer"] row_major = ["rstsr-core/row_major"] col_major = ["rstsr-core/col_major"] @@ -41,19 +42,20 @@ dispatch_dim_layout_iter = ["rstsr-core/dispatch_dim_layout_iter"] # rstsr BLAS device features openblas = ["dep:rstsr-openblas"] +accelerate = ["dep:rstsr-accelerate"] mkl = ["dep:rstsr-mkl"] blis = ["dep:rstsr-blis"] aocl = ["dep:rstsr-aocl"] kml = ["dep:rstsr-kml"] # dependencies specification -linalg = ["dep:rstsr-linalg-traits", "rstsr-openblas?/linalg", "rstsr-mkl?/linalg", "rstsr-blis?/linalg", "rstsr-aocl?/linalg", "rstsr-kml?/linalg"] -sci = ["dep:rstsr-sci-traits", "rstsr-openblas?/sci", "rstsr-mkl?/sci", "rstsr-blis?/sci", "rstsr-aocl?/sci", "rstsr-kml?/sci"] +linalg = ["dep:rstsr-linalg-traits", "rstsr-openblas?/linalg", "rstsr-accelerate?/linalg", "rstsr-mkl?/linalg", "rstsr-blis?/linalg", "rstsr-aocl?/linalg", "rstsr-kml?/linalg"] +sci = ["dep:rstsr-sci-traits", "rstsr-openblas?/sci", "rstsr-accelerate?/sci", "rstsr-mkl?/sci", "rstsr-blis?/sci", "rstsr-aocl?/sci", "rstsr-kml?/sci"] tblis = ["dep:rstsr-tblis"] # BLAS configurations -dynamic_loading = ["rstsr-openblas?/dynamic_loading", "rstsr-mkl?/dynamic_loading", "rstsr-blis?/dynamic_loading", "rstsr-aocl?/dynamic_loading", "rstsr-kml?/dynamic_loading"] -ilp64 = ["rstsr-openblas?/ilp64", "rstsr-mkl?/ilp64", "rstsr-blis?/ilp64", "rstsr-aocl?/ilp64", "rstsr-kml?/ilp64"] +dynamic_loading = ["rstsr-openblas?/dynamic_loading", "rstsr-accelerate?/dynamic_loading", "rstsr-mkl?/dynamic_loading", "rstsr-blis?/dynamic_loading", "rstsr-aocl?/dynamic_loading", "rstsr-kml?/dynamic_loading"] +ilp64 = ["rstsr-openblas?/ilp64", "rstsr-accelerate?/ilp64", "rstsr-mkl?/ilp64", "rstsr-blis?/ilp64", "rstsr-aocl?/ilp64", "rstsr-kml?/ilp64"] [package.metadata.docs.rs] features = ["default", "openblas", "linalg", "sci", "tblis"]