Skip to content

Conversation

@MLopez-Ibanez
Copy link

No description provided.

@MLopez-Ibanez
Copy link
Author

Twice faster for the case when all arguments have length 1.

image

  expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
  <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
1 test(rtno…  43.8ms  47.3ms     18.6     13.2MB     6.75    58    21      3.11s
2 test(msm:… 107.3ms 114.4ms      8.16    23.2MB    13.1     15    24      1.84s

@MLopez-Ibanez
Copy link
Author

Code to benchmark (note that the benchmark also checks that the output is identical):

library(msm)

rtnorm_new <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) {
    if (length(n) > 1)
        n <- length(n)
    # sd <- vapply(sd, max, numeric(1), 1e-15, USE.NAMES=FALSE) # Small values of sd break the function.
    # Fast-path for frequent case.
    if (length(mean) == 1L && length(sd) == 1L && length(lower) == 1L && length(upper) == 1L) {
      lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
      upper <- (upper - mean) / sd
      nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
      if (any(nas)) warning("NAs produced")
      alg <- if ((lower > upper) && nas) -1L # return NaN
      else if ((lower < 0 && upper == Inf) ||
               (lower == -Inf && upper > 0) ||
               (is.finite(lower) && is.finite(upper) && (lower < 0) && (upper > 0) && (upper - lower > sqrt(2*pi))))
        0L # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
      else if (lower >= 0 && (upper > lower + 2*sqrt(exp(1)) /
                              (lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4)))
        1L # rejection sampling with exponential proposal. Use if lower >> mean
      else if (upper <= 0 && (-lower > -upper + 2*sqrt(exp(1)) /
                              (-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)))
        2L # rejection sampling with exponential proposal. Use if upper << mean.
      else 3L # rejection sampling with uniform proposal. Use if bounds are narrow and central.

      ret <- rep_len(NaN, n)
      if (alg == -1L) {
        return(ret)
      } else if (alg == 0L) {
        ind.no <- seq_len(n)
        while (length(ind.no) > 0) {
          y <- rnorm(length(ind.no))
          done <- which(y >= lower & y <= upper)
          ret[ind.no[done]] <- y[done]
          ind.no <- setdiff(ind.no, ind.no[done])
        }
      } else if (alg == 1L) {
        ind.expl <- seq_len(n)
        a <- (lower + sqrt(lower^2 + 4)) / 2
        while (length(ind.expl) > 0) {
          z <- rexp(length(ind.expl), a) + lower
          u <- runif(length(ind.expl))
          done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper))
          ret[ind.expl[done]] <- z[done]
          ind.expl <- setdiff(ind.expl, ind.expl[done])
        }
      } else if (alg == 2L) {
        ind.expu <- seq_len(n)
        a <- (-upper + sqrt(upper^2 +4)) / 2
        while (length(ind.expu) > 0) {
          z <- rexp(length(ind.expu), a) - upper
          u <- runif(length(ind.expu))
          done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower))
          ret[ind.expu[done]] <- -z[done]
          ind.expu <- setdiff(ind.expu, ind.expu[done])
        }
      } else {
        ind.u <- seq_len(n)
        K <- if (lower > 0) lower^2 else if (upper < 0) upper^2 else 0
        while (length(ind.u) > 0) {
          z <- runif(length(ind.u), lower, upper)
          rho <- exp((K - z^2) / 2)
          u <- runif(length(ind.u))
          done <- which(u <= rho)
          ret[ind.u[done]] <- z[done]
          ind.u <- setdiff(ind.u, ind.u[done])
        }
      }
    } else {
    mean <- rep(mean, length=n)
    sd <- rep(sd, length=n)
    lower <- rep(lower, length=n)
    upper <- rep(upper, length=n)
    lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
    upper <- (upper - mean) / sd
    ind <- seq(length.out=n)
    ret <- numeric(n)
    nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
    if (any(nas)) warning("NAs produced")
    ## Different algorithms depending on where upper/lower limits lie.
    alg <- ifelse(
                  ((lower > upper) | nas),
                  -1,# return NaN
                  ifelse(
                         ((lower < 0 & upper == Inf) |
                          (lower == -Inf & upper > 0) |
                          (is.finite(lower) & is.finite(upper) & (lower < 0) & (upper > 0) & (upper-lower > sqrt(2*pi)))
                          ),
                         0, # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
                         ifelse(
                                (lower >= 0 & (upper > lower + 2*sqrt(exp(1)) /
                                 (lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4))),
                                1, # rejection sampling with exponential proposal. Use if lower >> mean
                                ifelse(upper <= 0 & (-lower > -upper + 2*sqrt(exp(1)) /
                                       (-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)),
                                       2, # rejection sampling with exponential proposal. Use if upper << mean.
                                       3)))) # rejection sampling with uniform proposal. Use if bounds are narrow and central.

    ind.nan <- ind[alg==-1]; ind.no <- ind[alg==0]; ind.expl <- ind[alg==1]; ind.expu <- ind[alg==2]; ind.u <- ind[alg==3]
    ret[ind.nan] <- NaN
    while (length(ind.no) > 0) {
        y <- rnorm(length(ind.no))
        done <- which(y >= lower[ind.no] & y <= upper[ind.no])
        ret[ind.no[done]] <- y[done]
        ind.no <- setdiff(ind.no, ind.no[done])
    }
    stopifnot(length(ind.no) == 0)
    while (length(ind.expl) > 0) {
        a <- (lower[ind.expl] + sqrt(lower[ind.expl]^2 + 4)) / 2
        z <- rexp(length(ind.expl), a) + lower[ind.expl]
        u <- runif(length(ind.expl))
        done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper[ind.expl]))
        ret[ind.expl[done]] <- z[done]
        ind.expl <- setdiff(ind.expl, ind.expl[done])
    }
    stopifnot(length(ind.expl) == 0)
    while (length(ind.expu) > 0) {
        a <- (-upper[ind.expu] + sqrt(upper[ind.expu]^2 +4)) / 2
        z <- rexp(length(ind.expu), a) - upper[ind.expu]
        u <- runif(length(ind.expu))
        done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower[ind.expu]))
        ret[ind.expu[done]] <- -z[done]
        ind.expu <- setdiff(ind.expu, ind.expu[done])
    }
    stopifnot(length(ind.expu) == 0)
    while (length(ind.u) > 0) {
        z <- runif(length(ind.u), lower[ind.u], upper[ind.u])
        rho <- ifelse(lower[ind.u] > 0,
                      exp((lower[ind.u]^2 - z^2) / 2), ifelse(upper[ind.u] < 0,
                                                            exp((upper[ind.u]^2 - z^2) / 2),
                                                            exp(-z^2/2)))
        u <- runif(length(ind.u))
        done <- which(u <= rho)
        ret[ind.u[done]] <- z[done]
        ind.u <- setdiff(ind.u, ind.u[done])
    }
    stopifnot(length(ind.u) == 0)
    }
    ret*sd + mean
}

test <- function(f) {
  set.seed(42)
  unlist(mapply(f, n = sample(c(1,1,10,100), 1000, replace=TRUE), mean = 2*runif(1000), sd = 0.001+runif(1000), lower=0, upper=1))
}

library(bench)
(x <- bench::mark(
               test(rtnorm_new),
               test(msm::rtnorm),
               check=TRUE, min_time=5))
plot(x)

@chjackson
Copy link
Owner

Hi Manuel - Thanks for this work. I'm unsure about putting it in though, because it results in a lot of duplicated code, and it makes the resulting function very long. I was curious what exactly it is it that makes such a difference to the efficiency? Presumably something inside ifelse() is doing a lot of unnecessary work, and if so, is that a clue to making a version that would handle both cases more cleanly?

Also I get the impression that there are several other implementations of the truncated normal distribution in other R packages - I haven't investigated these, but I'm curious what they all do differently. If there are more advanced/efficient ones out there (C++-based implementations perhaps??), then there probably isn't much point refining the one in msm. After all it's not really anything to do with multistate models.

@MLopez-Ibanez
Copy link
Author

There is this one: https://github.com/olafmersmann/truncnorm
But I haven't checked it for correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants