Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyrcoyle committed Nov 1, 2023
1 parent 47da2b5 commit 08bfb78
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 200 deletions.
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ importFrom(data.table,setorder)
importFrom(glmnet,cv.glmnet)
importFrom(glmnet,glmnet)
importFrom(methods,is)
importFrom(origami,cross_validate)
importFrom(origami,folds2foldvec)
importFrom(origami,make_folds)
importFrom(stats,aggregate)
Expand All @@ -36,7 +35,6 @@ importFrom(stats,median)
importFrom(stats,plogis)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,sd)
importFrom(stringr,str_detect)
importFrom(stringr,str_extract)
importFrom(stringr,str_match)
Expand Down
72 changes: 0 additions & 72 deletions R/cv_lasso.R

This file was deleted.

5 changes: 2 additions & 3 deletions R/formula_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ print.formula_hal9001 <- function(x, ...) {

#'
#' @param var_names A \code{character} vector of variable names representing a single type of interaction
#" (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3.
# " (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3.
#' var_names may include the wildcard variable "." in which case the argument `.` must be specified
#' so that all interactions matching the form of var_names are generated.
#' @param . Specification of variables for use in the formula.
Expand All @@ -276,7 +276,7 @@ fill_dots <- function(var_names, .) {
return(out)
})
is_nested <- is.list(all_items[[1]])
while(is_nested) {
while (is_nested) {
all_items <- unlist(all_items, recursive = FALSE)
is_nested <- is.list(all_items[[1]])
}
Expand All @@ -292,4 +292,3 @@ fill_dots <- function(var_names, .) {

return(unique(all_items))
}

14 changes: 7 additions & 7 deletions R/hal.R
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ fit_hal <- function(X,
)
}

