diff --git a/DESCRIPTION b/DESCRIPTION index a6621fb2..a57654b8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: hal9001 Title: The Scalable Highly Adaptive Lasso -Version: 0.4.3 +Version: 0.4.6 Authors@R: c( person("Jeremy", "Coyle", email = "jeremyrcoyle@gmail.com", role = c("aut", "cre"), @@ -68,5 +68,5 @@ LinkingTo: Rcpp, RcppEigen VignetteBuilder: knitr -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.3 Roxygen: list(markdown = TRUE) diff --git a/NAMESPACE b/NAMESPACE index 495ea7e5..3093718a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -26,18 +26,15 @@ importFrom(data.table,setorder) importFrom(glmnet,cv.glmnet) importFrom(glmnet,glmnet) importFrom(methods,is) -importFrom(origami,cross_validate) importFrom(origami,folds2foldvec) importFrom(origami,make_folds) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,coef) -importFrom(stats,gaussian) importFrom(stats,median) importFrom(stats,plogis) importFrom(stats,predict) importFrom(stats,quantile) -importFrom(stats,sd) importFrom(stringr,str_detect) importFrom(stringr,str_extract) importFrom(stringr,str_match) diff --git a/NEWS.md b/NEWS.md index 7ba8c7b0..c04ae0c2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,24 @@ +# hal9001 0.4.6 +* Fixed predict method to address changes required by Matrix 1.6.2 + +# hal9001 0.4.5 +* Added multivariate outcome prediction + +# hal9001 0.4.4 +* Fixed bug with `prediction_bounds` (a `fit_hal` argument in `fit_control` + list), which would error when it was specified as a numeric vector. Also, + added a check to assert this argument is correctly specified, and tests + to ensure a numeric vector of bounds is provided. +* Simplified `fit_control` list arguments in `fit_hal`. Users can still specify + additional arguments to `cv.glmnet` and `glmnet` in this list. +* Defined `weights` as a formal argument in `fit_hal`, opposed to an optional + argument in `fit_control`, to facilitate specification and avoid confusion. + This increases flexibility with SuperLearner wrapper `SL.hal9001` as well; + `fit_control` can now be customized with `SL.hal9001`. + +# hal9001 0.4.3 +* Version bump for CRAN resubmission following archiving. + # hal9001 0.4.2 * Version bump for CRAN resubmission following archiving. diff --git a/R/cv_lasso.R b/R/cv_lasso.R deleted file mode 100644 index a7fece86..00000000 --- a/R/cv_lasso.R +++ /dev/null @@ -1,72 +0,0 @@ -#' Cross-validated Lasso on Indicator Bases -#' -#' Fits Lasso regression using a customized procedure, with cross-validation -#' based on \pkg{origami} -#' -#' @param x_basis A \code{dgCMatrix} object corresponding to a sparse matrix of -#' the basis functions generated for the HAL algorithm. -#' @param y A \code{numeric} vector of the observed outcome variable values. -#' @param n_lambda A \code{numeric} scalar indicating the number of values of -#' the L1 regularization parameter (lambda) to be obtained from fitting the -#' Lasso to the full data. Cross-validation is used to select an optimal -#' lambda (that minimizes the risk) from among these. -#' @param n_folds A \code{numeric} scalar for the number of folds to be used in -#' the cross-validation procedure to select an optimal value of lambda. -#' @param center binary. If \code{TRUE}, covariates are centered. This is much -#' slower, but matches the \code{glmnet} implementation. Default \code{FALSE}. -#' -#' @importFrom origami make_folds cross_validate -#' @importFrom stats sd -cv_lasso <- function(x_basis, y, n_lambda = 100, n_folds = 10, - center = FALSE) { - # first, need to run lasso on the full data to get a sequence of lambdas - lasso_init <- lassi(y = y, x = x_basis, nlambda = n_lambda) - lambdas_init <- lasso_init$lambdas - - # next, set up a cross-validated lasso using the sequence of lambdas - full_data_mat <- cbind(y, x_basis) - folds <- origami::make_folds(full_data_mat, V = n_folds) - - # run the cross-validated lasso procedure to find the optimal lambda - cv_lasso_out <- origami::cross_validate( - cv_fun = lassi_origami, - folds = folds, - data = full_data_mat, - lambdas = lambdas_init, - center = center - ) - - # compute cv-mean of MSEs for each lambda - lambdas_cvmse <- colMeans(cv_lasso_out$mses) - - # find the lambda that minimizes the MSE - lambda_optim_index <- which.min(lambdas_cvmse) - lambda_minmse <- lambdas_init[lambda_optim_index] - - # also need the adjusted CV standard deviation for each lambda - lambdas_cvsd <- apply(cv_lasso_out$mses, 2, sd) / sqrt(n_folds) - - # find the maximum lambda among those 1 standard error above the minimum - lambda_min_1se <- (lambdas_cvmse + lambdas_cvsd)[lambda_optim_index - 1] - lambda_1se <- max(lambdas_init[lambdas_cvmse <= lambda_min_1se], - na.rm = TRUE - ) - lambda_1se_index <- which.min(abs(lambdas_init - lambda_1se)) - - # create output object - get_lambda_indices <- c(lambda_1se_index, lambda_optim_index) - betas_out <- lasso_init$beta_mat[, get_lambda_indices] - colnames(betas_out) <- c("lambda_1se", "lambda_min") - - # add in intercept term to coefs matrix and convert to sparse matrix output - betas_out <- rbind(rep(mean(y), ncol(betas_out)), betas_out) - betas_out <- as_dgCMatrix(betas_out * 1.0) - - # create output object - cv_lasso_out <- list(betas_out, lambda_minmse, lambda_1se, lambdas_cvmse) - names(cv_lasso_out) <- c( - "betas_mat", "lambda_min", "lambda_1se", - "lambdas_cvmse" - ) - return(cv_lasso_out) -} diff --git a/R/formula_hal9001.R b/R/formula_hal9001.R index 019260df..fe3b85a2 100644 --- a/R/formula_hal9001.R +++ b/R/formula_hal9001.R @@ -143,7 +143,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, var_names <- unlist(list(...)) } formula_term <- paste0("h(", paste0(var_names, collapse = ", "), ")") - + if (is.null(k)) { k <- get("num_knots", envir = parent.frame()) k <- suppressWarnings(k + rep(0, length(var_names))) # recycle @@ -152,8 +152,8 @@ h <- function(..., k = NULL, s = NULL, pf = 1, if (is.null(s)) { s <- get("smoothness_orders", envir = parent.frame())[1] } - - + + if ("." %in% var_names) { var_names_filled <- fill_dots(var_names, . = .) @@ -186,7 +186,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, return(all_items) } - + # Get corresponding column indices @@ -240,6 +240,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, return(out) } + #' Print formula_hal9001 object #' #' @param x A formula_hal9001 object. @@ -250,16 +251,18 @@ print.formula_hal9001 <- function(x, ...) { cat(paste0("A hal9001 formula object of the form: ~ ", x$formula_term)) } -#' Formula Helpers #' -#' @param var_names A \code{character} vector of variable names. +#' @param var_names A \code{character} vector of variable names representing a single type of interaction +# " (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3. +#' var_names may include the wildcard variable "." in which case the argument `.` must be specified +#' so that all interactions matching the form of var_names are generated. #' @param . Specification of variables for use in the formula. -#' -#' @name formula_helpers -NULL - +#' This function takes a character vector `var_names` of the form c(name1, name2, ".", name3, ".") +#' with any number of name{int} variables and any number of wild card variables ".". +#' It returns a list of character vectors of the form c(name1, name2, wildcard1, name3, wildcard2) +#' where wildcard1 and wildcard2 are iterated over all possible character names given in the argument `.`. #' @rdname formula_helpers -fill_dots_helper <- function(var_names, .) { +fill_dots <- function(var_names, .) { index <- which(var_names == ".") if (length(index) == 0) { return(sort(var_names)) @@ -269,26 +272,23 @@ fill_dots_helper <- function(var_names, .) { all_items <- lapply(., function(var) { new_var_names <- var_names new_var_names[index] <- var - out <- fill_dots_helper(new_var_names, .) - - if (is.list(out[[1]])) { - out <- unlist(out, recursive = FALSE) - } + out <- fill_dots(new_var_names, .) return(out) }) - - - return(unique(all_items)) -} - -#' @rdname formula_helpers -fill_dots <- function(var_names, .) { - x <- unique(unlist(fill_dots_helper(var_names, . = .), recursive = FALSE)) - keep <- sapply(x, function(item) { + is_nested <- is.list(all_items[[1]]) + while (is_nested) { + all_items <- unlist(all_items, recursive = FALSE) + is_nested <- is.list(all_items[[1]]) + } + # Remove combinations of variable names that have duplicates. + # This removes generated interactions that include two of the same variable. + keep <- sapply(all_items, function(item) { if (any(duplicated(item))) { return(FALSE) } return(TRUE) }) - return(x[keep]) + all_items <- all_items[keep] + + return(unique(all_items)) } diff --git a/R/hal.R b/R/hal.R index e7895761..e73d6d90 100644 --- a/R/hal.R +++ b/R/hal.R @@ -59,7 +59,9 @@ #' @param X An input \code{matrix} with dimensions number of observations -by- #' number of covariates that will be used to derive the design matrix of basis #' functions. -#' @param Y A \code{numeric} vector of observations of the outcome variable. +#' @param Y A \code{numeric} vector of observations of the outcome variable. For +#' \code{family="mgaussian"}, \code{Y} is a matrix of observations of the +#' outcome variables. #' @param formula A character string formula to be used in #' \code{\link{formula_hal}}. See its documentation for details. #' @param X_unpenalized An input \code{matrix} with the same number of rows as @@ -81,18 +83,20 @@ #' combinatorial explosions in the number of higher-degree and higher-order #' basis functions generated. This allows the complexity of the optimization #' problem to grow scalably. See details of \code{num_knots} more information. -#' @param reduce_basis A \code{numeric} value bounded in the open unit interval -#' indicating the minimum proportion of 1's in a basis function column needed -#' for the basis function to be included in the procedure to fit the lasso. -#' Any basis functions with a lower proportion of 1's than the cutoff will be -#' removed. When \code{reduce_basis} is set to \code{NULL}, all basis -#' functions are used in the lasso-fitting stage of \code{fit_hal}. +#' @param reduce_basis Am optional \code{numeric} value bounded in the open +#' unit interval indicating the minimum proportion of 1's in a basis function +#' column needed for the basis function to be included in the procedure to fit +#' the lasso. Any basis functions with a lower proportion of 1's than the +#' cutoff will be removed. Defaults to 1 over the square root of the number of +#' observations. Only applicable for models fit with zero-order splines, i.e. +#' \code{smoothness_orders = 0}. #' @param family A \code{character} or a \code{\link[stats]{family}} object #' (supported by \code{\link[glmnet]{glmnet}}) specifying the error/link #' family for a generalized linear model. \code{character} options are limited #' to "gaussian" for fitting a standard penalized linear model, "binomial" for #' penalized logistic regression, "poisson" for penalized Poisson regression, -#' and "cox" for a penalized proportional hazards model. Note that passing in +#' "cox" for a penalized proportional hazards model, and "mgaussian" for +#' multivariate penalized linear model. Note that passing in #' family objects leads to slower performance relative to passing in a #' character family (if supported). For example, one should set #' \code{family = "binomial"} instead of \code{family = binomial()} when @@ -108,38 +112,34 @@ #' @param id A vector of ID values that is used to generate cross-validation #' folds for \code{\link[glmnet]{cv.glmnet}}. This argument is ignored when #' \code{fit_control}'s \code{cv_select} argument is \code{FALSE}. +#' @param weights observation weights; defaults to 1 per observation. #' @param offset a vector of offset values, used in fitting. -#' @param fit_control List of arguments for fitting. Includes the following -#' arguments, and any others to be passed to \code{\link[glmnet]{cv.glmnet}} -#' or \code{\link[glmnet]{glmnet}}. +#' @param fit_control List of arguments, including the following, and any +#' others to be passed to \code{\link[glmnet]{cv.glmnet}} or +#' \code{\link[glmnet]{glmnet}}. #' - \code{cv_select}: A \code{logical} specifying if the sequence of #' specified \code{lambda} values should be passed to #' \code{\link[glmnet]{cv.glmnet}} in order for a single, optimal value of #' \code{lambda} to be selected according to cross-validation. When #' \code{cv_select = FALSE}, a \code{\link[glmnet]{glmnet}} model will be #' used to fit the sequence of (or single) \code{lambda}. -#' - \code{n_folds}: Integer for the number of folds to be used when splitting -#' the data for V-fold cross-validation. Only used when +#' - \code{use_min}: Specify the choice of lambda to be selected by +#' \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is +#' used; otherwise, \code{"lambda.1se"}. Only used when #' \code{cv_select = TRUE}. -#' - \code{foldid}: An optional \code{numeric} containing values between 1 and -#' \code{n_folds}, identifying the fold to which each observation is -#' assigned. If supplied, \code{n_folds} can be missing. In such a case, -#' this vector is passed directly to \code{\link[glmnet]{cv.glmnet}}. Only -#' used when \code{cv_select = TRUE}. -#' - \code{use_min}: Specify the choice of lambda to be selected by -#' \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is -#' used; otherwise, \code{"lambda.1se"}. Only used when -#' \code{cv_select = TRUE}. -#' - \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument specifying -#' the smallest value for \code{lambda}, as a fraction of \code{lambda.max}, -#' the (data derived) entry value (i.e. the smallest value for which all -#' coefficients are zero). We've seen that not setting \code{lambda.min.ratio} -#' can lead to no \code{lambda} values that fit the data sufficiently well. -#' - \code{prediction_bounds}: A vector of size two that provides the lower and -#' upper bounds for predictions. When \code{prediction_bounds = "default"}, -#' the predictions are bounded between \code{min(Y) - sd(Y)} and -#' \code{max(Y) + sd(Y)}. Bounding ensures that there is no extrapolation, -#' and it is necessary for cross-validation selection and/or Super Learning. +#' - \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument +#' specifying the smallest value for \code{lambda}, as a fraction of +#' \code{lambda.max}, the (data derived) entry value (i.e. the smallest value +#' for which all coefficients are zero). We've seen that not setting +#' \code{lambda.min.ratio} can lead to no \code{lambda} values that fit the +#' data sufficiently well. +#' - \code{prediction_bounds}: An optional vector of size two that provides +#' the lower and upper bounds predictions; not used when +#' \code{family = "cox"}. When \code{prediction_bounds = "default"}, the +#' predictions are bounded between \code{min(Y) - sd(Y)} and +#' \code{max(Y) + sd(Y)} for each outcome (when \code{family = "mgaussian"}, +#' each outcome can have different bounds). Bounding ensures that there is +#' no extrapolation. #' @param basis_list The full set of basis functions generated from \code{X}. #' @param return_lasso A \code{logical} indicating whether or not to return #' the \code{\link[glmnet]{glmnet}} fit object of the lasso model. @@ -182,15 +182,14 @@ fit_hal <- function(X, base_num_knots_0 = 200, base_num_knots_1 = 50 ), - reduce_basis = 1 / sqrt(length(Y)), - family = c("gaussian", "binomial", "poisson", "cox"), + reduce_basis = NULL, + family = c("gaussian", "binomial", "poisson", "cox", "mgaussian"), lambda = NULL, id = NULL, + weights = NULL, offset = NULL, fit_control = list( cv_select = TRUE, - n_folds = 10, - foldid = NULL, use_min = TRUE, lambda.min.ratio = 1e-4, prediction_bounds = "default" @@ -202,22 +201,33 @@ fit_hal <- function(X, if (!inherits(family, "family")) { family <- match.arg(family) } + fam <- ifelse(inherits(family, "family"), family$family, family) # errors when a supplied control list is missing arguments defaults <- list( - cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE, - lambda.min.ratio = 1e-4, prediction_bounds = "default" + cv_select = TRUE, use_min = TRUE, lambda.min.ratio = 1e-4, + prediction_bounds = "default" ) if (any(!names(defaults) %in% names(fit_control))) { fit_control <- c( - defaults[which(!names(defaults) %in% names(fit_control))], fit_control + defaults[!names(defaults) %in% names(fit_control)], fit_control ) } - # errors when a supplied control list is missing arguments - defaults <- list( - exclusive_dot = FALSE, custom_group = NULL - ) - + # check fit_control names (exluding defaults) are glmnet/cv.glmnet formals + glmnet_formals <- unique(c( + names(formals(glmnet::cv.glmnet)), + names(formals(glmnet::glmnet)), + names(formals(glmnet::relax.glmnet)) # extra allowed args to glmnet + )) + control_names <- names(fit_control[!names(fit_control) %in% names(defaults)]) + if (any(!control_names %in% glmnet_formals)) { + bad_args <- control_names[(!control_names %in% glmnet_formals)] + warning(sprintf( + "Some fit_control arguments are neither default nor glmnet/cv.glmnet arguments: %s; \nThey will be removed from fit_control", + paste0(bad_args, collapse = ", ") + )) + fit_control <- fit_control[!names(fit_control) %in% bad_args] + } if (!is.matrix(X)) X <- as.matrix(X) @@ -230,9 +240,11 @@ fit_hal <- function(X, all(!is.na(Y)), msg = "NA detected in `Y`, missingness in `Y` is not supported" ) + + n_Y <- ifelse(is.matrix(Y), nrow(Y), length(Y)) assertthat::assert_that( - nrow(X) == length(Y), - msg = "Number of rows in `X` and length of `Y` must be equal" + nrow(X) == n_Y, + msg = "Number of rows in `X` and `Y` must be equal" ) if (!is.null(X_unpenalized)) { @@ -252,6 +264,24 @@ fit_hal <- function(X, ) } + if (!is.character(fit_control$prediction_bounds)) { + if (fam == "mgaussian") { + assertthat::assert_that( + is.list(fit_control$prediction_bounds) & + length(fit_control$prediction_bounds) == ncol(Y), + msg = "prediction_bounds must be 'default' or list of numeric (lower, upper) bounds for each outcome" + ) + } else { + assertthat::assert_that( + is.numeric(fit_control$prediction_bounds) & + length(fit_control$prediction_bounds) == 2, + msg = "prediction_bounds must be 'default' or numeric (lower, upper) bounds" + ) + } + } + + + if (!is.null(formula)) { # formula <- formula_hal( # formula = formula, X = X, smoothness_orders = smoothness_orders, @@ -260,7 +290,11 @@ fit_hal <- function(X, # ) if (!inherits(formula, "formula_hal")) { - formula <- formula_hal(formula, X = X, smoothness_orders = smoothness_orders, num_knots = num_knots) + formula <- formula_hal( + formula, + X = X, smoothness_orders = smoothness_orders, + num_knots = num_knots + ) } basis_list <- formula$basis_list fit_control$upper.limits <- formula$upper.limits @@ -275,10 +309,11 @@ fit_hal <- function(X, # Generate fold_ids that respect id if (is.null(fit_control$foldid)) { + if (is.null(fit_control$nfolds)) fit_control$nfolds <- 10 folds <- origami::make_folds( - n = length(Y), V = fit_control$n_folds, cluster_ids = id + n = n_Y, V = fit_control$nfolds, cluster_ids = id ) - foldid <- origami::folds2foldvec(folds) + fit_control$foldid <- origami::folds2foldvec(folds) } # bookkeeping: get start time of enumerate basis procedure @@ -306,12 +341,19 @@ fit_hal <- function(X, time_design_matrix <- proc.time() # NOTE: keep only basis functions with some (or higher) proportion of 1's - if (!is.null(reduce_basis) && is.numeric(reduce_basis) && - all(smoothness_orders == 0)) { + if (all(smoothness_orders == 0)) { + if (is.null(reduce_basis)) { + reduce_basis <- 1 / sqrt(n_Y) + } reduced_basis_map <- make_reduced_basis_map(x_basis, reduce_basis) x_basis <- x_basis[, reduced_basis_map] basis_list <- basis_list[reduced_basis_map] + } else { + if (!is.null(reduce_basis)) { + warning("Dropping reduce_basis; only applies if smoothness_orders = 0") + } } + time_reduce_basis <- proc.time() # catalog and eliminate duplicates @@ -364,7 +406,7 @@ fit_hal <- function(X, # x_basis <- as.matrix(x_basis) # } - if (!inherits(family, "family") && family == "cox") { + if (fam == "cox") { x_basis <- as.matrix(x_basis) } @@ -388,6 +430,7 @@ fit_hal <- function(X, fit_control$lambda <- lambda fit_control$penalty.factor <- penalty_factor fit_control$offset <- offset + fit_control$weights <- weights if (!fit_control$cv_select) { hal_lasso <- do.call(glmnet::glmnet, fit_control) @@ -422,13 +465,19 @@ fit_hal <- function(X, ) # Bounds for prediction on new data (to prevent extrapolation for linear HAL) - if (!inherits(Y, "Surv") & fit_control$prediction_bounds == "default") { - # This would break if Y was a survival object as in coxnet - fit_control$prediction_bounds <- c( - min(Y) - 2 * stats::sd(Y), max(Y) + 2 * stats::sd(Y) - ) - } else if (inherits(Y, "Surv") & fit_control$prediction_bounds == "default") { - fit_control$prediction_bounds <- NULL + if (is.character(fit_control$prediction_bounds) && + fit_control$prediction_bounds == "default") { + if (fam == "mgaussian") { + fit_control$prediction_bounds <- lapply(seq(ncol(Y)), function(i) { + c(min(Y[, i]) - 2 * stats::sd(Y[, i]), max(Y[, i]) + 2 * stats::sd(Y[, i])) + }) + } else if (fam == "cox") { + fit_control$prediction_bounds <- NULL + } else { + fit_control$prediction_bounds <- c( + min(Y) - 2 * stats::sd(Y), max(Y) + 2 * stats::sd(Y) + ) + } } # construct output object via lazy S3 list diff --git a/R/make_basis.R b/R/make_basis.R index c68f783a..d8cd1de2 100644 --- a/R/make_basis.R +++ b/R/make_basis.R @@ -327,8 +327,9 @@ quantizer <- function(X, bins) { p <- max(1 - (20 / nrow(X)), 0.98) quants <- seq(0, p, length.out = bins) - q <- stats::quantile(x, quants) - nearest <- findInterval(x, q) + q <- unique(stats::quantile(x, quants, type = 1)) + # NOTE: all.inside must be FALSE or else all binary variables are mapped to zero. + nearest <- findInterval(x, q, all.inside = FALSE) x <- q[nearest] return(x) } diff --git a/R/predict.R b/R/predict.R index 7a19681d..41895867 100644 --- a/R/predict.R +++ b/R/predict.R @@ -16,10 +16,6 @@ #' @param offset A vector of offsets. Must be provided if provided at training. #' @param type Either "response" for predictions of the response, or "link" for #' un-transformed predictions (on the scale of the link function). -#' @param p_reserve Sparse matrix pre-allocation proportion, which is the -#' anticipated proportion of 1's in the design matrix. Default value is -#' recommended in most settings. If a dense design matrix is expected, it -#' would be useful to set \code{p_reserve} to a higher value. #' @param ... Additional arguments passed to \code{predict} as necessary. #' #' @importFrom Matrix tcrossprod @@ -44,10 +40,10 @@ predict.hal9001 <- function(object, new_X_unpenalized = NULL, offset = NULL, type = c("response", "link"), - p_reserve = 0.75, ...) { + family <- ifelse(inherits(object$family, "family"), object$family$family, object$family) + type <- match.arg(type) - p_reserve <- pmax(pmin(p_reserve, 1), 0) # cast new data to matrix if not so already if (!is.matrix(new_data)) new_data <- as.matrix(new_data) @@ -56,9 +52,7 @@ predict.hal9001 <- function(object, } # generate design matrix - pred_x_basis <- make_design_matrix(new_data, object$basis_list, - p_reserve = p_reserve - ) + pred_x_basis <- make_design_matrix(new_data, object$basis_list) # reduce matrix of basis functions # pred_x_basis <- apply_copy_map(pred_x_basis, object$copy_map) @@ -83,35 +77,31 @@ predict.hal9001 <- function(object, } # generate predictions - if (inherits(object$family, "family") || object$family != "cox") { + if (!family %in% c("cox", "mgaussian")) { if (ncol(object$coefs) > 1) { - preds <- apply(object$coefs, 2, function(hal_coefs) { - as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = hal_coefs[-1] - ) + hal_coefs[1]) - }) + preds <- pred_x_basis %*% object$coefs[-1, ] + + matrix(object$coefs[1, ], + nrow = nrow(pred_x_basis), + ncol = ncol(object$coefs), byrow = TRUE + ) } else { preds <- as.vector(Matrix::tcrossprod( x = pred_x_basis, - y = object$coefs[-1] + y = matrix(object$coefs[-1], nrow = 1) ) + object$coefs[1]) } } else { - # Note: there is no intercept in the Cox model (built into the baseline - # hazard and would cancel in the partial likelihood). - if (ncol(object$coefs) > 1) { - preds <- apply(object$coefs, 2, function(hal_coefs) { - as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = hal_coefs - )) - }) - } else { - preds <- as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = as.vector(object$coefs) - )) + if (family == "cox") { + # Note: there is no intercept in the Cox model (built into the baseline + # hazard and would cancel in the partial likelihood). + # Note: there is no intercept in the Cox model (built into the baseline + # hazard and would cancel in the partial likelihood). + preds <- pred_x_basis %*% object$coefs + } else if (family == "mgaussian") { + preds <- stats::predict( + object$lasso_fit, + newx = pred_x_basis, s = object$lambda_star + ) } } @@ -130,22 +120,46 @@ predict.hal9001 <- function(object, if (inherits(object$family, "family")) { inverse_link_fun <- object$family$linkinv preds <- inverse_link_fun(preds) - } else if (object$family == "binomial") { - preds <- stats::plogis(preds) - } else if (object$family %in% c("poisson", "cox")) { - preds <- exp(preds) + } else { + if (family == "binomial") { + transform <- stats::plogis + } else if (family %in% c("poisson", "cox")) { + transform <- exp + } else if (family %in% c("gaussian", "mgaussian")) { + transform <- identity + } else { + stop("unsupported family") + } + + if (length(ncol(preds))) { + # apply along only the first dimension (to handle n-d arrays) + margin <- seq(length(dim(preds)))[-1] + preds <- apply(preds, margin, transform) + } else { + preds <- transform(preds) + } } # bound predictions within observed outcome bounds if on response scale - bounds <- object$prediction_bounds - if (!is.null(bounds)) { - bounds <- sort(bounds) - if (is.matrix(preds)) { - preds <- apply(preds, 2, pmax, bounds[1]) - preds <- apply(preds, 2, pmin, bounds[2]) + if (!is.null(object$prediction_bounds)) { + bounds <- object$prediction_bounds + if (family == "mgaussian") { + preds <- do.call(cbind, lapply(seq(ncol(preds)), function(i) { + bounds_y <- sort(bounds[[i]]) + preds_y <- preds[, i, ] + preds_y <- pmax(bounds_y[1], preds_y) + preds_y <- pmin(preds_y, bounds_y[2]) + return(preds_y) + })) } else { - preds <- pmax(bounds[1], preds) - preds <- pmin(preds, bounds[2]) + bounds <- sort(bounds) + if (is.matrix(preds)) { + preds <- apply(preds, 2, pmax, bounds[1]) + preds <- apply(preds, 2, pmin, bounds[2]) + } else { + preds <- pmax(bounds[1], preds) + preds <- pmin(preds, bounds[2]) + } } } diff --git a/R/sl_hal9001.R b/R/sl_hal9001.R index d8537d65..0fe18b08 100644 --- a/R/sl_hal9001.R +++ b/R/sl_hal9001.R @@ -22,19 +22,9 @@ #' specifying the maximum number of knot points (i.e., bins) for each #' covariate for generating basis functions. See \code{num_knots} argument in #' \code{\link{fit_hal}} for more information. -#' @param reduce_basis A \code{numeric} value bounded in the open unit interval -#' indicating the minimum proportion of 1's in a basis function column needed -#' for the basis function to be included in the procedure to fit the lasso. -#' Any basis functions with a lower proportion of 1's than the cutoff will be -#' removed. -#' @param lambda A user-specified sequence of values of the regularization -#' parameter for the lasso L1 regression. If \code{NULL}, the default sequence -#' in \code{\link[glmnet]{cv.glmnet}} will be used. The cross-validated -#' optimal value of this regularization parameter will be selected with -#' \code{\link[glmnet]{cv.glmnet}}. -#' @param ... Not used. +#' @param ... Additional arguments to \code{\link{fit_hal}}. #' -#' @importFrom stats predict gaussian +#' @importFrom stats predict #' #' @export #' @@ -42,27 +32,23 @@ #' object and corresponding predictions based on the input data. SL.hal9001 <- function(Y, X, - newX = NULL, - family = stats::gaussian(), - obsWeights = rep(1, length(Y)), - id = NULL, - max_degree = ifelse(ncol(X) >= 20, 2, 3), + newX, + family, + obsWeights, + id, + max_degree = 2, smoothness_orders = 1, - num_knots = ifelse(smoothness_orders >= 1, 25, 50), - reduce_basis = 1 / sqrt(length(Y)), - lambda = NULL, + num_knots = 5, ...) { - # create matrix version of X and newX for use with hal9001::fit_hal if (!is.matrix(X)) X <- as.matrix(X) if (!is.null(newX) & !is.matrix(newX)) newX <- as.matrix(newX) # fit hal - hal_fit <- fit_hal( - X = X, Y = Y, family = family$family, - fit_control = list(weights = obsWeights), id = id, max_degree = max_degree, - smoothness_orders = smoothness_orders, num_knots = num_knots, reduce_basis - = reduce_basis, lambda = lambda + hal_fit <- hal9001::fit_hal( + Y = Y, X = X, family = family$family, weights = obsWeights, id = id, + max_degree = max_degree, smoothness_orders = smoothness_orders, + num_knots = num_knots, ... ) # compute predictions based on `newX` or input `X` diff --git a/R/summary.R b/R/summary.R index 5878dae1..5193e7cd 100644 --- a/R/summary.R +++ b/R/summary.R @@ -52,6 +52,7 @@ summary.hal9001 <- function(object, include_redundant_terms = FALSE, round_cutoffs = 3, ...) { + family <- ifelse(inherits(object$family, "family"), object$family$family, object$family) abs_coef <- basis_list_idx <- coef_idx <- dup <- NULL # retain coefficients corresponding to lambda @@ -71,12 +72,20 @@ summary.hal9001 <- function(object, stop("Coefficients for the specified lambda do not exist.") } else { lambda_idx <- which(object$lasso_fit$lambda == lambda) - coefs <- object$lasso_fit$glmnet.fit$beta[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$lasso_fit$glmnet.fit$beta[, lambda_idx] + } else { + coefs <- lapply(object$lasso_fit$glmnet.fit$beta, function(x) x[, lambda_idx]) + } } } } else { lambda_idx <- which(object$lambda_star == lambda) - coefs <- object$coefs[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$coefs[, lambda_idx] + } else { + coefs <- lapply(object$coefs, function(x) x[, lambda_idx]) + } } } @@ -89,169 +98,210 @@ summary.hal9001 <- function(object, "Summarizing coefficients corresponding to minimum lambda." ) lambda_idx <- which.min(lambda) - coefs <- object$coefs[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$coefs[, lambda_idx] + } else { + coefs <- lapply(object$coefs, function(x) x[, lambda_idx]) + } } } # cox model has no intercept - if (object$family != "cox") { - coefs_no_intercept <- coefs[-1] - } else { + if (family == "cox") { coefs_no_intercept <- coefs + } else if (family == "mgaussian") { + coefs_no_intercept <- lapply(coefs, function(x) x[-1]) + } else { + coefs_no_intercept <- coefs[-1] } # subset to non-zero coefficients if (only_nonzero_coefs) { - coef_idxs <- which(coefs_no_intercept != 0) + if (family == "mgaussian") { + coef_idxs <- lapply(coefs_no_intercept, function(x) which(x != 0)) + } else { + coef_idxs <- which(coefs_no_intercept != 0) + } } else { - coef_idxs <- seq_along(coefs_no_intercept) + if (family == "mgaussian") { + coef_idxs <- lapply(coefs_no_intercept, function(x) seq_along(x)) + } else { + coef_idxs <- seq_along(coefs_no_intercept) + } } - copy_map <- object$copy_map[coef_idxs] - - # summarize coefficients with respect to basis list - coefs_summ <- data.table::rbindlist( - lapply(seq_along(copy_map), function(map_idx) { - coef_idx <- coef_idxs[map_idx] - coef <- coefs_no_intercept[coef_idx] - - basis_list_idxs <- copy_map[[map_idx]] # indices of duplicates - basis_dups <- object$basis_list[basis_list_idxs] - - data.table::rbindlist( - lapply(seq_along(basis_dups), function(i) { - coef_idx <- ifelse(object$family != "cox", coef_idx + 1, coef_idx) - dt <- data.table::data.table( - coef_idx = coef_idx, # coefficient index - coef, # coefficient - basis_list_idx = basis_list_idxs[i], # basis list index - col_idx = basis_dups[[i]]$cols, # column idx in X - col_cutoff = basis_dups[[i]]$cutoffs, # cutoff - col_order = basis_dups[[i]]$orders # smoothness order - ) - return(dt) - }) - ) - }) - ) - - if (!include_redundant_terms) { - coef_idxs <- unique(coefs_summ$coef_idx) - coefs_summ <- data.table::rbindlist(lapply(coef_idxs, function(idx) { - # subset to matching coefficient index - coef_summ <- coefs_summ[coef_idx == idx] - - # label duplicates (i.e. basis functions with identical col & cutoff) - dups_tbl <- coef_summ[, c("col_idx", "col_cutoff", "col_order")] - if (!anyDuplicated(dups_tbl)) { - return(coef_summ) - } else { - # add col indicating whether or not there is a duplicate - coef_summ[, dup := (duplicated(dups_tbl) | - duplicated(dups_tbl, fromLast = TRUE))] - - # if basis_list_idx contains redundant duplicates, remove them - redundant_dups <- coef_summ[dup == TRUE, "basis_list_idx"] - if (nrow(redundant_dups) > 1) { - # keep the redundant duplicate term that has the shortest length - retain_idx <- which.min(apply(redundant_dups, 1, function(idx) { - nrow(coef_summ[basis_list_idx == idx]) - })) - idx_keep <- unname(unlist(redundant_dups[retain_idx])) - coef_summ <- coef_summ[basis_list_idx == idx_keep] - } - return(coef_summ[, -"dup"]) - } - })) + + if (family == "mgaussian") { + copy_map <- lapply(coef_idxs, function(x) object$copy_map[x]) + } else { + copy_map <- object$copy_map[coef_idxs] } - # summarize with respect to x column names: - x_names <- data.table::data.table( - col_idx = 1:length(object$X_colnames), - col_names = object$X_colnames - ) - summ <- merge(coefs_summ, x_names, by = "col_idx", all.x = TRUE) - - # combine name, cutoff into 0-order basis function (may include interaction) - summ$zero_term <- paste0( - "I(", summ$col_names, " >= ", round(summ$col_cutoff, round_cutoffs), ")" - ) - summ$higher_term <- ifelse( - summ$col_order == 0, "", - paste0( - "(", summ$col_names, " - ", - round(summ$col_cutoff, round_cutoffs), ")" + # ============================================================================ + # utility function to summarize HAL fit which can be used for multiple outcomes + summarize_coefs <- function(copy_map, coef_idxs, coefs_no_intercept, coefs) { + # summarize coefficients with respect to basis list + coefs_summ <- data.table::rbindlist( + lapply(seq_along(copy_map), function(map_idx) { + coef_idx <- coef_idxs[map_idx] + coef <- coefs_no_intercept[coef_idx] + + basis_list_idxs <- copy_map[[map_idx]] # indices of duplicates + basis_dups <- object$basis_list[basis_list_idxs] + + data.table::rbindlist( + lapply(seq_along(basis_dups), function(i) { + coef_idx <- ifelse(family != "cox", coef_idx + 1, coef_idx) + dt <- data.table::data.table( + coef_idx = coef_idx, # coefficient index + coef, # coefficient + basis_list_idx = basis_list_idxs[i], # basis list index + col_idx = basis_dups[[i]]$cols, # column idx in X + col_cutoff = basis_dups[[i]]$cutoffs, # cutoff + col_order = basis_dups[[i]]$orders # smoothness order + ) + return(dt) + }) + ) + }) + ) + + if (!include_redundant_terms) { + coef_idxs <- unique(coefs_summ$coef_idx) + coefs_summ <- data.table::rbindlist(lapply(coef_idxs, function(idx) { + # subset to matching coefficient index + coef_summ <- coefs_summ[coef_idx == idx] + + # label duplicates (i.e. basis functions with identical col & cutoff) + dups_tbl <- coef_summ[, c("col_idx", "col_cutoff", "col_order")] + if (!anyDuplicated(dups_tbl)) { + return(coef_summ) + } else { + # add col indicating whether or not there is a duplicate + coef_summ[, dup := (duplicated(dups_tbl) | + duplicated(dups_tbl, fromLast = TRUE))] + + # if basis_list_idx contains redundant duplicates, remove them + redundant_dups <- coef_summ[dup == TRUE, "basis_list_idx"] + if (nrow(redundant_dups) > 1) { + # keep the redundant duplicate term that has the shortest length + retain_idx <- which.min(apply(redundant_dups, 1, function(idx) { + nrow(coef_summ[basis_list_idx == idx]) + })) + idx_keep <- unname(unlist(redundant_dups[retain_idx])) + coef_summ <- coef_summ[basis_list_idx == idx_keep] + } + return(coef_summ[, -"dup"]) + } + })) + } + + # summarize with respect to x column names: + x_names <- data.table::data.table( + col_idx = 1:length(object$X_colnames), + col_names = object$X_colnames + ) + summ <- merge(coefs_summ, x_names, by = "col_idx", all.x = TRUE) + + # combine name, cutoff into 0-order basis function (may include interaction) + summ$zero_term <- paste0( + "I(", summ$col_names, " >= ", round(summ$col_cutoff, round_cutoffs), ")" ) - ) - summ$higher_term <- ifelse( - summ$col_order < 1, summ$higher_term, - paste0(summ$higher_term, "^", summ$col_order) - ) - summ$term <- ifelse( - summ$col_order == 0, - paste0("[ ", summ$zero_term, " ]"), - paste0("[ ", summ$zero_term, "*", summ$higher_term, " ]") - ) - - term_tbl <- data.table::as.data.table(stats::aggregate( - term ~ basis_list_idx, - data = summ, paste, collapse = " * " - )) - - # no longer need the columns or rows that were incorporated in the term - redundant <- c( - "term", "col_cutoff", "col_names", "col_idx", "col_order", "zero_term", - "higher_term" - ) - summ <- summ[, -..redundant] - summ_unique <- unique(summ) - summ <- merge( - term_tbl, summ_unique, - by = "basis_list_idx", all.x = TRUE, all.y = FALSE - ) - - # summarize in a list - coefs_list <- lapply(unique(summ$coef_idx), function(this_coef_idx) { - coef_terms <- summ[coef_idx == this_coef_idx] - list(coef = unique(coef_terms$coef), term = t(coef_terms$term)) - }) - - # summarize in a table - coefs_tbl <- data.table::as.data.table(stats::aggregate( - term ~ coef_idx, - data = summ, FUN = paste, collapse = " OR " - )) - redundant <- c("term", "basis_list_idx") - summ_unique_coefs <- unique(summ[, -..redundant]) - coefs_tbl <- data.table::data.table(merge( - summ_unique_coefs, coefs_tbl, - by = "coef_idx", all = TRUE - )) - coefs_tbl[, "abs_coef" := abs(coef)] - coefs_tbl <- data.table::setorder(coefs_tbl[, -"coef_idx"], -abs_coef) - coefs_tbl <- coefs_tbl[, -"abs_coef", with = FALSE] - - # incorporate intercept - if (object$family != "cox") { - intercept <- list(data.table::data.table( - coef = coefs[1], term = "(Intercept)" + summ$higher_term <- ifelse( + summ$col_order == 0, "", + paste0( + "(", summ$col_names, " - ", + round(summ$col_cutoff, round_cutoffs), ")" + ) + ) + summ$higher_term <- ifelse( + summ$col_order < 1, summ$higher_term, + paste0(summ$higher_term, "^", summ$col_order) + ) + summ$term <- ifelse( + summ$col_order == 0, + paste0("[ ", summ$zero_term, " ]"), + paste0("[ ", summ$zero_term, "*", summ$higher_term, " ]") + ) + + term_tbl <- data.table::as.data.table(stats::aggregate( + term ~ basis_list_idx, + data = summ, paste, collapse = " * " )) - coefs_tbl <- data.table::rbindlist( - c(intercept, list(coefs_tbl)), - fill = TRUE + + + # no longer need the columns or rows that were incorporated in the term + redundant <- c( + "term", "col_cutoff", "col_names", "col_idx", "col_order", "zero_term", + "higher_term" + ) + summ <- summ[, -..redundant] + summ_unique <- unique(summ) + summ <- merge( + term_tbl, summ_unique, + by = "basis_list_idx", all.x = TRUE, all.y = FALSE + ) + + # generate input for rules summary + rules_tbl <- generate_all_rules( + object$basis_list[summ$basis_list_idx], summ$coef, object$X_colnames ) - intercept <- list(coef = coefs[1], term = "(Intercept)") - coefs_list <- c(list(intercept), coefs_list) + + # summarize in a list + coefs_list <- lapply(unique(summ$coef_idx), function(this_coef_idx) { + coef_terms <- summ[coef_idx == this_coef_idx] + list(coef = unique(coef_terms$coef), term = t(coef_terms$term)) + }) + + # summarize in a table + coefs_tbl <- data.table::as.data.table(stats::aggregate( + term ~ coef_idx, + data = summ, FUN = paste, collapse = " OR " + )) + redundant <- c("term", "basis_list_idx") + summ_unique_coefs <- unique(summ[, -..redundant]) + coefs_tbl <- data.table::data.table(merge( + summ_unique_coefs, coefs_tbl, + by = "coef_idx", all = TRUE + )) + coefs_tbl[, "abs_coef" := abs(coef)] + coefs_tbl <- data.table::setorder(coefs_tbl[, -"coef_idx"], -abs_coef) + coefs_tbl <- coefs_tbl[, -"abs_coef", with = FALSE] + + # incorporate intercept + if (family != "cox") { + intercept <- list(data.table::data.table( + coef = coefs[1], term = "(Intercept)" + )) + coefs_tbl <- data.table::rbindlist( + c(intercept, list(coefs_tbl)), + fill = TRUE + ) + intercept <- list(coef = coefs[1], term = "(Intercept)") + coefs_list <- c(list(intercept), coefs_list) + } + out <- list( + table = coefs_tbl, + list = coefs_list, + lambda = lambda, + only_nonzero_coefs = only_nonzero_coefs, + family = family, + rules = rules_tbl + ) + class(out) <- "summary.hal9001" + return(out) } - out <- list( - table = coefs_tbl, - list = coefs_list, - lambda = lambda, - only_nonzero_coefs = only_nonzero_coefs - ) - class(out) <- "summary.hal9001" - return(out) -} + # ============================================================================ + if (family == "mgaussian") { + return_obj <- lapply(seq_along(copy_map), function(i) { + summarize_coefs(copy_map[[i]], coef_idxs[[i]], coefs_no_intercept[[i]], coefs[[i]]) + }) + class(return_obj) <- "summary.hal9001" + } else { + return_obj <- summarize_coefs(copy_map, coef_idxs, coefs_no_intercept, coefs) + } + return(return_obj) +} ############################################################################### #' Print Method for Summary Class of HAL fits @@ -262,28 +312,166 @@ summary.hal9001 <- function(object, #' #' @export print.summary.hal9001 <- function(x, length = NULL, ...) { - if (x$only_nonzero_coefs & is.null(length)) { - cat( - "\n\nSummary of non-zero coefficients is based on lambda of", - x$lambda, "\n\n" - ) - } else if (!x$only_nonzero_coefs & is.null(length)) { - cat("\nSummary of coefficients is based on lambda of", x$lambda, "\n\n") - } else if (!x$only_nonzero_coefs & !is.null(length)) { - cat( - "\nSummary of top", length, - "coefficients is based on lambda of", x$lambda, "\n\n" - ) - } else if (x$only_nonzero_coefs & !is.null(length)) { - cat( - "\nSummary of top", length, - "non-zero coefficients is based on lambda of", x$lambda, "\n\n" - ) + if (x$family != "mgaussian" && !is.null(x$family)) { + if (x$only_nonzero_coefs & is.null(length)) { + cat( + "\n\nSummary of non-zero coefficients is based on lambda of", + x$lambda, "\n\n" + ) + } else if (!x$only_nonzero_coefs & is.null(length)) { + cat("\nSummary of coefficients is based on lambda of", x$lambda, "\n\n") + } else if (!x$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "coefficients is based on lambda of", x$lambda, "\n\n" + ) + } else if (x$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "non-zero coefficients is based on lambda of", x$lambda, "\n\n" + ) + } + + if (is.null(length)) { + print(x$table, row.names = FALSE) + } else { + print(utils::head(x$table, length), row.names = FALSE) + } + cat("\n\n Summary of aggregated marginal and interaction regions: \n\n") + print(x$rules, row.names = FALSE) + } else { + for (i in 1:length(x)) { + if (x[[i]]$only_nonzero_coefs & is.null(length)) { + cat( + "\n\nSummary of non-zero coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } else if (!x[[i]]$only_nonzero_coefs & is.null(length)) { + cat( + "\nSummary of coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } else if (!x[[i]]$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "coefficients for each outcome is based on lambda of", x[[i]]$lambda, "\n\n" + ) + } else if (x[[i]]$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "non-zero coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } + if (is.null(length)) { + print(x[[i]]$table, row.names = FALSE) + } else { + print(utils::head(x[[i]]$table, length), row.names = FALSE) + } + cat("\n\n Summary of aggregated marginal and interaction regions: \n\n") + print(x[[i]]$rules, row.names = FALSE) + } + } +} + +#' Generates rules based on knot points of the fitted HAL basis functions with +#' non-zero coefficients. +#' +#' @keywords internal +generate_all_rules <- function(basis_list, coefs, X_colnames) { + # Convert coefficients to matrix and filter out the intercept + coefs_mat <- as.matrix(coefs) + + # Identify indices where coefficients are non-zero + # (i.e., relevant to the final model) + relevant_indices <- which(coefs_mat != 0) + + # Initialize a list to store cutoffs for each feature + cutoffs_list <- vector("list", length(X_colnames)) + names(cutoffs_list) <- X_colnames + + # Initialize list to store interaction rules and their cumulative coefficients + interaction_rules <- list() + interaction_coefs <- list() + + # Loop over each basis function that has a non-zero coefficient + for (i in relevant_indices) { + basis <- basis_list[[i]] + coef_val <- coefs_mat[i, ] + + # For marginal basis functions (no interactions) + if (length(basis$cols) == 1) { + colname <- X_colnames[basis$cols[1]] + # Add unique cutoffs to the cutoffs_list + cutoffs_list[[colname]] <- unique(c(cutoffs_list[[colname]], basis$cutoffs[1])) + } + + # For interaction basis functions + if (length(basis$cols) > 1) { + interaction_name <- paste(X_colnames[basis$cols], collapse = "-") + if (!interaction_name %in% names(interaction_rules)) { + interaction_rules[[interaction_name]] <- list() + interaction_coefs[[interaction_name]] <- 0 + for (j in basis$cols) { + interaction_rules[[interaction_name]][[X_colnames[j]]] <- c() + } + } + for (j in seq_along(basis$cols)) { + interaction_rules[[interaction_name]][[X_colnames[basis$cols[j]]]] <- + c(interaction_rules[[interaction_name]][[X_colnames[basis$cols[j]]]], basis$cutoffs[j]) + } + interaction_coefs[[interaction_name]] <- interaction_coefs[[interaction_name]] + coef_val + } + } + + # check if there are any marginal rules + # (i.e., any non-interaction basis functions with non-zero coefficients) + cutoffs_list <- cutoffs_list[-which(sapply(cutoffs_list, is.null))] + if (length(cutoffs_list) > 0) { + # for each feature, identify the min cutoff and form rule + min_cutoffs <- sapply(cutoffs_list, min, na.rm = TRUE) + marginal_rules <- sapply(seq_along(min_cutoffs), function(i) { + paste0(X_colnames[i], " >= ", min_cutoffs[i]) + }, USE.NAMES = TRUE) + names(marginal_rules) <- names(min_cutoffs) + } else { + # instantiate empty marginal rules if there are none + marginal_rules <- c() + min_cutoffs <- c() } - if (is.null(length)) { - print(x$table, row.names = FALSE) + # check if there are any interaction rules + # (i.e., any interaction basis functions with non-zero coefficients) + if (length(interaction_rules) > 0) { + # create bounding box rules for interactions + bounding_rules <- list() + for (interaction in names(interaction_rules)) { + rules <- c() + for (var in names(interaction_rules[[interaction]])) { + min_val <- min(interaction_rules[[interaction]][[var]], na.rm = TRUE) + rules <- c(rules, paste0(var, " >= ", min_val)) + } + bounding_rules[[interaction]] <- paste(rules, collapse = " & ") + } + # combine all rules + all_rules <- c(marginal_rules, unlist(bounding_rules)) + all_coefs <- c(min_cutoffs, unlist(interaction_coefs)) + } else { + # all rules are only comprised of the marginals + all_rules <- marginal_rules + all_coefs <- min_cutoffs + } + + if (is.null(all_rules) | is.null(all_coefs)) { + # there are no rules! + rules_df <- NULL } else { - print(utils::head(x$table, length), row.names = FALSE) + # convert rules into a data table for easy viewing and interpretation + rules_df <- data.table::data.table( + variables = names(all_rules), + rule = all_rules, + cumulative_coefficient = all_coefs + ) } + return(rules_df) } diff --git a/README.md b/README.md index e8ed0d6a..728a247b 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,11 @@ predictions via Highly Adaptive Lasso regression: # load the package and set a seed library(hal9001) #> Loading required package: Rcpp +<<<<<<< HEAD +#> hal9001 v0.4.4: The Scalable Highly Adaptive Lasso +======= #> hal9001 v0.4.5: The Scalable Highly Adaptive Lasso +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c #> note: fit_hal defaults have changed. See ?fit_hal for details set.seed(385971) @@ -100,12 +104,12 @@ hal_fit <- fit_hal(X = x, Y = y, yolo = TRUE) #> [1] "I'm sorry, Dave. I'm afraid I can't do that." hal_fit$times #> user.self sys.self elapsed user.child sys.child -#> enumerate_basis 0.008 0.001 0.009 0 0 -#> design_matrix 0.003 0.000 0.003 0 0 +#> enumerate_basis 0.014 0.003 0.059 0 0 +#> design_matrix 0.004 0.001 0.005 0 0 #> reduce_basis 0.000 0.000 0.000 0 0 -#> remove_duplicates 0.000 0.000 0.001 0 0 -#> lasso 1.690 0.036 1.764 0 0 -#> total 1.702 0.037 1.778 0 0 +#> remove_duplicates 0.000 0.000 0.000 0 0 +#> lasso 2.684 0.343 6.583 0 0 +#> total 2.703 0.348 6.655 0 0 # training sample prediction preds <- predict(hal_fit, new_data = x) diff --git a/docs/404.html b/docs/404.html index 2fa8a621..f0d9df35 100644 --- a/docs/404.html +++ b/docs/404.html @@ -38,7 +38,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/CONTRIBUTING.html b/docs/CONTRIBUTING.html index c278dd85..f0dcfa9c 100644 --- a/docs/CONTRIBUTING.html +++ b/docs/CONTRIBUTING.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/LICENSE-text.html b/docs/LICENSE-text.html index 210dedbf..a4e240cc 100644 --- a/docs/LICENSE-text.html +++ b/docs/LICENSE-text.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/articles/index.html b/docs/articles/index.html index e707ca32..564c15c4 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/articles/intro_hal9001.html b/docs/articles/intro_hal9001.html index 21f17a12..d5262436 100644 --- a/docs/articles/intro_hal9001.html +++ b/docs/articles/intro_hal9001.html @@ -39,7 +39,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -85,7 +85,11 @@

Nima Hejazi, Jeremy Coyle, Rachael Phillips, Lars van der Laan

+<<<<<<< HEAD +

2022-07-12

+=======

2022-11-04

+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c Source: vignettes/intro_hal9001.Rmd @@ -154,7 +158,7 @@

Using the Highly Adaptive Lasso
 library(hal9001)
## Loading required package: Rcpp
-
## hal9001 v0.4.3: The Scalable Highly Adaptive Lasso
+
## hal9001 v0.4.5: The Scalable Highly Adaptive Lasso
 ## note: fit_hal defaults have changed. See ?fit_hal for details

Fitting the model @@ -165,12 +169,21 @@

Fitting the modelhal_fit <- fit_hal(X = x, Y = y) hal_fit$times

##                   user.self sys.self elapsed user.child sys.child
+<<<<<<< HEAD
+## enumerate_basis       0.027    0.003   0.053          0         0
+## design_matrix         0.129    0.021   0.340          0         0
+## reduce_basis          0.000    0.000   0.000          0         0
+## remove_duplicates     0.000    0.000   0.000          0         0
+## lasso                 3.577    0.436   9.466          0         0
+## total                 3.734    0.460   9.861          0         0
+======= ## enumerate_basis 0.017 0.000 0.018 0 0 ## design_matrix 0.082 0.003 0.086 0 0 ## reduce_basis 0.000 0.000 0.000 0 0 ## remove_duplicates 0.000 0.000 0.000 0 0 ## lasso 2.284 0.072 2.399 0 0 ## total 2.384 0.075 2.503 0 0 +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c

Summarizing the model @@ -186,81 +199,117 @@

Summarizing the modelprint(summary(hal_fit))

## 
 ## 
-## Summary of non-zero coefficients is based on lambda of 0.00487 
+## Summary of non-zero coefficients is based on lambda of 0.002108107 
 ## 
 ##           coef
-##  -1.424003e+00
-##   3.654157e-01
-##  -2.358064e-01
-##   2.182822e-01
-##  -1.713943e-01
-##   1.648782e-01
-##   1.622979e-01
-##  -1.205285e-01
-##  -9.526696e-02
-##  -9.382067e-02
-##   5.468026e-02
-##   5.273315e-02
-##  -5.056465e-02
-##   4.527346e-02
-##  -3.735277e-02
-##  -3.543529e-02
-##   2.380514e-02
-##  -2.255508e-02
-##  -2.176114e-02
-##  -1.612916e-02
-##   1.531317e-02
-##  -1.233111e-02
-##   1.172920e-02
-##   8.162029e-03
-##  -7.160000e-03
-##  -5.500155e-03
-##   5.083293e-03
-##  -3.625005e-03
-##   1.488567e-03
-##   1.350348e-03
-##   3.224244e-04
-##  -3.021403e-04
-##  -2.126168e-04
-##  -6.625025e-05
-##   1.927162e-06
+##  -1.435851e+00
+##   4.391709e-01
+##  -2.180325e-01
+##  -1.908016e-01
+##   1.837943e-01
+##   1.617427e-01
+##  -1.307707e-01
+##   1.205775e-01
+##  -1.203965e-01
+##   1.175336e-01
+##  -1.166074e-01
+##  -1.018822e-01
+##   8.214302e-02
+##   7.525308e-02
+##   7.518641e-02
+##   7.328934e-02
+##   7.066728e-02
+##   6.364869e-02
+##  -4.686124e-02
+##  -4.672286e-02
+##  -4.499741e-02
+##  -4.377227e-02
+##  -3.830315e-02
+##   3.779762e-02
+##  -3.744479e-02
+##   3.386721e-02
+##   3.286990e-02
+##   3.254816e-02
+##  -3.203164e-02
+##   3.041717e-02
+##  -1.901118e-02
+##   1.170430e-02
+##  -1.147950e-02
+##  -1.053684e-02
+##   9.934522e-03
+##   9.600888e-03
+##  -7.160528e-03
+##  -6.499773e-03
+##  -5.794305e-03
+##  -5.714597e-03
+##  -5.698275e-03
+##  -5.424112e-03
+##   5.170208e-03
+##   4.516979e-03
+##   4.245836e-03
+##  -4.125254e-03
+##   2.495093e-03
+##  -1.063492e-04
+##  -7.186831e-05
+##   2.188558e-05
+##   1.890986e-05
+##  -1.757825e-05
+##   1.024909e-05
 ##           coef
 ##                                                                                                             term
 ##                                                                                                      (Intercept)
 ##                                                                              [ I(x2 >= -1.583)*(x2 - -1.583)^1 ]
 ##                                          [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                        [ I(x1 >= -0.962)*(x1 - -0.962)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                          [ I(x1 >= 1.606)*(x1 - 1.606)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
-##                                          [ I(x2 >= -1.11)*(x2 - -1.11)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                        [ I(x1 >= -0.962)*(x1 - -0.962)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                        [ I(x1 >= -1.403)*(x1 - -1.403)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
 ##                                          [ I(x1 >= 0.941)*(x1 - 0.941)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                          [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                          [ I(x2 >= -1.11)*(x2 - -1.11)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                          [ I(x1 >= 1.368)*(x1 - 1.368)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                        [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                                                              [ I(x1 >= -1.844)*(x1 - -1.844)^1 ]
+##                                                                              [ I(x1 >= -1.593)*(x1 - -1.593)^1 ]
 ##                                          [ I(x2 >= 0.696)*(x2 - 0.696)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                          [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ]
+##                                                                              [ I(x1 >= -1.844)*(x1 - -1.844)^1 ]
 ##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.156)*(x2 - 1.156)^1 ] * [ I(x3 >= -0.497)*(x3 - -0.497)^1 ]
-##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##          [ I(x1 >= 0.9)*(x1 - 0.9)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##    [ I(x1 >= 1.347)*(x1 - 1.347)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+##                                        [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                                                              [ I(x2 >= -1.322)*(x2 - -1.322)^1 ]
+##                                        [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
 ##                                        [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                          [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##    [ I(x1 >= 0.521)*(x1 - 0.521)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                        [ I(x1 >= -1.092)*(x1 - -1.092)^1 ] * [ I(x3 >= -0.046)*(x3 - -0.046)^1 ]
+##                                            [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.45)*(x3 - 0.45)^1 ]
+##                                          [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                        [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ]
-##                                          [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ]
-##                                          [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
-##    [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                        [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##  [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+##                                                                              [ I(x2 >= -0.916)*(x2 - -0.916)^1 ]
+##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##    [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ]
-##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                                                              [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ]
 ##  [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ]
+##  [ I(x1 >= -1.313)*(x1 - -1.313)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.781)*(x3 - -0.781)^1 ]
+##  [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+##    [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.772)*(x2 - 0.772)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                        [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                        [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.317)*(x3 - -0.317)^1 ]
+##  [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 1.202)*(x3 - 1.202)^1 ]
+##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= 0.174)*(x3 - 0.174)^1 ]
+##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+##                                          [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.528)*(x2 - 0.528)^1 ]
+##    [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                        [ I(x1 >= -0.555)*(x1 - -0.555)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
+##    [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ]
+##                                                                              [ I(x2 >= -0.816)*(x2 - -0.816)^1 ]
+##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -1.331)*(x3 - -1.331)^1 ]
+##                                        [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.441)*(x2 - -0.441)^1 ]
+##    [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -0.781)*(x3 - -0.781)^1 ]
 ##  [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                          [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-##                                          [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                        [ I(x1 >= -0.685)*(x1 - -0.685)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##  [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+##                                          [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
+##                                        [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
 ##                                                                                                             term

Note the length and width of these tables! The R environment might not be the optimal location to view the summary. Tip: Tables can be @@ -280,12 +329,12 @@

Obtaining model predictionspreds_hal <- predict(object = hal_fit, new_data = x) mse_hal <- mse(preds = preds_hal, y = y) mse_hal -
## [1] 0.04967736
+
## [1] 0.04426571
 oob_hal <- predict(object = hal_fit, new_data = test_x)
 oob_hal_mse <- mse(preds = oob_hal, y = test_y)
 oob_hal_mse
-
## [1] 1.778584
+
## [1] 1.803562

Reducing basis functions @@ -303,9 +352,24 @@

Reducing basis functionsreduce_basis argument to the fit_hal function:

-hal_fit_reduced <- fit_hal(X = x, Y = y, reduce_basis = 0.1)
-hal_fit_reduced$times
+hal_fit_reduced <- fit_hal(X = x, Y = y, reduce_basis = 0.1)

+
## Warning in fit_hal(X = x, Y = y, reduce_basis = 0.1): Dropping reduce_basis;
+## only applies if smoothness_orders = 0
+
+hal_fit_reduced$times
##                   user.self sys.self elapsed user.child sys.child
+<<<<<<< HEAD
+## enumerate_basis       0.028    0.007   0.122          0         0
+## design_matrix         0.131    0.017   0.463          0         0
+## reduce_basis          0.000    0.000   0.000          0         0
+## remove_duplicates     0.000    0.000   0.000          0         0
+## lasso                 3.667    0.733  16.867          0         0
+## total                 3.826    0.757  17.453          0         0
+

In the above, all basis functions with fewer than 10% of observations +meeting the criterion imposed are automatically removed prior to the +Lasso step of fitting the HAL regression. The results appear below

+
+=======
 ## enumerate_basis       0.026    0.005   0.031          0         0
 ## design_matrix         0.080    0.002   0.082          0         0
 ## reduce_basis          0.000    0.000   0.000          0         0
@@ -316,43 +380,44 @@ 

Reducing basis functions
##              coef
-##  1: -1.4371363869
-##  2:  0.3753894167
-##  3: -0.2355902545
-##  4:  0.2166287283
-##  5: -0.1851691193
-##  6:  0.1631844732
-##  7:  0.1597858903
-##  8: -0.1234431109
-##  9: -0.0946988403
-## 10: -0.0884063929
-## 11:  0.0584079820
-## 12:  0.0563691217
-## 13: -0.0551713872
-## 14:  0.0533243296
-## 15: -0.0382683652
-## 16: -0.0372833807
-## 17: -0.0309200484
-## 18:  0.0230878908
-## 19:  0.0179840202
-## 20: -0.0156428631
-## 21:  0.0149240584
-## 22: -0.0144903662
-## 23: -0.0132097499
-## 24: -0.0065506011
-## 25: -0.0062280674
-## 26:  0.0051538711
-## 27:  0.0043861580
-## 28: -0.0024858177
-## 29:  0.0023100499
-## 30: -0.0022703275
-## 31:  0.0021266202
-## 32:  0.0017535038
-## 33:  0.0016041990
-## 34: -0.0010032801
-## 35: -0.0002829866
+##  1: -1.424003e+00
+##  2:  3.654157e-01
+##  3: -2.358064e-01
+##  4:  2.182822e-01
+##  5: -1.713943e-01
+##  6:  1.648782e-01
+##  7:  1.622979e-01
+##  8: -1.205285e-01
+##  9: -9.526696e-02
+## 10: -9.382067e-02
+## 11:  5.468026e-02
+## 12:  5.273315e-02
+## 13: -5.056465e-02
+## 14:  4.527346e-02
+## 15: -3.735277e-02
+## 16: -3.543529e-02
+## 17:  2.380514e-02
+## 18: -2.255508e-02
+## 19: -2.176114e-02
+## 20: -1.612916e-02
+## 21:  1.531317e-02
+## 22: -1.233111e-02
+## 23:  1.172920e-02
+## 24:  8.162029e-03
+## 25: -7.160000e-03
+## 26: -5.500155e-03
+## 27:  5.083293e-03
+## 28: -3.625005e-03
+## 29:  1.488567e-03
+## 30:  1.350348e-03
+## 31:  3.224244e-04
+## 32: -3.021403e-04
+## 33: -2.126168e-04
+## 34: -6.625025e-05
+## 35:  1.927162e-06
 ##              coef
 ##                                                                                                                term
 ##  1:                                                                                                     (Intercept)
@@ -365,31 +430,31 @@ 

Reducing basis functions## 8: [ I(x1 >= 0.941)*(x1 - 0.941)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 9: [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 10: [ I(x1 >= 1.368)*(x1 - 1.368)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 11: [ I(x1 >= -1.844)*(x1 - -1.844)^1 ] -## 12: [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 11: [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 12: [ I(x1 >= -1.844)*(x1 - -1.844)^1 ] ## 13: [ I(x2 >= 0.696)*(x2 - 0.696)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 14: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.156)*(x2 - 1.156)^1 ] * [ I(x3 >= -0.497)*(x3 - -0.497)^1 ] -## 15: [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 16: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 17: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] -## 18: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] -## 19: [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 15: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 16: [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 17: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] +## 18: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] +## 19: [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] ## 20: [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 21: [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 22: [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] -## 23: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ] -## 24: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 21: [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 22: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ] +## 23: [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 24: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ] ## 25: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 26: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ] -## 27: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ] -## 28: [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 29: [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 30: [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 31: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ] -## 32: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 33: [ I(x1 >= -1.593)*(x1 - -1.593)^1 ] -## 34: [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 35: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 1.202)*(x3 - 1.202)^1 ] +## 26: [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 27: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ] +## 28: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 29: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ] +## 30: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 31: [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 32: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 33: [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 34: [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 35: [ I(x1 >= -0.685)*(x1 - -0.685)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## term

Other approaches exist for reducing the set of basis functions before they are actually created, which is essential for most @@ -428,7 +493,7 @@

Specifying smoothness of the HAL under the constraint that the total variation of the function’s \(k^{\text{th}}\) derivative is bounded by some constant, which is selected with cross-validation.

Let’s see this in action.

-
+
 set.seed(98109)
 num_knots <- 100 # Try changing this value to see what happens.
 n_covars <- 1
@@ -474,16 +539,20 @@ 

Specifying smoothness of the HAL for a near-optimal fit. Therefore, one can safely pass a smaller value to num_knots for a big decrease in runtime without sacrificing performance.

+<<<<<<< HEAD +
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 mean((pred_0 - ytrue)^2)
## [1] 0.00732315
-
-mean((pred_smooth_1- ytrue)^2)
-
## [1] 0.002432486
-mean((pred_smooth_2 - ytrue)^2)
-
## [1] 0.001848927
+mean((pred_smooth_1- ytrue)^2)
+
## [1] 0.003458005
+mean((pred_smooth_2 - ytrue)^2)
+
## [1] 0.001834611
+
 dt <- data.table(x = as.vector(x),
                  ytrue = ytrue,
                  y = y,
@@ -493,13 +562,13 @@ 

Specifying smoothness of the HAL long <- melt(dt, id = "x") ggplot(long, aes(x = x, y = value, color = variable)) + geom_line()

-
+
 plot(x, pred_0, main = "Zero order smoothness fit")

-
+
 plot(x, pred_smooth_1, main = "First order smoothness fit")

-
+
 plot(x, pred_smooth_2, main = "Second order smoothness fit")

In general, if the basis functions are not coarse, then the @@ -513,7 +582,11 @@

Specifying smoothness of the HAL Comparing the following simulation and the previous one, the HAL with second-order smoothness performed better when there were fewer knot points.

+<<<<<<< HEAD +
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 set.seed(98109)
 n_covars <- 1
 n_obs <- 250
@@ -535,22 +608,22 @@ 

Specifying smoothness of the HAL pred_0 <- predict(hal_fit_0, new_data = x) pred_smooth_1 <- predict(hal_fit_smooth_1, new_data = x) pred_smooth_2 <- predict(hal_fit_smooth_2, new_data = x)

-
+
 mean((pred_0 - ytrue)^2)
## [1] 0.00732315
-
-mean((pred_smooth_1- ytrue)^2)
-
## [1] 0.002432486
-mean((pred_smooth_2 - ytrue)^2)
-
## [1] 0.001834611
+mean((pred_smooth_1- ytrue)^2)
+
## [1] 0.003458005
+mean((pred_smooth_2 - ytrue)^2)
+
## [1] 0.001848927
+
 plot(x, pred_0, main = "Zero order smoothness fit")

-
+
 plot(x, pred_smooth_1, main = "First order smoothness fit")

-
+
 plot(x, pred_smooth_2, main = "Second order smoothness fit")

@@ -570,7 +643,11 @@

Formula interface
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 set.seed(98109)
 num_knots <- 100
 
@@ -588,7 +665,11 @@ 

Formula interface
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 # The `h` function is used to specify the basis functions for a given term
 # h(x1) generates one-way basis functions for the variable x1
 # This is an additive model:
@@ -607,7 +688,7 @@ 

Formula interface## ## $orders ## [1] 0

-
+
 # We don't need the variables in the parent environment if we specify them directly:
 rm(smoothness_orders)
 rm(num_knots)
@@ -616,7 +697,7 @@ 

Formula interface# They are the same! length( form_term_new$basis_list) == length(form_term$basis_list)

## [1] TRUE
-
+
 #To evaluate a unevaluated formula object like:
 formula <- ~h(x1) + h(x2) + h(A)
 # we can use the formula_hal function:
@@ -633,13 +714,17 @@ 

Formula interface
+=======
 
 
## [1] "x1" "x2" "A"
-
+
 # Shortcut:
 formula1 <- h(.)
 # Longcut:
@@ -647,7 +732,7 @@ 

Formula interface# Same number of basis functions length(formula1$basis_list ) == length(formula2$basis_list)

## [1] TRUE
-
+
 # Maybe we only want an additive model for x1 and x2
 # Use the `.` argument
 formula1 <- h(., . = c("x1", "x2"))
@@ -655,19 +740,19 @@ 

Formula interfacelength(formula1$basis_list ) == length(formula2$basis_list)

## [1] TRUE

We can specify interactions as follows.

-
+
 #  Two way interactions
 formula1 <-  h(x1) + h(x2)  + h(A) + h(x1, x2)
 formula2 <-  h(.)  + h(x1, x2)
 length(formula1$basis_list ) == length(formula2$basis_list)
## [1] TRUE
-
+
 #
 formula1 <-  h(.) + h(x1, x2)    + h(x1,A) + h(x2,A)
 formula2 <-  h(.)  + h(., .)
 length(formula1$basis_list ) == length(formula2$basis_list)
## [1] TRUE
-
+
 #  Three way interactions
 formula1 <-  h(.) + h(.,.)    + h(x1,A,x2)  
 formula2 <-  h(.)  + h(., .)+ h(.,.,.)  
@@ -677,7 +762,11 @@ 

Formula interface
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 # Write it all out
 formula <-  h(x1) + h(x2) + h(A) + h(A, x1) + h(A,x2)
  
@@ -706,7 +795,11 @@ 

Formula interface
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 # An additive monotone increasing model
 formula <- formula_hal(
   y ~ h(., monotone = "i"), X, smoothness_orders = 0, num_knots = 100
@@ -729,7 +822,7 @@ 

Formula interface~ h(., monotone = "d") + h(.,., monotone = "d"), X, smoothness_orders = 1, num_knots = 100 )

The penalization feature can be used to reproduce glm

-
+
 # Additive glm
 # One knot (at the origin) and first order smoothness
 formula <- h(., s = 1, k = 1, pf = 0)
@@ -741,27 +834,31 @@ 

Formula interface
+=======
 
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
 # get formula object
 fit <- fit_hal(
   X = X, Y = Y, formula = ~ h(.), smoothness_orders = 1, num_knots = 100
 )
 print(summary(fit), 10) # prints top 10 rows, i.e., highest absolute coefs
## 
-## Summary of top 10 non-zero coefficients is based on lambda of 0.0005900481 
+## Summary of top 10 non-zero coefficients is based on lambda of 0.0005376299 
 ## 
 ##        coef                                term
-##   0.7395438                         (Intercept)
-##  -0.3026891   [ I(A >= -3.978)*(A - -3.978)^1 ]
-##  -0.2948108 [ I(x2 >= -3.992)*(x2 - -3.992)^1 ]
-##  -0.2763138   [ I(x1 >= 1.172)*(x1 - 1.172)^1 ]
-##  -0.2493828   [ I(x2 >= 1.638)*(x2 - 1.638)^1 ]
-##  -0.2475319 [ I(x1 >= -3.972)*(x1 - -3.972)^1 ]
-##   0.2098793   [ I(x1 >= -1.45)*(x1 - -1.45)^1 ]
-##   0.2079505   [ I(A >= -1.356)*(A - -1.356)^1 ]
-##  -0.2075157     [ I(A >= 1.384)*(A - 1.384)^1 ]
-##   0.2035278 [ I(x2 >= -1.293)*(x2 - -1.293)^1 ]
-
+##   0.7473715                         (Intercept)
+##  -0.3081866   [ I(A >= -3.978)*(A - -3.978)^1 ]
+##  -0.2894577 [ I(x2 >= -3.992)*(x2 - -3.992)^1 ]
+##  -0.2811616   [ I(x1 >= 1.172)*(x1 - 1.172)^1 ]
+##  -0.2463477 [ I(x1 >= -3.972)*(x1 - -3.972)^1 ]
+##  -0.2405977   [ I(x2 >= 1.638)*(x2 - 1.638)^1 ]
+##   0.2092225   [ I(x1 >= -1.45)*(x1 - -1.45)^1 ]
+##   0.2078703 [ I(x2 >= -1.293)*(x2 - -1.293)^1 ]
+##  -0.2039605     [ I(A >= 1.384)*(A - 1.384)^1 ]
+##   0.1993688   [ I(A >= -1.356)*(A - -1.356)^1 ]
+
 plot(predict(fit, new_data = X), Y)

diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png index 32553fc6..1348298c 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png index d12920ac..ee2008b3 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png index 6b1c3032..b4891964 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png index 957b9b53..f2a89afe 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png index 6b1c3032..b4891964 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png index f2a89afe..957b9b53 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png differ diff --git a/docs/authors.html b/docs/authors.html index b4c87ab8..9f3e40fb 100644 --- a/docs/authors.html +++ b/docs/authors.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5

@@ -103,13 +103,17 @@

Citation

Coyle J, Hejazi N, Phillips R, van der Laan L, van der Laan M (2022). hal9001: The scalable highly adaptive lasso. +<<<<<<< HEAD +doi:10.5281/zenodo.3558313, R package version 0.4.5, https://github.com/tlverse/hal9001. +======= doi:10.5281/zenodo.3558313, R package version 0.4.3, https://github.com/tlverse/hal9001. +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c

@Manual{,
   title = {{hal9001}: The scalable highly adaptive lasso},
   author = {Jeremy R Coyle and Nima S Hejazi and Rachael V Phillips and Lars WP {van der Laan} and Mark J {van der Laan}},
   year = {2022},
-  note = {R package version 0.4.3},
+  note = {R package version 0.4.5},
   doi = {10.5281/zenodo.3558313},
   url = {https://github.com/tlverse/hal9001},
 }
diff --git a/docs/index.html b/docs/index.html index 4faf8401..b68b1ec1 100644 --- a/docs/index.html +++ b/docs/index.html @@ -49,7 +49,7 @@ hal9001 - 0.4.3 + 0.4.5
@@ -128,7 +128,11 @@

Example# load the package and set a seed library(hal9001) #> Loading required package: Rcpp +<<<<<<< HEAD +#> hal9001 v0.4.4: The Scalable Highly Adaptive Lasso +======= #> hal9001 v0.4.5: The Scalable Highly Adaptive Lasso +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c #> note: fit_hal defaults have changed. See ?fit_hal for details set.seed(385971) @@ -143,12 +147,21 @@

Example#> [1] "I'm sorry, Dave. I'm afraid I can't do that." hal_fit$times #> user.self sys.self elapsed user.child sys.child +<<<<<<< HEAD +#> enumerate_basis 0.014 0.003 0.059 0 0 +#> design_matrix 0.004 0.001 0.005 0 0 +#> reduce_basis 0.000 0.000 0.000 0 0 +#> remove_duplicates 0.000 0.000 0.000 0 0 +#> lasso 2.684 0.343 6.583 0 0 +#> total 2.703 0.348 6.655 0 0 +======= #> enumerate_basis 0.008 0.001 0.009 0 0 #> design_matrix 0.003 0.000 0.003 0 0 #> reduce_basis 0.000 0.000 0.000 0 0 #> remove_duplicates 0.000 0.000 0.001 0 0 #> lasso 1.690 0.036 1.764 0 0 #> total 1.702 0.037 1.778 0 0 +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c # training sample prediction preds <- predict(hal_fit, new_data = x) diff --git a/docs/news/index.html b/docs/news/index.html index 880a8638..ae051c98 100644 --- a/docs/news/index.html +++ b/docs/news/index.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5

@@ -59,6 +59,20 @@

Changelog

Source: NEWS.md

+
+ +
  • Added multivariate outcome prediction
  • +
+
+ +
  • Fixed bug with prediction_bounds (a fit_hal argument in fit_control list), which would error when it was specified as a numeric vector. Also, added a check to assert this argument is correctly specified, and tests to ensure a numeric vector of bounds is provided.
  • +
  • Simplified fit_control list arguments in fit_hal. Users can still specify additional arguments to cv.glmnet and glmnet in this list.
  • +
  • Defined weights as a formal argument in fit_hal, opposed to an optional argument in fit_control, to facilitate specification and avoid confusion. This increases flexibility with SuperLearner wrapper SL.hal9001 as well; fit_control can now be customized with SL.hal9001.
  • +
+
+ +
  • Version bump for CRAN resubmission following archiving.
  • +
  • Version bump for CRAN resubmission following archiving.
  • diff --git a/docs/pkgdown.yml b/docs/pkgdown.yml index ce08e0d7..57535cfb 100644 --- a/docs/pkgdown.yml +++ b/docs/pkgdown.yml @@ -3,7 +3,11 @@ pkgdown: 2.0.3 pkgdown_sha: ~ articles: intro_hal9001: intro_hal9001.html +<<<<<<< HEAD +last_built: 2022-07-13T01:48Z +======= last_built: 2022-11-04T20:06Z +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c urls: reference: https://tlverse.org/hal9001/reference article: https://tlverse.org/hal9001/articles diff --git a/docs/reference/SL.hal9001.html b/docs/reference/SL.hal9001.html index d10342df..9b649b08 100644 --- a/docs/reference/SL.hal9001.html +++ b/docs/reference/SL.hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
@@ -68,15 +68,13 @@

Wrapper for Classic SuperLearner

SL.hal9001(
   Y,
   X,
-  newX = NULL,
-  family = stats::gaussian(),
-  obsWeights = rep(1, length(Y)),
-  id = NULL,
-  max_degree = ifelse(ncol(X) >= 20, 2, 3),
+  newX,
+  family,
+  obsWeights,
+  id,
+  max_degree = 2,
   smoothness_orders = 1,
-  num_knots = ifelse(smoothness_orders >= 1, 25, 50),
-  reduce_basis = 1/sqrt(length(Y)),
-  lambda = NULL,
+  num_knots = 5,
   ...
 )
@@ -112,20 +110,8 @@

Arguments

specifying the maximum number of knot points (i.e., bins) for each covariate for generating basis functions. See num_knots argument in fit_hal for more information.

-
reduce_basis
-

A numeric value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed.

-
lambda
-

A user-specified sequence of values of the regularization -parameter for the lasso L1 regression. If NULL, the default sequence -in cv.glmnet will be used. The cross-validated -optimal value of this regularization parameter will be selected with -cv.glmnet.

...
-

Not used.

+

Additional arguments to fit_hal.

Value

diff --git a/docs/reference/apply_copy_map.html b/docs/reference/apply_copy_map.html index 609faab4..c7bd4772 100644 --- a/docs/reference/apply_copy_map.html +++ b/docs/reference/apply_copy_map.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/as_dgCMatrix.html b/docs/reference/as_dgCMatrix.html index 19770364..c5864c12 100644 --- a/docs/reference/as_dgCMatrix.html +++ b/docs/reference/as_dgCMatrix.html @@ -25,7 +25,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/basis_list_cols.html b/docs/reference/basis_list_cols.html index 25e30700..4d8fe72b 100644 --- a/docs/reference/basis_list_cols.html +++ b/docs/reference/basis_list_cols.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/basis_of_degree.html b/docs/reference/basis_of_degree.html index 9e6d0da5..db22bf40 100644 --- a/docs/reference/basis_of_degree.html +++ b/docs/reference/basis_of_degree.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/calc_pnz.html b/docs/reference/calc_pnz.html index 1d5a5b48..60788892 100644 --- a/docs/reference/calc_pnz.html +++ b/docs/reference/calc_pnz.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5

diff --git a/docs/reference/calc_xscale.html b/docs/reference/calc_xscale.html index b00b4c64..a1ea2570 100644 --- a/docs/reference/calc_xscale.html +++ b/docs/reference/calc_xscale.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/cv_lasso.html b/docs/reference/cv_lasso.html index 46721e99..0fdd5e63 100644 --- a/docs/reference/cv_lasso.html +++ b/docs/reference/cv_lasso.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/cv_lasso_early_stopping.html b/docs/reference/cv_lasso_early_stopping.html index 4bf959c2..395f3dec 100644 --- a/docs/reference/cv_lasso_early_stopping.html +++ b/docs/reference/cv_lasso_early_stopping.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5
@@ -99,7 +99,7 @@

Arguments

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/enumerate_basis.html b/docs/reference/enumerate_basis.html index f703cd4f..390c6af6 100644 --- a/docs/reference/enumerate_basis.html +++ b/docs/reference/enumerate_basis.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/enumerate_edge_basis.html b/docs/reference/enumerate_edge_basis.html index e2fc5b7a..c38f93fc 100644 --- a/docs/reference/enumerate_edge_basis.html +++ b/docs/reference/enumerate_edge_basis.html @@ -28,7 +28,7 @@ hal9001 - 0.4.3 + 0.4.5

diff --git a/docs/reference/evaluate_basis.html b/docs/reference/evaluate_basis.html index 535d5286..896ceb95 100644 --- a/docs/reference/evaluate_basis.html +++ b/docs/reference/evaluate_basis.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/fit_hal.html b/docs/reference/fit_hal.html index e66ae317..8aac31ee 100644 --- a/docs/reference/fit_hal.html +++ b/docs/reference/fit_hal.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
@@ -74,13 +74,14 @@

HAL: The Highly Adaptive Lasso

smoothness_orders = 1, num_knots = num_knots_generator(max_degree = max_degree, smoothness_orders = smoothness_orders, base_num_knots_0 = 200, base_num_knots_1 = 50), - reduce_basis = 1/sqrt(length(Y)), - family = c("gaussian", "binomial", "poisson", "cox"), + reduce_basis = NULL, + family = c("gaussian", "binomial", "poisson", "cox", "mgaussian"), lambda = NULL, id = NULL, + weights = NULL, offset = NULL, - fit_control = list(cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE, - lambda.min.ratio = 1e-04, prediction_bounds = "default"), + fit_control = list(cv_select = TRUE, use_min = TRUE, lambda.min.ratio = 1e-04, + prediction_bounds = "default"), basis_list = NULL, return_lasso = TRUE, return_x_basis = FALSE, @@ -95,7 +96,9 @@

Arguments

number of covariates that will be used to derive the design matrix of basis functions.

Y
-

A numeric vector of observations of the outcome variable.

+

A numeric vector of observations of the outcome variable. For +family="mgaussian", Y is a matrix of observations of the +outcome variables.

formula

A character string formula to be used in formula_hal. See its documentation for details.

@@ -123,19 +126,21 @@

Arguments

basis functions generated. This allows the complexity of the optimization problem to grow scalably. See details of num_knots more information.

reduce_basis
-

A numeric value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed. When reduce_basis is set to NULL, all basis -functions are used in the lasso-fitting stage of fit_hal.

+

Am optional numeric value bounded in the open +unit interval indicating the minimum proportion of 1's in a basis function +column needed for the basis function to be included in the procedure to fit +the lasso. Any basis functions with a lower proportion of 1's than the +cutoff will be removed. Defaults to 1 over the square root of the number of +observations. Only applicable for models fit with zero-order splines, i.e. +smoothness_orders = 0.

family

A character or a family object (supported by glmnet) specifying the error/link family for a generalized linear model. character options are limited to "gaussian" for fitting a standard penalized linear model, "binomial" for penalized logistic regression, "poisson" for penalized Poisson regression, -and "cox" for a penalized proportional hazards model. Note that passing in +"cox" for a penalized proportional hazards model, and "mgaussian" for +multivariate penalized linear model. Note that passing in family objects leads to slower performance relative to passing in a character family (if supported). For example, one should set family = "binomial" instead of family = binomial() when @@ -153,39 +158,36 @@

Arguments

A vector of ID values that is used to generate cross-validation folds for cv.glmnet. This argument is ignored when fit_control's cv_select argument is FALSE.

+
weights
+

observation weights; defaults to 1 per observation.

offset

a vector of offset values, used in fitting.

fit_control
-

List of arguments for fitting. Includes the following -arguments, and any others to be passed to cv.glmnet -or glmnet.

  • cv_select: A logical specifying if the sequence of +

    List of arguments, including the following, and any +others to be passed to cv.glmnet or +glmnet.

    • cv_select: A logical specifying if the sequence of specified lambda values should be passed to cv.glmnet in order for a single, optimal value of lambda to be selected according to cross-validation. When cv_select = FALSE, a glmnet model will be used to fit the sequence of (or single) lambda.

    • -
    • n_folds: Integer for the number of folds to be used when splitting -the data for V-fold cross-validation. Only used when -cv_select = TRUE.

    • -
    • foldid: An optional numeric containing values between 1 and -n_folds, identifying the fold to which each observation is -assigned. If supplied, n_folds can be missing. In such a case, -this vector is passed directly to cv.glmnet. Only -used when cv_select = TRUE.

    • use_min: Specify the choice of lambda to be selected by cv.glmnet. When TRUE, "lambda.min" is used; otherwise, "lambda.1se". Only used when cv_select = TRUE.

    • -
    • lambda.min.ratio: A glmnet argument specifying -the smallest value for lambda, as a fraction of lambda.max, -the (data derived) entry value (i.e. the smallest value for which all -coefficients are zero). We've seen that not setting lambda.min.ratio -can lead to no lambda values that fit the data sufficiently well.

    • -
    • prediction_bounds: A vector of size two that provides the lower and -upper bounds for predictions. When prediction_bounds = "default", -the predictions are bounded between min(Y) - sd(Y) and -max(Y) + sd(Y). Bounding ensures that there is no extrapolation, -and it is necessary for cross-validation selection and/or Super Learning.

    • +
    • lambda.min.ratio: A glmnet argument +specifying the smallest value for lambda, as a fraction of +lambda.max, the (data derived) entry value (i.e. the smallest value +for which all coefficients are zero). We've seen that not setting +lambda.min.ratio can lead to no lambda values that fit the +data sufficiently well.

    • +
    • prediction_bounds: An optional vector of size two that provides +the lower and upper bounds predictions; not used when +family = "cox". When prediction_bounds = "default", the +predictions are bounded between min(Y) - sd(Y) and +max(Y) + sd(Y) for each outcome (when family = "mgaussian", +each outcome can have different bounds). Bounding ensures that there is +no extrapolation.

    basis_list

    The full set of basis functions generated from X.

    diff --git a/docs/reference/formula_hal.html b/docs/reference/formula_hal.html index e8230126..72f5de0d 100644 --- a/docs/reference/formula_hal.html +++ b/docs/reference/formula_hal.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5

diff --git a/docs/reference/formula_helpers.html b/docs/reference/formula_helpers.html index a7d761fb..a44a2397 100644 --- a/docs/reference/formula_helpers.html +++ b/docs/reference/formula_helpers.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5
diff --git a/docs/reference/h.html b/docs/reference/h.html index dde8a391..ee946331 100644 --- a/docs/reference/h.html +++ b/docs/reference/h.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5

diff --git a/docs/reference/hal9000.html b/docs/reference/hal9000.html index d2b2a1e0..10fb358f 100644 --- a/docs/reference/hal9000.html +++ b/docs/reference/hal9000.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/hal9001.html b/docs/reference/hal9001.html index 9721fdbb..082d05ed 100644 --- a/docs/reference/hal9001.html +++ b/docs/reference/hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/hal_quotes.html b/docs/reference/hal_quotes.html index 8c94ebad..7f14d9b4 100644 --- a/docs/reference/hal_quotes.html +++ b/docs/reference/hal_quotes.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/index.html b/docs/reference/index.html index 0c29898d..ad92e617 100644 --- a/docs/reference/index.html +++ b/docs/reference/index.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/index_first_copy.html b/docs/reference/index_first_copy.html index ebb7fa38..9129c585 100644 --- a/docs/reference/index_first_copy.html +++ b/docs/reference/index_first_copy.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/lassi.html b/docs/reference/lassi.html index b7c8f2c4..0e3db014 100644 --- a/docs/reference/lassi.html +++ b/docs/reference/lassi.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -104,7 +104,7 @@

Arguments

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/lassi_fit_module.html b/docs/reference/lassi_fit_module.html index 52199f43..d563a3b9 100644 --- a/docs/reference/lassi_fit_module.html +++ b/docs/reference/lassi_fit_module.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -78,7 +78,7 @@

Rcpp module: lassi_fit_module

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/lassi_origami.html b/docs/reference/lassi_origami.html index cf65496f..be67aded 100644 --- a/docs/reference/lassi_origami.html +++ b/docs/reference/lassi_origami.html @@ -26,7 +26,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -103,7 +103,7 @@

Arguments

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/lassi_predict.html b/docs/reference/lassi_predict.html index eb63fe65..6527eab9 100644 --- a/docs/reference/lassi_predict.html +++ b/docs/reference/lassi_predict.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -90,7 +90,7 @@

Arguments

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/make_basis_list.html b/docs/reference/make_basis_list.html index 7c265ce1..a57979d7 100644 --- a/docs/reference/make_basis_list.html +++ b/docs/reference/make_basis_list.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/make_copy_map.html b/docs/reference/make_copy_map.html index 6a02ad48..77966b3e 100644 --- a/docs/reference/make_copy_map.html +++ b/docs/reference/make_copy_map.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/make_design_matrix.html b/docs/reference/make_design_matrix.html index 9ed0bebd..4a27c0ca 100644 --- a/docs/reference/make_design_matrix.html +++ b/docs/reference/make_design_matrix.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/make_reduced_basis_map.html b/docs/reference/make_reduced_basis_map.html index 10805fb4..a435cdf2 100644 --- a/docs/reference/make_reduced_basis_map.html +++ b/docs/reference/make_reduced_basis_map.html @@ -25,7 +25,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/meets_basis.html b/docs/reference/meets_basis.html index 1eea6fdb..05c915df 100644 --- a/docs/reference/meets_basis.html +++ b/docs/reference/meets_basis.html @@ -24,7 +24,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/num_knots_generator.html b/docs/reference/num_knots_generator.html index 75235939..83df7597 100644 --- a/docs/reference/num_knots_generator.html +++ b/docs/reference/num_knots_generator.html @@ -26,7 +26,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/plus-.formula_hal9001.html b/docs/reference/plus-.formula_hal9001.html index 6559f4fe..de21f362 100644 --- a/docs/reference/plus-.formula_hal9001.html +++ b/docs/reference/plus-.formula_hal9001.html @@ -26,7 +26,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/predict.SL.hal9001.html b/docs/reference/predict.SL.hal9001.html index 9f2d224f..2d076d12 100644 --- a/docs/reference/predict.SL.hal9001.html +++ b/docs/reference/predict.SL.hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/predict.hal9001.html b/docs/reference/predict.hal9001.html index d94608e6..093aa978 100644 --- a/docs/reference/predict.hal9001.html +++ b/docs/reference/predict.hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -72,7 +72,6 @@

Prediction from HAL fits

new_X_unpenalized = NULL, offset = NULL, type = c("response", "link"), - p_reserve = 0.75, ... )
@@ -96,11 +95,6 @@

Arguments

type

Either "response" for predictions of the response, or "link" for un-transformed predictions (on the scale of the link function).

-
p_reserve
-

Sparse matrix pre-allocation proportion, which is the -anticipated proportion of 1's in the design matrix. Default value is -recommended in most settings. If a dense design matrix is expected, it -would be useful to set p_reserve to a higher value.

...

Additional arguments passed to predict as necessary.

diff --git a/docs/reference/predict.lassi.html b/docs/reference/predict.lassi.html index c63ed016..2b3d71a4 100644 --- a/docs/reference/predict.lassi.html +++ b/docs/reference/predict.lassi.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 @@ -91,7 +91,7 @@

Arguments

-

Site built with pkgdown 2.0.2.

+

Site built with pkgdown 2.0.3.

diff --git a/docs/reference/print.formula_hal9001.html b/docs/reference/print.formula_hal9001.html index d4228a34..42cb6fe6 100644 --- a/docs/reference/print.formula_hal9001.html +++ b/docs/reference/print.formula_hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/print.summary.hal9001.html b/docs/reference/print.summary.hal9001.html index 6ff49f7e..71c0d652 100644 --- a/docs/reference/print.summary.hal9001.html +++ b/docs/reference/print.summary.hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/quantizer.html b/docs/reference/quantizer.html index 7a710d60..af453232 100644 --- a/docs/reference/quantizer.html +++ b/docs/reference/quantizer.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/squash_hal_fit.html b/docs/reference/squash_hal_fit.html index e6c72b5c..0ff4cac9 100644 --- a/docs/reference/squash_hal_fit.html +++ b/docs/reference/squash_hal_fit.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/docs/reference/summary.hal9001.html b/docs/reference/summary.hal9001.html index 795a34c1..b4ec1b16 100644 --- a/docs/reference/summary.hal9001.html +++ b/docs/reference/summary.hal9001.html @@ -23,7 +23,7 @@ hal9001 - 0.4.3 + 0.4.5 diff --git a/man/SL.hal9001.Rd b/man/SL.hal9001.Rd index 974d3fb2..9cff7396 100644 --- a/man/SL.hal9001.Rd +++ b/man/SL.hal9001.Rd @@ -7,15 +7,13 @@ SL.hal9001( Y, X, - newX = NULL, - family = stats::gaussian(), - obsWeights = rep(1, length(Y)), - id = NULL, - max_degree = ifelse(ncol(X) >= 20, 2, 3), + newX, + family, + obsWeights, + id, + max_degree = 2, smoothness_orders = 1, - num_knots = ifelse(smoothness_orders >= 1, 25, 50), - reduce_basis = 1/sqrt(length(Y)), - lambda = NULL, + num_knots = 5, ... ) } @@ -49,19 +47,7 @@ specifying the maximum number of knot points (i.e., bins) for each covariate for generating basis functions. See \code{num_knots} argument in \code{\link{fit_hal}} for more information.} -\item{reduce_basis}{A \code{numeric} value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed.} - -\item{lambda}{A user-specified sequence of values of the regularization -parameter for the lasso L1 regression. If \code{NULL}, the default sequence -in \code{\link[glmnet]{cv.glmnet}} will be used. The cross-validated -optimal value of this regularization parameter will be selected with -\code{\link[glmnet]{cv.glmnet}}.} - -\item{...}{Not used.} +\item{...}{Additional arguments to \code{\link{fit_hal}}.} } \value{ An object of class \code{SL.hal9001} with a fitted \code{hal9001} diff --git a/man/cv_lasso.Rd b/man/cv_lasso.Rd deleted file mode 100644 index 26aa3dff..00000000 --- a/man/cv_lasso.Rd +++ /dev/null @@ -1,29 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/cv_lasso.R -\name{cv_lasso} -\alias{cv_lasso} -\title{Cross-validated Lasso on Indicator Bases} -\usage{ -cv_lasso(x_basis, y, n_lambda = 100, n_folds = 10, center = FALSE) -} -\arguments{ -\item{x_basis}{A \code{dgCMatrix} object corresponding to a sparse matrix of -the basis functions generated for the HAL algorithm.} - -\item{y}{A \code{numeric} vector of the observed outcome variable values.} - -\item{n_lambda}{A \code{numeric} scalar indicating the number of values of -the L1 regularization parameter (lambda) to be obtained from fitting the -Lasso to the full data. Cross-validation is used to select an optimal -lambda (that minimizes the risk) from among these.} - -\item{n_folds}{A \code{numeric} scalar for the number of folds to be used in -the cross-validation procedure to select an optimal value of lambda.} - -\item{center}{binary. If \code{TRUE}, covariates are centered. This is much -slower, but matches the \code{glmnet} implementation. Default \code{FALSE}.} -} -\description{ -Fits Lasso regression using a customized procedure, with cross-validation -based on \pkg{origami} -} diff --git a/man/fit_hal.Rd b/man/fit_hal.Rd index f2ac0a21..8f67fdf1 100644 --- a/man/fit_hal.Rd +++ b/man/fit_hal.Rd @@ -13,13 +13,14 @@ fit_hal( smoothness_orders = 1, num_knots = num_knots_generator(max_degree = max_degree, smoothness_orders = smoothness_orders, base_num_knots_0 = 200, base_num_knots_1 = 50), - reduce_basis = 1/sqrt(length(Y)), - family = c("gaussian", "binomial", "poisson", "cox"), + reduce_basis = NULL, + family = c("gaussian", "binomial", "poisson", "cox", "mgaussian"), lambda = NULL, id = NULL, + weights = NULL, offset = NULL, - fit_control = list(cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE, - lambda.min.ratio = 1e-04, prediction_bounds = "default"), + fit_control = list(cv_select = TRUE, use_min = TRUE, lambda.min.ratio = 1e-04, + prediction_bounds = "default"), basis_list = NULL, return_lasso = TRUE, return_x_basis = FALSE, @@ -31,7 +32,9 @@ fit_hal( number of covariates that will be used to derive the design matrix of basis functions.} -\item{Y}{A \code{numeric} vector of observations of the outcome variable.} +\item{Y}{A \code{numeric} vector of observations of the outcome variable. For +\code{family="mgaussian"}, \code{Y} is a matrix of observations of the +outcome variables.} \item{formula}{A character string formula to be used in \code{\link{formula_hal}}. See its documentation for details.} @@ -59,19 +62,21 @@ combinatorial explosions in the number of higher-degree and higher-order basis functions generated. This allows the complexity of the optimization problem to grow scalably. See details of \code{num_knots} more information.} -\item{reduce_basis}{A \code{numeric} value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed. When \code{reduce_basis} is set to \code{NULL}, all basis -functions are used in the lasso-fitting stage of \code{fit_hal}.} +\item{reduce_basis}{Am optional \code{numeric} value bounded in the open +unit interval indicating the minimum proportion of 1's in a basis function +column needed for the basis function to be included in the procedure to fit +the lasso. Any basis functions with a lower proportion of 1's than the +cutoff will be removed. Defaults to 1 over the square root of the number of +observations. Only applicable for models fit with zero-order splines, i.e. +\code{smoothness_orders = 0}.} \item{family}{A \code{character} or a \code{\link[stats]{family}} object (supported by \code{\link[glmnet]{glmnet}}) specifying the error/link family for a generalized linear model. \code{character} options are limited to "gaussian" for fitting a standard penalized linear model, "binomial" for penalized logistic regression, "poisson" for penalized Poisson regression, -and "cox" for a penalized proportional hazards model. Note that passing in +"cox" for a penalized proportional hazards model, and "mgaussian" for +multivariate penalized linear model. Note that passing in family objects leads to slower performance relative to passing in a character family (if supported). For example, one should set \code{family = "binomial"} instead of \code{family = binomial()} when @@ -90,11 +95,13 @@ lambda in the input array will be returned.} folds for \code{\link[glmnet]{cv.glmnet}}. This argument is ignored when \code{fit_control}'s \code{cv_select} argument is \code{FALSE}.} +\item{weights}{observation weights; defaults to 1 per observation.} + \item{offset}{a vector of offset values, used in fitting.} -\item{fit_control}{List of arguments for fitting. Includes the following -arguments, and any others to be passed to \code{\link[glmnet]{cv.glmnet}} -or \code{\link[glmnet]{glmnet}}. +\item{fit_control}{List of arguments, including the following, and any +others to be passed to \code{\link[glmnet]{cv.glmnet}} or +\code{\link[glmnet]{glmnet}}. \itemize{ \item \code{cv_select}: A \code{logical} specifying if the sequence of specified \code{lambda} values should be passed to @@ -102,28 +109,23 @@ specified \code{lambda} values should be passed to \code{lambda} to be selected according to cross-validation. When \code{cv_select = FALSE}, a \code{\link[glmnet]{glmnet}} model will be used to fit the sequence of (or single) \code{lambda}. -\item \code{n_folds}: Integer for the number of folds to be used when splitting -the data for V-fold cross-validation. Only used when -\code{cv_select = TRUE}. -\item \code{foldid}: An optional \code{numeric} containing values between 1 and -\code{n_folds}, identifying the fold to which each observation is -assigned. If supplied, \code{n_folds} can be missing. In such a case, -this vector is passed directly to \code{\link[glmnet]{cv.glmnet}}. Only -used when \code{cv_select = TRUE}. \item \code{use_min}: Specify the choice of lambda to be selected by \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is used; otherwise, \code{"lambda.1se"}. Only used when \code{cv_select = TRUE}. -\item \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument specifying -the smallest value for \code{lambda}, as a fraction of \code{lambda.max}, -the (data derived) entry value (i.e. the smallest value for which all -coefficients are zero). We've seen that not setting \code{lambda.min.ratio} -can lead to no \code{lambda} values that fit the data sufficiently well. -\item \code{prediction_bounds}: A vector of size two that provides the lower and -upper bounds for predictions. When \code{prediction_bounds = "default"}, -the predictions are bounded between \code{min(Y) - sd(Y)} and -\code{max(Y) + sd(Y)}. Bounding ensures that there is no extrapolation, -and it is necessary for cross-validation selection and/or Super Learning. +\item \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument +specifying the smallest value for \code{lambda}, as a fraction of +\code{lambda.max}, the (data derived) entry value (i.e. the smallest value +for which all coefficients are zero). We've seen that not setting +\code{lambda.min.ratio} can lead to no \code{lambda} values that fit the +data sufficiently well. +\item \code{prediction_bounds}: An optional vector of size two that provides +the lower and upper bounds predictions; not used when +\code{family = "cox"}. When \code{prediction_bounds = "default"}, the +predictions are bounded between \code{min(Y) - sd(Y)} and +\code{max(Y) + sd(Y)} for each outcome (when \code{family = "mgaussian"}, +each outcome can have different bounds). Bounding ensures that there is +no extrapolation. }} \item{basis_list}{The full set of basis functions generated from \code{X}.} diff --git a/man/formula_helpers.Rd b/man/formula_helpers.Rd deleted file mode 100644 index 893a2de1..00000000 --- a/man/formula_helpers.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/formula_hal9001.R -\name{formula_helpers} -\alias{formula_helpers} -\alias{fill_dots_helper} -\alias{fill_dots} -\title{Formula Helpers} -\usage{ -fill_dots_helper(var_names, .) - -fill_dots(var_names, .) -} -\arguments{ -\item{var_names}{A \code{character} vector of variable names.} - -\item{.}{Specification of variables for use in the formula.} -} -\description{ -Formula Helpers -} diff --git a/man/generate_all_rules.Rd b/man/generate_all_rules.Rd new file mode 100644 index 00000000..6fc2967a --- /dev/null +++ b/man/generate_all_rules.Rd @@ -0,0 +1,14 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/summary.R +\name{generate_all_rules} +\alias{generate_all_rules} +\title{Generates rules based on knot points of the fitted HAL basis functions with +non-zero coefficients.} +\usage{ +generate_all_rules(basis_list, coefs, X_colnames) +} +\description{ +Generates rules based on knot points of the fitted HAL basis functions with +non-zero coefficients. +} +\keyword{internal} diff --git a/man/predict.hal9001.Rd b/man/predict.hal9001.Rd index f2e174c8..43657035 100644 --- a/man/predict.hal9001.Rd +++ b/man/predict.hal9001.Rd @@ -10,7 +10,6 @@ new_X_unpenalized = NULL, offset = NULL, type = c("response", "link"), - p_reserve = 0.75, ... ) } @@ -32,11 +31,6 @@ observations as \code{new_data}.} \item{type}{Either "response" for predictions of the response, or "link" for un-transformed predictions (on the scale of the link function).} -\item{p_reserve}{Sparse matrix pre-allocation proportion, which is the -anticipated proportion of 1's in the design matrix. Default value is -recommended in most settings. If a dense design matrix is expected, it -would be useful to set \code{p_reserve} to a higher value.} - \item{...}{Additional arguments passed to \code{predict} as necessary.} } \value{ diff --git a/tests/testthat/test-formula.R b/tests/testthat/test-formula.R index aeff975d..ba901a94 100644 --- a/tests/testthat/test-formula.R +++ b/tests/testthat/test-formula.R @@ -1,4 +1,3 @@ - context("check formula function") @@ -7,11 +6,11 @@ p <- 3 X <- xmat <- matrix(rnorm(n * p), n, p) colnames(X) <- c("X1", "X2", "X3") - + test_that("Check formula", { smoothness_orders <- 1 -num_knots <- 3 + num_knots <- 3 expect_true(length(h(X1)$basis_list) == num_knots) expect_true(h(X1)$basis_list[[1]]$orders == 1) expect_true(all(h(X1)$penalty.factors == 1)) @@ -29,9 +28,9 @@ num_knots <- 3 expect_true(length(setdiff(formula_hal(formula)$basis_list, (h(X1) + h(X2))$basis_list)) == 0) expect_true(length(formula_hal(formula, num_knots = 3)$basis_list) == length(formula_hal(formula)$basis_list)) expect_true(length(formula_hal(formula, num_knots = 10)$basis_list) != length(formula_hal(formula)$basis_list)) - formula <- h(., k =2)$basis_list + formula <- h(., k = 2)$basis_list expect_true(length(formula[[1]]$cols) == 1) - formula <- h(.,., k =2)$basis_list + formula <- h(., ., k = 2)$basis_list expect_true(length(formula[[1]]$cols) == 2) }) diff --git a/tests/testthat/test-general_families.R b/tests/testthat/test-general_families.R index 712d3239..1f4a2efc 100644 --- a/tests/testthat/test-general_families.R +++ b/tests/testthat/test-general_families.R @@ -17,9 +17,11 @@ test_n <- 100 test_x <- matrix(rnorm(test_n * p), test_n, p) test_y_prob <- plogis(3 * sin(test_x[, 1]) + sin(test_x[, 2])) test_y <- rbinom(n = test_n, size = 1, prob = y_prob) - +fit_control <- list(prediction_bounds = c(0.01, 0.99)) # ml implementation -ml_hal_fit <- suppressWarnings(fit_hal(X = x, Y = y, family = "binomial")) +ml_hal_fit <- suppressWarnings( + fit_hal(X = x, Y = y, family = "binomial", fit_control = fit_control) +) ml_hal_fit$times x_basis <- make_design_matrix(x, ml_hal_fit$basis_list) @@ -28,7 +30,9 @@ preds <- predict(ml_hal_fit, new_data = x) ml_hal_mse1 <- mse(preds, y_prob) set.seed(45791) -ml_hal_fit <- suppressWarnings(fit_hal(X = x, Y = y, family = binomial())) +ml_hal_fit <- suppressWarnings( + fit_hal(X = x, Y = y, family = binomial(), fit_control = fit_control) +) ml_hal_fit$times x_basis <- make_design_matrix(x, ml_hal_fit$basis_list) @@ -61,3 +65,17 @@ ml_hal_mse2 <- mse(preds, y_prob) test_that("MSE for logistic regression close to logistic family object pred", { expect_true(abs(ml_hal_mse1 - ml_hal_mse2) < 0.01) }) + +test_that("Error when prediction_bounds is incorrectly formatted", { + fit_control <- list(prediction_bounds = 9) + expect_error(fit_hal(X = x, Y = y, fit_control = fit_control)) +}) + +test_that("Message when standardize set to TRUE", { + fit_control <- list(standardize = TRUE) + expect_message(fit_hal(X = x, Y = y, fit_control = fit_control)) +}) + +test_that("Warning when reduce_basis without zero-order smoothness", { + expect_warning(fit_hal(X = x, Y = y, reduce_basis = 0.95)) +}) diff --git a/tests/testthat/test-hal_binomial.R b/tests/testthat/test-hal_binomial.R index 3636066e..cd5d41b2 100644 --- a/tests/testthat/test-hal_binomial.R +++ b/tests/testthat/test-hal_binomial.R @@ -39,3 +39,27 @@ oob_ml_hal_mse <- mse(oob_preds, y = test_y_prob) test_that("MSE for logistic regression on test set is less than for nulll", { expect_lt(oob_ml_hal_mse, mse(rep(mean(y), test_n), test_y_prob)) }) + +test_that("Prediction bounds respected when numeric vector supplied", { + ml_hal_fit <- fit_hal( + X = x, Y = y, family = "binomial", + fit_control = list(prediction_bounds = c(0.4, 0.7)) + ) + preds <- predict(ml_hal_fit, new_data = x) + expect_true(min(preds) >= 0.4) + expect_true(max(preds) <= 0.7) +}) + +test_that("Check of prediction_bounds formatting errors", { + kitty_fit_control <- list(prediction_bounds = 9) + expect_error( + fit_hal(X = x, Y = y, family = "binomial", fit_control = kitty_fit_control) + ) +}) + +test_that("Check of fit_control formatting errors", { + kitty_fit_control <- list("kitty" = TRUE) + expect_warning( + fit_hal(X = x, Y = y, family = "binomial", fit_control = kitty_fit_control) + ) +}) diff --git a/tests/testthat/test-hal_hazards.R b/tests/testthat/test-hal_hazards.R index 41d1e154..3933b099 100644 --- a/tests/testthat/test-hal_hazards.R +++ b/tests/testthat/test-hal_hazards.R @@ -5,7 +5,7 @@ library(glmnet) library(survival) # create survival data structures -data(kidney) +data(cancer, package = "survival") y_surv <- Surv(kidney$time, kidney$status) x_surv <- kidney[, c("age", "sex", "disease", "frail")] x_surv$disease <- as.numeric(x_surv$disease) diff --git a/tests/testthat/test-hal_multivariate.R b/tests/testthat/test-hal_multivariate.R new file mode 100644 index 00000000..2593466b --- /dev/null +++ b/tests/testthat/test-hal_multivariate.R @@ -0,0 +1,55 @@ +context("Multivariate outcome prediction with HAL") + +library(glmnet) +data(MultiGaussianExample) + +# get hal fit +set.seed(74296) +hal_fit <- fit_hal( + X = MultiGaussianExample$x, Y = MultiGaussianExample$y, family = "mgaussian", + return_x_basis = TRUE +) +hal_summary <- summary(hal_fit) + +test_that("HAL and glmnet predictions match for multivariate outcome", { + # get hal preds + hal_pred <- predict(hal_fit, new_data = MultiGaussianExample$x) + # get glmnet preds + set.seed(74296) + glmnet_fit <- cv.glmnet( + x = hal_fit$x_basis, y = MultiGaussianExample$y, + family = "mgaussian", standardize = FALSE, + lambda.min.ratio = 1e-4 + ) + glmnet_pred <- predict(glmnet_fit, hal_fit$x_basis, s = hal_fit$lambda_star)[, , 1] + # test equivalence + colnames(glmnet_pred) <- colnames(hal_pred) + expect_equivalent(glmnet_pred, hal_pred) +}) + +test_that("HAL summarizes coefs for each multivariate outcome prediction", { + expect_equal(ncol(MultiGaussianExample$y), length(hal_summary)) +}) + +test_that("HAL summarizes coefs appropriately for multivariate outcome", { + # this checks intercept matches + lapply(seq_along(hal_summary), function(i) { + expect_equal(hal_fit$coefs[[i]][1, ], as.numeric(hal_summary[[i]]$table[1, 1])) + }) +}) + +test_that("Error when prediction_bounds is incorrectly formatted", { + fit_control <- list(prediction_bounds = 9) + expect_error(fit_hal( + X = MultiGaussianExample$x, Y = MultiGaussianExample$y, + family = "mgaussian", fit_control = fit_control + )) +}) + +test_that("HAL summary for multivariate outcome predictions prints", { + hal_summary2 <- summary(hal_fit, only_nonzero_coefs = FALSE) + expect_output(print(hal_summary, length = 2)) + expect_output(print(hal_summary)) + expect_output(print(hal_summary2, length = 2)) + expect_output(print(hal_summary2)) +}) diff --git a/tests/testthat/test-reduce_basis_filter.R b/tests/testthat/test-reduce_basis_filter.R index 71c5befc..07c5ed53 100644 --- a/tests/testthat/test-reduce_basis_filter.R +++ b/tests/testthat/test-reduce_basis_filter.R @@ -50,7 +50,7 @@ system.time({ mse <- mean(se) se[c(current_i, new_i)] <- 0 new_i <- which.max(se) - print(sprintf("%f, %f", old_mse, mse)) + # print(sprintf("%f, %f", old_mse, mse)) continue <- mse < 1.1 * old_mse if (mse < old_mse) { good_i <- unique(c(good_i, new_i)) @@ -58,7 +58,7 @@ system.time({ old_mse <- mse coefs <- as.vector(coef(screen_glmnet, s = "lambda.min"))[-1] # old_basis <- union(old_basis,c(old_basis,b1)[which(coefs!=0)]) - print(length(old_basis)) + # print(length(old_basis)) old_basis <- unique(c(old_basis, b1)) } @@ -72,7 +72,7 @@ system.time({ if (is.na(rate)) { rate <- -Inf } - print(rate) + # print(rate) continue <- (-1 * rate) > 1e-4 continue <- TRUE continue <- length(current_i) < n @@ -108,13 +108,13 @@ b1 <- coef(fit) fit <- glmnet( x = x_basis, y = y, family = "gaussian", offset = offset, - intercept = FALSE, maxit = 1, thresh = 1, lambda = 0.03 + intercept = FALSE, maxit = 2, thresh = 1, lambda = 0.03 ) b2 <- coef(fit) fit <- glmnet( x = x_basis, y = y, family = "gaussian", offset = offset, - intercept = FALSE, maxit = 1, thresh = 1, lambda = 0.03 + intercept = FALSE, maxit = 2, thresh = 1, lambda = 0.03 ) b3 <- coef(fit) diff --git a/tests/testthat/test-sl_ecpolley.R b/tests/testthat/test-sl_ecpolley.R index 4ad422ea..c0d2a42e 100644 --- a/tests/testthat/test-sl_ecpolley.R +++ b/tests/testthat/test-sl_ecpolley.R @@ -1,4 +1,4 @@ -context("Fits and prediction of classic Super Learner with HAL.") +context("Fits and prediction of SuperLearner package.") library(SuperLearner) # easily compute MSE @@ -25,7 +25,12 @@ pred_hal_test <- predict(hal, new_data = test_x) # run SL-classic with glmnet and get predictions hal_sl <- SuperLearner(Y = y, X = x, SL.lib = "SL.hal9001") -sl_hal_fit <- SL.hal9001(Y = y, X = x) +sl_hal_fit <- SL.hal9001( + Y = y, X = x, newX = NULL, + family = stats::gaussian(), + obsWeights = rep(1, length(y)), + id = seq_along(y) +) # hal9001:::predict.SL.hal9001(sl_hal_fit$fit,newX=x,newdata=x) pred_hal_sl_train <- as.numeric(predict(hal_sl, newX = x)$pred) pred_hal_sl_test <- as.numeric(predict(hal_sl, newX = test_x)$pred) @@ -35,8 +40,6 @@ sl <- SuperLearner( Y = y, X = x, SL.lib = c("SL.mean", "SL.hal9001"), cvControl = list(validRows = hal_sl$validRows) ) -pred_sl_train <- as.numeric(predict(sl, newX = x)$pred) -pred_sl_test <- as.numeric(predict(sl, newX = test_x)$pred) # test for HAL vs. SL-HAL: outputs are the same length test_that("HAL and SuperLearner-HAL produce results of same shape", { @@ -47,6 +50,7 @@ test_that("HAL and SuperLearner-HAL produce results of same shape", { # test of MSEs being close: SL-HAL and SL dominated by HAL should be very close # (hence the rather low tolerance, esp. given an additive scale) test_that("HAL dominates other algorithms when used in SuperLearner", { + pred_sl_test <- as.numeric(predict(sl, newX = test_x)$pred) expect_equal( mse(pred_sl_test, test_y), expected = mse(pred_hal_sl_test, test_y), diff --git a/vignettes/intro_hal9001.Rmd b/vignettes/intro_hal9001.Rmd index a7a1bc7a..46b2c3ee 100644 --- a/vignettes/intro_hal9001.Rmd +++ b/vignettes/intro_hal9001.Rmd @@ -94,7 +94,7 @@ LaTeX with the `xtable` R package. Here's an example: ```{r eval-mse} # training sample prediction for HAL vs HAL9000 mse <- function(preds, y) { - mean((preds - y)^2) + mean((preds - y)^2) } preds_hal <- predict(object = hal_fit, new_data = x) @@ -188,7 +188,7 @@ hal_fit_smooth_1 <- fit_hal( ) hal_fit_smooth_2_all <- fit_hal( - X = x, Y = y, smoothness_orders = 2, num_knots = num_knots, + X = x, Y = y, smoothness_orders = 2, num_knots = num_knots, fit_control = list(cv_select = FALSE) ) @@ -204,7 +204,8 @@ pred_smooth_2_all <- predict(hal_fit_smooth_2_all, new_data = x) dt <- data.table(x = as.vector(x)) dt <- cbind(dt, pred_smooth_2_all) long <- melt(dt, id = "x") -ggplot(long, aes(x = x, y = value, group = variable)) + geom_line() +ggplot(long, aes(x = x, y = value, group = variable)) + + geom_line() ``` Comparing the mean squared error (MSE) between the predictions and the true @@ -218,17 +219,20 @@ to `num_knots` for a big decrease in runtime without sacrificing performance. ```{r} mean((pred_0 - ytrue)^2) -mean((pred_smooth_1- ytrue)^2) +mean((pred_smooth_1 - ytrue)^2) mean((pred_smooth_2 - ytrue)^2) -dt <- data.table(x = as.vector(x), - ytrue = ytrue, - y = y, - pred0 = pred_0, - pred1 = pred_smooth_1, - pred2 = pred_smooth_2) +dt <- data.table( + x = as.vector(x), + ytrue = ytrue, + y = y, + pred0 = pred_0, + pred1 = pred_smooth_1, + pred2 = pred_smooth_2 +) long <- melt(dt, id = "x") -ggplot(long, aes(x = x, y = value, color = variable)) + geom_line() +ggplot(long, aes(x = x, y = value, color = variable)) + + geom_line() plot(x, pred_0, main = "Zero order smoothness fit") plot(x, pred_smooth_1, main = "First order smoothness fit") plot(x, pred_smooth_2, main = "Second order smoothness fit") @@ -270,7 +274,7 @@ pred_smooth_2 <- predict(hal_fit_smooth_2, new_data = x) ```{r} mean((pred_0 - ytrue)^2) -mean((pred_smooth_1- ytrue)^2) +mean((pred_smooth_1 - ytrue)^2) mean((pred_smooth_2 - ytrue)^2) plot(x, pred_0, main = "Zero order smoothness fit") plot(x, pred_smooth_1, main = "First order smoothness fit") @@ -295,7 +299,7 @@ set.seed(98109) num_knots <- 100 n_obs <- 500 -x1 <- runif(n_obs, min = -4, max = 4) +x1 <- runif(n_obs, min = -4, max = 4) x2 <- runif(n_obs, min = -4, max = 4) A <- runif(n_obs, min = -4, max = 4) X <- data.frame(x1 = x1, x2 = x2, A = A) @@ -314,8 +318,8 @@ and without "y" in the character string. # The `h` function is used to specify the basis functions for a given term # h(x1) generates one-way basis functions for the variable x1 # This is an additive model: -formula <- ~h(x1) + h(x2) + h(A) -#We can actually evaluate the h function as well. We need to specify some tuning parameters in the current environment: +formula <- ~ h(x1) + h(x2) + h(A) +# We can actually evaluate the h function as well. We need to specify some tuning parameters in the current environment: smoothness_orders <- 0 num_knots <- 10 # It will look in the parent environment for `X` and the above tuning parameters @@ -324,25 +328,23 @@ form_term$basis_list[[1]] # We don't need the variables in the parent environment if we specify them directly: rm(smoothness_orders) rm(num_knots) -# `h` excepts the arguments `s` and `k`. `s` stands for smoothness and is equivalent to smoothness_orders in use. `k` specifies the number of knots. ` +# `h` excepts the arguments `s` and `k`. `s` stands for smoothness and is equivalent to smoothness_orders in use. `k` specifies the number of knots. ` form_term_new <- h(x1, s = 0, k = 10) + h(x2, s = 0, k = 10) + h(A, s = 0, k = 10) # They are the same! -length( form_term_new$basis_list) == length(form_term$basis_list) +length(form_term_new$basis_list) == length(form_term$basis_list) -#To evaluate a unevaluated formula object like: -formula <- ~h(x1) + h(x2) + h(A) +# To evaluate a unevaluated formula object like: +formula <- ~ h(x1) + h(x2) + h(A) # we can use the formula_hal function: formula <- formula_hal( - ~ h(x1) + h(x2) + h(A), X = X, smoothness_orders = 1, num_knots = 10 + ~ h(x1) + h(x2) + h(A), + X = X, smoothness_orders = 1, num_knots = 10 ) # Note that the arguments smoothness_orders and/or num_knots will not be used if `s` and/or `k` are specified in `h`. formula <- formula_hal( - Y ~ h(x1, k=1) + h(x2, k=1) + h(A, k=1), X = X, smoothness_orders = 1, num_knots = 10 + Y ~ h(x1, k = 1) + h(x2, k = 1) + h(A, k = 1), + X = X, smoothness_orders = 1, num_knots = 10 ) - - - - ``` The `.` argument. We can generate an additive model for all or a subset of variables using the `.` variable and `.` argument of `h`. By default, `.` in `h(.)` is treated as a wildcard and basis functions are generated by replacing the `.` with all variables in `X`. @@ -356,33 +358,30 @@ formula1 <- h(.) # Longcut: formula2 <- h(x1) + h(x2) + h(A) # Same number of basis functions -length(formula1$basis_list ) == length(formula2$basis_list) +length(formula1$basis_list) == length(formula2$basis_list) # Maybe we only want an additive model for x1 and x2 # Use the `.` argument formula1 <- h(., . = c("x1", "x2")) -formula2 <- h(x1) + h(x2) -length(formula1$basis_list ) == length(formula2$basis_list) - +formula2 <- h(x1) + h(x2) +length(formula1$basis_list) == length(formula2$basis_list) ``` We can specify interactions as follows. ```{r} # Two way interactions -formula1 <- h(x1) + h(x2) + h(A) + h(x1, x2) -formula2 <- h(.) + h(x1, x2) -length(formula1$basis_list ) == length(formula2$basis_list) +formula1 <- h(x1) + h(x2) + h(A) + h(x1, x2) +formula2 <- h(.) + h(x1, x2) +length(formula1$basis_list) == length(formula2$basis_list) # -formula1 <- h(.) + h(x1, x2) + h(x1,A) + h(x2,A) -formula2 <- h(.) + h(., .) -length(formula1$basis_list ) == length(formula2$basis_list) +formula1 <- h(.) + h(x1, x2) + h(x1, A) + h(x2, A) +formula2 <- h(.) + h(., .) +length(formula1$basis_list) == length(formula2$basis_list) # Three way interactions -formula1 <- h(.) + h(.,.) + h(x1,A,x2) -formula2 <- h(.) + h(., .)+ h(.,.,.) -length(formula1$basis_list ) == length(formula2$basis_list) - - +formula1 <- h(.) + h(., .) + h(x1, A, x2) +formula2 <- h(.) + h(., .) + h(., ., .) +length(formula1$basis_list) == length(formula2$basis_list) ``` Sometimes, one might want to build an additive model, but include all two-way @@ -391,23 +390,23 @@ variety of ways. The `.` argument allows you to specify a subset of variables. ```{r} # Write it all out -formula <- h(x1) + h(x2) + h(A) + h(A, x1) + h(A,x2) - +formula <- h(x1) + h(x2) + h(A) + h(A, x1) + h(A, x2) + # Use the "h(.)" which stands for add all additive terms and then manually add # interactions -formula <- y ~ h(.) + h(A,x1) + h(A,x2) +formula <- y ~ h(.) + h(A, x1) + h(A, x2) + - # Use the "wildcard" feature for when "." is included in the "h()" term. This # useful when you have many variables and do not want to write out every term. -formula <- h(.) + h(A,.) +formula <- h(.) + h(A, .) -formula1 <- h(A,x1) -formula2 <- h(A,., . = c("x1")) - length(formula1$basis_list) == length(formula2$basis_list) +formula1 <- h(A, x1) +formula2 <- h(A, ., . = c("x1")) +length(formula1$basis_list) == length(formula2$basis_list) ``` @@ -419,29 +418,29 @@ these constraints is achieved by specifying the `monotone` argument of `h`. Note ```{r} # An additive monotone increasing model formula <- formula_hal( - y ~ h(., monotone = "i"), X, smoothness_orders = 0, num_knots = 100 + y ~ h(., monotone = "i"), X, + smoothness_orders = 0, num_knots = 100 ) # An additive unpenalized monotone increasing model (NPMLE isotonic regressio) # Set the penalty factor argument `pf` to remove L1 penalization formula <- formula_hal( - y ~ h(., monotone = "i", pf = 0), X, smoothness_orders = 0, num_knots = 100 + y ~ h(., monotone = "i", pf = 0), X, + smoothness_orders = 0, num_knots = 100 ) # An additive unpenalized convex model (NPMLE convex regressio) # Set the penalty factor argument `pf` to remove L1 penalization # Note the second term is equivalent to adding unpenalized and unconstrained main-terms (e.g. main-term glm) formula <- formula_hal( - ~ h(., monotone = "i", pf = 0, k=200, s=1) + h(., monotone = "none", pf = 0, k=1, s=1), X) - + ~ h(., monotone = "i", pf = 0, k = 200, s = 1) + h(., monotone = "none", pf = 0, k = 1, s = 1), X +) + # A bi-additive monotone decreasing model formula <- formula_hal( - ~ h(., monotone = "d") + h(.,., monotone = "d"), X, smoothness_orders = 1, num_knots = 100 + ~ h(., monotone = "d") + h(., ., monotone = "d"), X, + smoothness_orders = 1, num_knots = 100 ) - - - - ``` @@ -456,7 +455,6 @@ formula <- h(., s = 1, k = 1, pf = 0) # intraction glm formula <- h(., ., s = 1, k = 1, pf = 0) + h(., s = 1, k = 1, pf = 0) # Running HAL with this formula will be equivalent to running glm with the formula Y ~ .^2 - ```