Skip to content

Commit

Permalink
Merge pull request #268 from tlverse/loss-varimp
Browse files Browse the repository at this point in the history
Loss function arguments for varimp
  • Loading branch information
nhejazi authored Feb 28, 2020
2 parents 94d43f2 + 74c668e commit 63d0800
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions R/varimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ varimp <- function(fit, loss, fold_number = "validation") {
Y <- task$Y

preds <- fit$predict_fold(task, fold_number = fold_number)
risk <- mean(loss(Y, preds))
risk <- mean(loss(preds, Y))


X <- task$nodes$covariates
Expand All @@ -45,7 +45,7 @@ varimp <- function(fit, loss, fold_number = "validation") {
scrambled_sl_preds <- fit$predict_fold(scrambled_col_task, fold_number)

# risk on scrambled col task
risk_scrambled <- mean(loss(Y, scrambled_sl_preds))
risk_scrambled <- mean(loss(scrambled_sl_preds, Y))

# calculate risk difference
rd <- risk_scrambled - risk
Expand Down

0 comments on commit 63d0800

Please sign in to comment.