-
Notifications
You must be signed in to change notification settings - Fork 17
Add fast-path to rtnorm #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Twice faster for the case when all arguments have length 1. |
|
Code to benchmark (note that the benchmark also checks that the output is identical): library(msm)
rtnorm_new <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) {
if (length(n) > 1)
n <- length(n)
# sd <- vapply(sd, max, numeric(1), 1e-15, USE.NAMES=FALSE) # Small values of sd break the function.
# Fast-path for frequent case.
if (length(mean) == 1L && length(sd) == 1L && length(lower) == 1L && length(upper) == 1L) {
lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
upper <- (upper - mean) / sd
nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
if (any(nas)) warning("NAs produced")
alg <- if ((lower > upper) && nas) -1L # return NaN
else if ((lower < 0 && upper == Inf) ||
(lower == -Inf && upper > 0) ||
(is.finite(lower) && is.finite(upper) && (lower < 0) && (upper > 0) && (upper - lower > sqrt(2*pi))))
0L # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
else if (lower >= 0 && (upper > lower + 2*sqrt(exp(1)) /
(lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4)))
1L # rejection sampling with exponential proposal. Use if lower >> mean
else if (upper <= 0 && (-lower > -upper + 2*sqrt(exp(1)) /
(-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)))
2L # rejection sampling with exponential proposal. Use if upper << mean.
else 3L # rejection sampling with uniform proposal. Use if bounds are narrow and central.
ret <- rep_len(NaN, n)
if (alg == -1L) {
return(ret)
} else if (alg == 0L) {
ind.no <- seq_len(n)
while (length(ind.no) > 0) {
y <- rnorm(length(ind.no))
done <- which(y >= lower & y <= upper)
ret[ind.no[done]] <- y[done]
ind.no <- setdiff(ind.no, ind.no[done])
}
} else if (alg == 1L) {
ind.expl <- seq_len(n)
a <- (lower + sqrt(lower^2 + 4)) / 2
while (length(ind.expl) > 0) {
z <- rexp(length(ind.expl), a) + lower
u <- runif(length(ind.expl))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper))
ret[ind.expl[done]] <- z[done]
ind.expl <- setdiff(ind.expl, ind.expl[done])
}
} else if (alg == 2L) {
ind.expu <- seq_len(n)
a <- (-upper + sqrt(upper^2 +4)) / 2
while (length(ind.expu) > 0) {
z <- rexp(length(ind.expu), a) - upper
u <- runif(length(ind.expu))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower))
ret[ind.expu[done]] <- -z[done]
ind.expu <- setdiff(ind.expu, ind.expu[done])
}
} else {
ind.u <- seq_len(n)
K <- if (lower > 0) lower^2 else if (upper < 0) upper^2 else 0
while (length(ind.u) > 0) {
z <- runif(length(ind.u), lower, upper)
rho <- exp((K - z^2) / 2)
u <- runif(length(ind.u))
done <- which(u <= rho)
ret[ind.u[done]] <- z[done]
ind.u <- setdiff(ind.u, ind.u[done])
}
}
} else {
mean <- rep(mean, length=n)
sd <- rep(sd, length=n)
lower <- rep(lower, length=n)
upper <- rep(upper, length=n)
lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
upper <- (upper - mean) / sd
ind <- seq(length.out=n)
ret <- numeric(n)
nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
if (any(nas)) warning("NAs produced")
## Different algorithms depending on where upper/lower limits lie.
alg <- ifelse(
((lower > upper) | nas),
-1,# return NaN
ifelse(
((lower < 0 & upper == Inf) |
(lower == -Inf & upper > 0) |
(is.finite(lower) & is.finite(upper) & (lower < 0) & (upper > 0) & (upper-lower > sqrt(2*pi)))
),
0, # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
ifelse(
(lower >= 0 & (upper > lower + 2*sqrt(exp(1)) /
(lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4))),
1, # rejection sampling with exponential proposal. Use if lower >> mean
ifelse(upper <= 0 & (-lower > -upper + 2*sqrt(exp(1)) /
(-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)),
2, # rejection sampling with exponential proposal. Use if upper << mean.
3)))) # rejection sampling with uniform proposal. Use if bounds are narrow and central.
ind.nan <- ind[alg==-1]; ind.no <- ind[alg==0]; ind.expl <- ind[alg==1]; ind.expu <- ind[alg==2]; ind.u <- ind[alg==3]
ret[ind.nan] <- NaN
while (length(ind.no) > 0) {
y <- rnorm(length(ind.no))
done <- which(y >= lower[ind.no] & y <= upper[ind.no])
ret[ind.no[done]] <- y[done]
ind.no <- setdiff(ind.no, ind.no[done])
}
stopifnot(length(ind.no) == 0)
while (length(ind.expl) > 0) {
a <- (lower[ind.expl] + sqrt(lower[ind.expl]^2 + 4)) / 2
z <- rexp(length(ind.expl), a) + lower[ind.expl]
u <- runif(length(ind.expl))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper[ind.expl]))
ret[ind.expl[done]] <- z[done]
ind.expl <- setdiff(ind.expl, ind.expl[done])
}
stopifnot(length(ind.expl) == 0)
while (length(ind.expu) > 0) {
a <- (-upper[ind.expu] + sqrt(upper[ind.expu]^2 +4)) / 2
z <- rexp(length(ind.expu), a) - upper[ind.expu]
u <- runif(length(ind.expu))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower[ind.expu]))
ret[ind.expu[done]] <- -z[done]
ind.expu <- setdiff(ind.expu, ind.expu[done])
}
stopifnot(length(ind.expu) == 0)
while (length(ind.u) > 0) {
z <- runif(length(ind.u), lower[ind.u], upper[ind.u])
rho <- ifelse(lower[ind.u] > 0,
exp((lower[ind.u]^2 - z^2) / 2), ifelse(upper[ind.u] < 0,
exp((upper[ind.u]^2 - z^2) / 2),
exp(-z^2/2)))
u <- runif(length(ind.u))
done <- which(u <= rho)
ret[ind.u[done]] <- z[done]
ind.u <- setdiff(ind.u, ind.u[done])
}
stopifnot(length(ind.u) == 0)
}
ret*sd + mean
}
test <- function(f) {
set.seed(42)
unlist(mapply(f, n = sample(c(1,1,10,100), 1000, replace=TRUE), mean = 2*runif(1000), sd = 0.001+runif(1000), lower=0, upper=1))
}
library(bench)
(x <- bench::mark(
test(rtnorm_new),
test(msm::rtnorm),
check=TRUE, min_time=5))
plot(x) |
|
Hi Manuel - Thanks for this work. I'm unsure about putting it in though, because it results in a lot of duplicated code, and it makes the resulting function very long. I was curious what exactly it is it that makes such a difference to the efficiency? Presumably something inside Also I get the impression that there are several other implementations of the truncated normal distribution in other R packages - I haven't investigated these, but I'm curious what they all do differently. If there are more advanced/efficient ones out there (C++-based implementations perhaps??), then there probably isn't much point refining the one in |
|
There is this one: https://github.com/olafmersmann/truncnorm |

No description provided.