Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subsampling LOO estimates with diff-est-srs-wor start #496

Open
wants to merge 128 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
bba6adf
subsampling LOO estimates with diff-est-srs-wor start
avehtari Apr 10, 2024
781e331
put back unnecessarily removed weights back in kfold
avehtari Apr 10, 2024
d50f4bd
Revert "put back unnecessarily removed weights back in kfold"
avehtari Apr 10, 2024
4de50b7
subsampling LOO for acc and pctcor
avehtari Apr 10, 2024
c50bdbf
ignore nloo if validate_search=FALSE
avehtari Apr 10, 2024
649f0ad
fix tests
avehtari Apr 11, 2024
aba8670
fix mse interval for delta=TRUE
avehtari Apr 12, 2024
81b5dc3
don't stop due to repeated arguments
avehtari Apr 12, 2024
0ed8391
normal approximation for mse, rmse, and R2
avehtari Apr 15, 2024
a36a3d8
rename internal function select -> .select
avehtari Apr 16, 2024
2c846a3
with delta and mse/rmse/R2/acc/pctcorr/auc, plot values in orig scale
avehtari Apr 16, 2024
fbda70e
don't warn about subsampling
avehtari Apr 16, 2024
1fa7fcd
improve messages
avehtari Apr 17, 2024
7e7fc7a
if available, use progressr for parallel progress bar
avehtari Apr 17, 2024
09223c6
verbosity improvements
avehtari Apr 25, 2024
45e22a6
fix
avehtari Apr 25, 2024
1cb78fc
use do_call instead of do.call
avehtari Apr 25, 2024
98757a9
add progress and progressr to Suggests
avehtari Apr 25, 2024
bf5af29
Merge branch 'master' into fix-subsampling
avehtari Jun 11, 2024
01f3bed
remove unneeded code
avehtari Jun 27, 2024
a5c5103
remove unnecessary sum
avehtari Jun 27, 2024
a223725
revert the addition of correct_baseline
avehtari Jun 28, 2024
24fc370
remove unneeded code
avehtari Jun 28, 2024
36f3543
document deltas=TRUE change
avehtari Jun 28, 2024
eeef49a
wcv -> wobs in summary_funs
avehtari Jun 28, 2024
1fc4669
newline in startup message to make it more readable
avehtari Jun 28, 2024
9139ce0
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Jun 30, 2024
10ab731
rename remaining occurrences of `wcv` to `wobs`
fweber144 Jun 30, 2024
c23fa78
re-add a comment
fweber144 Jun 30, 2024
b81e401
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Jul 4, 2024
71a2198
progressr: remove code that is part of the end-user's API (see
fweber144 Jul 11, 2024
b41ad1e
use `use_progressr` consistently
fweber144 Jul 11, 2024
81b4bf0
package `progress` is no longer needed in the "Suggests" dependencies
fweber144 Jul 11, 2024
497a245
add function `get_use_progressr()` to avoid redundancies;
fweber144 Jul 11, 2024
5ac4041
rename `p` to `progressor_obj` to identify it more clearly and
fweber144 Jul 11, 2024
825fc3d
use argument `steps` of `progressr::progressor()` explicitly and
fweber144 Jul 11, 2024
1655c84
remove unnecessary `""` in the `progressor_obj()` call
fweber144 Jul 11, 2024
1633137
use a simpler solution for identifying whether `progressr` should be …
fweber144 Jul 11, 2024
0996e7c
add the possibility to use `progressr` at the remaining occurrences o…
fweber144 Jul 11, 2024
8488e3e
fix a bug (`could not find function "do_call"`) when using the `doFut…
fweber144 Jul 12, 2024
118838c
remove `.select <- .select` (the issue does not occur when installing
fweber144 Jul 15, 2024
d663df8
fix a bug when checking arguments in `cv_varsel.vsel()`
fweber144 Jul 17, 2024
b6fe949
minor cleaning for consistency
fweber144 Jul 17, 2024
d9fdbaf
don't include the argument content in the message as the argument con…
fweber144 Jul 17, 2024
5b2f283
add functionality for option deltas='mixed'
avehtari Aug 11, 2024
17be148
remove option `baseline = "best"`
avehtari Aug 11, 2024
6b6520a
attempt to fix vsel.summary
avehtari Aug 11, 2024
fc8c665
move code for deltas='mixed' to plot.vsel
avehtari Aug 12, 2024
99e9c79
fixes
avehtari Aug 13, 2024
cd4b248
docs: fix minor typos
fweber144 Jul 18, 2024
7176dfe
avoid `object` within `cv_varsel.refmodel()` (for consistency; I don'…
fweber144 Jul 18, 2024
4109aae
fix a verbose message (at `?projpred::cv_varsel`, the documentation for
fweber144 Jul 18, 2024
7616128
mention thinning in the verbose message which gives information about
fweber144 Jul 18, 2024
48a6af1
minor cleaning
fweber144 Aug 18, 2024
c0593c6
fix usage of argument `summaries_fast` (at that place, `sel_cv$summar…
fweber144 Aug 18, 2024
0c85eb6
use argument `summaries_fast` as it was probably intended to
fweber144 Aug 18, 2024
4f893f9
fixup! use argument `summaries_fast` as it was probably intended to
fweber144 Aug 18, 2024
d40acd8
fixup! fixup! use argument `summaries_fast` as it was probably intend…
fweber144 Aug 18, 2024
d1d37bd
fix input for argument `search_path_fulldata` when running fast LOO-C…
fweber144 Aug 18, 2024
b9f8368
for argument `verbose`, default to a new global option:
fweber144 Aug 19, 2024
a9ce55f
argument `summaries_fast` should not change either (when calling `cv_…
fweber144 Aug 20, 2024
23b4a2a
remove unused object `n_arg_nms_internal_used`
fweber144 Aug 20, 2024
cc26cf7
define `arg_nms_internal_used` more straightforwardly
fweber144 Aug 20, 2024
66f2e34
minor enhancements
fweber144 Aug 20, 2024
e61f916
fix a verbose message (at `?projpred::cv_varsel`, the documentation for
fweber144 Aug 20, 2024
750c81a
fix a message when using standard importance sampling (SIS)
fweber144 Aug 20, 2024
f84fe55
remove fragment `verb_txt_start <-`
fweber144 Aug 22, 2024
8dde0bc
fix verbose message
fweber144 Aug 22, 2024
9c4d1a4
docs: abbreviate the performance statistics appropriately
fweber144 Aug 23, 2024
ea18e9a
UNFINISHED: move out the new "mixed deltas" variant of `plot.vsel()`,…
fweber144 Aug 23, 2024
aea7f08
Revert "UNFINISHED: move out the new "mixed deltas" variant of `plot.…
fweber144 Aug 23, 2024
867f29f
in `.onAttach()`, keep the temporary "NOTE" in separate lines (to
fweber144 Aug 23, 2024
2d9a652
add comments in `summary_funs.R`
fweber144 Aug 23, 2024
35cc542
simplify `summaries_fast_sub <- varsel$summaries_fast$sub` and `summa…
fweber144 Aug 23, 2024
19948e3
`loo_inds` as stored in `vsel` objects was unused so far
fweber144 Aug 23, 2024
f26e4fb
in `get_stat()`, the `is.null(summaries_fast)` checks are not necessary
fweber144 Aug 23, 2024
319c2b7
avoid object name `n` at more places
fweber144 Aug 23, 2024
dfcc58c
the definition of `loo_ref_oscale` does not make sense to be placed a…
fweber144 Aug 23, 2024
f53bc48
simplify an SRS-WOR `value` computation (if `mu_baseline` is `NULL`, …
fweber144 Aug 23, 2024
59afc97
simplify initialization of `est_list`
fweber144 Aug 23, 2024
2727cff
avoid redundant computations by moving `sqrt(srs_diffe$v_y_hat + srs_…
fweber144 Aug 23, 2024
52fd9c3
add an early error for `!validate_search && nloo < refmodel[["nobs"]]`
fweber144 Aug 23, 2024
81ec495
add a comment and a check in `loo_varsel()` for `!validate_search && …
fweber144 Aug 23, 2024
348827c
Revert "avoid redundant computations by moving `sqrt(srs_diffe$v_y_ha…
fweber144 Aug 26, 2024
f4f9760
simplify definitions of `mu_baseline` (possible because
fweber144 Aug 26, 2024
519dac2
fixup! `loo_inds` as stored in `vsel` objects was unused so far
fweber144 Aug 26, 2024
a7458b9
move out the new "mixed deltas" variant of `plot.vsel()`, the
fweber144 Aug 23, 2024
43878e1
use a consistent order of the `if` cases differentiating between
fweber144 Aug 29, 2024
925f2cd
remove unused `var_mse_e` definition
fweber144 Aug 29, 2024
040d05e
there was only one use of `var_mse_e` and since `value_se`
fweber144 Aug 29, 2024
0d73c8e
remove unused `mu_baseline` (only unused in case of
fweber144 Aug 29, 2024
288d948
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Sep 15, 2024
6a85b41
re-document
fweber144 Sep 18, 2024
ef33da6
add a placeholder for the documentation of argument `summaries_fast` …
fweber144 Sep 18, 2024
cdaf384
avoid partial argument matching of 'w' to 'wobs' in `.srs_diff_est_w(…
fweber144 Sep 18, 2024
8aacd5d
fixup! remove unused `mu_baseline` (only unused in case of
fweber144 Sep 18, 2024
9e79883
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Sep 25, 2024
1127412
`vsel_obj$nloo` can be `NULL` (for `vsel_obj` created by `varsel()`), so
fweber144 Sep 25, 2024
c331cb4
fix the `get_stat()` call for the reference model statistics (`loo_inds`
fweber144 Sep 25, 2024
8389d39
Tests: Subsampled PSIS-LOO-CV is not supported for `validate_search =…
fweber144 Sep 25, 2024
a246dd3
fix the `stat %in% c("acc", "pctcorr", "auc")` case in `get_stat()`
fweber144 Sep 26, 2024
2360630
fix `.tabulate_stats()` (`catmaxprb()` also needs to be
fweber144 Sep 26, 2024
7469aea
fix `.tabulate_stats()` (several steps in case of the latent projection
fweber144 Sep 26, 2024
0dcfcf2
Since `summaries_fast` is created by a call to `loo_varsel()` with
fweber144 Oct 9, 2024
6151b0c
Revert changes that are unrelated to subsampled LOO-CV (to find out
fweber144 Oct 16, 2024
360354a
Adapt the existing tests to work with the new implementation of subsa…
fweber144 Sep 26, 2024
0bd3830
fix the early check for all-`NA`s in `get_stat()`
fweber144 Nov 4, 2024
3923005
Revert "fix the early check for all-`NA`s in `get_stat()`"
fweber144 Nov 4, 2024
da4272a
avoid the early check for all-`NA`s in `get_stat()`
fweber144 Nov 4, 2024
fbabed5
replace `n_full <- sum(!is.na(mu))` with `n_full <- length(mu)` because
fweber144 Nov 17, 2024
8666643
tests: `NA`s in summaries should be `NA_real_`s now
fweber144 Nov 17, 2024
0a6b8ba
divide by `(n_full - 1)` instead of `n_full` where necessary, see
fweber144 Dec 22, 2024
e8f17e8
add test `".srs_diff_est_w() works as expected"` (copied from 'loo', see
fweber144 Dec 22, 2024
df17c08
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Dec 29, 2024
caeab04
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Dec 30, 2024
f173ef4
explain "fast LOO" and "full LOO" in the docs, see
fweber144 Jan 4, 2025
8e563b4
make computation of `var_e_i` in `.srs_diff_est_w()` numerically more…
fweber144 Jan 5, 2025
de1cf4d
fix a comment, see
fweber144 Jan 5, 2025
cac6dee
minor cleaning
fweber144 Jan 6, 2025
1184a4e
fix a comment, see
fweber144 Jan 6, 2025
a6b7988
replace `mean(y)` with `mean(wobs * y)`, see
fweber144 Jan 6, 2025
e300620
avoid redundant computations by introducing object `y_mean_w`
fweber144 Jan 6, 2025
8d25990
use `log1p()` for numerical stability (in `get_stat()`), see
fweber144 Jan 6, 2025
7e61368
drop quotes from "exact" in a comment, see
fweber144 Jan 12, 2025
9260fd0
fix docs for RMSE and R2
fweber144 Jan 12, 2025
1e3544a
add `stat = "R2"` to the tests;
fweber144 Jan 12, 2025
f132e70
set a negative squared standard error which is numerically equal to zero
fweber144 Jan 16, 2025
f2992dd
fix first-order Taylor approximation of the variance (delta method)
fweber144 Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 57 additions & 95 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
#' is performed, which avoids refitting the reference model `nloo` times (in
#' contrast to a standard LOO-CV). In the `"kfold"` case, a \eqn{K}-fold-CV is
#' performed. See also section "Note" below.
#' @param nloo **Caution:** Still experimental. Only relevant if `cv_method =
#' "LOO"`. If `nloo` is smaller than the number of all observations,
#' approximate full LOO-CV using probability-proportional-to-size-sampling
#' (PPS) to make accurate computation only for `nloo` (anything from 1 to the
#' number of all observations) leave-one-out folds (Magnusson et al., 2019).
#' Smaller values lead to faster computation but higher uncertainty in the
#' evaluation part. If `NULL`, all observations are used (as by default).
#' @param nloo Only relevant if `cv_method = "LOO"` and `validate_search =
#' TRUE`. If `nloo > 0` is smaller than the number of all observations, full
#' LOO-CV (i.e., PSIS-LOO CV with `validate_search = TRUE` and with `nloo =
#' n` where `n` denotes the number of all observations) is approximated by
#' combining the fast (i.e., `validate_search = FALSE`) LOO result for the
#' selected models and `nloo` leave-one-out searches using the difference
#' estimator with simple random sampling (SRS) without replacement (WOR)
#' (Magnusson et al., 2020). Smaller `nloo` values lead to faster computation,
#' but higher uncertainty in the evaluation part. If `NULL`, all observations
#' are used (as by default).
#' @param K Only relevant if `cv_method = "kfold"` and if `cvfits` is `NULL`
#' (which is the case for reference model objects created by
#' [get_refmodel.stanreg()] or [brms::get_refmodel.brmsfit()]). Number of
Expand All @@ -36,14 +39,11 @@
#' [run_cvfun()] can be inserted here straightforwardly.
#' @param validate_search A single logical value indicating whether to
#' cross-validate also the search part, i.e., whether to run the search
#' separately for each CV-fold (`TRUE`) or not (`FALSE`). We strongly do not
#' recommend setting this to `FALSE`, because this is known to bias the
#' predictive performance estimates of the selected submodels. However,
#' setting this to `FALSE` can sometimes be useful because comparing the
#' results to the case where this argument is `TRUE` gives an idea of how
#' strongly the search is (over-)fitted to the data (the difference
#' corresponds to the search degrees of freedom or the effective number of
#' parameters introduced by the search).
#' separately for each CV-fold (`TRUE`) or not (`FALSE`). With `FALSE`
#' the computation is faster, but the predictive performance estimates
#' of the selected submodels are optimistically biased. However, these fast
#' biased estimated can be useful to obtain initial information on the
#' usefulness of projection predictive variable selection.
#' @param seed Pseudorandom number generation (PRNG) seed by which the same
#' results can be obtained again if needed. Passed to argument `seed` of
#' [set.seed()], but can also be `NA` to not call [set.seed()] at all. If not
Expand Down Expand Up @@ -370,6 +370,7 @@ cv_varsel.refmodel <- function(
search_out_rks <- NULL
}

summaries_fast <- NULL
if (cv_method == "LOO") {
sel_cv <- loo_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
Expand All @@ -388,6 +389,21 @@ cv_varsel.refmodel <- function(
search_terms_was_null = search_terms_was_null,
search_out_rks = search_out_rks, parallel = parallel, ...
)
if (validate_search && nloo < refmodel$nobs) {
# Run fast LOO-CV to be used in subsampling difference estimator
summaries_fast <- loo_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, search_control = search_control,
nloo = refmodel$nobs, # fast LOO-CV (using all observations)
validate_search = FALSE, # fast LOO-CV (using all observations)
search_path_fulldata = search_path_fulldata,
search_terms = search_terms,
search_terms_was_null = search_terms_was_null,
search_out_rks = search_out_rks, parallel = parallel, ...
)[["summaries"]]
}
} else if (cv_method == "kfold") {
sel_cv <- kfold_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
Expand Down Expand Up @@ -440,11 +456,13 @@ cv_varsel.refmodel <- function(
y_wobs_test,
nobs_test = nrow(y_wobs_test),
summaries = sel_cv$summaries,
summaries_fast,
nterms_all,
nterms_max,
method,
cv_method,
nloo,
loo_inds = sel_cv$inds,
K,
validate_search,
cvfits,
Expand Down Expand Up @@ -528,9 +546,10 @@ parse_args_cv_varsel <- function(refmodel, cv_method, nloo, K, cvfits,
nloo <- min(nloo, refmodel[["nobs"]])
if (nloo < 1) {
stop("nloo must be at least 1")
} else if (nloo < refmodel[["nobs"]] &&
getOption("projpred.warn_subsampled_loo", TRUE)) {
warning("Subsampled PSIS-LOO-CV is still experimental.")
}
if (!validate_search && nloo < refmodel[["nobs"]]) {
stop("Subsampled PSIS-LOO-CV is not supported for ",
"`validate_search = FALSE`.")
}
}

Expand Down Expand Up @@ -674,11 +693,15 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
refmodel$y <- y_lat_E$value
}

# LOO PPS subsampling (by default, don't subsample, but use all observations):
# validset <- loo_subsample(n, nloo, pareto_k)
loo_ref_oscale <- apply(loglik_forPSIS + lw, 2, log_sum_exp)
validset <- loo_subsample_pps(nloo, loo_ref_oscale)
inds <- validset$inds

if (validate_search && nloo < n) {
# Select which LOO-folds get more accurate computation using simple
# random sampling without resampling (Magnusson et al., 2020)
inds <- sample.int(n, size = nloo, replace = FALSE)
} else {
inds <- seq_len(n)
}

# Initialize objects where to store the results:
loo_sub <- replicate(nterms_max + 1L, rep(NA, n), simplify = FALSE)
Expand Down Expand Up @@ -711,6 +734,15 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
if (!validate_search) {
## Case `validate_search = FALSE` -----------------------------------------

# NOTE: The case where `inds` is an actual subset of the set of all
# observation indices should never occur here in the
# `validate_search = FALSE` case. Thus, in principle, the code could be
# simplified here, but keeping `inds` in case this might be helpful in the
# future.
if (nloo < n) {
stop("`nloo < n` is unexpected if `validate_search = FALSE`")
}

# "Run" the performance evaluation for the submodels along the predictor
# ranking (in fact, we only prepare the performance evaluation by computing
# precursor quantities, but for users, this difference is not perceivable):
Expand Down Expand Up @@ -1072,10 +1104,9 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,

# Submodel predictive performance:
summ_sub <- lapply(seq_len(prv_len_rk + 1L), function(k) {
summ_k <- list(lppd = loo_sub[[k]], mu = mu_sub[[k]], wcv = validset$wcv)
summ_k <- list(lppd = loo_sub[[k]], mu = mu_sub[[k]])
if (refmodel$family$for_latent) {
summ_k$oscale <- list(lppd = loo_sub_oscale[[k]], mu = mu_sub_oscale[[k]],
wcv = validset$wcv)
summ_k$oscale <- list(lppd = loo_sub_oscale[[k]], mu = mu_sub_oscale[[k]])
}
return(summ_k)
})
Expand Down Expand Up @@ -1192,7 +1223,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
out_list <- c(out_list,
nlist(summaries,
y_wobs_test = as.data.frame(refmodel[nms_y_wobs_test()]),
clust_used_eval, nprjdraws_eval))
clust_used_eval, nprjdraws_eval, inds))
return(out_list)
}

Expand Down Expand Up @@ -1371,15 +1402,9 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
summ$mu <- summ$mu[order(idxs_sorted_by_fold_flx)]
summ$lppd <- summ$lppd[order(idxs_sorted_by_fold)]

# Add fold-specific weights (see the discussion at GitHub issue #94 for why
# this might have to be changed):
summ$wcv <- rep(1, length(summ$lppd))
summ$wcv <- summ$wcv / sum(summ$wcv)

if (!is.null(summ$oscale)) {
summ$oscale$mu <- summ$oscale$mu[order(idxs_sorted_by_fold_aug)]
summ$oscale$lppd <- summ$oscale$lppd[order(idxs_sorted_by_fold)]
summ$oscale$wcv <- summ$wcv
}
return(summ)
})
Expand Down Expand Up @@ -1577,69 +1602,6 @@ run_cvfun.refmodel <- function(object,
return(structure(cvfits, folds = folds))
}

# PSIS-LOO-CV helpers -----------------------------------------------------

# ## decide which points to go through in the validation (i.e., which points
# ## belong to the semi random subsample of validation points)
# loo_subsample <- function(n, nloo, pareto_k) {
# # Note: A seed is not set here because this function is not exported and has
# # a calling stack at the beginning of which a seed is set.
#
# resample <- function(x, ...) x[sample.int(length(x), ...)]
#
# if (nloo < n) {
# bad <- which(pareto_k > 0.7)
# ok <- which(pareto_k <= 0.7 & pareto_k > 0.5)
# good <- which(pareto_k <= 0.5)
# inds <- resample(bad, min(length(bad), floor(nloo / 3)))
# inds <- c(inds, resample(ok, min(length(ok), floor(nloo / 3))))
# inds <- c(inds, resample(good, min(length(good), floor(nloo / 3))))
# if (length(inds) < nloo) {
# ## not enough points selected, so choose randomly among the rest
# inds <- c(inds, resample(setdiff(seq_len(n), inds), nloo - length(inds)))
# }
#
# ## assign the weights corresponding to this stratification (for example,
# ## the 'bad' values are likely to be overpresented in the sample)
# wcv <- rep(0, n)
# wcv[inds[inds %in% bad]] <- length(bad) / sum(inds %in% bad)
# wcv[inds[inds %in% ok]] <- length(ok) / sum(inds %in% ok)
# wcv[inds[inds %in% good]] <- length(good) / sum(inds %in% good)
# } else {
# ## all points used
# inds <- seq_len(n)
# wcv <- rep(1, n)
# }
#
# ## ensure weights are normalized
# wcv <- wcv / sum(wcv)
#
# return(nlist(inds, wcv))
# }

## Select which points to go through in the validation based on
## proportional-to-size subsampling (PPS) as proposed by Magnusson, M.,
## Andersen, M. R., Jonasson, J. and Vehtari, A. (2019). Leave-One-Out
## Cross-Validation for Large Data. In *Proceedings of
## the 36th International Conference on Machine Learning*, edited by Kamalika
## Chaudhuri and Ruslan Salakhutdinov, 97:4244--53. Proceedings of Machine
## Learning Research. PMLR. <https://proceedings.mlr.press/v97/magnusson19a.html>.
loo_subsample_pps <- function(nloo, lppd) {
# Note: A seed is not set here because this function is not exported and has a
# calling stack at the beginning of which a seed is set.

if (nloo == length(lppd)) {
inds <- seq_len(nloo)
wcv <- rep(1, nloo)
} else if (nloo < length(lppd)) {
wcv <- exp(lppd - max(lppd))
inds <- sample(seq_along(lppd), size = nloo, prob = wcv)
}
wcv <- wcv / sum(wcv)

return(nlist(inds, wcv))
}

#' Pareto-smoothing k-hat threshold
#'
#' Copied from loo package. Remove after loo package exposes this.
Expand Down
32 changes: 18 additions & 14 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ plot.vsel <- function(
# Parse input:
object <- x
validate_vsel_object_stats(object, stats, resp_oscale = resp_oscale)
baseline <- validate_baseline(object$refmodel, baseline, deltas)
baseline <- validate_baseline(object, baseline, deltas)
if (!is.null(ranking_repel) && !requireNamespace("ggrepel", quietly = TRUE)) {
warning("Package 'ggrepel' is needed for a non-`NULL` argument ",
"`ranking_repel`, but could not be found. Setting `ranking_repel` ",
Expand Down Expand Up @@ -1065,11 +1065,11 @@ plot.vsel <- function(
# direction = 1)
###
}
if (all(stats %in% c("rmse", "auc"))) {
if (all(stats %in% c("auc"))) {
ci_type <- "bootstrap "
} else if (all(stats %in% c("gmpd"))) {
ci_type <- "exponentiated normal-approximation "
} else if (all(!stats %in% c("rmse", "auc", "gmpd"))) {
} else if (all(!stats %in% c("auc", "gmpd"))) {
ci_type <- "normal-approximation "
} else {
ci_type <- ""
Expand Down Expand Up @@ -1158,23 +1158,26 @@ plot.vsel <- function(
#' are again all observations because the test set is the same as the training
#' set). Available statistics are:
#' * `"elpd"`: expected log (pointwise) predictive density (for a new
#' dataset). Estimated by the sum of the observation-specific log predictive
#' density values (with each of these predictive density values being
#' a---possibly weighted---average across the parameter draws).
#' * `"mlpd"`: mean log predictive density, that is, `"elpd"` divided by the
#' number of observations.
#' dataset) (ELPD). Estimated by the sum of the observation-specific log
#' predictive density values (with each of these predictive density values
#' being a---possibly weighted---average across the parameter draws).
#' * `"mlpd"`: mean log predictive density (MLPD), that is, the ELPD divided
#' by the number of observations.
#' * `"gmpd"`: geometric mean predictive density (GMPD), that is, [exp()] of
#' `"mlpd"`. The GMPD is especially helpful for discrete response families
#' the MLPD. The GMPD is especially helpful for discrete response families
#' (because there, the GMPD is bounded by zero and one). For the corresponding
#' standard error, the delta method is used. The corresponding confidence
#' interval type is "exponentiated normal approximation" because the
#' confidence interval bounds are the exponentiated confidence interval bounds
#' of the `"mlpd"`.
#' of the MLPD.
#' * `"mse"`: mean squared error (only available in the situations mentioned
#' in section "Details" below).
#' * `"rmse"`: root mean squared error (only available in the situations
#' mentioned in section "Details" below). For the corresponding standard error
#' and lower and upper confidence interval bounds, bootstrapping is used.
#' mentioned in section "Details" below). For the corresponding standard
#' error, the delta method is used.
#' * `"R2"`: R-squared, i.e., coefficient of determination (only available in
#' the situations mentioned in section "Details" below). For the corresponding
#' standard error, the delta method is used.
#' * `"acc"` (or its alias, `"pctcorr"`): classification accuracy (only
#' available in the situations mentioned in section "Details" below). By
#' "classification accuracy", we mean the proportion of correctly classified
Expand Down Expand Up @@ -1222,7 +1225,8 @@ plot.vsel <- function(
#' and `seed` (see [set.seed()], but defaulting to `NA` so that [set.seed()]
#' is not called within that function at all).
#'
#' @details The `stats` options `"mse"` and `"rmse"` are only available for:
#' @details The `stats` options `"mse"`, `"rmse"`, and `"R2"` are only available
#' for:
#' * the traditional projection,
#' * the latent projection with `resp_oscale = FALSE`,
#' * the latent projection with `resp_oscale = TRUE` in combination with
Expand Down Expand Up @@ -1283,7 +1287,7 @@ summary.vsel <- function(
...
) {
validate_vsel_object_stats(object, stats, resp_oscale = resp_oscale)
baseline <- validate_baseline(object$refmodel, baseline, deltas)
baseline <- validate_baseline(object, baseline, deltas)

# Initialize output:
out <- c(
Expand Down
Loading