Skip to content

Commit 25b75d0

Browse files
authored
Merge pull request #26 from asgr/copilot/add-rcpp-dispersion-speedup
Add Rcpp speedup for `disp_stars` convolution in `SFHfunc()` and `SFHburst()`
2 parents ec8f79b + 38bfb2b commit 25b75d0

5 files changed

Lines changed: 190 additions & 29 deletions

File tree

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4+
.disp_stars_cpp <- function(wave_log, lum_log, z_disp, grid, weights, res) {
5+
.Call(`_ProSpect_disp_stars_cpp`, wave_log, lum_log, z_disp, grid, weights, res)
6+
}
7+
48
.colSums_wt_cpp <- function(mat, vec_wt = 1L) {
59
.Call(`_ProSpect_colSums_wt_cpp`, mat, vec_wt)
610
}

R/SFH.R

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -349,20 +349,7 @@ SFHfunc = function(massfunc = massfunc_b5,
349349

350350
z_disp = sqrt(veldisp^2 + vel_LSF^2)/(.c_to_mps/1000) #this will be a vector of length wave_lum
351351

352-
lum_conv = numeric(length(lum))
353-
354-
for(i in seq_along(grid)){
355-
z_seq_log = log10(1 + grid[i]*z_disp)
356-
new_wave_log = wave_lum_log + z_seq_log
357-
new_lum = lum_log - z_seq_log
358-
new_lum = 10^approx(x=new_wave_log, y=new_lum, xout=wave_lum_log, rule=2, yleft=new_lum[1], yright=new_lum[length(new_lum)])$y
359-
new_lum = new_lum*weights[i]
360-
#lum_conv = lum_conv + new_lum
361-
.vec_add_cpp(lum_conv, new_lum)
362-
}
363-
364-
lum = lum_conv*res
365-
rm(lum_conv)
352+
lum = .disp_stars_cpp(wave_lum_log, lum_log, z_disp, grid, weights, res)
366353
}
367354

