diff --git a/R/varimp.R b/R/varimp.R index 1daae835..86d76c8b 100644 --- a/R/varimp.R +++ b/R/varimp.R @@ -23,7 +23,7 @@ varimp <- function(fit, loss, fold_number = "validation") { Y <- task$Y preds <- fit$predict_fold(task, fold_number = fold_number) - risk <- mean(loss(Y, preds)) + risk <- mean(loss(preds, Y)) X <- task$nodes$covariates @@ -45,7 +45,7 @@ varimp <- function(fit, loss, fold_number = "validation") { scrambled_sl_preds <- fit$predict_fold(scrambled_col_task, fold_number) # risk on scrambled col task - risk_scrambled <- mean(loss(Y, scrambled_sl_preds)) + risk_scrambled <- mean(loss(scrambled_sl_preds, Y)) # calculate risk difference rd <- risk_scrambled - risk