Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shap loss values #347

Open
rawanmahdi opened this issue Jun 22, 2023 · 4 comments
Open

Shap loss values #347

rawanmahdi opened this issue Jun 22, 2023 · 4 comments

Comments

@rawanmahdi
Copy link
Contributor

rawanmahdi commented Jun 22, 2023

For debugging black box models, it would be nice to get shapley feature importance values as they relate to the loss of the model rather than the prediction. I've seen this implemeted by the original makers of SHAP, using TreeExplainer, with the assumption that features are independant. This Medium article goes into more depth about the implementation.

I'm wondering, would it be possible to obtain model agnostic shap loss values on dependant features, similar to how shapr does so for the predictions?

@martinju
Copy link
Member

I agree that this would be of interest. I think SAGE is the method you are looking for here. Take a look at this nice presentation: https://iancovert.com/blog/understanding-shap-sage/
As far as I know, the SAGE implementation ignores feature dependence, so it would be nice to implement it using a proper, conditioning scheme like we have in shapr. I certainly think it is doable, but we currently don't have it on the TODO-list.

@rawanmahdi
Copy link
Contributor Author

Interesting! From what I understand, SAGE seems like it would be relatively easy to implement with the current code.. mainly altering the compute_vS functions to compute a loss. I may be free to work on this in a few weeks. Any comments on how you would want it organized in this repo?

@wbound90
Copy link

Hi @martinju and @rawanmahdi,

I find this interesting and I would like to know if there are any updates on this? I would like to see the contribution of each feature to the loss of an MLPRegressor model. I am quite new to this field, so I would like to know more.

My current understanding is that only models with single output variables are supported now i.e., y = f(x1, x2, x3,...xn). In my case, the output variables are more than size 1 [e.g, (y1,y2) = f(f(x1, x2, x3,...xn)]. So, I am trying to avoid going into the issue discussed in #323. Please correct me if I am wrong here.

Is it possible to pass a user-defined python function f that returns a measure of loss directly into explain() to give feature contributions for explaining model loss?

@martinju
Copy link
Member

martinju commented Nov 7, 2024

Sorry for the very late reply.
I just got reminded about this. It is not supported directly within shapr, but it is straight forward to do it based on the output of shapr. I am considering adding it properly to the package. I have attached the basic script with the work-around below (requires the github version of the package)

library(xgboost)
library(shapr) # remotes::insall_github("NorskRegnesentral/shapr")


data("airquality")
data <- data.table::as.data.table(airquality)
data <- data[complete.cases(data), ]

x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"

x_train <- data[, ..x_var]
y_train <- data[, get(y_var)]


# Fitting a basic xgboost model to the training data
model <- xgboost(
  data = as.matrix(x_train),
  label = y_train,
  nround = 20,
  verbose = FALSE
)

p0 <- mean(y_train)

explanation <- explain(
  model = model,
  x_explain = x_train,
  x_train = x_train,
  approach = "gaussian",
  phi0 = p0
)

#### SAGE ####

full_loss <- mean((explanation$pred_explain-y_train)^2)
zero_loss <- mean((p0-y_train)^2)

# Decompose the difference between the zero and full loss:
zero_loss - full_loss

vS_SHAP <- explanation$internal$output$dt_vS[,-1]

vS_SAGE <- zero_loss-colMeans((t(vS_SHAP)-y_train)^2)


W <- explanation$internal$objects$W


dt_SAGE <- data.table::as.data.table(t(W %*% as.matrix(vS_SAGE)))
colnames(dt_SAGE) <- c("none", x_var)

# The SAGE values
dt_SAGE[,-1]

sum(dt_SAGE)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants