Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
{name = "Otto Vintola", email="hello@ottovintola.com" },
]
description = "Rust implementation of Bayesian Additive Regression Trees for Probabilistic programming with PyMC"
requires-python = ">=3.12, <3.14"
requires-python = ">=3.12, <3.15"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
Expand Down
44 changes: 30 additions & 14 deletions python/bartrs/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import (
compile,
inputvars,
join_nonshared_inputs,
make_shared_replacements,
)
from numba import carray, cfunc, extending, float64, types, njit
from numba import carray, cfunc, extending, float64, types, njit, int32
from numba.core import cgutils
from pytensor.assumptions.specify import assume


@numba.extending.intrinsic
Expand Down Expand Up @@ -187,15 +189,24 @@ def _make_functions(self, model, vars):
#
shared = make_shared_replacements(initial_values, value_vars, model)

out_vars = [model.datalogp]
idx = pt.ivector("idx")
idx_unique = assume(idx, unique_indices=True)

obs_logps = []
for obs_rv in model.observed_RVs:
rv_logp_list = model.logp(obs_rv, sum=False)
rv_logp = rv_logp_list[0]
subset_logp = rv_logp[idx_unique]
obs_logps.append(pt.sum(subset_logp))
out_vars = pt.sum(obs_logps)

# Join non-shared inputs and prepare for compilation
# This separates model parameters from shared/observed data
new_out, new_joined_inputs = join_nonshared_inputs(
initial_values, out_vars, value_vars, shared
initial_values, [out_vars], value_vars, shared
)

logp_fn = compile(inputs=[new_joined_inputs], outputs=new_out[0], mode="NUMBA")
logp_fn = compile(inputs=[new_joined_inputs, idx], outputs=new_out[0], mode="NUMBA")
logp_fn.trust_input = True

return shape, logp_fn, shared
Expand Down Expand Up @@ -231,8 +242,8 @@ def _make_persistent_arrays(self):
All arrays are ensured to be float64 for consistency with
the compiled function interface.
"""
arrays = [item.storage[0].copy() for item in self.logp_fn_ptr.input_storage[1:]]
assert all(arr.dtype == np.float64 for arr in arrays)
arrays = [item.storage[0].copy() for item in self.logp_fn_ptr.input_storage[2:]]
# assert all(arr.dtype == np.float64 for arr in arrays)
return arrays

# TODO: fast update for shared arrays
Expand All @@ -258,7 +269,7 @@ def update_shared_arrays(self):
Observed data (e.g., design matrix `X` and targets `y`) are stored in
input storage.
"""
for array, storage in zip(self.logp_args, self.logp_fn_ptr.input_storage[1:]):
for array, storage in zip(self.logp_args, self.logp_fn_ptr.input_storage[2:]):
# NOTE: np.copyto is the old implementation
np.copyto(array, storage.storage[0])
# self._fast_update(array, storage.storage[0])
Expand All @@ -269,7 +280,7 @@ def _generate_logp_function(self):
This method creates a Numba-compiled C function that can be called
from external code (e.g., Rust via FFI). The function signature is:

`double logp(double* ptr, int size)`
`double logp(double* ptr, int size, int* idx_ptr, int idx_size)`

The generated function:
1. Wraps the input pointer as a NumPy array
Expand All @@ -286,8 +297,9 @@ def _generate_logp_function(self):
shared_arrays = self.logp_args

code = [
"def _logp(ptr, size):",
"def _logp(ptr, size, idx_ptr, idx_size):",
" data = carray(ptr, (size, ), dtype=float64)",
" indexes = carray(idx_ptr, (idx_size, ), dtype=int32)",
]

for i, array in enumerate(shared_arrays):
Expand All @@ -298,7 +310,7 @@ def _generate_logp_function(self):
)
code.append(line)

ret = f" return logp_fn(data, {', '.join(f'arg{i}' for i in range(len(shared_arrays)))})[0].item()"
ret = f" return logp_fn(data, indexes, {', '.join(f'arg{i}' for i in range(len(shared_arrays)))})[0].item()"
code.append(ret)
source = "\n".join(code)

Expand All @@ -311,6 +323,8 @@ def _generate_logp_function(self):
sig = types.float64(
types.CPointer(types.float64),
types.intc,
types.CPointer(types.int32),
types.intc,
)
try:
return cfunc(sig)(logp)
Expand All @@ -324,18 +338,20 @@ def _generate_logp_function(self):
def _generate_ctypes_logp_function(self, logp_fn, shared_arrays):
"""Fallback log-probability callback implemented with ctypes."""

callback_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int)
callback_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int,
ctypes.POINTER(ctypes.c_int32), ctypes.c_int)

def _logp(ptr, size):
def _logp(ptr, size, idx_ptr, idx_size):
data = np.ctypeslib.as_array(ptr, shape=(size,))
indexes = np.ctypeslib.as_array(idx_ptr, shape=(idx_size,)).astype(np.int32)

args = []
for array in shared_arrays:
args.append(array)

result = logp_fn(data, *args)
result = logp_fn(data, indexes, *args)
return float(np.asarray(result).reshape(-1)[0])

callback = callback_type(_logp)
self._ctypes_logp_callback = callback
return callback
return callback
5 changes: 2 additions & 3 deletions python/bartrs/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

