diff --git a/DESCRIPTION b/DESCRIPTION index a6621fb2..a57654b8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: hal9001 Title: The Scalable Highly Adaptive Lasso -Version: 0.4.3 +Version: 0.4.6 Authors@R: c( person("Jeremy", "Coyle", email = "jeremyrcoyle@gmail.com", role = c("aut", "cre"), @@ -68,5 +68,5 @@ LinkingTo: Rcpp, RcppEigen VignetteBuilder: knitr -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.3 Roxygen: list(markdown = TRUE) diff --git a/NAMESPACE b/NAMESPACE index 495ea7e5..3093718a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -26,18 +26,15 @@ importFrom(data.table,setorder) importFrom(glmnet,cv.glmnet) importFrom(glmnet,glmnet) importFrom(methods,is) -importFrom(origami,cross_validate) importFrom(origami,folds2foldvec) importFrom(origami,make_folds) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,coef) -importFrom(stats,gaussian) importFrom(stats,median) importFrom(stats,plogis) importFrom(stats,predict) importFrom(stats,quantile) -importFrom(stats,sd) importFrom(stringr,str_detect) importFrom(stringr,str_extract) importFrom(stringr,str_match) diff --git a/NEWS.md b/NEWS.md index 7ba8c7b0..c04ae0c2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,24 @@ +# hal9001 0.4.6 +* Fixed predict method to address changes required by Matrix 1.6.2 + +# hal9001 0.4.5 +* Added multivariate outcome prediction + +# hal9001 0.4.4 +* Fixed bug with `prediction_bounds` (a `fit_hal` argument in `fit_control` + list), which would error when it was specified as a numeric vector. Also, + added a check to assert this argument is correctly specified, and tests + to ensure a numeric vector of bounds is provided. +* Simplified `fit_control` list arguments in `fit_hal`. Users can still specify + additional arguments to `cv.glmnet` and `glmnet` in this list. +* Defined `weights` as a formal argument in `fit_hal`, opposed to an optional + argument in `fit_control`, to facilitate specification and avoid confusion. + This increases flexibility with SuperLearner wrapper `SL.hal9001` as well; + `fit_control` can now be customized with `SL.hal9001`. + +# hal9001 0.4.3 +* Version bump for CRAN resubmission following archiving. + # hal9001 0.4.2 * Version bump for CRAN resubmission following archiving. diff --git a/R/cv_lasso.R b/R/cv_lasso.R deleted file mode 100644 index a7fece86..00000000 --- a/R/cv_lasso.R +++ /dev/null @@ -1,72 +0,0 @@ -#' Cross-validated Lasso on Indicator Bases -#' -#' Fits Lasso regression using a customized procedure, with cross-validation -#' based on \pkg{origami} -#' -#' @param x_basis A \code{dgCMatrix} object corresponding to a sparse matrix of -#' the basis functions generated for the HAL algorithm. -#' @param y A \code{numeric} vector of the observed outcome variable values. -#' @param n_lambda A \code{numeric} scalar indicating the number of values of -#' the L1 regularization parameter (lambda) to be obtained from fitting the -#' Lasso to the full data. Cross-validation is used to select an optimal -#' lambda (that minimizes the risk) from among these. -#' @param n_folds A \code{numeric} scalar for the number of folds to be used in -#' the cross-validation procedure to select an optimal value of lambda. -#' @param center binary. If \code{TRUE}, covariates are centered. This is much -#' slower, but matches the \code{glmnet} implementation. Default \code{FALSE}. -#' -#' @importFrom origami make_folds cross_validate -#' @importFrom stats sd -cv_lasso <- function(x_basis, y, n_lambda = 100, n_folds = 10, - center = FALSE) { - # first, need to run lasso on the full data to get a sequence of lambdas - lasso_init <- lassi(y = y, x = x_basis, nlambda = n_lambda) - lambdas_init <- lasso_init$lambdas - - # next, set up a cross-validated lasso using the sequence of lambdas - full_data_mat <- cbind(y, x_basis) - folds <- origami::make_folds(full_data_mat, V = n_folds) - - # run the cross-validated lasso procedure to find the optimal lambda - cv_lasso_out <- origami::cross_validate( - cv_fun = lassi_origami, - folds = folds, - data = full_data_mat, - lambdas = lambdas_init, - center = center - ) - - # compute cv-mean of MSEs for each lambda - lambdas_cvmse <- colMeans(cv_lasso_out$mses) - - # find the lambda that minimizes the MSE - lambda_optim_index <- which.min(lambdas_cvmse) - lambda_minmse <- lambdas_init[lambda_optim_index] - - # also need the adjusted CV standard deviation for each lambda - lambdas_cvsd <- apply(cv_lasso_out$mses, 2, sd) / sqrt(n_folds) - - # find the maximum lambda among those 1 standard error above the minimum - lambda_min_1se <- (lambdas_cvmse + lambdas_cvsd)[lambda_optim_index - 1] - lambda_1se <- max(lambdas_init[lambdas_cvmse <= lambda_min_1se], - na.rm = TRUE - ) - lambda_1se_index <- which.min(abs(lambdas_init - lambda_1se)) - - # create output object - get_lambda_indices <- c(lambda_1se_index, lambda_optim_index) - betas_out <- lasso_init$beta_mat[, get_lambda_indices] - colnames(betas_out) <- c("lambda_1se", "lambda_min") - - # add in intercept term to coefs matrix and convert to sparse matrix output - betas_out <- rbind(rep(mean(y), ncol(betas_out)), betas_out) - betas_out <- as_dgCMatrix(betas_out * 1.0) - - # create output object - cv_lasso_out <- list(betas_out, lambda_minmse, lambda_1se, lambdas_cvmse) - names(cv_lasso_out) <- c( - "betas_mat", "lambda_min", "lambda_1se", - "lambdas_cvmse" - ) - return(cv_lasso_out) -} diff --git a/R/formula_hal9001.R b/R/formula_hal9001.R index 019260df..fe3b85a2 100644 --- a/R/formula_hal9001.R +++ b/R/formula_hal9001.R @@ -143,7 +143,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, var_names <- unlist(list(...)) } formula_term <- paste0("h(", paste0(var_names, collapse = ", "), ")") - + if (is.null(k)) { k <- get("num_knots", envir = parent.frame()) k <- suppressWarnings(k + rep(0, length(var_names))) # recycle @@ -152,8 +152,8 @@ h <- function(..., k = NULL, s = NULL, pf = 1, if (is.null(s)) { s <- get("smoothness_orders", envir = parent.frame())[1] } - - + + if ("." %in% var_names) { var_names_filled <- fill_dots(var_names, . = .) @@ -186,7 +186,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, return(all_items) } - + # Get corresponding column indices @@ -240,6 +240,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1, return(out) } + #' Print formula_hal9001 object #' #' @param x A formula_hal9001 object. @@ -250,16 +251,18 @@ print.formula_hal9001 <- function(x, ...) { cat(paste0("A hal9001 formula object of the form: ~ ", x$formula_term)) } -#' Formula Helpers #' -#' @param var_names A \code{character} vector of variable names. +#' @param var_names A \code{character} vector of variable names representing a single type of interaction +# " (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3. +#' var_names may include the wildcard variable "." in which case the argument `.` must be specified +#' so that all interactions matching the form of var_names are generated. #' @param . Specification of variables for use in the formula. -#' -#' @name formula_helpers -NULL - +#' This function takes a character vector `var_names` of the form c(name1, name2, ".", name3, ".") +#' with any number of name{int} variables and any number of wild card variables ".". +#' It returns a list of character vectors of the form c(name1, name2, wildcard1, name3, wildcard2) +#' where wildcard1 and wildcard2 are iterated over all possible character names given in the argument `.`. #' @rdname formula_helpers -fill_dots_helper <- function(var_names, .) { +fill_dots <- function(var_names, .) { index <- which(var_names == ".") if (length(index) == 0) { return(sort(var_names)) @@ -269,26 +272,23 @@ fill_dots_helper <- function(var_names, .) { all_items <- lapply(., function(var) { new_var_names <- var_names new_var_names[index] <- var - out <- fill_dots_helper(new_var_names, .) - - if (is.list(out[[1]])) { - out <- unlist(out, recursive = FALSE) - } + out <- fill_dots(new_var_names, .) return(out) }) - - - return(unique(all_items)) -} - -#' @rdname formula_helpers -fill_dots <- function(var_names, .) { - x <- unique(unlist(fill_dots_helper(var_names, . = .), recursive = FALSE)) - keep <- sapply(x, function(item) { + is_nested <- is.list(all_items[[1]]) + while (is_nested) { + all_items <- unlist(all_items, recursive = FALSE) + is_nested <- is.list(all_items[[1]]) + } + # Remove combinations of variable names that have duplicates. + # This removes generated interactions that include two of the same variable. + keep <- sapply(all_items, function(item) { if (any(duplicated(item))) { return(FALSE) } return(TRUE) }) - return(x[keep]) + all_items <- all_items[keep] + + return(unique(all_items)) } diff --git a/R/hal.R b/R/hal.R index e7895761..e73d6d90 100644 --- a/R/hal.R +++ b/R/hal.R @@ -59,7 +59,9 @@ #' @param X An input \code{matrix} with dimensions number of observations -by- #' number of covariates that will be used to derive the design matrix of basis #' functions. -#' @param Y A \code{numeric} vector of observations of the outcome variable. +#' @param Y A \code{numeric} vector of observations of the outcome variable. For +#' \code{family="mgaussian"}, \code{Y} is a matrix of observations of the +#' outcome variables. #' @param formula A character string formula to be used in #' \code{\link{formula_hal}}. See its documentation for details. #' @param X_unpenalized An input \code{matrix} with the same number of rows as @@ -81,18 +83,20 @@ #' combinatorial explosions in the number of higher-degree and higher-order #' basis functions generated. This allows the complexity of the optimization #' problem to grow scalably. See details of \code{num_knots} more information. -#' @param reduce_basis A \code{numeric} value bounded in the open unit interval -#' indicating the minimum proportion of 1's in a basis function column needed -#' for the basis function to be included in the procedure to fit the lasso. -#' Any basis functions with a lower proportion of 1's than the cutoff will be -#' removed. When \code{reduce_basis} is set to \code{NULL}, all basis -#' functions are used in the lasso-fitting stage of \code{fit_hal}. +#' @param reduce_basis Am optional \code{numeric} value bounded in the open +#' unit interval indicating the minimum proportion of 1's in a basis function +#' column needed for the basis function to be included in the procedure to fit +#' the lasso. Any basis functions with a lower proportion of 1's than the +#' cutoff will be removed. Defaults to 1 over the square root of the number of +#' observations. Only applicable for models fit with zero-order splines, i.e. +#' \code{smoothness_orders = 0}. #' @param family A \code{character} or a \code{\link[stats]{family}} object #' (supported by \code{\link[glmnet]{glmnet}}) specifying the error/link #' family for a generalized linear model. \code{character} options are limited #' to "gaussian" for fitting a standard penalized linear model, "binomial" for #' penalized logistic regression, "poisson" for penalized Poisson regression, -#' and "cox" for a penalized proportional hazards model. Note that passing in +#' "cox" for a penalized proportional hazards model, and "mgaussian" for +#' multivariate penalized linear model. Note that passing in #' family objects leads to slower performance relative to passing in a #' character family (if supported). For example, one should set #' \code{family = "binomial"} instead of \code{family = binomial()} when @@ -108,38 +112,34 @@ #' @param id A vector of ID values that is used to generate cross-validation #' folds for \code{\link[glmnet]{cv.glmnet}}. This argument is ignored when #' \code{fit_control}'s \code{cv_select} argument is \code{FALSE}. +#' @param weights observation weights; defaults to 1 per observation. #' @param offset a vector of offset values, used in fitting. -#' @param fit_control List of arguments for fitting. Includes the following -#' arguments, and any others to be passed to \code{\link[glmnet]{cv.glmnet}} -#' or \code{\link[glmnet]{glmnet}}. +#' @param fit_control List of arguments, including the following, and any +#' others to be passed to \code{\link[glmnet]{cv.glmnet}} or +#' \code{\link[glmnet]{glmnet}}. #' - \code{cv_select}: A \code{logical} specifying if the sequence of #' specified \code{lambda} values should be passed to #' \code{\link[glmnet]{cv.glmnet}} in order for a single, optimal value of #' \code{lambda} to be selected according to cross-validation. When #' \code{cv_select = FALSE}, a \code{\link[glmnet]{glmnet}} model will be #' used to fit the sequence of (or single) \code{lambda}. -#' - \code{n_folds}: Integer for the number of folds to be used when splitting -#' the data for V-fold cross-validation. Only used when +#' - \code{use_min}: Specify the choice of lambda to be selected by +#' \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is +#' used; otherwise, \code{"lambda.1se"}. Only used when #' \code{cv_select = TRUE}. -#' - \code{foldid}: An optional \code{numeric} containing values between 1 and -#' \code{n_folds}, identifying the fold to which each observation is -#' assigned. If supplied, \code{n_folds} can be missing. In such a case, -#' this vector is passed directly to \code{\link[glmnet]{cv.glmnet}}. Only -#' used when \code{cv_select = TRUE}. -#' - \code{use_min}: Specify the choice of lambda to be selected by -#' \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is -#' used; otherwise, \code{"lambda.1se"}. Only used when -#' \code{cv_select = TRUE}. -#' - \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument specifying -#' the smallest value for \code{lambda}, as a fraction of \code{lambda.max}, -#' the (data derived) entry value (i.e. the smallest value for which all -#' coefficients are zero). We've seen that not setting \code{lambda.min.ratio} -#' can lead to no \code{lambda} values that fit the data sufficiently well. -#' - \code{prediction_bounds}: A vector of size two that provides the lower and -#' upper bounds for predictions. When \code{prediction_bounds = "default"}, -#' the predictions are bounded between \code{min(Y) - sd(Y)} and -#' \code{max(Y) + sd(Y)}. Bounding ensures that there is no extrapolation, -#' and it is necessary for cross-validation selection and/or Super Learning. +#' - \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument +#' specifying the smallest value for \code{lambda}, as a fraction of +#' \code{lambda.max}, the (data derived) entry value (i.e. the smallest value +#' for which all coefficients are zero). We've seen that not setting +#' \code{lambda.min.ratio} can lead to no \code{lambda} values that fit the +#' data sufficiently well. +#' - \code{prediction_bounds}: An optional vector of size two that provides +#' the lower and upper bounds predictions; not used when +#' \code{family = "cox"}. When \code{prediction_bounds = "default"}, the +#' predictions are bounded between \code{min(Y) - sd(Y)} and +#' \code{max(Y) + sd(Y)} for each outcome (when \code{family = "mgaussian"}, +#' each outcome can have different bounds). Bounding ensures that there is +#' no extrapolation. #' @param basis_list The full set of basis functions generated from \code{X}. #' @param return_lasso A \code{logical} indicating whether or not to return #' the \code{\link[glmnet]{glmnet}} fit object of the lasso model. @@ -182,15 +182,14 @@ fit_hal <- function(X, base_num_knots_0 = 200, base_num_knots_1 = 50 ), - reduce_basis = 1 / sqrt(length(Y)), - family = c("gaussian", "binomial", "poisson", "cox"), + reduce_basis = NULL, + family = c("gaussian", "binomial", "poisson", "cox", "mgaussian"), lambda = NULL, id = NULL, + weights = NULL, offset = NULL, fit_control = list( cv_select = TRUE, - n_folds = 10, - foldid = NULL, use_min = TRUE, lambda.min.ratio = 1e-4, prediction_bounds = "default" @@ -202,22 +201,33 @@ fit_hal <- function(X, if (!inherits(family, "family")) { family <- match.arg(family) } + fam <- ifelse(inherits(family, "family"), family$family, family) # errors when a supplied control list is missing arguments defaults <- list( - cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE, - lambda.min.ratio = 1e-4, prediction_bounds = "default" + cv_select = TRUE, use_min = TRUE, lambda.min.ratio = 1e-4, + prediction_bounds = "default" ) if (any(!names(defaults) %in% names(fit_control))) { fit_control <- c( - defaults[which(!names(defaults) %in% names(fit_control))], fit_control + defaults[!names(defaults) %in% names(fit_control)], fit_control ) } - # errors when a supplied control list is missing arguments - defaults <- list( - exclusive_dot = FALSE, custom_group = NULL - ) - + # check fit_control names (exluding defaults) are glmnet/cv.glmnet formals + glmnet_formals <- unique(c( + names(formals(glmnet::cv.glmnet)), + names(formals(glmnet::glmnet)), + names(formals(glmnet::relax.glmnet)) # extra allowed args to glmnet + )) + control_names <- names(fit_control[!names(fit_control) %in% names(defaults)]) + if (any(!control_names %in% glmnet_formals)) { + bad_args <- control_names[(!control_names %in% glmnet_formals)] + warning(sprintf( + "Some fit_control arguments are neither default nor glmnet/cv.glmnet arguments: %s; \nThey will be removed from fit_control", + paste0(bad_args, collapse = ", ") + )) + fit_control <- fit_control[!names(fit_control) %in% bad_args] + } if (!is.matrix(X)) X <- as.matrix(X) @@ -230,9 +240,11 @@ fit_hal <- function(X, all(!is.na(Y)), msg = "NA detected in `Y`, missingness in `Y` is not supported" ) + + n_Y <- ifelse(is.matrix(Y), nrow(Y), length(Y)) assertthat::assert_that( - nrow(X) == length(Y), - msg = "Number of rows in `X` and length of `Y` must be equal" + nrow(X) == n_Y, + msg = "Number of rows in `X` and `Y` must be equal" ) if (!is.null(X_unpenalized)) { @@ -252,6 +264,24 @@ fit_hal <- function(X, ) } + if (!is.character(fit_control$prediction_bounds)) { + if (fam == "mgaussian") { + assertthat::assert_that( + is.list(fit_control$prediction_bounds) & + length(fit_control$prediction_bounds) == ncol(Y), + msg = "prediction_bounds must be 'default' or list of numeric (lower, upper) bounds for each outcome" + ) + } else { + assertthat::assert_that( + is.numeric(fit_control$prediction_bounds) & + length(fit_control$prediction_bounds) == 2, + msg = "prediction_bounds must be 'default' or numeric (lower, upper) bounds" + ) + } + } + + + if (!is.null(formula)) { # formula <- formula_hal( # formula = formula, X = X, smoothness_orders = smoothness_orders, @@ -260,7 +290,11 @@ fit_hal <- function(X, # ) if (!inherits(formula, "formula_hal")) { - formula <- formula_hal(formula, X = X, smoothness_orders = smoothness_orders, num_knots = num_knots) + formula <- formula_hal( + formula, + X = X, smoothness_orders = smoothness_orders, + num_knots = num_knots + ) } basis_list <- formula$basis_list fit_control$upper.limits <- formula$upper.limits @@ -275,10 +309,11 @@ fit_hal <- function(X, # Generate fold_ids that respect id if (is.null(fit_control$foldid)) { + if (is.null(fit_control$nfolds)) fit_control$nfolds <- 10 folds <- origami::make_folds( - n = length(Y), V = fit_control$n_folds, cluster_ids = id + n = n_Y, V = fit_control$nfolds, cluster_ids = id ) - foldid <- origami::folds2foldvec(folds) + fit_control$foldid <- origami::folds2foldvec(folds) } # bookkeeping: get start time of enumerate basis procedure @@ -306,12 +341,19 @@ fit_hal <- function(X, time_design_matrix <- proc.time() # NOTE: keep only basis functions with some (or higher) proportion of 1's - if (!is.null(reduce_basis) && is.numeric(reduce_basis) && - all(smoothness_orders == 0)) { + if (all(smoothness_orders == 0)) { + if (is.null(reduce_basis)) { + reduce_basis <- 1 / sqrt(n_Y) + } reduced_basis_map <- make_reduced_basis_map(x_basis, reduce_basis) x_basis <- x_basis[, reduced_basis_map] basis_list <- basis_list[reduced_basis_map] + } else { + if (!is.null(reduce_basis)) { + warning("Dropping reduce_basis; only applies if smoothness_orders = 0") + } } + time_reduce_basis <- proc.time() # catalog and eliminate duplicates @@ -364,7 +406,7 @@ fit_hal <- function(X, # x_basis <- as.matrix(x_basis) # } - if (!inherits(family, "family") && family == "cox") { + if (fam == "cox") { x_basis <- as.matrix(x_basis) } @@ -388,6 +430,7 @@ fit_hal <- function(X, fit_control$lambda <- lambda fit_control$penalty.factor <- penalty_factor fit_control$offset <- offset + fit_control$weights <- weights if (!fit_control$cv_select) { hal_lasso <- do.call(glmnet::glmnet, fit_control) @@ -422,13 +465,19 @@ fit_hal <- function(X, ) # Bounds for prediction on new data (to prevent extrapolation for linear HAL) - if (!inherits(Y, "Surv") & fit_control$prediction_bounds == "default") { - # This would break if Y was a survival object as in coxnet - fit_control$prediction_bounds <- c( - min(Y) - 2 * stats::sd(Y), max(Y) + 2 * stats::sd(Y) - ) - } else if (inherits(Y, "Surv") & fit_control$prediction_bounds == "default") { - fit_control$prediction_bounds <- NULL + if (is.character(fit_control$prediction_bounds) && + fit_control$prediction_bounds == "default") { + if (fam == "mgaussian") { + fit_control$prediction_bounds <- lapply(seq(ncol(Y)), function(i) { + c(min(Y[, i]) - 2 * stats::sd(Y[, i]), max(Y[, i]) + 2 * stats::sd(Y[, i])) + }) + } else if (fam == "cox") { + fit_control$prediction_bounds <- NULL + } else { + fit_control$prediction_bounds <- c( + min(Y) - 2 * stats::sd(Y), max(Y) + 2 * stats::sd(Y) + ) + } } # construct output object via lazy S3 list diff --git a/R/make_basis.R b/R/make_basis.R index c68f783a..d8cd1de2 100644 --- a/R/make_basis.R +++ b/R/make_basis.R @@ -327,8 +327,9 @@ quantizer <- function(X, bins) { p <- max(1 - (20 / nrow(X)), 0.98) quants <- seq(0, p, length.out = bins) - q <- stats::quantile(x, quants) - nearest <- findInterval(x, q) + q <- unique(stats::quantile(x, quants, type = 1)) + # NOTE: all.inside must be FALSE or else all binary variables are mapped to zero. + nearest <- findInterval(x, q, all.inside = FALSE) x <- q[nearest] return(x) } diff --git a/R/predict.R b/R/predict.R index 7a19681d..41895867 100644 --- a/R/predict.R +++ b/R/predict.R @@ -16,10 +16,6 @@ #' @param offset A vector of offsets. Must be provided if provided at training. #' @param type Either "response" for predictions of the response, or "link" for #' un-transformed predictions (on the scale of the link function). -#' @param p_reserve Sparse matrix pre-allocation proportion, which is the -#' anticipated proportion of 1's in the design matrix. Default value is -#' recommended in most settings. If a dense design matrix is expected, it -#' would be useful to set \code{p_reserve} to a higher value. #' @param ... Additional arguments passed to \code{predict} as necessary. #' #' @importFrom Matrix tcrossprod @@ -44,10 +40,10 @@ predict.hal9001 <- function(object, new_X_unpenalized = NULL, offset = NULL, type = c("response", "link"), - p_reserve = 0.75, ...) { + family <- ifelse(inherits(object$family, "family"), object$family$family, object$family) + type <- match.arg(type) - p_reserve <- pmax(pmin(p_reserve, 1), 0) # cast new data to matrix if not so already if (!is.matrix(new_data)) new_data <- as.matrix(new_data) @@ -56,9 +52,7 @@ predict.hal9001 <- function(object, } # generate design matrix - pred_x_basis <- make_design_matrix(new_data, object$basis_list, - p_reserve = p_reserve - ) + pred_x_basis <- make_design_matrix(new_data, object$basis_list) # reduce matrix of basis functions # pred_x_basis <- apply_copy_map(pred_x_basis, object$copy_map) @@ -83,35 +77,31 @@ predict.hal9001 <- function(object, } # generate predictions - if (inherits(object$family, "family") || object$family != "cox") { + if (!family %in% c("cox", "mgaussian")) { if (ncol(object$coefs) > 1) { - preds <- apply(object$coefs, 2, function(hal_coefs) { - as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = hal_coefs[-1] - ) + hal_coefs[1]) - }) + preds <- pred_x_basis %*% object$coefs[-1, ] + + matrix(object$coefs[1, ], + nrow = nrow(pred_x_basis), + ncol = ncol(object$coefs), byrow = TRUE + ) } else { preds <- as.vector(Matrix::tcrossprod( x = pred_x_basis, - y = object$coefs[-1] + y = matrix(object$coefs[-1], nrow = 1) ) + object$coefs[1]) } } else { - # Note: there is no intercept in the Cox model (built into the baseline - # hazard and would cancel in the partial likelihood). - if (ncol(object$coefs) > 1) { - preds <- apply(object$coefs, 2, function(hal_coefs) { - as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = hal_coefs - )) - }) - } else { - preds <- as.vector(Matrix::tcrossprod( - x = pred_x_basis, - y = as.vector(object$coefs) - )) + if (family == "cox") { + # Note: there is no intercept in the Cox model (built into the baseline + # hazard and would cancel in the partial likelihood). + # Note: there is no intercept in the Cox model (built into the baseline + # hazard and would cancel in the partial likelihood). + preds <- pred_x_basis %*% object$coefs + } else if (family == "mgaussian") { + preds <- stats::predict( + object$lasso_fit, + newx = pred_x_basis, s = object$lambda_star + ) } } @@ -130,22 +120,46 @@ predict.hal9001 <- function(object, if (inherits(object$family, "family")) { inverse_link_fun <- object$family$linkinv preds <- inverse_link_fun(preds) - } else if (object$family == "binomial") { - preds <- stats::plogis(preds) - } else if (object$family %in% c("poisson", "cox")) { - preds <- exp(preds) + } else { + if (family == "binomial") { + transform <- stats::plogis + } else if (family %in% c("poisson", "cox")) { + transform <- exp + } else if (family %in% c("gaussian", "mgaussian")) { + transform <- identity + } else { + stop("unsupported family") + } + + if (length(ncol(preds))) { + # apply along only the first dimension (to handle n-d arrays) + margin <- seq(length(dim(preds)))[-1] + preds <- apply(preds, margin, transform) + } else { + preds <- transform(preds) + } } # bound predictions within observed outcome bounds if on response scale - bounds <- object$prediction_bounds - if (!is.null(bounds)) { - bounds <- sort(bounds) - if (is.matrix(preds)) { - preds <- apply(preds, 2, pmax, bounds[1]) - preds <- apply(preds, 2, pmin, bounds[2]) + if (!is.null(object$prediction_bounds)) { + bounds <- object$prediction_bounds + if (family == "mgaussian") { + preds <- do.call(cbind, lapply(seq(ncol(preds)), function(i) { + bounds_y <- sort(bounds[[i]]) + preds_y <- preds[, i, ] + preds_y <- pmax(bounds_y[1], preds_y) + preds_y <- pmin(preds_y, bounds_y[2]) + return(preds_y) + })) } else { - preds <- pmax(bounds[1], preds) - preds <- pmin(preds, bounds[2]) + bounds <- sort(bounds) + if (is.matrix(preds)) { + preds <- apply(preds, 2, pmax, bounds[1]) + preds <- apply(preds, 2, pmin, bounds[2]) + } else { + preds <- pmax(bounds[1], preds) + preds <- pmin(preds, bounds[2]) + } } } diff --git a/R/sl_hal9001.R b/R/sl_hal9001.R index d8537d65..0fe18b08 100644 --- a/R/sl_hal9001.R +++ b/R/sl_hal9001.R @@ -22,19 +22,9 @@ #' specifying the maximum number of knot points (i.e., bins) for each #' covariate for generating basis functions. See \code{num_knots} argument in #' \code{\link{fit_hal}} for more information. -#' @param reduce_basis A \code{numeric} value bounded in the open unit interval -#' indicating the minimum proportion of 1's in a basis function column needed -#' for the basis function to be included in the procedure to fit the lasso. -#' Any basis functions with a lower proportion of 1's than the cutoff will be -#' removed. -#' @param lambda A user-specified sequence of values of the regularization -#' parameter for the lasso L1 regression. If \code{NULL}, the default sequence -#' in \code{\link[glmnet]{cv.glmnet}} will be used. The cross-validated -#' optimal value of this regularization parameter will be selected with -#' \code{\link[glmnet]{cv.glmnet}}. -#' @param ... Not used. +#' @param ... Additional arguments to \code{\link{fit_hal}}. #' -#' @importFrom stats predict gaussian +#' @importFrom stats predict #' #' @export #' @@ -42,27 +32,23 @@ #' object and corresponding predictions based on the input data. SL.hal9001 <- function(Y, X, - newX = NULL, - family = stats::gaussian(), - obsWeights = rep(1, length(Y)), - id = NULL, - max_degree = ifelse(ncol(X) >= 20, 2, 3), + newX, + family, + obsWeights, + id, + max_degree = 2, smoothness_orders = 1, - num_knots = ifelse(smoothness_orders >= 1, 25, 50), - reduce_basis = 1 / sqrt(length(Y)), - lambda = NULL, + num_knots = 5, ...) { - # create matrix version of X and newX for use with hal9001::fit_hal if (!is.matrix(X)) X <- as.matrix(X) if (!is.null(newX) & !is.matrix(newX)) newX <- as.matrix(newX) # fit hal - hal_fit <- fit_hal( - X = X, Y = Y, family = family$family, - fit_control = list(weights = obsWeights), id = id, max_degree = max_degree, - smoothness_orders = smoothness_orders, num_knots = num_knots, reduce_basis - = reduce_basis, lambda = lambda + hal_fit <- hal9001::fit_hal( + Y = Y, X = X, family = family$family, weights = obsWeights, id = id, + max_degree = max_degree, smoothness_orders = smoothness_orders, + num_knots = num_knots, ... ) # compute predictions based on `newX` or input `X` diff --git a/R/summary.R b/R/summary.R index 5878dae1..5193e7cd 100644 --- a/R/summary.R +++ b/R/summary.R @@ -52,6 +52,7 @@ summary.hal9001 <- function(object, include_redundant_terms = FALSE, round_cutoffs = 3, ...) { + family <- ifelse(inherits(object$family, "family"), object$family$family, object$family) abs_coef <- basis_list_idx <- coef_idx <- dup <- NULL # retain coefficients corresponding to lambda @@ -71,12 +72,20 @@ summary.hal9001 <- function(object, stop("Coefficients for the specified lambda do not exist.") } else { lambda_idx <- which(object$lasso_fit$lambda == lambda) - coefs <- object$lasso_fit$glmnet.fit$beta[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$lasso_fit$glmnet.fit$beta[, lambda_idx] + } else { + coefs <- lapply(object$lasso_fit$glmnet.fit$beta, function(x) x[, lambda_idx]) + } } } } else { lambda_idx <- which(object$lambda_star == lambda) - coefs <- object$coefs[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$coefs[, lambda_idx] + } else { + coefs <- lapply(object$coefs, function(x) x[, lambda_idx]) + } } } @@ -89,169 +98,210 @@ summary.hal9001 <- function(object, "Summarizing coefficients corresponding to minimum lambda." ) lambda_idx <- which.min(lambda) - coefs <- object$coefs[, lambda_idx] + if (family != "mgaussian") { + coefs <- object$coefs[, lambda_idx] + } else { + coefs <- lapply(object$coefs, function(x) x[, lambda_idx]) + } } } # cox model has no intercept - if (object$family != "cox") { - coefs_no_intercept <- coefs[-1] - } else { + if (family == "cox") { coefs_no_intercept <- coefs + } else if (family == "mgaussian") { + coefs_no_intercept <- lapply(coefs, function(x) x[-1]) + } else { + coefs_no_intercept <- coefs[-1] } # subset to non-zero coefficients if (only_nonzero_coefs) { - coef_idxs <- which(coefs_no_intercept != 0) + if (family == "mgaussian") { + coef_idxs <- lapply(coefs_no_intercept, function(x) which(x != 0)) + } else { + coef_idxs <- which(coefs_no_intercept != 0) + } } else { - coef_idxs <- seq_along(coefs_no_intercept) + if (family == "mgaussian") { + coef_idxs <- lapply(coefs_no_intercept, function(x) seq_along(x)) + } else { + coef_idxs <- seq_along(coefs_no_intercept) + } } - copy_map <- object$copy_map[coef_idxs] - - # summarize coefficients with respect to basis list - coefs_summ <- data.table::rbindlist( - lapply(seq_along(copy_map), function(map_idx) { - coef_idx <- coef_idxs[map_idx] - coef <- coefs_no_intercept[coef_idx] - - basis_list_idxs <- copy_map[[map_idx]] # indices of duplicates - basis_dups <- object$basis_list[basis_list_idxs] - - data.table::rbindlist( - lapply(seq_along(basis_dups), function(i) { - coef_idx <- ifelse(object$family != "cox", coef_idx + 1, coef_idx) - dt <- data.table::data.table( - coef_idx = coef_idx, # coefficient index - coef, # coefficient - basis_list_idx = basis_list_idxs[i], # basis list index - col_idx = basis_dups[[i]]$cols, # column idx in X - col_cutoff = basis_dups[[i]]$cutoffs, # cutoff - col_order = basis_dups[[i]]$orders # smoothness order - ) - return(dt) - }) - ) - }) - ) - - if (!include_redundant_terms) { - coef_idxs <- unique(coefs_summ$coef_idx) - coefs_summ <- data.table::rbindlist(lapply(coef_idxs, function(idx) { - # subset to matching coefficient index - coef_summ <- coefs_summ[coef_idx == idx] - - # label duplicates (i.e. basis functions with identical col & cutoff) - dups_tbl <- coef_summ[, c("col_idx", "col_cutoff", "col_order")] - if (!anyDuplicated(dups_tbl)) { - return(coef_summ) - } else { - # add col indicating whether or not there is a duplicate - coef_summ[, dup := (duplicated(dups_tbl) | - duplicated(dups_tbl, fromLast = TRUE))] - - # if basis_list_idx contains redundant duplicates, remove them - redundant_dups <- coef_summ[dup == TRUE, "basis_list_idx"] - if (nrow(redundant_dups) > 1) { - # keep the redundant duplicate term that has the shortest length - retain_idx <- which.min(apply(redundant_dups, 1, function(idx) { - nrow(coef_summ[basis_list_idx == idx]) - })) - idx_keep <- unname(unlist(redundant_dups[retain_idx])) - coef_summ <- coef_summ[basis_list_idx == idx_keep] - } - return(coef_summ[, -"dup"]) - } - })) + + if (family == "mgaussian") { + copy_map <- lapply(coef_idxs, function(x) object$copy_map[x]) + } else { + copy_map <- object$copy_map[coef_idxs] } - # summarize with respect to x column names: - x_names <- data.table::data.table( - col_idx = 1:length(object$X_colnames), - col_names = object$X_colnames - ) - summ <- merge(coefs_summ, x_names, by = "col_idx", all.x = TRUE) - - # combine name, cutoff into 0-order basis function (may include interaction) - summ$zero_term <- paste0( - "I(", summ$col_names, " >= ", round(summ$col_cutoff, round_cutoffs), ")" - ) - summ$higher_term <- ifelse( - summ$col_order == 0, "", - paste0( - "(", summ$col_names, " - ", - round(summ$col_cutoff, round_cutoffs), ")" + # ============================================================================ + # utility function to summarize HAL fit which can be used for multiple outcomes + summarize_coefs <- function(copy_map, coef_idxs, coefs_no_intercept, coefs) { + # summarize coefficients with respect to basis list + coefs_summ <- data.table::rbindlist( + lapply(seq_along(copy_map), function(map_idx) { + coef_idx <- coef_idxs[map_idx] + coef <- coefs_no_intercept[coef_idx] + + basis_list_idxs <- copy_map[[map_idx]] # indices of duplicates + basis_dups <- object$basis_list[basis_list_idxs] + + data.table::rbindlist( + lapply(seq_along(basis_dups), function(i) { + coef_idx <- ifelse(family != "cox", coef_idx + 1, coef_idx) + dt <- data.table::data.table( + coef_idx = coef_idx, # coefficient index + coef, # coefficient + basis_list_idx = basis_list_idxs[i], # basis list index + col_idx = basis_dups[[i]]$cols, # column idx in X + col_cutoff = basis_dups[[i]]$cutoffs, # cutoff + col_order = basis_dups[[i]]$orders # smoothness order + ) + return(dt) + }) + ) + }) + ) + + if (!include_redundant_terms) { + coef_idxs <- unique(coefs_summ$coef_idx) + coefs_summ <- data.table::rbindlist(lapply(coef_idxs, function(idx) { + # subset to matching coefficient index + coef_summ <- coefs_summ[coef_idx == idx] + + # label duplicates (i.e. basis functions with identical col & cutoff) + dups_tbl <- coef_summ[, c("col_idx", "col_cutoff", "col_order")] + if (!anyDuplicated(dups_tbl)) { + return(coef_summ) + } else { + # add col indicating whether or not there is a duplicate + coef_summ[, dup := (duplicated(dups_tbl) | + duplicated(dups_tbl, fromLast = TRUE))] + + # if basis_list_idx contains redundant duplicates, remove them + redundant_dups <- coef_summ[dup == TRUE, "basis_list_idx"] + if (nrow(redundant_dups) > 1) { + # keep the redundant duplicate term that has the shortest length + retain_idx <- which.min(apply(redundant_dups, 1, function(idx) { + nrow(coef_summ[basis_list_idx == idx]) + })) + idx_keep <- unname(unlist(redundant_dups[retain_idx])) + coef_summ <- coef_summ[basis_list_idx == idx_keep] + } + return(coef_summ[, -"dup"]) + } + })) + } + + # summarize with respect to x column names: + x_names <- data.table::data.table( + col_idx = 1:length(object$X_colnames), + col_names = object$X_colnames + ) + summ <- merge(coefs_summ, x_names, by = "col_idx", all.x = TRUE) + + # combine name, cutoff into 0-order basis function (may include interaction) + summ$zero_term <- paste0( + "I(", summ$col_names, " >= ", round(summ$col_cutoff, round_cutoffs), ")" ) - ) - summ$higher_term <- ifelse( - summ$col_order < 1, summ$higher_term, - paste0(summ$higher_term, "^", summ$col_order) - ) - summ$term <- ifelse( - summ$col_order == 0, - paste0("[ ", summ$zero_term, " ]"), - paste0("[ ", summ$zero_term, "*", summ$higher_term, " ]") - ) - - term_tbl <- data.table::as.data.table(stats::aggregate( - term ~ basis_list_idx, - data = summ, paste, collapse = " * " - )) - - # no longer need the columns or rows that were incorporated in the term - redundant <- c( - "term", "col_cutoff", "col_names", "col_idx", "col_order", "zero_term", - "higher_term" - ) - summ <- summ[, -..redundant] - summ_unique <- unique(summ) - summ <- merge( - term_tbl, summ_unique, - by = "basis_list_idx", all.x = TRUE, all.y = FALSE - ) - - # summarize in a list - coefs_list <- lapply(unique(summ$coef_idx), function(this_coef_idx) { - coef_terms <- summ[coef_idx == this_coef_idx] - list(coef = unique(coef_terms$coef), term = t(coef_terms$term)) - }) - - # summarize in a table - coefs_tbl <- data.table::as.data.table(stats::aggregate( - term ~ coef_idx, - data = summ, FUN = paste, collapse = " OR " - )) - redundant <- c("term", "basis_list_idx") - summ_unique_coefs <- unique(summ[, -..redundant]) - coefs_tbl <- data.table::data.table(merge( - summ_unique_coefs, coefs_tbl, - by = "coef_idx", all = TRUE - )) - coefs_tbl[, "abs_coef" := abs(coef)] - coefs_tbl <- data.table::setorder(coefs_tbl[, -"coef_idx"], -abs_coef) - coefs_tbl <- coefs_tbl[, -"abs_coef", with = FALSE] - - # incorporate intercept - if (object$family != "cox") { - intercept <- list(data.table::data.table( - coef = coefs[1], term = "(Intercept)" + summ$higher_term <- ifelse( + summ$col_order == 0, "", + paste0( + "(", summ$col_names, " - ", + round(summ$col_cutoff, round_cutoffs), ")" + ) + ) + summ$higher_term <- ifelse( + summ$col_order < 1, summ$higher_term, + paste0(summ$higher_term, "^", summ$col_order) + ) + summ$term <- ifelse( + summ$col_order == 0, + paste0("[ ", summ$zero_term, " ]"), + paste0("[ ", summ$zero_term, "*", summ$higher_term, " ]") + ) + + term_tbl <- data.table::as.data.table(stats::aggregate( + term ~ basis_list_idx, + data = summ, paste, collapse = " * " )) - coefs_tbl <- data.table::rbindlist( - c(intercept, list(coefs_tbl)), - fill = TRUE + + + # no longer need the columns or rows that were incorporated in the term + redundant <- c( + "term", "col_cutoff", "col_names", "col_idx", "col_order", "zero_term", + "higher_term" + ) + summ <- summ[, -..redundant] + summ_unique <- unique(summ) + summ <- merge( + term_tbl, summ_unique, + by = "basis_list_idx", all.x = TRUE, all.y = FALSE + ) + + # generate input for rules summary + rules_tbl <- generate_all_rules( + object$basis_list[summ$basis_list_idx], summ$coef, object$X_colnames ) - intercept <- list(coef = coefs[1], term = "(Intercept)") - coefs_list <- c(list(intercept), coefs_list) + + # summarize in a list + coefs_list <- lapply(unique(summ$coef_idx), function(this_coef_idx) { + coef_terms <- summ[coef_idx == this_coef_idx] + list(coef = unique(coef_terms$coef), term = t(coef_terms$term)) + }) + + # summarize in a table + coefs_tbl <- data.table::as.data.table(stats::aggregate( + term ~ coef_idx, + data = summ, FUN = paste, collapse = " OR " + )) + redundant <- c("term", "basis_list_idx") + summ_unique_coefs <- unique(summ[, -..redundant]) + coefs_tbl <- data.table::data.table(merge( + summ_unique_coefs, coefs_tbl, + by = "coef_idx", all = TRUE + )) + coefs_tbl[, "abs_coef" := abs(coef)] + coefs_tbl <- data.table::setorder(coefs_tbl[, -"coef_idx"], -abs_coef) + coefs_tbl <- coefs_tbl[, -"abs_coef", with = FALSE] + + # incorporate intercept + if (family != "cox") { + intercept <- list(data.table::data.table( + coef = coefs[1], term = "(Intercept)" + )) + coefs_tbl <- data.table::rbindlist( + c(intercept, list(coefs_tbl)), + fill = TRUE + ) + intercept <- list(coef = coefs[1], term = "(Intercept)") + coefs_list <- c(list(intercept), coefs_list) + } + out <- list( + table = coefs_tbl, + list = coefs_list, + lambda = lambda, + only_nonzero_coefs = only_nonzero_coefs, + family = family, + rules = rules_tbl + ) + class(out) <- "summary.hal9001" + return(out) } - out <- list( - table = coefs_tbl, - list = coefs_list, - lambda = lambda, - only_nonzero_coefs = only_nonzero_coefs - ) - class(out) <- "summary.hal9001" - return(out) -} + # ============================================================================ + if (family == "mgaussian") { + return_obj <- lapply(seq_along(copy_map), function(i) { + summarize_coefs(copy_map[[i]], coef_idxs[[i]], coefs_no_intercept[[i]], coefs[[i]]) + }) + class(return_obj) <- "summary.hal9001" + } else { + return_obj <- summarize_coefs(copy_map, coef_idxs, coefs_no_intercept, coefs) + } + return(return_obj) +} ############################################################################### #' Print Method for Summary Class of HAL fits @@ -262,28 +312,166 @@ summary.hal9001 <- function(object, #' #' @export print.summary.hal9001 <- function(x, length = NULL, ...) { - if (x$only_nonzero_coefs & is.null(length)) { - cat( - "\n\nSummary of non-zero coefficients is based on lambda of", - x$lambda, "\n\n" - ) - } else if (!x$only_nonzero_coefs & is.null(length)) { - cat("\nSummary of coefficients is based on lambda of", x$lambda, "\n\n") - } else if (!x$only_nonzero_coefs & !is.null(length)) { - cat( - "\nSummary of top", length, - "coefficients is based on lambda of", x$lambda, "\n\n" - ) - } else if (x$only_nonzero_coefs & !is.null(length)) { - cat( - "\nSummary of top", length, - "non-zero coefficients is based on lambda of", x$lambda, "\n\n" - ) + if (x$family != "mgaussian" && !is.null(x$family)) { + if (x$only_nonzero_coefs & is.null(length)) { + cat( + "\n\nSummary of non-zero coefficients is based on lambda of", + x$lambda, "\n\n" + ) + } else if (!x$only_nonzero_coefs & is.null(length)) { + cat("\nSummary of coefficients is based on lambda of", x$lambda, "\n\n") + } else if (!x$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "coefficients is based on lambda of", x$lambda, "\n\n" + ) + } else if (x$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "non-zero coefficients is based on lambda of", x$lambda, "\n\n" + ) + } + + if (is.null(length)) { + print(x$table, row.names = FALSE) + } else { + print(utils::head(x$table, length), row.names = FALSE) + } + cat("\n\n Summary of aggregated marginal and interaction regions: \n\n") + print(x$rules, row.names = FALSE) + } else { + for (i in 1:length(x)) { + if (x[[i]]$only_nonzero_coefs & is.null(length)) { + cat( + "\n\nSummary of non-zero coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } else if (!x[[i]]$only_nonzero_coefs & is.null(length)) { + cat( + "\nSummary of coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } else if (!x[[i]]$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "coefficients for each outcome is based on lambda of", x[[i]]$lambda, "\n\n" + ) + } else if (x[[i]]$only_nonzero_coefs & !is.null(length)) { + cat( + "\nSummary of top", length, + "non-zero coefficients for each outcome is based on lambda of", + x[[i]]$lambda, "\n\n" + ) + } + if (is.null(length)) { + print(x[[i]]$table, row.names = FALSE) + } else { + print(utils::head(x[[i]]$table, length), row.names = FALSE) + } + cat("\n\n Summary of aggregated marginal and interaction regions: \n\n") + print(x[[i]]$rules, row.names = FALSE) + } + } +} + +#' Generates rules based on knot points of the fitted HAL basis functions with +#' non-zero coefficients. +#' +#' @keywords internal +generate_all_rules <- function(basis_list, coefs, X_colnames) { + # Convert coefficients to matrix and filter out the intercept + coefs_mat <- as.matrix(coefs) + + # Identify indices where coefficients are non-zero + # (i.e., relevant to the final model) + relevant_indices <- which(coefs_mat != 0) + + # Initialize a list to store cutoffs for each feature + cutoffs_list <- vector("list", length(X_colnames)) + names(cutoffs_list) <- X_colnames + + # Initialize list to store interaction rules and their cumulative coefficients + interaction_rules <- list() + interaction_coefs <- list() + + # Loop over each basis function that has a non-zero coefficient + for (i in relevant_indices) { + basis <- basis_list[[i]] + coef_val <- coefs_mat[i, ] + + # For marginal basis functions (no interactions) + if (length(basis$cols) == 1) { + colname <- X_colnames[basis$cols[1]] + # Add unique cutoffs to the cutoffs_list + cutoffs_list[[colname]] <- unique(c(cutoffs_list[[colname]], basis$cutoffs[1])) + } + + # For interaction basis functions + if (length(basis$cols) > 1) { + interaction_name <- paste(X_colnames[basis$cols], collapse = "-") + if (!interaction_name %in% names(interaction_rules)) { + interaction_rules[[interaction_name]] <- list() + interaction_coefs[[interaction_name]] <- 0 + for (j in basis$cols) { + interaction_rules[[interaction_name]][[X_colnames[j]]] <- c() + } + } + for (j in seq_along(basis$cols)) { + interaction_rules[[interaction_name]][[X_colnames[basis$cols[j]]]] <- + c(interaction_rules[[interaction_name]][[X_colnames[basis$cols[j]]]], basis$cutoffs[j]) + } + interaction_coefs[[interaction_name]] <- interaction_coefs[[interaction_name]] + coef_val + } + } + + # check if there are any marginal rules + # (i.e., any non-interaction basis functions with non-zero coefficients) + cutoffs_list <- cutoffs_list[-which(sapply(cutoffs_list, is.null))] + if (length(cutoffs_list) > 0) { + # for each feature, identify the min cutoff and form rule + min_cutoffs <- sapply(cutoffs_list, min, na.rm = TRUE) + marginal_rules <- sapply(seq_along(min_cutoffs), function(i) { + paste0(X_colnames[i], " >= ", min_cutoffs[i]) + }, USE.NAMES = TRUE) + names(marginal_rules) <- names(min_cutoffs) + } else { + # instantiate empty marginal rules if there are none + marginal_rules <- c() + min_cutoffs <- c() } - if (is.null(length)) { - print(x$table, row.names = FALSE) + # check if there are any interaction rules + # (i.e., any interaction basis functions with non-zero coefficients) + if (length(interaction_rules) > 0) { + # create bounding box rules for interactions + bounding_rules <- list() + for (interaction in names(interaction_rules)) { + rules <- c() + for (var in names(interaction_rules[[interaction]])) { + min_val <- min(interaction_rules[[interaction]][[var]], na.rm = TRUE) + rules <- c(rules, paste0(var, " >= ", min_val)) + } + bounding_rules[[interaction]] <- paste(rules, collapse = " & ") + } + # combine all rules + all_rules <- c(marginal_rules, unlist(bounding_rules)) + all_coefs <- c(min_cutoffs, unlist(interaction_coefs)) + } else { + # all rules are only comprised of the marginals + all_rules <- marginal_rules + all_coefs <- min_cutoffs + } + + if (is.null(all_rules) | is.null(all_coefs)) { + # there are no rules! + rules_df <- NULL } else { - print(utils::head(x$table, length), row.names = FALSE) + # convert rules into a data table for easy viewing and interpretation + rules_df <- data.table::data.table( + variables = names(all_rules), + rule = all_rules, + cumulative_coefficient = all_coefs + ) } + return(rules_df) } diff --git a/README.md b/README.md index e8ed0d6a..728a247b 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,11 @@ predictions via Highly Adaptive Lasso regression: # load the package and set a seed library(hal9001) #> Loading required package: Rcpp +<<<<<<< HEAD +#> hal9001 v0.4.4: The Scalable Highly Adaptive Lasso +======= #> hal9001 v0.4.5: The Scalable Highly Adaptive Lasso +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c #> note: fit_hal defaults have changed. See ?fit_hal for details set.seed(385971) @@ -100,12 +104,12 @@ hal_fit <- fit_hal(X = x, Y = y, yolo = TRUE) #> [1] "I'm sorry, Dave. I'm afraid I can't do that." hal_fit$times #> user.self sys.self elapsed user.child sys.child -#> enumerate_basis 0.008 0.001 0.009 0 0 -#> design_matrix 0.003 0.000 0.003 0 0 +#> enumerate_basis 0.014 0.003 0.059 0 0 +#> design_matrix 0.004 0.001 0.005 0 0 #> reduce_basis 0.000 0.000 0.000 0 0 -#> remove_duplicates 0.000 0.000 0.001 0 0 -#> lasso 1.690 0.036 1.764 0 0 -#> total 1.702 0.037 1.778 0 0 +#> remove_duplicates 0.000 0.000 0.000 0 0 +#> lasso 2.684 0.343 6.583 0 0 +#> total 2.703 0.348 6.655 0 0 # training sample prediction preds <- predict(hal_fit, new_data = x) diff --git a/docs/404.html b/docs/404.html index 2fa8a621..f0d9df35 100644 --- a/docs/404.html +++ b/docs/404.html @@ -38,7 +38,7 @@
diff --git a/docs/CONTRIBUTING.html b/docs/CONTRIBUTING.html index c278dd85..f0dcfa9c 100644 --- a/docs/CONTRIBUTING.html +++ b/docs/CONTRIBUTING.html @@ -23,7 +23,7 @@ diff --git a/docs/LICENSE-text.html b/docs/LICENSE-text.html index 210dedbf..a4e240cc 100644 --- a/docs/LICENSE-text.html +++ b/docs/LICENSE-text.html @@ -23,7 +23,7 @@ diff --git a/docs/articles/index.html b/docs/articles/index.html index e707ca32..564c15c4 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -23,7 +23,7 @@ diff --git a/docs/articles/intro_hal9001.html b/docs/articles/intro_hal9001.html index 21f17a12..d5262436 100644 --- a/docs/articles/intro_hal9001.html +++ b/docs/articles/intro_hal9001.html @@ -39,7 +39,7 @@ @@ -85,7 +85,11 @@vignettes/intro_hal9001.Rmd
intro_hal9001.Rmd
library(hal9001)
## Loading required package: Rcpp
-## hal9001 v0.4.3: The Scalable Highly Adaptive Lasso
+## hal9001 v0.4.5: The Scalable Highly Adaptive Lasso
## note: fit_hal defaults have changed. See ?fit_hal for details
Fitting the model
@@ -165,12 +169,21 @@ Fitting the modelhal_fit <- fit_hal(X = x, Y = y)
hal_fit$times
## user.self sys.self elapsed user.child sys.child
+<<<<<<< HEAD
+## enumerate_basis 0.027 0.003 0.053 0 0
+## design_matrix 0.129 0.021 0.340 0 0
+## reduce_basis 0.000 0.000 0.000 0 0
+## remove_duplicates 0.000 0.000 0.000 0 0
+## lasso 3.577 0.436 9.466 0 0
+## total 3.734 0.460 9.861 0 0
+=======
## enumerate_basis 0.017 0.000 0.018 0 0
## design_matrix 0.082 0.003 0.086 0 0
## reduce_basis 0.000 0.000 0.000 0 0
## remove_duplicates 0.000 0.000 0.000 0 0
## lasso 2.284 0.072 2.399 0 0
## total 2.384 0.075 2.503 0 0
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
##
##
-## Summary of non-zero coefficients is based on lambda of 0.00487
+## Summary of non-zero coefficients is based on lambda of 0.002108107
##
## coef
-## -1.424003e+00
-## 3.654157e-01
-## -2.358064e-01
-## 2.182822e-01
-## -1.713943e-01
-## 1.648782e-01
-## 1.622979e-01
-## -1.205285e-01
-## -9.526696e-02
-## -9.382067e-02
-## 5.468026e-02
-## 5.273315e-02
-## -5.056465e-02
-## 4.527346e-02
-## -3.735277e-02
-## -3.543529e-02
-## 2.380514e-02
-## -2.255508e-02
-## -2.176114e-02
-## -1.612916e-02
-## 1.531317e-02
-## -1.233111e-02
-## 1.172920e-02
-## 8.162029e-03
-## -7.160000e-03
-## -5.500155e-03
-## 5.083293e-03
-## -3.625005e-03
-## 1.488567e-03
-## 1.350348e-03
-## 3.224244e-04
-## -3.021403e-04
-## -2.126168e-04
-## -6.625025e-05
-## 1.927162e-06
+## -1.435851e+00
+## 4.391709e-01
+## -2.180325e-01
+## -1.908016e-01
+## 1.837943e-01
+## 1.617427e-01
+## -1.307707e-01
+## 1.205775e-01
+## -1.203965e-01
+## 1.175336e-01
+## -1.166074e-01
+## -1.018822e-01
+## 8.214302e-02
+## 7.525308e-02
+## 7.518641e-02
+## 7.328934e-02
+## 7.066728e-02
+## 6.364869e-02
+## -4.686124e-02
+## -4.672286e-02
+## -4.499741e-02
+## -4.377227e-02
+## -3.830315e-02
+## 3.779762e-02
+## -3.744479e-02
+## 3.386721e-02
+## 3.286990e-02
+## 3.254816e-02
+## -3.203164e-02
+## 3.041717e-02
+## -1.901118e-02
+## 1.170430e-02
+## -1.147950e-02
+## -1.053684e-02
+## 9.934522e-03
+## 9.600888e-03
+## -7.160528e-03
+## -6.499773e-03
+## -5.794305e-03
+## -5.714597e-03
+## -5.698275e-03
+## -5.424112e-03
+## 5.170208e-03
+## 4.516979e-03
+## 4.245836e-03
+## -4.125254e-03
+## 2.495093e-03
+## -1.063492e-04
+## -7.186831e-05
+## 2.188558e-05
+## 1.890986e-05
+## -1.757825e-05
+## 1.024909e-05
## coef
## term
## (Intercept)
## [ I(x2 >= -1.583)*(x2 - -1.583)^1 ]
## [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -0.962)*(x1 - -0.962)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= 1.606)*(x1 - 1.606)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
-## [ I(x2 >= -1.11)*(x2 - -1.11)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -0.962)*(x1 - -0.962)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -1.403)*(x1 - -1.403)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
## [ I(x1 >= 0.941)*(x1 - 0.941)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x2 >= -1.11)*(x2 - -1.11)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= 1.368)*(x1 - 1.368)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -1.844)*(x1 - -1.844)^1 ]
+## [ I(x1 >= -1.593)*(x1 - -1.593)^1 ]
## [ I(x2 >= 0.696)*(x2 - 0.696)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ]
+## [ I(x1 >= -1.844)*(x1 - -1.844)^1 ]
## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.156)*(x2 - 1.156)^1 ] * [ I(x3 >= -0.497)*(x3 - -0.497)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= 0.9)*(x1 - 0.9)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= 1.347)*(x1 - 1.347)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+## [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x2 >= -1.322)*(x2 - -1.322)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
## [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= 0.521)*(x1 - 0.521)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -1.092)*(x1 - -1.092)^1 ] * [ I(x3 >= -0.046)*(x3 - -0.046)^1 ]
+## [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.45)*(x3 - 0.45)^1 ]
+## [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ]
-## [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
-## [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+## [ I(x2 >= -0.916)*(x2 - -0.916)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ]
## [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ]
+## [ I(x1 >= -1.313)*(x1 - -1.313)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.781)*(x3 - -0.781)^1 ]
+## [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+## [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.772)*(x2 - 0.772)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.317)*(x3 - -0.317)^1 ]
+## [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 1.202)*(x3 - 1.202)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= 0.174)*(x3 - 0.174)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.528)*(x2 - 0.528)^1 ]
+## [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -0.555)*(x1 - -0.555)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ]
+## [ I(x2 >= -0.816)*(x2 - -0.816)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -1.331)*(x3 - -1.331)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.441)*(x2 - -0.441)^1 ]
+## [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -0.781)*(x3 - -0.781)^1 ]
## [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
-## [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## [ I(x1 >= -0.685)*(x1 - -0.685)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.624)*(x2 - -0.624)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
+## [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ]
+## [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ]
## term
Note the length and width of these tables! The R environment might not be the optimal location to view the summary. Tip: Tables can be @@ -280,12 +329,12 @@
## [1] 0.04967736
+## [1] 0.04426571
oob_hal <- predict(object = hal_fit, new_data = test_x)
oob_hal_mse <- mse(preds = oob_hal, y = test_y)
oob_hal_mse
## [1] 1.778584
+## [1] 1.803562
fit_hal
function:
-hal_fit_reduced <- fit_hal(X = x, Y = y, reduce_basis = 0.1)
-hal_fit_reduced$times
hal_fit_reduced <- fit_hal(X = x, Y = y, reduce_basis = 0.1)
## Warning in fit_hal(X = x, Y = y, reduce_basis = 0.1): Dropping reduce_basis;
+## only applies if smoothness_orders = 0
+
+hal_fit_reduced$times
## user.self sys.self elapsed user.child sys.child
+<<<<<<< HEAD
+## enumerate_basis 0.028 0.007 0.122 0 0
+## design_matrix 0.131 0.017 0.463 0 0
+## reduce_basis 0.000 0.000 0.000 0 0
+## remove_duplicates 0.000 0.000 0.000 0 0
+## lasso 3.667 0.733 16.867 0 0
+## total 3.826 0.757 17.453 0 0
+In the above, all basis functions with fewer than 10% of observations +meeting the criterion imposed are automatically removed prior to the +Lasso step of fitting the HAL regression. The results appear below
++======= ## enumerate_basis 0.026 0.005 0.031 0 0 ## design_matrix 0.080 0.002 0.082 0 0 ## reduce_basis 0.000 0.000 0.000 0 0 @@ -316,43 +380,44 @@Reducing basis functions
+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1csummary(hal_fit_reduced)$table
## coef -## 1: -1.4371363869 -## 2: 0.3753894167 -## 3: -0.2355902545 -## 4: 0.2166287283 -## 5: -0.1851691193 -## 6: 0.1631844732 -## 7: 0.1597858903 -## 8: -0.1234431109 -## 9: -0.0946988403 -## 10: -0.0884063929 -## 11: 0.0584079820 -## 12: 0.0563691217 -## 13: -0.0551713872 -## 14: 0.0533243296 -## 15: -0.0382683652 -## 16: -0.0372833807 -## 17: -0.0309200484 -## 18: 0.0230878908 -## 19: 0.0179840202 -## 20: -0.0156428631 -## 21: 0.0149240584 -## 22: -0.0144903662 -## 23: -0.0132097499 -## 24: -0.0065506011 -## 25: -0.0062280674 -## 26: 0.0051538711 -## 27: 0.0043861580 -## 28: -0.0024858177 -## 29: 0.0023100499 -## 30: -0.0022703275 -## 31: 0.0021266202 -## 32: 0.0017535038 -## 33: 0.0016041990 -## 34: -0.0010032801 -## 35: -0.0002829866 +## 1: -1.424003e+00 +## 2: 3.654157e-01 +## 3: -2.358064e-01 +## 4: 2.182822e-01 +## 5: -1.713943e-01 +## 6: 1.648782e-01 +## 7: 1.622979e-01 +## 8: -1.205285e-01 +## 9: -9.526696e-02 +## 10: -9.382067e-02 +## 11: 5.468026e-02 +## 12: 5.273315e-02 +## 13: -5.056465e-02 +## 14: 4.527346e-02 +## 15: -3.735277e-02 +## 16: -3.543529e-02 +## 17: 2.380514e-02 +## 18: -2.255508e-02 +## 19: -2.176114e-02 +## 20: -1.612916e-02 +## 21: 1.531317e-02 +## 22: -1.233111e-02 +## 23: 1.172920e-02 +## 24: 8.162029e-03 +## 25: -7.160000e-03 +## 26: -5.500155e-03 +## 27: 5.083293e-03 +## 28: -3.625005e-03 +## 29: 1.488567e-03 +## 30: 1.350348e-03 +## 31: 3.224244e-04 +## 32: -3.021403e-04 +## 33: -2.126168e-04 +## 34: -6.625025e-05 +## 35: 1.927162e-06 ## coef ## term ## 1: (Intercept) @@ -365,31 +430,31 @@
Reducing basis functions## 8: [ I(x1 >= 0.941)*(x1 - 0.941)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 9: [ I(x2 >= 1.017)*(x2 - 1.017)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 10: [ I(x1 >= 1.368)*(x1 - 1.368)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 11: [ I(x1 >= -1.844)*(x1 - -1.844)^1 ] -## 12: [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 11: [ I(x2 >= -1.565)*(x2 - -1.565)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 12: [ I(x1 >= -1.844)*(x1 - -1.844)^1 ] ## 13: [ I(x2 >= 0.696)*(x2 - 0.696)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## 14: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.156)*(x2 - 1.156)^1 ] * [ I(x3 >= -0.497)*(x3 - -0.497)^1 ] -## 15: [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 16: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 17: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] -## 18: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] -## 19: [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 15: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 16: [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 17: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] +## 18: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 1.595)*(x2 - 1.595)^1 ] +## 19: [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] ## 20: [ I(x1 >= 0.307)*(x1 - 0.307)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 21: [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 22: [ I(x1 >= 1.135)*(x1 - 1.135)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] -## 23: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ] -## 24: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 21: [ I(x1 >= -0.422)*(x1 - -0.422)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 22: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -0.047)*(x3 - -0.047)^1 ] +## 23: [ I(x1 >= -0.313)*(x1 - -0.313)^1 ] * [ I(x2 >= 0.118)*(x2 - 0.118)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 24: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ] ## 25: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 26: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ] -## 27: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= 0.489)*(x2 - 0.489)^1 ] * [ I(x3 >= -0.257)*(x3 - -0.257)^1 ] -## 28: [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 29: [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 30: [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 31: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ] -## 32: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 33: [ I(x1 >= -1.593)*(x1 - -1.593)^1 ] -## 34: [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] -## 35: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 1.202)*(x3 - 1.202)^1 ] +## 26: [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 27: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.595)*(x3 - 0.595)^1 ] +## 28: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 29: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= 0.374)*(x3 - 0.374)^1 ] +## 30: [ I(x1 >= -3.224)*(x1 - -3.224)^1 ] * [ I(x2 >= -0.375)*(x2 - -0.375)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 31: [ I(x2 >= -0.699)*(x2 - -0.699)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 32: [ I(x1 >= -0.901)*(x1 - -0.901)^1 ] * [ I(x2 >= -3.038)*(x2 - -3.038)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 33: [ I(x1 >= 0.739)*(x1 - 0.739)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 34: [ I(x1 >= 0.594)*(x1 - 0.594)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] +## 35: [ I(x1 >= -0.685)*(x1 - -0.685)^1 ] * [ I(x3 >= -3.289)*(x3 - -3.289)^1 ] ## term
Other approaches exist for reducing the set of basis functions before they are actually created, which is essential for most @@ -428,7 +493,7 @@
Specifying smoothness of the HAL under the constraint that the total variation of the function’s \(k^{\text{th}}\) derivative is bounded by some constant, which is selected with cross-validation.
Let’s see this in action.
-diff --git a/docs/reference/h.html b/docs/reference/h.html index dde8a391..ee946331 100644 --- a/docs/reference/h.html +++ b/docs/reference/h.html @@ -23,7 +23,7 @@+diff --git a/docs/reference/formula_helpers.html b/docs/reference/formula_helpers.html index a7d761fb..a44a2397 100644 --- a/docs/reference/formula_helpers.html +++ b/docs/reference/formula_helpers.html @@ -23,7 +23,7 @@set.seed(98109) num_knots <- 100 # Try changing this value to see what happens. n_covars <- 1 @@ -474,16 +539,20 @@
Specifying smoothness of the HAL for a near-optimal fit. Therefore, one can safely pass a smaller value to
-num_knots
for a big decrease in runtime without sacrificing performance. +<<<<<<< HEAD +++=======+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1cmean((pred_0 - ytrue)^2)
-## [1] 0.00732315
--mean((pred_smooth_1- ytrue)^2)
## [1] 0.002432486
--mean((pred_smooth_2 - ytrue)^2)
+## [1] 0.001848927
mean((pred_smooth_1- ytrue)^2)
## [1] 0.003458005
++mean((pred_smooth_2 - ytrue)^2)
+## [1] 0.001834611
@@ -74,13 +74,14 @@+-plot(x, pred_0, main = "Zero order smoothness fit")
diff --git a/docs/reference/fit_hal.html b/docs/reference/fit_hal.html index e66ae317..8aac31ee 100644 --- a/docs/reference/fit_hal.html +++ b/docs/reference/fit_hal.html @@ -23,7 +23,7 @@+-plot(x, pred_smooth_1, main = "First order smoothness fit")
diff --git a/docs/reference/evaluate_basis.html b/docs/reference/evaluate_basis.html index 535d5286..896ceb95 100644 --- a/docs/reference/evaluate_basis.html +++ b/docs/reference/evaluate_basis.html @@ -23,7 +23,7 @@+plot(x, pred_smooth_2, main = "Second order smoothness fit")
In general, if the basis functions are not coarse, then the @@ -513,7 +582,11 @@
Specifying smoothness of the HAL Comparing the following simulation and the previous one, the HAL with second-order smoothness performed better when there were fewer knot points. +<<<<<<< HEAD +
diff --git a/docs/reference/enumerate_edge_basis.html b/docs/reference/enumerate_edge_basis.html index e2fc5b7a..c38f93fc 100644 --- a/docs/reference/enumerate_edge_basis.html +++ b/docs/reference/enumerate_edge_basis.html @@ -28,7 +28,7 @@+=======-+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1cset.seed(98109) n_covars <- 1 n_obs <- 250 @@ -535,22 +608,22 @@
Specifying smoothness of the HAL pred_0 <- predict(hal_fit_0, new_data = x) pred_smooth_1 <- predict(hal_fit_smooth_1, new_data = x) pred_smooth_2 <- predict(hal_fit_smooth_2, new_data = x)
++mean((pred_0 - ytrue)^2)
-## [1] 0.00732315
--mean((pred_smooth_1- ytrue)^2)
## [1] 0.002432486
--mean((pred_smooth_2 - ytrue)^2)
+## [1] 0.001834611
mean((pred_smooth_1- ytrue)^2)
## [1] 0.003458005
++mean((pred_smooth_2 - ytrue)^2)
+## [1] 0.001848927
-plot(x, pred_0, main = "Zero order smoothness fit")
diff --git a/docs/reference/enumerate_basis.html b/docs/reference/enumerate_basis.html index f703cd4f..390c6af6 100644 --- a/docs/reference/enumerate_basis.html +++ b/docs/reference/enumerate_basis.html @@ -24,7 +24,7 @@+-plot(x, pred_smooth_1, main = "First order smoothness fit")
@@ -570,7 +643,11 @@+plot(x, pred_smooth_2, main = "Second order smoothness fit")
Formula interface
+=======+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1cset.seed(98109) num_knots <- 100 @@ -588,7 +665,11 @@
Formula interface
+=======@@ -99,7 +99,7 @@+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c-# The `h` function is used to specify the basis functions for a given term # h(x1) generates one-way basis functions for the variable x1 # This is an additive model: @@ -607,7 +688,7 @@
Formula interface## ## $orders ## [1] 0
diff --git a/docs/reference/cv_lasso_early_stopping.html b/docs/reference/cv_lasso_early_stopping.html index 4bf959c2..395f3dec 100644 --- a/docs/reference/cv_lasso_early_stopping.html +++ b/docs/reference/cv_lasso_early_stopping.html @@ -24,7 +24,7 @@+# We don't need the variables in the parent environment if we specify them directly: rm(smoothness_orders) rm(num_knots) @@ -616,7 +697,7 @@
Formula interface# They are the same! length( form_term_new$basis_list) == length(form_term$basis_list)
-## [1] TRUE
diff --git a/docs/reference/cv_lasso.html b/docs/reference/cv_lasso.html index 46721e99..0fdd5e63 100644 --- a/docs/reference/cv_lasso.html +++ b/docs/reference/cv_lasso.html @@ -24,7 +24,7 @@+diff --git a/docs/reference/calc_xscale.html b/docs/reference/calc_xscale.html index b00b4c64..a1ea2570 100644 --- a/docs/reference/calc_xscale.html +++ b/docs/reference/calc_xscale.html @@ -23,7 +23,7 @@#To evaluate a unevaluated formula object like: formula <- ~h(x1) + h(x2) + h(A) # we can use the formula_hal function: @@ -633,13 +714,17 @@
Formula interface
+=======+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1csmoothness_orders <- 1 num_knots <- 5 # A additive model colnames(X)
-## [1] "x1" "x2" "A"
diff --git a/docs/reference/calc_pnz.html b/docs/reference/calc_pnz.html index 1d5a5b48..60788892 100644 --- a/docs/reference/calc_pnz.html +++ b/docs/reference/calc_pnz.html @@ -23,7 +23,7 @@+# Shortcut: formula1 <- h(.) # Longcut: @@ -647,7 +732,7 @@
Formula interface# Same number of basis functions length(formula1$basis_list ) == length(formula2$basis_list)
-## [1] TRUE
diff --git a/docs/reference/basis_of_degree.html b/docs/reference/basis_of_degree.html index 9e6d0da5..db22bf40 100644 --- a/docs/reference/basis_of_degree.html +++ b/docs/reference/basis_of_degree.html @@ -23,7 +23,7 @@+# Maybe we only want an additive model for x1 and x2 # Use the `.` argument formula1 <- h(., . = c("x1", "x2")) @@ -655,19 +740,19 @@
Formula interfacelength(formula1$basis_list ) == length(formula2$basis_list)
## [1] TRUE
We can specify interactions as follows.
-diff --git a/docs/reference/basis_list_cols.html b/docs/reference/basis_list_cols.html index 25e30700..4d8fe72b 100644 --- a/docs/reference/basis_list_cols.html +++ b/docs/reference/basis_list_cols.html @@ -23,7 +23,7 @@+# Two way interactions formula1 <- h(x1) + h(x2) + h(A) + h(x1, x2) formula2 <- h(.) + h(x1, x2) length(formula1$basis_list ) == length(formula2$basis_list)
-## [1] TRUE
+# formula1 <- h(.) + h(x1, x2) + h(x1,A) + h(x2,A) formula2 <- h(.) + h(., .) length(formula1$basis_list ) == length(formula2$basis_list)
-## [1] TRUE
@@ -112,20 +110,8 @@++ +# Three way interactions formula1 <- h(.) + h(.,.) + h(x1,A,x2) formula2 <- h(.) + h(., .)+ h(.,.,.) @@ -677,7 +762,11 @@
Formula interface
+=======@@ -59,6 +59,20 @@+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c# Write it all out formula <- h(x1) + h(x2) + h(A) + h(A, x1) + h(A,x2) @@ -706,7 +795,11 @@
Formula interface
+=======+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c# An additive monotone increasing model formula <- formula_hal( y ~ h(., monotone = "i"), X, smoothness_orders = 0, num_knots = 100 @@ -729,7 +822,7 @@
Formula interface~ h(., monotone = "d") + h(.,., monotone = "d"), X, smoothness_orders = 1, num_knots = 100 )
The penalization feature can be used to reproduce glm
-@@ -128,7 +128,11 @@+@@ -103,13 +103,17 @@# Additive glm # One knot (at the origin) and first order smoothness formula <- h(., s = 1, k = 1, pf = 0) @@ -741,27 +834,31 @@
Formula interface
+=======+>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c# get formula object fit <- fit_hal( X = X, Y = Y, formula = ~ h(.), smoothness_orders = 1, num_knots = 100 ) print(summary(fit), 10) # prints top 10 rows, i.e., highest absolute coefs
-## -## Summary of top 10 non-zero coefficients is based on lambda of 0.0005900481 +## Summary of top 10 non-zero coefficients is based on lambda of 0.0005376299 ## ## coef term -## 0.7395438 (Intercept) -## -0.3026891 [ I(A >= -3.978)*(A - -3.978)^1 ] -## -0.2948108 [ I(x2 >= -3.992)*(x2 - -3.992)^1 ] -## -0.2763138 [ I(x1 >= 1.172)*(x1 - 1.172)^1 ] -## -0.2493828 [ I(x2 >= 1.638)*(x2 - 1.638)^1 ] -## -0.2475319 [ I(x1 >= -3.972)*(x1 - -3.972)^1 ] -## 0.2098793 [ I(x1 >= -1.45)*(x1 - -1.45)^1 ] -## 0.2079505 [ I(A >= -1.356)*(A - -1.356)^1 ] -## -0.2075157 [ I(A >= 1.384)*(A - 1.384)^1 ] -## 0.2035278 [ I(x2 >= -1.293)*(x2 - -1.293)^1 ]
diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png index 32553fc6..1348298c 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-13-1.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png index d12920ac..ee2008b3 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-1.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png index 6b1c3032..b4891964 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-3.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png index 957b9b53..f2a89afe 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-3-4.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png index 6b1c3032..b4891964 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-2.png differ diff --git a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png index f2a89afe..957b9b53 100644 Binary files a/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png and b/docs/articles/intro_hal9001_files/figure-html/unnamed-chunk-5-3.png differ diff --git a/docs/authors.html b/docs/authors.html index b4c87ab8..9f3e40fb 100644 --- a/docs/authors.html +++ b/docs/authors.html @@ -23,7 +23,7 @@+## 0.7473715 (Intercept) +## -0.3081866 [ I(A >= -3.978)*(A - -3.978)^1 ] +## -0.2894577 [ I(x2 >= -3.992)*(x2 - -3.992)^1 ] +## -0.2811616 [ I(x1 >= 1.172)*(x1 - 1.172)^1 ] +## -0.2463477 [ I(x1 >= -3.972)*(x1 - -3.972)^1 ] +## -0.2405977 [ I(x2 >= 1.638)*(x2 - 1.638)^1 ] +## 0.2092225 [ I(x1 >= -1.45)*(x1 - -1.45)^1 ] +## 0.2078703 [ I(x2 >= -1.293)*(x2 - -1.293)^1 ] +## -0.2039605 [ I(A >= 1.384)*(A - 1.384)^1 ] +## 0.1993688 [ I(A >= -1.356)*(A - -1.356)^1 ]+Citation
Coyle J, Hejazi N, Phillips R, van der Laan L, van der Laan M (2022). hal9001: The scalable highly adaptive lasso. +<<<<<<< HEAD +doi:10.5281/zenodo.3558313, R package version 0.4.5, https://github.com/tlverse/hal9001. +======= doi:10.5281/zenodo.3558313, R package version 0.4.3, https://github.com/tlverse/hal9001. +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c
@Manual{, title = {{hal9001}: The scalable highly adaptive lasso}, author = {Jeremy R Coyle and Nima S Hejazi and Rachael V Phillips and Lars WP {van der Laan} and Mark J {van der Laan}}, year = {2022}, - note = {R package version 0.4.3}, + note = {R package version 0.4.5}, doi = {10.5281/zenodo.3558313}, url = {https://github.com/tlverse/hal9001}, }diff --git a/docs/index.html b/docs/index.html index 4faf8401..b68b1ec1 100644 --- a/docs/index.html +++ b/docs/index.html @@ -49,7 +49,7 @@Example
# load the package and set a seed library(hal9001) #> Loading required package: Rcpp +<<<<<<< HEAD +#> hal9001 v0.4.4: The Scalable Highly Adaptive Lasso +======= #> hal9001 v0.4.5: The Scalable Highly Adaptive Lasso +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c #> note: fit_hal defaults have changed. See ?fit_hal for details set.seed(385971) @@ -143,12 +147,21 @@
Example #> [1] "I'm sorry, Dave. I'm afraid I can't do that." hal_fit$times #> user.self sys.self elapsed user.child sys.child +<<<<<<< HEAD +#> enumerate_basis 0.014 0.003 0.059 0 0 +#> design_matrix 0.004 0.001 0.005 0 0 +#> reduce_basis 0.000 0.000 0.000 0 0 +#> remove_duplicates 0.000 0.000 0.000 0 0 +#> lasso 2.684 0.343 6.583 0 0 +#> total 2.703 0.348 6.655 0 0 +======= #> enumerate_basis 0.008 0.001 0.009 0 0 #> design_matrix 0.003 0.000 0.003 0 0 #> reduce_basis 0.000 0.000 0.000 0 0 #> remove_duplicates 0.000 0.000 0.001 0 0 #> lasso 1.690 0.036 1.764 0 0 #> total 1.702 0.037 1.778 0 0 +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c # training sample prediction preds <- predict(hal_fit, new_data = x) diff --git a/docs/news/index.html b/docs/news/index.html index 880a8638..ae051c98 100644 --- a/docs/news/index.html +++ b/docs/news/index.html @@ -23,7 +23,7 @@
Changelog
Source:NEWS.md
++hal9001 0.4.4
+
- Fixed bug with
+prediction_bounds
(afit_hal
argument infit_control
list), which would error when it was specified as a numeric vector. Also, added a check to assert this argument is correctly specified, and tests to ensure a numeric vector of bounds is provided.- Simplified
+fit_control
list arguments infit_hal
. Users can still specify additional arguments tocv.glmnet
andglmnet
in this list.- Defined
+weights
as a formal argument infit_hal
, opposed to an optional argument infit_control
, to facilitate specification and avoid confusion. This increases flexibility with SuperLearner wrapperSL.hal9001
as well;fit_control
can now be customized withSL.hal9001
.@@ -68,15 +68,13 @@hal9001 0.4.22022-01-26
- Version bump for CRAN resubmission following archiving.
diff --git a/docs/pkgdown.yml b/docs/pkgdown.yml index ce08e0d7..57535cfb 100644 --- a/docs/pkgdown.yml +++ b/docs/pkgdown.yml @@ -3,7 +3,11 @@ pkgdown: 2.0.3 pkgdown_sha: ~ articles: intro_hal9001: intro_hal9001.html +<<<<<<< HEAD +last_built: 2022-07-13T01:48Z +======= last_built: 2022-11-04T20:06Z +>>>>>>> 81093a5ceebcd36630f308dd07f69d4e30f07f1c urls: reference: https://tlverse.org/hal9001/reference article: https://tlverse.org/hal9001/articles diff --git a/docs/reference/SL.hal9001.html b/docs/reference/SL.hal9001.html index d10342df..9b649b08 100644 --- a/docs/reference/SL.hal9001.html +++ b/docs/reference/SL.hal9001.html @@ -23,7 +23,7 @@Wrapper for Classic SuperLearner
SL.hal9001( Y, X, - newX = NULL, - family = stats::gaussian(), - obsWeights = rep(1, length(Y)), - id = NULL, - max_degree = ifelse(ncol(X) >= 20, 2, 3), + newX, + family, + obsWeights, + id, + max_degree = 2, smoothness_orders = 1, - num_knots = ifelse(smoothness_orders >= 1, 25, 50), - reduce_basis = 1/sqrt(length(Y)), - lambda = NULL, + num_knots = 5, ... )
Arguments
specifying the maximum number of knot points (i.e., bins) for each covariate for generating basis functions. Seenum_knots
argument infit_hal
for more information. -reduce_basis -- A
numeric
value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed.lambda -A user-specified sequence of values of the regularization -parameter for the lasso L1 regression. If
NULL
, the default sequence -incv.glmnet
will be used. The cross-validated -optimal value of this regularization parameter will be selected with -cv.glmnet
.... -+ Not used.
Additional arguments to
fit_hal
.diff --git a/docs/reference/as_dgCMatrix.html b/docs/reference/as_dgCMatrix.html index 19770364..c5864c12 100644 --- a/docs/reference/as_dgCMatrix.html +++ b/docs/reference/as_dgCMatrix.html @@ -25,7 +25,7 @@Value
diff --git a/docs/reference/apply_copy_map.html b/docs/reference/apply_copy_map.html index 609faab4..c7bd4772 100644 --- a/docs/reference/apply_copy_map.html +++ b/docs/reference/apply_copy_map.html @@ -23,7 +23,7 @@Arguments
HAL: The Highly Adaptive Lasso
smoothness_orders = 1, num_knots = num_knots_generator(max_degree = max_degree, smoothness_orders = smoothness_orders, base_num_knots_0 = 200, base_num_knots_1 = 50), - reduce_basis = 1/sqrt(length(Y)), - family = c("gaussian", "binomial", "poisson", "cox"), + reduce_basis = NULL, + family = c("gaussian", "binomial", "poisson", "cox", "mgaussian"), lambda = NULL, id = NULL, + weights = NULL, offset = NULL, - fit_control = list(cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE, - lambda.min.ratio = 1e-04, prediction_bounds = "default"), + fit_control = list(cv_select = TRUE, use_min = TRUE, lambda.min.ratio = 1e-04, + prediction_bounds = "default"), basis_list = NULL, return_lasso = TRUE, return_x_basis = FALSE, @@ -95,7 +96,9 @@Arguments
number of covariates that will be used to derive the design matrix of basis functions.Y -+ A
numeric
vector of observations of the outcome variable.A
numeric
vector of observations of the outcome variable. For +family="mgaussian"
,Y
is a matrix of observations of the +outcome variables.formula @@ -123,19 +126,21 @@ A character string formula to be used in
formula_hal
. See its documentation for details.Arguments
basis functions generated. This allows the complexity of the optimization problem to grow scalably. See details ofnum_knots
more information.reduce_basis -+ A
numeric
value bounded in the open unit interval -indicating the minimum proportion of 1's in a basis function column needed -for the basis function to be included in the procedure to fit the lasso. -Any basis functions with a lower proportion of 1's than the cutoff will be -removed. Whenreduce_basis
is set toNULL
, all basis -functions are used in the lasso-fitting stage offit_hal
.Am optional
numeric
value bounded in the open +unit interval indicating the minimum proportion of 1's in a basis function +column needed for the basis function to be included in the procedure to fit +the lasso. Any basis functions with a lower proportion of 1's than the +cutoff will be removed. Defaults to 1 over the square root of the number of +observations. Only applicable for models fit with zero-order splines, i.e. +smoothness_orders = 0
.family A
character
or afamily
object (supported byglmnet
) specifying the error/link family for a generalized linear model.character
options are limited to "gaussian" for fitting a standard penalized linear model, "binomial" for penalized logistic regression, "poisson" for penalized Poisson regression, -and "cox" for a penalized proportional hazards model. Note that passing in +"cox" for a penalized proportional hazards model, and "mgaussian" for +multivariate penalized linear model. Note that passing in family objects leads to slower performance relative to passing in a character family (if supported). For example, one should setfamily = "binomial"
instead offamily = binomial()
when @@ -153,39 +158,36 @@Arguments
+ A vector of ID values that is used to generate cross-validation folds for
cv.glmnet
. This argument is ignored whenfit_control
'scv_select
argument isFALSE
.weights +observation weights; defaults to 1 per observation.
offset a vector of offset values, used in fitting.
fit_control -List of arguments for fitting. Includes the following -arguments, and any others to be passed to
cv.glmnet
-orglmnet
.
cv_select
: Alogical
specifying if the sequence of +List of arguments, including the following, and any +others to be passed to
cv.glmnet
or +glmnet
.
- -
cv_select
: Alogical
specifying if the sequence of specifiedlambda
values should be passed tocv.glmnet
in order for a single, optimal value oflambda
to be selected according to cross-validation. Whencv_select = FALSE
, aglmnet
model will be used to fit the sequence of (or single)lambda
.- -
n_folds
: Integer for the number of folds to be used when splitting -the data for V-fold cross-validation. Only used when -cv_select = TRUE
.
foldid
: An optionalnumeric
containing values between 1 and -n_folds
, identifying the fold to which each observation is -assigned. If supplied,n_folds
can be missing. In such a case, -this vector is passed directly tocv.glmnet
. Only -used whencv_select = TRUE
.- -
use_min
: Specify the choice of lambda to be selected bycv.glmnet
. WhenTRUE
,"lambda.min"
is used; otherwise,"lambda.1se"
. Only used whencv_select = TRUE
.- -
lambda.min.ratio
: Aglmnet
argument specifying -the smallest value forlambda
, as a fraction oflambda.max
, -the (data derived) entry value (i.e. the smallest value for which all -coefficients are zero). We've seen that not settinglambda.min.ratio
-can lead to nolambda
values that fit the data sufficiently well.- +
prediction_bounds
: A vector of size two that provides the lower and -upper bounds for predictions. Whenprediction_bounds = "default"
, -the predictions are bounded betweenmin(Y) - sd(Y)
and -max(Y) + sd(Y)
. Bounding ensures that there is no extrapolation, -and it is necessary for cross-validation selection and/or Super Learning.- +
lambda.min.ratio
: Aglmnet
argument +specifying the smallest value forlambda
, as a fraction of +lambda.max
, the (data derived) entry value (i.e. the smallest value +for which all coefficients are zero). We've seen that not setting +lambda.min.ratio
can lead to nolambda
values that fit the +data sufficiently well.
prediction_bounds
: An optional vector of size two that provides +the lower and upper bounds predictions; not used when +family = "cox"
. Whenprediction_bounds = "default"
, the +predictions are bounded betweenmin(Y) - sd(Y)
and +max(Y) + sd(Y)
for each outcome (whenfamily = "mgaussian"
, +each outcome can have different bounds). Bounding ensures that there is +no extrapolation.- basis_list
- diff --git a/docs/reference/formula_hal.html b/docs/reference/formula_hal.html index e8230126..72f5de0d 100644 --- a/docs/reference/formula_hal.html +++ b/docs/reference/formula_hal.html @@ -23,7 +23,7 @@
The full set of basis functions generated from
X
.
Either "response" for predictions of the response, or "link" for un-transformed predictions (on the scale of the link function).
Sparse matrix pre-allocation proportion, which is the
-anticipated proportion of 1's in the design matrix. Default value is
-recommended in most settings. If a dense design matrix is expected, it
-would be useful to set p_reserve
to a higher value.
Additional arguments passed to predict
as necessary.