#' Sample from the ensemble posterior distribution
#'
#' Given model averaging weights (e.g., from Bayesian model averaging (BMA), pseudo-BMA, or stacking) and a matrix of
#' posterior samples from the candidate models, this function draws samples from the model-averaged posterior distribution.
#' Here, each "model" refers to a unique combination of an outcome model and its associated priors. Posterior draws
#' are randomly selected from the candidate models in proportion to their specified weights, producing samples from
#' the ensemble of posterior distributions.
#'
#' This function is typically used in combination with [compute.ensemble.weights()], which computes model averaging
#' weights using methods such as Bayesian model averaging (BMA), pseudo-BMA, pseudo-BMA with the Bayesian bootstrap,
#' or stacking). The input matrix of posterior samples should have one column per candidate model, with each column
#' containing posterior draws from that model.
#'
#' @export
#'
#' @param wts               a numeric vector of normalized model averaging weights (e.g., from [compute.ensemble.weights()]).
#'                          The length of `wts` must match the number of columns in `samples.mtx`.
#' @param samples.mtx  a matrix of posterior samples. Each column corresponds to samples from a different model, and each
#'                          row is one posterior draw (e.g., from Markov chain Monte Carlo (MCMC) sampling). All columns must
#'                          have the same number of samples.
#'
#' @return
#'  The function returns a numeric vector of ensemble posterior draws, sampled proportionally to the provided model weights.
#'  The returned vector has the same length as the number of rows in `samples.mtx`.
#'
#' @seealso [compute.ensemble.weights()]
#'
#' @examples
#' if (instantiate::stan_cmdstan_exists()) {
#'   if(requireNamespace("survival")){
#'     library(survival)
#'     data(E1684)
#'     data(E1690)
#'     ## replace 0 failure times with 0.50 days
#'     E1684$failtime[E1684$failtime == 0] = 0.50/365.25
#'     E1690$failtime[E1690$failtime == 0] = 0.50/365.25
#'     E1684$cage = as.numeric(scale(E1684$age))
#'     E1690$cage = as.numeric(scale(E1690$age))
#'     data_list = list(currdata = E1690, histdata = E1684)
#'     nbreaks = 3
#'     probs   = 1:nbreaks / nbreaks
#'     breaks  = as.numeric(
#'       quantile(E1690[E1690$failcens==1, ]$failtime, probs = probs)
#'     )
#'     breaks  = c(0, breaks)
#'     breaks[length(breaks)] = max(10000, 1000 * breaks[length(breaks)])
#'     fit.pwe.pp = pwe.pp(
#'       formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#'       data.list = data_list,
#'       breaks = breaks,
#'       a0 = 0.5,
#'       get.loglik = TRUE,
#'       chains = 1, iter_warmup = 1000, iter_sampling = 2000
#'     )
#'     fit.pwe.post = pwe.post(
#'       formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#'       data.list = data_list,
#'       breaks = breaks,
#'       get.loglik = TRUE,
#'       chains = 1, iter_warmup = 1000, iter_sampling = 2000
#'     )
#'     fit.pwe.commensurate = pwe.commensurate(
#'       formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#'       data.list = data_list,
#'       breaks = breaks,
#'       p.spike = 0.1,
#'       get.loglik = TRUE,
#'       chains = 1, iter_warmup = 1000, iter_sampling = 2000
#'     )
#'     fit.list = list(fit.pwe.post, fit.pwe.pp, fit.pwe.commensurate)
#'     samples.mtx = do.call(
#'      cbind, lapply(fit.list, function(d){
#'        as.numeric( d[["treatment"]] )
#'      })
#'     )
#'     wts = compute.ensemble.weights(
#'       fit.list = fit.list,
#'       type = "pseudobma+"
#'     )$weights
#'     sample.ensemble(
#'      wts = wts, samples.mtx = samples.mtx
#'     )
#'   }
#' }
sample.ensemble = function(
    wts,
    samples.mtx
) {
  if( ncol(samples.mtx) != length(wts) ){
    stop("The number of columns in samples.mtx must match the length of wts.")
  }

  n                    = nrow(samples.mtx)
  samples.mtx.permuted = samples.mtx[sample(x = seq_len(n), size = n, replace = F), ]
  wts                  = as.numeric(wts)

  ## draw n i.i.d. samples (c0) from categorical distribution with probability being `wts`
  c0          = sample(x = seq_len(length(wts)), size = n, replace = T, prob = wts)
  models      = unique(c0)
  res.samples = lapply(models, function(j){
    nsample = sum(c0 == j)
    return( samples.mtx.permuted[1:nsample, j] )
  })
  res.samples = unlist(res.samples)
  return(res.samples)
}