class PGBART(ArrayStepShared):
"""
Particle Gibss BART sampling step.
Particle Gibbs BART sampling step.

Parameters
----------
Expand Down Expand Up @@ -101,7 +101,6 @@ def __init__( # noqa: PLR0915

self.shape = 1 if len(shape) == 1 else shape[0]


# Set trees_shape (dim for separate tree structures)
# and leaves_shape (dim for leaf node values)
# One of the two is always one, the other equal to self.shape
Expand Down Expand Up @@ -244,7 +243,7 @@ def astep(self, _):

sum_trees, trees, variable_inclusion = self.pg_bart.step(self.tune)
if not self.tune:
self.bart.all_trees.append(trees) # this doubles runtime
self.bart.all_trees.append(trees) # this is slow, I think
t1 = perf_counter()

stats = {
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
NUM_DRAWS = 600
NUM_CHAINS = 4
BATCH_SIZE = (0.1, 0.1)
NUM_TREES = 50
NUM_TREES = 10
NUM_PARTICLES = 10
RANDOM_SEED = 42

Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl Default for BartConfig {
batch_tune: 0.1,
batch_post: 0.1,
n_outputs: 1,
response: "gaussian".to_string(),
response: "constant".to_string(),
}
}
}
17 changes: 17 additions & 0 deletions src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,20 @@ impl OwnedData {
}

}


pub trait NotNan {
fn is_valid(&self) -> bool;
}

impl NotNan for f64 {
fn is_valid(&self) -> bool {
!self.is_nan()
}
}

impl NotNan for i32 {
fn is_valid(&self) -> bool {
true
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use rand::SeedableRng;
use rand::rngs::SmallRng;
use rand_distr::StandardNormal;

type LogpFunc = unsafe extern "C" fn(*const f64, usize) -> c_double;
type LogpFunc = unsafe extern "C" fn(*const f64, i32, *const i32, i32) -> c_double;

#[pyclass]
#[derive(Clone, Debug)]
Expand Down
5 changes: 4 additions & 1 deletion src/particle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub struct Particle {
pub expandable_nodes: VecDeque<u32>,
pub sample_map: LeafSamplesFlat,
pub log_weight: f64,
pub ll_pointwise: Vec<f64>,
}

impl Particle {
Expand All @@ -67,6 +68,7 @@ impl Particle {
expandable_nodes: VecDeque::from([0]),
sample_map: LeafSamplesFlat::new(n_samples, max_depth),
log_weight: 0.0,
ll_pointwise: vec![0.0; n_samples],
}
}

Expand All @@ -77,6 +79,7 @@ impl Particle {
expandable_nodes: VecDeque::new(),
sample_map: LeafSamplesFlat::new(n_samples, max_depth),
log_weight: 0.0,
ll_pointwise: vec![0.0; n_samples],
}
}

Expand Down Expand Up @@ -113,7 +116,7 @@ impl Particle {
}
}

Self { tree: Arc::new(tree), expandable_nodes, sample_map, log_weight: 0.0 }
Self { tree: Arc::new(tree), expandable_nodes, sample_map, log_weight: 0.0, ll_pointwise: vec![0.0; n_samples] }
}

pub fn has_expandable_nodes(&self) -> bool {
Expand Down
22 changes: 22 additions & 0 deletions src/resampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ impl ResamplingStrategy for ResamplingStrategies {
}
}
}


#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_systematic_resample() {
let mut weights = &[0.0, 0.25, 0.75];
let mut rng = rand::rng();
let mut out = Vec::new();

ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out);
assert!(out.iter().all( |&index| index >= 1 && index < weights.len()));

weights = &[0.5, 0.3, 0.2];
out = Vec::new();
ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out);
assert!(out.iter().all( |&index| index >= 0 as usize && index < weights.len()));
}

}
32 changes: 31 additions & 1 deletion src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ impl ResponseStrategy for LinearStrategy {
for &s in node_samples {
let idx = s as usize;
let v = unsafe { *col.uget(idx) };

if v.is_nan() {
continue;
}

if v <= split_val {
left_idx.push(idx);
} else {
Expand Down Expand Up @@ -299,7 +304,7 @@ pub enum ResponseStrategies {
impl ResponseStrategies {
pub fn from_name(name: &str) -> Result<Self, String> {
match name.to_lowercase().as_str() {
"gaussian" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)),
"constant" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)),
"linear" => Ok(ResponseStrategies::Linear(LinearStrategy)),
"motr" => Ok(ResponseStrategies::Motr(MotrStrategy)),
_ => Err(format!(
Expand Down Expand Up @@ -372,6 +377,7 @@ impl ResponseStrategy for ResponseStrategies {
mod tests {
use super::*;

// 1d-intercept
#[test]
fn test_fit_linear_1d_intercept() {
let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
Expand All @@ -389,6 +395,7 @@ mod tests {
assert!((a - 0.0).abs() < 1e-6, "Expected intercept ~0.0, got {}", a);
}

// 1d-slope
#[test]
fn test_fit_linear_1d_slope() {
let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
Expand All @@ -406,5 +413,28 @@ mod tests {
assert!((b - 2.0).abs() < 1e-6, "Expected intercept ~2.0, got {}", b);

}


// 1d-intercept + slope
#[test]
fn test_linear_fit() {

let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
let y = &[3.0, 5.0, 7.0, 9.0, 11.0];

let noise = 0.0f64;

let n_trees = 1;

let Some((intercept, slope)) = LinearStrategy::fit_linear_1d(x, y, noise, n_trees) else {
panic!("Got None when was expecting intercept and slope");
};

let a = intercept;
let b = slope;

assert!((a - 1.0).abs() < 1e-6, "Expected intercept ~1.0, got {}", a);
assert!((b - 2.0).abs() < 1e-6, "Expected slope ~2.0, got {}", b);
}

}
Loading