Skip to content

Commit

Permalink
explainability: test explicit/implicit hydrogens
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 15, 2024
1 parent 667df3a commit a568de4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
2 changes: 1 addition & 1 deletion molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _make_grid_from_mol(
mol_height = yl[1] - yl[0]
mol_width = xl[1] - xl[0]

height_to_width_ratio_mol = mol_height / mol_width
height_to_width_ratio_mol = mol_height / (1e-16 + mol_width)
# the grids height / weight is the canvas height / width
height_to_width_ratio_canvas = grid_resolution[1] / grid_resolution[0]

Expand Down
71 changes: 55 additions & 16 deletions tests/test_explainability/test_visualization/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
_RANDOM_STATE = 67056


def _get_test_shap_explanations() -> list[SHAPExplanation]:
"""Get test explanations.
def _get_test_morgan_rf_pipeline() -> Pipeline:
"""Get a test pipeline with Morgan fingerprints and a random forest classifier.
Returns
-------
list[SHAPExplanation]
List of SHAP explanations.
Pipeline
Pipeline with Morgan fingerprints and a random forest classifier.
"""
pipeline = Pipeline(
[
Expand All @@ -45,26 +45,27 @@ def _get_test_shap_explanations() -> list[SHAPExplanation]:
),
]
)
pipeline.fit(TEST_SMILES, CONTAINS_OX)

explainer = SHAPTreeExplainer(pipeline)
explanations = explainer.explain(TEST_SMILES)
return explanations
return pipeline


class TestExplainabilityVisualization(unittest.TestCase):
"""Test the public interface of the visualization methods for explanations."""

explanations: ClassVar[list[SHAPExplanation]]
test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[list[SHAPExplanation]]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
cls.explanations = _get_test_shap_explanations()
cls.test_pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX)
cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations = cls.test_explainer.explain(TEST_SMILES)

def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None:
"""Test structure heatmap fingerprint-based atom coloring."""
for explanation in self.explanations:
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)
image = structure_heatmap(
Expand All @@ -78,7 +79,7 @@ def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None:

def test_structure_heatmap_shap_explanation(self) -> None:
"""Test structure heatmap SHAP explanation."""
for explanation in self.explanations:
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)
image = structure_heatmap_shap(
Expand All @@ -89,20 +90,58 @@ def test_structure_heatmap_shap_explanation(self) -> None:
self.assertIsNotNone(image)
self.assertEqual(image.format, "PNG")

def test_explicit_hydrogens(self):
"""Test that the visualization methods work with explicit hydrogens."""
mol_implicit_Hs = Chem.MolFromSmiles("C")
explanations1 = self.test_explainer.explain([Chem.MolToSmiles(mol_implicit_Hs)])
mol_added_Hs = Chem.AddHs(mol_implicit_Hs)
explanations2 = self.test_explainer.explain([Chem.MolToSmiles(mol_added_Hs)])
mol_explicit_Hs = Chem.MolFromSmiles("[H]C([H])([H])[H]")
explanations3 = self.test_explainer.explain([Chem.MolToSmiles(mol_explicit_Hs)])

# test explanations' atom weights
self.assertEqual(len(explanations1), 1)
self.assertEqual(len(explanations2), 1)
self.assertEqual(len(explanations3), 1)
self.assertIsNotNone(explanations1[0].atom_weights)
self.assertIsNotNone(explanations2[0].atom_weights)
self.assertIsNotNone(explanations3[0].atom_weights)
self.assertEqual(len(explanations1[0].atom_weights), 1)
self.assertEqual(len(explanations2[0].atom_weights), 1)
self.assertEqual(len(explanations3[0].atom_weights), 1)

# test visualization
all_explanations = explanations1 + explanations2 + explanations3
for explanation in all_explanations:
self.assertTrue(explanation.is_valid())
image = structure_heatmap(
explanation.molecule,
explanation.atom_weights, # type: ignore[arg-type]
width=128,
height=128,
) # type: ignore[union-attr]
self.assertIsNotNone(image)
self.assertEqual(image.format, "PNG")


class TestSumOfGaussiansGrid(unittest.TestCase):
"""Test visualization methods for explanations."""

explanations: ClassVar[list[SHAPExplanation]]
test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[list[SHAPExplanation]]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
cls.explanations = _get_test_shap_explanations()
cls.test_pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX)
cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations = cls.test_explainer.explain(TEST_SMILES)

def test_grid_with_shap_atom_weights(self) -> None:
"""Test grid with SHAP atom weights."""
for explanation in self.explanations:
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)

Expand Down

0 comments on commit a568de4

Please sign in to comment.