368355
if (emission) {
@@ -973,7 +960,7 @@ SFHburst = function(burstmass = 1e8,
973960
vel_LSF = LSF(wave_lum*(1 + z)) #to get LSF dispersion in km/s into z in the oberved frame
974961
}else if(is.matrix(LSF) | is.data.frame(LSF)){
975962
vel_LSF = approx(x=log10(LSF[,1]), y=LSF[,2], xout=log10(wave_lum*(1 + z)), rule=2)$y
976-
}else if(length(LSF == 1)){
963+
}else if(is.numeric(LSF) && length(LSF) == 1){
977964
vel_LSF = rep(LSF, length(wave_lum))
978965
}else{
979966
stop('LSF is in the wrong format!')
@@ -984,20 +971,7 @@ SFHburst = function(burstmass = 1e8,
984971

985972
z_disp = sqrt(veldisp^2 + vel_LSF^2)/(.c_to_mps/1000) #this will be a vector of length wave_lum
986973

987-
lum_conv = numeric(length(lum))
988-
989-
for(i in seq_along(grid)){
990-
z_seq = grid[i]*z_disp
991-
new_wave_log = wave_lum_log + log10(1 + z_seq)
992-
new_lum = lum_log - log10(1 + z_seq)
993-
new_lum = 10^approx(x=new_wave_log, y=new_lum, xout=wave_lum_log, rule=2, yleft=new_lum[1], yright=new_lum[length(new_lum)])$y
994-
new_lum = new_lum*weights[i]
995-
#lum_conv = lum_conv + new_lum
996-
.vec_add_cpp(lum_conv, new_lum)
997-
}
998-
999-
lum = lum_conv*res
1000-
rm(lum_conv)
974+
lum = .disp_stars_cpp(wave_lum_log, lum_log, z_disp, grid, weights, res)
1001975
}
1002976

1003977
if (emission & burstage < 1e7) {

src/RcppExports.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
1010
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1111
#endif
1212

13+
// disp_stars_cpp
14+
NumericVector disp_stars_cpp(NumericVector wave_log, NumericVector lum_log, NumericVector z_disp, NumericVector grid, NumericVector weights, double res);
15+
RcppExport SEXP _ProSpect_disp_stars_cpp(SEXP wave_logSEXP, SEXP lum_logSEXP, SEXP z_dispSEXP, SEXP gridSEXP, SEXP weightsSEXP, SEXP resSEXP) {
16+
BEGIN_RCPP
17+
Rcpp::RObject rcpp_result_gen;
18+
Rcpp::RNGScope rcpp_rngScope_gen;
19+
Rcpp::traits::input_parameter< NumericVector >::type wave_log(wave_logSEXP);
20+
Rcpp::traits::input_parameter< NumericVector >::type lum_log(lum_logSEXP);
21+
Rcpp::traits::input_parameter< NumericVector >::type z_disp(z_dispSEXP);
22+
Rcpp::traits::input_parameter< NumericVector >::type grid(gridSEXP);
23+
Rcpp::traits::input_parameter< NumericVector >::type weights(weightsSEXP);
24+
Rcpp::traits::input_parameter< double >::type res(resSEXP);
25+
rcpp_result_gen = Rcpp::wrap(disp_stars_cpp(wave_log, lum_log, z_disp, grid, weights, res));
26+
return rcpp_result_gen;
27+
END_RCPP
28+
}
1329
// colSums_wt_cpp
1430
NumericVector colSums_wt_cpp(NumericMatrix mat, NumericVector vec_wt);
1531
RcppExport SEXP _ProSpect_colSums_wt_cpp(SEXP matSEXP, SEXP vec_wtSEXP) {
@@ -101,6 +117,7 @@ END_RCPP
101117
}
102118

103119
static const R_CallMethodDef CallEntries[] = {
120+
{"_ProSpect_disp_stars_cpp", (DL_FUNC) &_ProSpect_disp_stars_cpp, 6},
104121
{"_ProSpect_colSums_wt_cpp", (DL_FUNC) &_ProSpect_colSums_wt_cpp, 2},
105122
{"_ProSpect_mat_vec_mult_col", (DL_FUNC) &_ProSpect_mat_vec_mult_col, 3},
106123
{"_ProSpect_mat_vec_mult_row", (DL_FUNC) &_ProSpect_mat_vec_mult_row, 3},

src/disp_stars.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include <Rcpp.h>
2+
#include <cmath>
3+
using namespace Rcpp;
4+
5+
// disp_stars_cpp: Rcpp implementation of the stellar dispersion convolution.
6+
//
7+
// Equivalent to the R loop in SFHfunc()/SFHburst() when disp_stars=TRUE:
8+
// for(i in seq_along(grid)){
9+
// z_seq_log = log10(1 + grid[i]*z_disp)
10+
// new_wave_log = wave_log + z_seq_log
11+
// new_lum = lum_log - z_seq_log
12+
// new_lum = 10^approx(x=new_wave_log, y=new_lum, xout=wave_log,
13+
// rule=2, yleft=new_lum[1], yright=new_lum[n])$y
14+
// lum_conv = lum_conv + weights[i] * new_lum
15+
// }
16+
// return(lum_conv * res)
17+
//
18+
// Inputs:
19+
// wave_log - log10 wavelengths (monotone increasing, length n)
20+
// lum_log - log10 luminosity (length n)
21+
// z_disp - wavelength-dependent velocity dispersion in z units (length n)
22+
// grid - Gaussian quadrature grid points, e.g. seq(-range, range, by=res)
23+
// weights - dnorm(grid)
24+
// res - grid spacing (multiplied into final result)
25+
//
26+
// [[Rcpp::export(".disp_stars_cpp")]]
27+
NumericVector disp_stars_cpp(NumericVector wave_log,
28+
NumericVector lum_log,
29+
NumericVector z_disp,
30+
NumericVector grid,
31+
NumericVector weights,
32+
double res) {
33+
int n = wave_log.size();
34+
int ng = grid.size();
35+
36+
if (lum_log.size() != n) stop("lum_log must have same length as wave_log");
37+
if (z_disp.size() != n) stop("z_disp must have same length as wave_log");
38+
if (weights.size() != ng) stop("weights must have same length as grid");
39+
40+
NumericVector out(n, 0.0);
41+
// Reusable buffers for the shifted wavelength / luminosity grids
42+
NumericVector x_src(n);
43+
NumericVector y_src(n);
44+
45+
for (int gi = 0; gi < ng; gi++) {
46+
double g = grid[gi];
47+
double w = weights[gi];
48+
49+
// Build shifted grids: x_src = wave_log + log10(1 + g*z_disp)
50+
// y_src = lum_log - log10(1 + g*z_disp)
51+
for (int j = 0; j < n; j++) {
52+
double val = 1.0 + g * z_disp[j];
53+
if (val <= 0.0) {
54+
stop("1 + grid[i]*z_disp <= 0: cannot take log10. "
55+
"Reduce veldisp or the grid range.");
56+
}
57+
double lv = std::log10(val);
58+
x_src[j] = wave_log[j] + lv;
59+
y_src[j] = lum_log[j] - lv;
60+
}
61+
62+
// Linear interpolation of y_src(x_src) onto wave_log.
63+
// Uses rule=2: extrapolate with boundary values y_src[0] / y_src[n-1].
64+
// Two-pointer sweep works because wave_log is monotone increasing and
65+
// x_src remains monotone for the small shifts typical of veldisp.
66+
// lo is intentionally reset to 0 for each grid point gi: since wave_log
67+
// is scanned left-to-right each iteration, lo must start fresh.
68+
int lo = 0;
69+
for (int j = 0; j < n; j++) {
70+
double xq = wave_log[j];
71+
double y_interp;
72+
73+
if (xq <= x_src[0]) {
74+
y_interp = y_src[0];
75+
} else if (xq >= x_src[n - 1]) {
76+
y_interp = y_src[n - 1];
77+
} else {
78+
// Advance lo so that x_src[lo] <= xq < x_src[lo+1]
79+
while (lo < n - 2 && x_src[lo + 1] <= xq) {
80+
lo++;
81+
}
82+
double t = (xq - x_src[lo]) / (x_src[lo + 1] - x_src[lo]);
83+
y_interp = y_src[lo] + t * (y_src[lo + 1] - y_src[lo]);
84+
}
85+
86+
out[j] += w * std::pow(10.0, y_interp);
87+
}
88+
}
89+
90+
// Scale by the grid spacing
91+
for (int j = 0; j < n; j++) {
92+
out[j] *= res;
93+
}
94+
95+
return out;
96+
}

tests/test_disp_stars_cpp.R

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Regression test: verify that .disp_stars_cpp() matches the original R loop
2+
# to within a tight numerical tolerance (relative error < 1e-10).
3+
#
4+
# Run via: Rscript tests/test_disp_stars_cpp.R
5+
# (after the package has been installed or loaded with devtools::load_all())
6+
7+
library(ProSpect)
8+
9+
# R reference implementation of the disp_stars loop (identical to pre-Rcpp SFH.R)
10+
disp_stars_R <- function(wave_lum_log, lum_log, z_disp, grid, weights, res) {
11+
lum_conv <- numeric(length(lum_log))
12+
for (i in seq_along(grid)) {
13+
z_seq_log <- log10(1 + grid[i] * z_disp)
14+
new_wave <- wave_lum_log + z_seq_log
15+
new_lum <- lum_log - z_seq_log
16+
new_lum <- 10^approx(x = new_wave, y = new_lum,
17+
xout = wave_lum_log,
18+
rule = 2,
19+
yleft = new_lum[1],
20+
yright = new_lum[length(new_lum)])$y
21+
lum_conv <- lum_conv + weights[i] * new_lum
22+
}
23+
lum_conv * res
24+
}
25+
26+
set.seed(42)
27+
n <- 500
28+
wave <- seq(3000, 10000, length.out = n)
29+
wave_log <- log10(wave)
30+
lum <- runif(n, 1e3, 1e8)
31+
lum_log <- log10(lum)
32+
33+
# Several test cases with different dispersion and grid settings
34+
test_cases <- list(
35+
list(veldisp = 100, range = 3, res = 0.1),
36+
list(veldisp = 250, range = 3, res = 0.1),
37+
list(veldisp = 50, range = 5, res = 0.05),
38+
list(veldisp = 300, range = 3, res = 0.2)
39+
)
40+
41+
c_to_mps <- 299792458 # speed of light in m/s (matches ProSpect:::.c_to_mps)
42+
43+
all_passed <- TRUE
44+
45+
for (tc in test_cases) {
46+
veldisp <- tc$veldisp
47+
range <- tc$range
48+
res <- tc$res
49+
50+
grid <- seq(-range, range, by = res)
51+
weights <- dnorm(grid)
52+
z_disp <- rep(veldisp / (c_to_mps / 1000), n) # constant z_disp for simplicity
53+
54+
ref <- disp_stars_R(wave_log, lum_log, z_disp, grid, weights, res)
55+
got <- .disp_stars_cpp(wave_log, lum_log, z_disp, grid, weights, res)
56+
57+
rel_err <- max(abs(ref - got) / pmax(abs(ref), 1e-300))
58+
59+
if (rel_err > 1e-10) {
60+
cat(sprintf("FAIL: veldisp=%g range=%g res=%g => max rel error = %e\n",
61+
veldisp, range, res, rel_err))
62+
all_passed <- FALSE
63+
} else {
64+
cat(sprintf("PASS: veldisp=%g range=%g res=%g => max rel error = %e\n",
65+
veldisp, range, res, rel_err))
66+
}
67+
}
68+
69+
if (!all_passed) stop("One or more disp_stars_cpp regression tests failed!")
70+
cat("All disp_stars_cpp regression tests passed.\n")

0 commit comments

Comments
 (0)