if(!is.character(fit_control$prediction_bounds)){
if(fam == "mgaussian"){
if (!is.character(fit_control$prediction_bounds)) {
if (fam == "mgaussian") {
assertthat::assert_that(
is.list(fit_control$prediction_bounds) &
length(fit_control$prediction_bounds) == ncol(Y),
Expand Down Expand Up @@ -342,14 +342,14 @@ fit_hal <- function(X,

# NOTE: keep only basis functions with some (or higher) proportion of 1's
if (all(smoothness_orders == 0)) {
if(is.null(reduce_basis)){
if (is.null(reduce_basis)) {
reduce_basis <- 1 / sqrt(n_Y)
}
reduced_basis_map <- make_reduced_basis_map(x_basis, reduce_basis)
x_basis <- x_basis[, reduced_basis_map]
basis_list <- basis_list[reduced_basis_map]
} else {
if(!is.null(reduce_basis)){
if (!is.null(reduce_basis)) {
warning("Dropping reduce_basis; only applies if smoothness_orders = 0")
}
}
Expand Down Expand Up @@ -467,9 +467,9 @@ fit_hal <- function(X,
# Bounds for prediction on new data (to prevent extrapolation for linear HAL)
if (is.character(fit_control$prediction_bounds) &&
fit_control$prediction_bounds == "default") {
if(fam == "mgaussian") {
fit_control$prediction_bounds <- lapply(seq(ncol(Y)), function(i){
c(min(Y[,i]) - 2 * stats::sd(Y[,i]), max(Y[,i]) + 2 * stats::sd(Y[,i]))
if (fam == "mgaussian") {
fit_control$prediction_bounds <- lapply(seq(ncol(Y)), function(i) {
c(min(Y[, i]) - 2 * stats::sd(Y[, i]), max(Y[, i]) + 2 * stats::sd(Y[, i]))
})
} else if (fam == "cox") {
fit_control$prediction_bounds <- NULL
Expand Down
32 changes: 17 additions & 15 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ predict.hal9001 <- function(object,
offset = NULL,
type = c("response", "link"),
...) {

family <- ifelse(inherits(object$family, "family"), object$family$family, object$family)

type <- match.arg(type)
Expand Down Expand Up @@ -80,25 +79,28 @@ predict.hal9001 <- function(object,
# generate predictions
if (!family %in% c("cox", "mgaussian")) {
if (ncol(object$coefs) > 1) {
preds <- pred_x_basis%*%object$coefs[-1,]+
matrix(object$coefs[1,], nrow=nrow(pred_x_basis),
ncol=ncol(object$coefs), byrow = TRUE)
preds <- pred_x_basis %*% object$coefs[-1, ] +
matrix(object$coefs[1, ],
nrow = nrow(pred_x_basis),
ncol = ncol(object$coefs), byrow = TRUE
)
} else {
preds <- as.vector(Matrix::tcrossprod(
x = pred_x_basis,
y = matrix(object$coefs[-1],nrow=1)
y = matrix(object$coefs[-1], nrow = 1)
) + object$coefs[1])
}
} else {
if(family == "cox") {
if (family == "cox") {
# Note: there is no intercept in the Cox model (built into the baseline
# hazard and would cancel in the partial likelihood).
# Note: there is no intercept in the Cox model (built into the baseline
# hazard and would cancel in the partial likelihood).
preds <- pred_x_basis%*%object$coefs
preds <- pred_x_basis %*% object$coefs
} else if (family == "mgaussian") {
preds <- stats::predict(
object$lasso_fit, newx = pred_x_basis, s = object$lambda_star
object$lasso_fit,
newx = pred_x_basis, s = object$lambda_star
)
}
}
Expand All @@ -123,13 +125,13 @@ predict.hal9001 <- function(object,
transform <- stats::plogis
} else if (family %in% c("poisson", "cox")) {
transform <- exp
} else if(family%in%c("gaussian","mgaussian")){
} else if (family %in% c("gaussian", "mgaussian")) {
transform <- identity
} else{
} else {
stop("unsupported family")
}
if(length(ncol(preds))){

if (length(ncol(preds))) {
# apply along only the first dimension (to handle n-d arrays)
margin <- seq(length(dim(preds)))[-1]
preds <- apply(preds, margin, transform)
Expand All @@ -141,10 +143,10 @@ predict.hal9001 <- function(object,
# bound predictions within observed outcome bounds if on response scale
if (!is.null(object$prediction_bounds)) {
bounds <- object$prediction_bounds
if(family == "mgaussian") {
preds <- do.call(cbind, lapply(seq(ncol(preds)), function(i){
if (family == "mgaussian") {
preds <- do.call(cbind, lapply(seq(ncol(preds)), function(i) {
bounds_y <- sort(bounds[[i]])
preds_y <- preds[,i,]
preds_y <- preds[, i, ]
preds_y <- pmax(bounds_y[1], preds_y)
preds_y <- pmin(preds_y, bounds_y[2])
return(preds_y)
Expand Down
1 change: 0 additions & 1 deletion R/sl_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ SL.hal9001 <- function(Y,
smoothness_orders = 1,
num_knots = 5,
...) {

# create matrix version of X and newX for use with hal9001::fit_hal
if (!is.matrix(X)) X <- as.matrix(X)
if (!is.null(newX) & !is.matrix(newX)) newX <- as.matrix(newX)
Expand Down
2 changes: 1 addition & 1 deletion R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ summary.hal9001 <- function(object,
} else {
# add col indicating whether or not there is a duplicate
coef_summ[, dup := (duplicated(dups_tbl) |
duplicated(dups_tbl, fromLast = TRUE))]
duplicated(dups_tbl, fromLast = TRUE))]

# if basis_list_idx contains redundant duplicates, remove them
redundant_dups <- coef_summ[dup == TRUE, "basis_list_idx"]
Expand Down
29 changes: 0 additions & 29 deletions man/cv_lasso.Rd

This file was deleted.

1 change: 0 additions & 1 deletion tests/testthat/test-formula.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

context("check formula function")


Expand Down
20 changes: 12 additions & 8 deletions tests/testthat/test-hal_multivariate.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ test_that("HAL and glmnet predictions match for multivariate outcome", {
hal_pred <- predict(hal_fit, new_data = MultiGaussianExample$x)
# get glmnet preds
set.seed(74296)
glmnet_fit <- cv.glmnet(x = hal_fit$x_basis, y = MultiGaussianExample$y,
family = "mgaussian", standardize = FALSE,
lambda.min.ratio = 1e-4)
glmnet_pred <- predict(glmnet_fit, hal_fit$x_basis, s = hal_fit$lambda_star)[,,1]
glmnet_fit <- cv.glmnet(
x = hal_fit$x_basis, y = MultiGaussianExample$y,
family = "mgaussian", standardize = FALSE,
lambda.min.ratio = 1e-4
)
glmnet_pred <- predict(glmnet_fit, hal_fit$x_basis, s = hal_fit$lambda_star)[, , 1]
# test equivalence
colnames(glmnet_pred) <- colnames(hal_pred)
expect_equivalent(glmnet_pred, hal_pred)
Expand All @@ -31,15 +33,17 @@ test_that("HAL summarizes coefs for each multivariate outcome prediction", {

test_that("HAL summarizes coefs appropriately for multivariate outcome", {
# this checks intercept matches
lapply(seq_along(hal_summary), function(i){
expect_equal(hal_fit$coefs[[i]][1,], as.numeric(hal_summary[[i]]$table[1,1]))
lapply(seq_along(hal_summary), function(i) {
expect_equal(hal_fit$coefs[[i]][1, ], as.numeric(hal_summary[[i]]$table[1, 1]))
})
})

test_that("Error when prediction_bounds is incorrectly formatted", {
fit_control <- list(prediction_bounds = 9)
expect_error(fit_hal(X = MultiGaussianExample$x, Y = MultiGaussianExample$y,
family = "mgaussian", fit_control = fit_control))
expect_error(fit_hal(
X = MultiGaussianExample$x, Y = MultiGaussianExample$y,
family = "mgaussian", fit_control = fit_control
))
})

test_that("HAL summary for multivariate outcome predictions prints", {
Expand Down
5 changes: 2 additions & 3 deletions tests/testthat/test-reduce_basis_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ system.time({
mse <- mean(se)
se[c(current_i, new_i)] <- 0
new_i <- which.max(se)
#print(sprintf("%f, %f", old_mse, mse))
# print(sprintf("%f, %f", old_mse, mse))
continue <- mse < 1.1 * old_mse
if (mse < old_mse) {
good_i <- unique(c(good_i, new_i))
offset <- preds
old_mse <- mse
coefs <- as.vector(coef(screen_glmnet, s = "lambda.min"))[-1]
# old_basis <- union(old_basis,c(old_basis,b1)[which(coefs!=0)])
#print(length(old_basis))
# print(length(old_basis))
old_basis <- unique(c(old_basis, b1))
}

Expand Down Expand Up @@ -164,4 +164,3 @@ test_that("Predictions are not too different when reducing basis functions", {
# ensure hal fit with reduce_basis works with new data for prediction
newx <- matrix(rnorm(n * p), n, p)
hal_pred_reduced_newx <- predict(hal_fit_reduced, new_data = newx)

Loading

0 comments on commit 08bfb78

Please sign in to comment.