Skip to content

Commit

Permalink
Merge pull request #110 from tlverse/fix-matrix
Browse files Browse the repository at this point in the history
Fix matrix
  • Loading branch information
jeremyrcoyle authored Nov 8, 2023
2 parents 81093a5 + 08bfb78 commit 35009e6
Show file tree
Hide file tree
Showing 80 changed files with 1,204 additions and 857 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hal9001
Title: The Scalable Highly Adaptive Lasso
Version: 0.4.3
Version: 0.4.6
Authors@R: c(
person("Jeremy", "Coyle", email = "[email protected]",
role = c("aut", "cre"),
Expand Down Expand Up @@ -68,5 +68,5 @@ LinkingTo:
Rcpp,
RcppEigen
VignetteBuilder: knitr
RoxygenNote: 7.2.0
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
3 changes: 0 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,15 @@ importFrom(data.table,setorder)
importFrom(glmnet,cv.glmnet)
importFrom(glmnet,glmnet)
importFrom(methods,is)
importFrom(origami,cross_validate)
importFrom(origami,folds2foldvec)
importFrom(origami,make_folds)
importFrom(stats,aggregate)
importFrom(stats,as.formula)
importFrom(stats,coef)
importFrom(stats,gaussian)
importFrom(stats,median)
importFrom(stats,plogis)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,sd)
importFrom(stringr,str_detect)
importFrom(stringr,str_extract)
importFrom(stringr,str_match)
Expand Down
21 changes: 21 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
# hal9001 0.4.6
* Fixed predict method to address changes required by Matrix 1.6.2

# hal9001 0.4.5
* Added multivariate outcome prediction

# hal9001 0.4.4
* Fixed bug with `prediction_bounds` (a `fit_hal` argument in `fit_control`
list), which would error when it was specified as a numeric vector. Also,
added a check to assert this argument is correctly specified, and tests
to ensure a numeric vector of bounds is provided.
* Simplified `fit_control` list arguments in `fit_hal`. Users can still specify
additional arguments to `cv.glmnet` and `glmnet` in this list.
* Defined `weights` as a formal argument in `fit_hal`, opposed to an optional
argument in `fit_control`, to facilitate specification and avoid confusion.
This increases flexibility with SuperLearner wrapper `SL.hal9001` as well;
`fit_control` can now be customized with `SL.hal9001`.

# hal9001 0.4.3
* Version bump for CRAN resubmission following archiving.

# hal9001 0.4.2
* Version bump for CRAN resubmission following archiving.

Expand Down
72 changes: 0 additions & 72 deletions R/cv_lasso.R

This file was deleted.

52 changes: 26 additions & 26 deletions R/formula_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1,
var_names <- unlist(list(...))
}
formula_term <- paste0("h(", paste0(var_names, collapse = ", "), ")")

if (is.null(k)) {
k <- get("num_knots", envir = parent.frame())
k <- suppressWarnings(k + rep(0, length(var_names))) # recycle
Expand All @@ -152,8 +152,8 @@ h <- function(..., k = NULL, s = NULL, pf = 1,
if (is.null(s)) {
s <- get("smoothness_orders", envir = parent.frame())[1]
}


if ("." %in% var_names) {
var_names_filled <- fill_dots(var_names, . = .)

Expand Down Expand Up @@ -186,7 +186,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1,
return(all_items)
}




# Get corresponding column indices
Expand Down Expand Up @@ -240,6 +240,7 @@ h <- function(..., k = NULL, s = NULL, pf = 1,
return(out)
}


#' Print formula_hal9001 object
#'
#' @param x A formula_hal9001 object.
Expand All @@ -250,16 +251,18 @@ print.formula_hal9001 <- function(x, ...) {
cat(paste0("A hal9001 formula object of the form: ~ ", x$formula_term))
}

#' Formula Helpers
#'
#' @param var_names A \code{character} vector of variable names.
#' @param var_names A \code{character} vector of variable names representing a single type of interaction
# " (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3.
#' var_names may include the wildcard variable "." in which case the argument `.` must be specified
#' so that all interactions matching the form of var_names are generated.
#' @param . Specification of variables for use in the formula.
#'
#' @name formula_helpers
NULL

#' This function takes a character vector `var_names` of the form c(name1, name2, ".", name3, ".")
#' with any number of name{int} variables and any number of wild card variables ".".
#' It returns a list of character vectors of the form c(name1, name2, wildcard1, name3, wildcard2)
#' where wildcard1 and wildcard2 are iterated over all possible character names given in the argument `.`.
#' @rdname formula_helpers
fill_dots_helper <- function(var_names, .) {
fill_dots <- function(var_names, .) {
index <- which(var_names == ".")
if (length(index) == 0) {
return(sort(var_names))
Expand All @@ -269,26 +272,23 @@ fill_dots_helper <- function(var_names, .) {
all_items <- lapply(., function(var) {
new_var_names <- var_names
new_var_names[index] <- var
out <- fill_dots_helper(new_var_names, .)

if (is.list(out[[1]])) {
out <- unlist(out, recursive = FALSE)
}
out <- fill_dots(new_var_names, .)
return(out)
})


return(unique(all_items))
}

#' @rdname formula_helpers
fill_dots <- function(var_names, .) {
x <- unique(unlist(fill_dots_helper(var_names, . = .), recursive = FALSE))
keep <- sapply(x, function(item) {
is_nested <- is.list(all_items[[1]])
while (is_nested) {
all_items <- unlist(all_items, recursive = FALSE)
is_nested <- is.list(all_items[[1]])
}
# Remove combinations of variable names that have duplicates.
# This removes generated interactions that include two of the same variable.
keep <- sapply(all_items, function(item) {
if (any(duplicated(item))) {
return(FALSE)
}
return(TRUE)
})
return(x[keep])
all_items <- all_items[keep]

return(unique(all_items))
}
Loading

0 comments on commit 35009e6

Please sign in to comment.