Skip to content

Commit

Permalink
sync tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniSanchezSantolaya committed Sep 15, 2023
1 parent b728d08 commit 70640e1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/explainability/test_anchors_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_anchors_fit_and_explain_coverage(model_and_data):
explanation = anchorsExtendedExplainer.explain(
explain_data.head(1).values, threshold=0.95
)
assert explanation.data['coverage'] == pytest.approx(0.42, 0.01)
assert explanation.data['precision'] > 0.95

def test_anchors_feature_importance_obtention(model_and_data):
"""
Expand Down
72 changes: 67 additions & 5 deletions tests/explainability/test_serializer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from mercury.explainability.explainers import (
AnchorsWithImportanceExplainer,
ALEExplainer,
MercuryExplainer
)
import pytest
import pickle
import pandas as pd
import numpy as np
import os


@pytest.fixture(scope="session")
def model_and_data():
def model_and_data_anchors():
logRegModel = pickle.load(open('tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb'))
fit_data = pd.read_csv('tests/explainability/model_and_data/fit_data_red.csv', index_col=0)
explain_data = pd.read_csv('tests/explainability/model_and_data/explain_data.csv', index_col=0)
Expand All @@ -19,17 +21,25 @@ def model_and_data():
'explain_data': explain_data
}

@pytest.fixture(scope='session')
def model_and_data_ale():
logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb'))
fit_data = pd.read_csv('./tests/explainability/model_and_data_pyspark/data_ale.csv', index_col=0)
return {
'logRegModel': logRegModel,
'data': fit_data,
}

pytestmark = pytest.mark.usefixtures("model_and_data")

def test_serializer_explainer(model_and_data):
def test_serializer_explainer(model_and_data_anchors):
"""
Testing out that the explainers are properly saved and then loaded
back.
"""
logRegModel = model_and_data['logRegModel']
fit_data = model_and_data['fit_data']
explain_data = model_and_data['explain_data']
logRegModel = model_and_data_anchors['logRegModel']
fit_data = model_and_data_anchors['fit_data']
explain_data = model_and_data_anchors['explain_data']
feature_names = list(explain_data.columns)

TEST_FILE = "/tmp/explainer.pkl"
Expand All @@ -46,4 +56,56 @@ def test_serializer_explainer(model_and_data):
assert type(anchorsExtendedExplainer_recovered) ==\
AnchorsWithImportanceExplainer, "Bad load"

os.remove(TEST_FILE)

def test_serializer_anchors_with_importance_explainer(model_and_data_anchors):

logRegModel = model_and_data_anchors['logRegModel']
fit_data = model_and_data_anchors['fit_data']
explain_data = model_and_data_anchors['explain_data']
feature_names = list(explain_data.columns)

TEST_FILE = "/tmp/explainer_anchor.pkl"

explainer = AnchorsWithImportanceExplainer(
train_data=fit_data,
predict_fn=logRegModel.predict_proba,
feature_names=feature_names
)
explainer.save(TEST_FILE)

explainer_loaded = MercuryExplainer.load(TEST_FILE)
assert isinstance(explainer_loaded, AnchorsWithImportanceExplainer)
assert explainer.params == explainer_loaded.params
assert explainer.feature_values == explainer_loaded.feature_values

# We are able to execute explain
explanation_loaded = explainer_loaded.explain(fit_data.values[0])

os.remove(TEST_FILE)

def test_serializer_ale_explainer(model_and_data_ale):

model = model_and_data_ale['logRegModel']
data_pd = model_and_data_ale['data']
features = [c for c in list(data_pd.columns) if c not in ['label']]

TEST_FILE = "/tmp/explainer_ale.pkl"

explainer = ALEExplainer(
lambda x: model.predict_proba(x),
target_names="label"
)
explainer.save(TEST_FILE)

explainer_loaded = MercuryExplainer.load(TEST_FILE)
isinstance(explainer_loaded, ALEExplainer)

# Check explanations
explanation = explainer.explain(data_pd[features])
explanation_loaded = explainer_loaded.explain(data_pd[features])

for i in range(len(explanation.data)):
assert np.all(explanation.data['ale_values'][i] == explanation_loaded.data['ale_values'][i])

os.remove(TEST_FILE)

0 comments on commit 70640e1

Please sign in to comment.