diff --git a/Cargo.lock b/Cargo.lock index 737bdab..beab318 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,7 +31,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "bartrs" -version = "0.2.0" +version = "0.3.0" dependencies = [ "criterion", "numpy", diff --git a/pyproject.toml b/pyproject.toml index b67a595..206ca11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ {name = "Otto Vintola", email="hello@ottovintola.com" }, ] description = "Rust implementation of Bayesian Additive Regression Trees for Probabilistic programming with PyMC" -requires-python = ">=3.12, <3.14" +requires-python = ">=3.12, <3.15" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", diff --git a/python/bartrs/compile_pymc.py b/python/bartrs/compile_pymc.py index f94246a..8af72d8 100644 --- a/python/bartrs/compile_pymc.py +++ b/python/bartrs/compile_pymc.py @@ -11,6 +11,7 @@ import pandas as pd import pymc as pm import pytensor +import pytensor.tensor as pt from pymc.pytensorf import ( compile, @@ -18,8 +19,9 @@ join_nonshared_inputs, make_shared_replacements, ) -from numba import carray, cfunc, extending, float64, types, njit +from numba import carray, cfunc, extending, float64, types, njit, int32 from numba.core import cgutils +from pytensor.assumptions.specify import assume @numba.extending.intrinsic @@ -187,15 +189,24 @@ def _make_functions(self, model, vars): # shared = make_shared_replacements(initial_values, value_vars, model) - out_vars = [model.datalogp] + idx = pt.ivector("idx") + idx_unique = assume(idx, unique_indices=True) + + obs_logps = [] + for obs_rv in model.observed_RVs: + rv_logp_list = model.logp(obs_rv, sum=False) + rv_logp = rv_logp_list[0] + subset_logp = rv_logp[idx_unique] + obs_logps.append(pt.sum(subset_logp)) + out_vars = pt.sum(obs_logps) # Join non-shared inputs and prepare for compilation # This separates model parameters from shared/observed data new_out, new_joined_inputs = join_nonshared_inputs( - initial_values, out_vars, value_vars, shared + initial_values, [out_vars], value_vars, shared ) - logp_fn = compile(inputs=[new_joined_inputs], outputs=new_out[0], mode="NUMBA") + logp_fn = compile(inputs=[new_joined_inputs, idx], outputs=new_out[0], mode="NUMBA") logp_fn.trust_input = True return shape, logp_fn, shared @@ -231,8 +242,8 @@ def _make_persistent_arrays(self): All arrays are ensured to be float64 for consistency with the compiled function interface. """ - arrays = [item.storage[0].copy() for item in self.logp_fn_ptr.input_storage[1:]] - assert all(arr.dtype == np.float64 for arr in arrays) + arrays = [item.storage[0].copy() for item in self.logp_fn_ptr.input_storage[2:]] + # assert all(arr.dtype == np.float64 for arr in arrays) return arrays # TODO: fast update for shared arrays @@ -258,7 +269,7 @@ def update_shared_arrays(self): Observed data (e.g., design matrix `X` and targets `y`) are stored in input storage. """ - for array, storage in zip(self.logp_args, self.logp_fn_ptr.input_storage[1:]): + for array, storage in zip(self.logp_args, self.logp_fn_ptr.input_storage[2:]): # NOTE: np.copyto is the old implementation np.copyto(array, storage.storage[0]) # self._fast_update(array, storage.storage[0]) @@ -269,7 +280,7 @@ def _generate_logp_function(self): This method creates a Numba-compiled C function that can be called from external code (e.g., Rust via FFI). The function signature is: - `double logp(double* ptr, int size)` + `double logp(double* ptr, int size, int* idx_ptr, int idx_size)` The generated function: 1. Wraps the input pointer as a NumPy array @@ -286,8 +297,9 @@ def _generate_logp_function(self): shared_arrays = self.logp_args code = [ - "def _logp(ptr, size):", + "def _logp(ptr, size, idx_ptr, idx_size):", " data = carray(ptr, (size, ), dtype=float64)", + " indexes = carray(idx_ptr, (idx_size, ), dtype=int32)", ] for i, array in enumerate(shared_arrays): @@ -298,7 +310,7 @@ def _generate_logp_function(self): ) code.append(line) - ret = f" return logp_fn(data, {', '.join(f'arg{i}' for i in range(len(shared_arrays)))})[0].item()" + ret = f" return logp_fn(data, indexes, {', '.join(f'arg{i}' for i in range(len(shared_arrays)))})[0].item()" code.append(ret) source = "\n".join(code) @@ -311,6 +323,8 @@ def _generate_logp_function(self): sig = types.float64( types.CPointer(types.float64), types.intc, + types.CPointer(types.int32), + types.intc, ) try: return cfunc(sig)(logp) @@ -324,18 +338,20 @@ def _generate_logp_function(self): def _generate_ctypes_logp_function(self, logp_fn, shared_arrays): """Fallback log-probability callback implemented with ctypes.""" - callback_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int) + callback_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int, + ctypes.POINTER(ctypes.c_int32), ctypes.c_int) - def _logp(ptr, size): + def _logp(ptr, size, idx_ptr, idx_size): data = np.ctypeslib.as_array(ptr, shape=(size,)) + indexes = np.ctypeslib.as_array(idx_ptr, shape=(idx_size,)).astype(np.int32) args = [] for array in shared_arrays: args.append(array) - result = logp_fn(data, *args) + result = logp_fn(data, indexes, *args) return float(np.asarray(result).reshape(-1)[0]) callback = callback_type(_logp) self._ctypes_logp_callback = callback - return callback + return callback \ No newline at end of file diff --git a/python/bartrs/pgbart.py b/python/bartrs/pgbart.py index 104cb78..c10d73d 100644 --- a/python/bartrs/pgbart.py +++ b/python/bartrs/pgbart.py @@ -34,7 +34,7 @@ class PGBART(ArrayStepShared): """ - Particle Gibss BART sampling step. + Particle Gibbs BART sampling step. Parameters ---------- @@ -101,7 +101,6 @@ def __init__( # noqa: PLR0915 self.shape = 1 if len(shape) == 1 else shape[0] - # Set trees_shape (dim for separate tree structures) # and leaves_shape (dim for leaf node values) # One of the two is always one, the other equal to self.shape @@ -244,7 +243,7 @@ def astep(self, _): sum_trees, trees, variable_inclusion = self.pg_bart.step(self.tune) if not self.tune: - self.bart.all_trees.append(trees) # this doubles runtime + self.bart.all_trees.append(trees) # this is slow, I think t1 = perf_counter() stats = { diff --git a/python/tests/test_sampler.py b/python/tests/test_sampler.py index 61885e7..59f1e28 100644 --- a/python/tests/test_sampler.py +++ b/python/tests/test_sampler.py @@ -8,7 +8,7 @@ NUM_DRAWS = 600 NUM_CHAINS = 4 BATCH_SIZE = (0.1, 0.1) -NUM_TREES = 50 +NUM_TREES = 10 NUM_PARTICLES = 10 RANDOM_SEED = 42 diff --git a/src/config.rs b/src/config.rs index 57be545..bdf490e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -34,7 +34,7 @@ impl Default for BartConfig { batch_tune: 0.1, batch_post: 0.1, n_outputs: 1, - response: "gaussian".to_string(), + response: "constant".to_string(), } } } diff --git a/src/data.rs b/src/data.rs index d8889e0..132dd2c 100644 --- a/src/data.rs +++ b/src/data.rs @@ -67,3 +67,20 @@ impl OwnedData { } } + + +pub trait NotNan { + fn is_valid(&self) -> bool; +} + +impl NotNan for f64 { + fn is_valid(&self) -> bool { + !self.is_nan() + } +} + +impl NotNan for i32 { + fn is_valid(&self) -> bool { + true + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 4ffe28e..635d41b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,7 +33,7 @@ use rand::SeedableRng; use rand::rngs::SmallRng; use rand_distr::StandardNormal; -type LogpFunc = unsafe extern "C" fn(*const f64, usize) -> c_double; +type LogpFunc = unsafe extern "C" fn(*const f64, i32, *const i32, i32) -> c_double; #[pyclass] #[derive(Clone, Debug)] diff --git a/src/particle.rs b/src/particle.rs index 03ba9de..fecea2d 100644 --- a/src/particle.rs +++ b/src/particle.rs @@ -57,6 +57,7 @@ pub struct Particle { pub expandable_nodes: VecDeque, pub sample_map: LeafSamplesFlat, pub log_weight: f64, + pub ll_pointwise: Vec, } impl Particle { @@ -67,6 +68,7 @@ impl Particle { expandable_nodes: VecDeque::from([0]), sample_map: LeafSamplesFlat::new(n_samples, max_depth), log_weight: 0.0, + ll_pointwise: vec![0.0; n_samples], } } @@ -77,6 +79,7 @@ impl Particle { expandable_nodes: VecDeque::new(), sample_map: LeafSamplesFlat::new(n_samples, max_depth), log_weight: 0.0, + ll_pointwise: vec![0.0; n_samples], } } @@ -113,7 +116,7 @@ impl Particle { } } - Self { tree: Arc::new(tree), expandable_nodes, sample_map, log_weight: 0.0 } + Self { tree: Arc::new(tree), expandable_nodes, sample_map, log_weight: 0.0, ll_pointwise: vec![0.0; n_samples] } } pub fn has_expandable_nodes(&self) -> bool { diff --git a/src/resampling.rs b/src/resampling.rs index 6d3c2a4..52af0b8 100644 --- a/src/resampling.rs +++ b/src/resampling.rs @@ -130,3 +130,25 @@ impl ResamplingStrategy for ResamplingStrategies { } } } + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_systematic_resample() { + let mut weights = &[0.0, 0.25, 0.75]; + let mut rng = rand::rng(); + let mut out = Vec::new(); + + ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out); + assert!(out.iter().all( |&index| index >= 1 && index < weights.len())); + + weights = &[0.5, 0.3, 0.2]; + out = Vec::new(); + ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out); + assert!(out.iter().all( |&index| index >= 0 as usize && index < weights.len())); + } + +} \ No newline at end of file diff --git a/src/response.rs b/src/response.rs index c3311c3..a944f71 100644 --- a/src/response.rs +++ b/src/response.rs @@ -211,6 +211,11 @@ impl ResponseStrategy for LinearStrategy { for &s in node_samples { let idx = s as usize; let v = unsafe { *col.uget(idx) }; + + if v.is_nan() { + continue; + } + if v <= split_val { left_idx.push(idx); } else { @@ -299,7 +304,7 @@ pub enum ResponseStrategies { impl ResponseStrategies { pub fn from_name(name: &str) -> Result { match name.to_lowercase().as_str() { - "gaussian" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)), + "constant" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)), "linear" => Ok(ResponseStrategies::Linear(LinearStrategy)), "motr" => Ok(ResponseStrategies::Motr(MotrStrategy)), _ => Err(format!( @@ -372,6 +377,7 @@ impl ResponseStrategy for ResponseStrategies { mod tests { use super::*; + // 1d-intercept #[test] fn test_fit_linear_1d_intercept() { let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; @@ -389,6 +395,7 @@ mod tests { assert!((a - 0.0).abs() < 1e-6, "Expected intercept ~0.0, got {}", a); } + // 1d-slope #[test] fn test_fit_linear_1d_slope() { let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; @@ -406,5 +413,28 @@ mod tests { assert!((b - 2.0).abs() < 1e-6, "Expected intercept ~2.0, got {}", b); } + + + // 1d-intercept + slope + #[test] + fn test_linear_fit() { + + let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; + let y = &[3.0, 5.0, 7.0, 9.0, 11.0]; + + let noise = 0.0f64; + + let n_trees = 1; + + let Some((intercept, slope)) = LinearStrategy::fit_linear_1d(x, y, noise, n_trees) else { + panic!("Got None when was expecting intercept and slope"); + }; + + let a = intercept; + let b = slope; + + assert!((a - 1.0).abs() < 1e-6, "Expected intercept ~1.0, got {}", a); + assert!((b - 2.0).abs() < 1e-6, "Expected slope ~2.0, got {}", b); + } } \ No newline at end of file diff --git a/src/smc.rs b/src/smc.rs index 81bd46e..75f8a56 100644 --- a/src/smc.rs +++ b/src/smc.rs @@ -8,7 +8,7 @@ use rand::distr::weighted::WeightedIndex; use rand_distr::{Distribution}; use crate::config::BartConfig; -use crate::data::DataView; +use crate::data::{DataView, NotNan}; use crate::particle::{Particle}; use crate::resampling::ResamplingStrategy; use crate::splitting::SplitRules; @@ -63,6 +63,9 @@ where current_tree.predict_training_into_multi(&mut current_tree_pred, Some(data.x)); sum_trees_noi.assign(sum_trees); sum_trees_noi -= ¤t_tree_pred; + + let storage: Vec = (0..n_samples as i32).collect(); + let base_index: &[i32] = &storage; let mut predictions_buf = Array::zeros((config.n_outputs, n_samples)); let mut ancestors_buf: Vec = Vec::with_capacity(n_non_ref); @@ -71,7 +74,7 @@ where while particles[1..].iter().any(|p| p.has_expandable_nodes()) { mutated.iter_mut().for_each(|m| *m = false); - for (i, particle) in particles[1..].iter_mut().enumerate() { + for (_i, particle) in particles[1..].iter_mut().enumerate() { if let Some(node_idx) = particle.peek_next_expandable() { let node_idx = node_idx as usize; @@ -87,10 +90,19 @@ where response, ) { MutationDecision::Accept(proposal) => { + + let active_indices: Vec = particle.leaf_samples(node_idx).iter().map( |&i| i as i32).collect(); + let old_cont: f64 = active_indices.iter().map( |i| particle.ll_pointwise[*i as usize]).sum(); + particle.pop_next_expandable(); particle.apply_mutation(&proposal, data.x); acceptance_count += 1; - mutated[i] = true; + // mutated[i] = true; + + particle.tree.predict_training_into_multi(&mut predictions_buf, Some(data.x)); + let flat: Vec = predictions_buf.iter().copied().collect(); + predictions_buf += &sum_trees_noi; + particle.log_weight -= old_cont + weight_fn.log_weight(&flat, &active_indices) } MutationDecision::Reject => { particle.pop_next_expandable(); @@ -99,15 +111,14 @@ where } } - for (i, particle) in particles[1..].iter_mut().enumerate() { - if mutated[i] { - predictions_buf.fill(0.0); - particle.tree.predict_training_into_multi(&mut predictions_buf, Some(data.x)); - predictions_buf += &sum_trees_noi; - let flat: Vec = predictions_buf.iter().copied().collect(); - particle.log_weight = weight_fn.log_weight(&flat); - } - } + // for (i, particle) in particles[1..].iter_mut().enumerate() { + // if mutated[i] { + // particle.tree.predict_training_into_multi(&mut predictions_buf, Some(data.x)); + // predictions_buf += &sum_trees_noi; + // let flat: Vec = predictions_buf.iter().copied().collect(); + // particle.log_weight = weight_fn.log_weight(&flat); + // } + // } inner_weights.copy_from_slice(&particles[1..].iter().map(|p| p.log_weight).collect::>()); @@ -122,11 +133,10 @@ where let mut log_weights = vec![0.0f64; config.n_particles]; for (i, particle) in particles.iter().enumerate() { - predictions_buf.fill(0.0); particle.tree.predict_training_into_multi(&mut predictions_buf, Some(data.x)); predictions_buf += &sum_trees_noi; let flat: Vec = predictions_buf.iter().copied().collect(); - log_weights[i] = weight_fn.log_weight(&flat); + log_weights[i] = weight_fn.log_weight(&flat, &base_index); } let mut weights = log_weights.clone(); @@ -167,6 +177,11 @@ fn propose_mutation( response: &dyn ResponseStrategy, ) -> MutationDecision { let depth = particle.tree.get_depth(node_idx); + + if depth >= config.max_depth as usize { + return MutationDecision::Reject; + } + if depth == 0 { // continue; } else { @@ -190,7 +205,9 @@ fn propose_mutation( let col = data.x.column(split_var); let feature_values = node_samples .iter() - .map(|&s| unsafe { *col.uget(s as usize) }); + .map(|&s| unsafe { *col.uget(s as usize) }) + .filter( |&v| v.is_valid()); + let split_strategy = &split_rules[split_var]; let split_val = match split_strategy.sample_split_value(rng, feature_values) { @@ -268,4 +285,4 @@ pub fn normalize_weights_inplace(weights: &mut [f64]) { for w in weights.iter_mut() { *w /= sum; } -} +} \ No newline at end of file diff --git a/src/splitting.rs b/src/splitting.rs index 12d0b32..f77db6e 100644 --- a/src/splitting.rs +++ b/src/splitting.rs @@ -113,7 +113,7 @@ impl SplitRule for OneHotSplit { where I: Iterator, { - data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) == threshold) + data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) <= threshold) } } @@ -170,3 +170,129 @@ impl SplitRules { } } } + + +#[cfg(test)] +mod tests { + use super::*; + use numpy::ndarray::array; + use rand::{rngs::StdRng, SeedableRng}; + + fn assert_partition_consistent( + data: &Array, + feature_idx: usize, + threshold: f64, + left: &[usize], + right: &[usize], + ) { + assert_eq!(left.len() + right.len(), data.nrows()); + + for &idx in left { + assert!(data[[idx, feature_idx]] <= threshold); + } + + for &idx in right { + assert!(data[[idx, feature_idx]] > threshold); + } + } + + #[test] + fn test_continuous_split_rule() { + let rule = ContinuousSplit; + + let mut rng = StdRng::seed_from_u64(42); + assert_eq!(rule.sample_split_value(&mut rng, vec![0.0].into_iter()), None); + + let available_values: Vec = (0..10).map(|x| x as f64).collect(); + let sv = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .expect("expected a split value"); + + let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]]; + + let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + + assert_partition_consistent(&data, 0, sv, &left, &right); + + let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + assert_eq!(left, left_repeated); + assert_eq!(right, right_repeated); + + let probs = (0..10_000) + .map(|_| { + let split_value = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .unwrap(); + let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows()); + let mut mask = vec![false; data.nrows()]; + for idx in left { + mask[idx] = true; + } + mask + }) + .fold(vec![0usize; data.nrows()], |mut acc, mask| { + for (i, b) in mask.into_iter().enumerate() { + if b { + acc[i] += 1; + } + } + acc + }) + .into_iter() + .map(|count| count as f64 / 10_000.0) + .collect::>(); + + assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1); + assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1); + } + + #[test] + fn test_one_hot_split_rule() { + let rule = OneHotSplit; + + let mut rng = StdRng::seed_from_u64(42); + assert_eq!(rule.sample_split_value(&mut rng, vec![0].into_iter()), None); + + let available_values: Vec = (0..10).collect(); + let sv = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .expect("expected a split value"); + + let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]]; + + let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + + assert_partition_consistent(&data, 0, sv as f64, &left, &right); + + let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + assert_eq!(left, left_repeated); + assert_eq!(right, right_repeated); + + let probs = (0..10_000) + .map(|_| { + let split_value = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .unwrap(); + let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows()); + let mut mask = vec![false; data.nrows()]; + for idx in left { + mask[idx] = true; + } + mask + }) + .fold(vec![0usize; data.nrows()], |mut acc, mask| { + for (i, b) in mask.into_iter().enumerate() { + if b { + acc[i] += 1; + } + } + acc + }) + .into_iter() + .map(|count| count as f64 / 10_000.0) + .collect::>(); + + assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1); + assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1); + } +} \ No newline at end of file diff --git a/src/tree.rs b/src/tree.rs index 9e546cb..6de578c 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -4,14 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use crate::response::{LeafKind, LeafPayload, LeafProposal}; - -// 1. DONE: remove pgbart.py from pymc-bart and use this sampler -// 2. DONE: linear terms -// 3. CURRENT: logp for samples that only affect the tree -// 4. monotonic response -// 5. Hawks example with separate BARTRVs -// 6. Reseaaaaarch - +use crate::data::NotNan; /// Bartz-style heap-indexed tree with separate internal/leaf arrays. /// @@ -387,14 +380,20 @@ impl TreeArrays { let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0); let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0); let mut contrib = data.column(var).to_owned(); - contrib.mapv_inplace(|x| intercept + slope * x); + contrib.mapv_inplace(|x| { + if x.is_valid() { + intercept + slope * x + } else { + 0.0 + } + }); row += &(weights.clone() * contrib); } } } } - fn fill_training_leaf_value( + pub fn fill_training_leaf_value( &self, leaf_idx: usize, sample_idx: usize, @@ -422,7 +421,11 @@ impl TreeArrays { for out_idx in 0..n_outputs { let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0); let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0); - out[[out_idx, sample_idx]] = intercept + slope * x; + out[[out_idx, sample_idx]] = if x.is_nan() { + intercept + } else { + intercept + slope * x + }; } } } diff --git a/src/weight.rs b/src/weight.rs index 4a3919d..cc1a5b7 100644 --- a/src/weight.rs +++ b/src/weight.rs @@ -2,14 +2,14 @@ use std::ffi::c_double; /// Safe trait for computing log-weights from predictions. pub trait WeightFn { - fn log_weight(&self, predictions: &[f64]) -> f64; + fn log_weight(&self, predictions: &[f64], indices: &[i32]) -> f64; } /// Weight function backed by a C function pointer from PyMC. /// /// The unsafe FFI call is isolated behind this safe trait implementation. pub struct PyMCWeightFn { - func_ptr: unsafe extern "C" fn(*const f64, usize) -> c_double, + func_ptr: unsafe extern "C" fn(*const f64, i32, *const i32, i32) -> c_double, } impl PyMCWeightFn { @@ -19,13 +19,13 @@ impl PyMCWeightFn { /// The caller must ensure the function pointer remains valid for /// the lifetime of this struct and that it correctly interprets /// a (pointer, length) pair as a slice of f64 values. - pub unsafe fn from_raw(ptr: unsafe extern "C" fn(*const f64, usize) -> c_double) -> Self { + pub unsafe fn from_raw(ptr: unsafe extern "C" fn(*const f64, i32, *const i32, i32) -> c_double) -> Self { Self { func_ptr: ptr } } } impl WeightFn for PyMCWeightFn { - fn log_weight(&self, predictions: &[f64]) -> f64 { - unsafe { (self.func_ptr)(predictions.as_ptr(), predictions.len()) } + fn log_weight(&self, predictions: &[f64], indices: &[i32]) -> f64 { + unsafe { (self.func_ptr)(predictions.as_ptr(), predictions.len() as i32, indices.as_ptr(), indices.len() as i32) } } }