diff --git a/molpipeline/explainability/visualization/visualization.py b/molpipeline/explainability/visualization/visualization.py index a5d436bf..b4bd4f15 100644 --- a/molpipeline/explainability/visualization/visualization.py +++ b/molpipeline/explainability/visualization/visualization.py @@ -67,7 +67,7 @@ def _make_grid_from_mol( mol_height = yl[1] - yl[0] mol_width = xl[1] - xl[0] - height_to_width_ratio_mol = mol_height / mol_width + height_to_width_ratio_mol = mol_height / (1e-16 + mol_width) # the grids height / weight is the canvas height / width height_to_width_ratio_canvas = grid_resolution[1] / grid_resolution[0] diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py index 880e8aa0..3a001f5a 100644 --- a/tests/test_explainability/test_visualization/test_visualization.py +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -27,13 +27,13 @@ _RANDOM_STATE = 67056 -def _get_test_shap_explanations() -> list[SHAPExplanation]: - """Get test explanations. +def _get_test_morgan_rf_pipeline() -> Pipeline: + """Get a test pipeline with Morgan fingerprints and a random forest classifier. Returns ------- - list[SHAPExplanation] - List of SHAP explanations. + Pipeline + Pipeline with Morgan fingerprints and a random forest classifier. """ pipeline = Pipeline( [ @@ -45,26 +45,27 @@ def _get_test_shap_explanations() -> list[SHAPExplanation]: ), ] ) - pipeline.fit(TEST_SMILES, CONTAINS_OX) - - explainer = SHAPTreeExplainer(pipeline) - explanations = explainer.explain(TEST_SMILES) - return explanations + return pipeline class TestExplainabilityVisualization(unittest.TestCase): """Test the public interface of the visualization methods for explanations.""" - explanations: ClassVar[list[SHAPExplanation]] + test_pipeline: ClassVar[Pipeline] + test_explainer: ClassVar[SHAPTreeExplainer] + test_explanations: ClassVar[list[SHAPExplanation]] @classmethod def setUpClass(cls) -> None: """Set up the tests.""" - cls.explanations = _get_test_shap_explanations() + cls.test_pipeline = _get_test_morgan_rf_pipeline() + cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) + cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) + cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None: """Test structure heatmap fingerprint-based atom coloring.""" - for explanation in self.explanations: + for explanation in self.test_explanations: self.assertTrue(explanation.is_valid()) self.assertIsInstance(explanation.atom_weights, np.ndarray) image = structure_heatmap( @@ -78,7 +79,7 @@ def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None: def test_structure_heatmap_shap_explanation(self) -> None: """Test structure heatmap SHAP explanation.""" - for explanation in self.explanations: + for explanation in self.test_explanations: self.assertTrue(explanation.is_valid()) self.assertIsInstance(explanation.atom_weights, np.ndarray) image = structure_heatmap_shap( @@ -89,20 +90,58 @@ def test_structure_heatmap_shap_explanation(self) -> None: self.assertIsNotNone(image) self.assertEqual(image.format, "PNG") + def test_explicit_hydrogens(self): + """Test that the visualization methods work with explicit hydrogens.""" + mol_implicit_Hs = Chem.MolFromSmiles("C") + explanations1 = self.test_explainer.explain([Chem.MolToSmiles(mol_implicit_Hs)]) + mol_added_Hs = Chem.AddHs(mol_implicit_Hs) + explanations2 = self.test_explainer.explain([Chem.MolToSmiles(mol_added_Hs)]) + mol_explicit_Hs = Chem.MolFromSmiles("[H]C([H])([H])[H]") + explanations3 = self.test_explainer.explain([Chem.MolToSmiles(mol_explicit_Hs)]) + + # test explanations' atom weights + self.assertEqual(len(explanations1), 1) + self.assertEqual(len(explanations2), 1) + self.assertEqual(len(explanations3), 1) + self.assertIsNotNone(explanations1[0].atom_weights) + self.assertIsNotNone(explanations2[0].atom_weights) + self.assertIsNotNone(explanations3[0].atom_weights) + self.assertEqual(len(explanations1[0].atom_weights), 1) + self.assertEqual(len(explanations2[0].atom_weights), 1) + self.assertEqual(len(explanations3[0].atom_weights), 1) + + # test visualization + all_explanations = explanations1 + explanations2 + explanations3 + for explanation in all_explanations: + self.assertTrue(explanation.is_valid()) + image = structure_heatmap( + explanation.molecule, + explanation.atom_weights, # type: ignore[arg-type] + width=128, + height=128, + ) # type: ignore[union-attr] + self.assertIsNotNone(image) + self.assertEqual(image.format, "PNG") + class TestSumOfGaussiansGrid(unittest.TestCase): """Test visualization methods for explanations.""" - explanations: ClassVar[list[SHAPExplanation]] + test_pipeline: ClassVar[Pipeline] + test_explainer: ClassVar[SHAPTreeExplainer] + test_explanations: ClassVar[list[SHAPExplanation]] @classmethod def setUpClass(cls) -> None: """Set up the tests.""" - cls.explanations = _get_test_shap_explanations() + cls.test_pipeline = _get_test_morgan_rf_pipeline() + cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) + cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) + cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) def test_grid_with_shap_atom_weights(self) -> None: """Test grid with SHAP atom weights.""" - for explanation in self.explanations: + for explanation in self.test_explanations: self.assertTrue(explanation.is_valid()) self.assertIsInstance(explanation.atom_weights, np.ndarray)