diff --git a/molpipeline/explainability/visualization/heatmaps.py b/molpipeline/explainability/visualization/heatmaps.py index a1fda193..91a110f1 100644 --- a/molpipeline/explainability/visualization/heatmaps.py +++ b/molpipeline/explainability/visualization/heatmaps.py @@ -6,7 +6,7 @@ """ import abc -from typing import Sequence, Callable +from typing import Callable, Sequence import numpy as np import numpy.typing as npt diff --git a/molpipeline/explainability/visualization/utils.py b/molpipeline/explainability/visualization/utils.py new file mode 100644 index 00000000..b2f1a72d --- /dev/null +++ b/molpipeline/explainability/visualization/utils.py @@ -0,0 +1,173 @@ +"""Utility functions for visualization of molecules and their explanations.""" + +import io +from typing import Sequence + +import numpy as np +import numpy.typing as npt +from matplotlib import pyplot as plt +from matplotlib.colors import Colormap, ListedColormap +from matplotlib.pyplot import get_cmap +from PIL import Image +from rdkit import Chem + +# red green blue alpha tuple +RGBAtuple = tuple[float, float, float, float] + + +def get_mol_lims(mol: Chem.Mol) -> tuple[tuple[float, float], tuple[float, float]]: + """Return the extent of the molecule. + + x- and y-coordinates of all atoms in the molecule are accessed, returning min- and max-values for both axes. + + Parameters + ---------- + mol: Chem.Mol + RDKit Molecule object of which the limits are determined. + + Returns + ------- + tuple[tuple[float, float], tuple[float, float]] + Limits of the molecule. + """ + coords_list = [] + conf = mol.GetConformer(0) + for i, _ in enumerate(mol.GetAtoms()): + pos = conf.GetAtomPosition(i) + coords_list.append((pos.x, pos.y)) + coords: npt.NDArray[np.float64] = np.array(coords_list) + min_p = np.min(coords, axis=0) + max_p = np.max(coords, axis=0) + x_lim = min_p[0], max_p[0] + y_lim = min_p[1], max_p[1] + return x_lim, y_lim + + +def pad( + lim: Sequence[float] | npt.NDArray[np.float64], ratio: float +) -> tuple[float, float]: + """Take a 2-dimensional vector and adds len(vector) * ratio / 2 to each side and returns obtained vector. + + Parameters + ---------- + lim: Sequence[float] | npt.NDArray[np.float64] + Limits which are extended. + ratio: float + factor by which the limits are extended. + + Returns + ------- + List[float, float] + Extended limits + """ + diff = max(lim) - min(lim) + diff *= ratio / 2 + return lim[0] - diff, lim[1] + diff + + +def get_color_map_from_input( + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None +) -> Colormap: + """Get a colormap from a user defined color scheme. + + Parameters + ---------- + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None + The color scheme. + + Returns + ------- + Colormap + The colormap. + """ + # read user definer color scheme as ColorMap + if color is None: + coolwarm = ( + (0.017, 0.50, 0.850, 0.5), + (1.0, 1.0, 1.0, 0.5), + (1.0, 0.25, 0.0, 0.5), + ) + coolwarm = (coolwarm[2], coolwarm[1], coolwarm[0]) + color = coolwarm + if isinstance(color, Colormap): + color_map = color + elif isinstance(color, tuple): + color_map = color_tuple_to_colormap(color) # type: ignore + elif isinstance(color, str): + color_map = get_cmap(color) + else: + raise ValueError("Color must be a tuple, string or ColorMap.") + return color_map + + +def color_tuple_to_colormap( + color_tuple: tuple[RGBAtuple, RGBAtuple, RGBAtuple] +) -> Colormap: + """Convert a color tuple to a colormap. + + Parameters + ---------- + color_tuple: tuple[RGBAtuple, RGBAtuple, RGBAtuple] + The color tuple. + + Returns + ------- + Colormap + The colormap (a matplotlib data structure). + """ + if len(color_tuple) != 3: + raise ValueError("Color tuple must have 3 elements") + + # Definition of color + col1, col2, col3 = map(np.array, color_tuple) + + # Creating linear gradient for color mixing + linspace = np.linspace(0, 1, int(128)) + linspace4d = np.vstack([linspace] * 4).T + + # interpolating values for 0 to 0.5 by mixing purple and white + zero_to_half = linspace4d * col2 + (1 - linspace4d) * col3 + # interpolating values for 0.5 to 1 by mixing white and yellow + half_to_one = col1 * linspace4d + col2 * (1 - linspace4d) + + # Creating new colormap from + color_map = ListedColormap(np.vstack([zero_to_half, half_to_one])) + return color_map + + +def to_png(data: bytes) -> Image.Image: + """Show a PNG image from a byte stream. + + Parameters + ---------- + data: bytes + The image data. + + Returns + ------- + Image + The image. + """ + bio = io.BytesIO(data) + img = Image.open(bio) + return img + + +def plt_to_pil(figure: plt.Figure) -> Image.Image: + """Convert a matplotlib figure to a PIL image. + + Parameters + ---------- + figure: plt.Figure + The figure. + + Returns + ------- + Image + The image. + """ + bio = io.BytesIO() + figure.savefig(bio, format="png") + bio.seek(0) + img = Image.open(bio) + return img diff --git a/molpipeline/explainability/visualization/visualization.py b/molpipeline/explainability/visualization/visualization.py index 86baeb09..437663d6 100644 --- a/molpipeline/explainability/visualization/visualization.py +++ b/molpipeline/explainability/visualization/visualization.py @@ -7,15 +7,14 @@ from __future__ import annotations -import io from typing import Sequence import numpy as np import numpy.typing as npt +from matplotlib import colors +from matplotlib import pyplot as plt +from matplotlib.colors import Colormap from PIL import Image -from matplotlib import pyplot as plt, colors -from matplotlib.colors import Colormap, ListedColormap -from matplotlib.pyplot import get_cmap from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.Draw import rdMolDraw2D @@ -24,97 +23,18 @@ from molpipeline.explainability.explanation import SHAPExplanation from molpipeline.explainability.visualization.gauss import GaussFunctor2D from molpipeline.explainability.visualization.heatmaps import ( - color_canvas, ValueGrid, + color_canvas, get_color_normalizer_from_data, ) - -RGBAtuple = tuple[float, float, float, float] - - -def get_mol_lims(mol: Chem.Mol) -> tuple[tuple[float, float], tuple[float, float]]: - """Return the extent of the molecule. - - x- and y-coordinates of all atoms in the molecule are accessed, returning min- and max-values for both axes. - - Parameters - ---------- - mol: Chem.Mol - RDKit Molecule object of which the limits are determined. - - Returns - ------- - tuple[tuple[float, float], tuple[float, float]] - Limits of the molecule. - """ - coords_list = [] - conf = mol.GetConformer(0) - for i, _ in enumerate(mol.GetAtoms()): - pos = conf.GetAtomPosition(i) - coords_list.append((pos.x, pos.y)) - coords: npt.NDArray[np.float64] = np.array(coords_list) - min_p = np.min(coords, axis=0) - max_p = np.max(coords, axis=0) - x_lim = min_p[0], max_p[0] - y_lim = min_p[1], max_p[1] - return x_lim, y_lim - - -def pad( - lim: Sequence[float] | npt.NDArray[np.float64], ratio: float -) -> tuple[float, float]: - """Take a 2-dimensional vector and adds len(vector) * ratio / 2 to each side and returns obtained vector. - - Parameters - ---------- - lim: Sequence[float] | npt.NDArray[np.float64] - Limits which are extended. - ratio: float - factor by which the limits are extended. - - Returns - ------- - List[float, float] - Extended limits - """ - diff = max(lim) - min(lim) - diff *= ratio / 2 - return lim[0] - diff, lim[1] + diff - - -def color_tuple_to_colormap( - color_tuple: tuple[RGBAtuple, RGBAtuple, RGBAtuple] -) -> Colormap: - """Convert a color tuple to a colormap. - - Parameters - ---------- - color_tuple: tuple[RGBAtuple, RGBAtuple, RGBAtuple] - The color tuple. - - Returns - ------- - Colormap - The colormap (a matplotlib data structure). - """ - if len(color_tuple) != 3: - raise ValueError("Color tuple must have 3 elements") - - # Definition of color - col1, col2, col3 = map(np.array, color_tuple) - - # Creating linear gradient for color mixing - linspace = np.linspace(0, 1, int(128)) - linspace4d = np.vstack([linspace] * 4).T - - # interpolating values for 0 to 0.5 by mixing purple and white - zero_to_half = linspace4d * col2 + (1 - linspace4d) * col3 - # interpolating values for 0.5 to 1 by mixing white and yellow - half_to_one = col1 * linspace4d + col2 * (1 - linspace4d) - - # Creating new colormap from - newcmp = ListedColormap(np.vstack([zero_to_half, half_to_one])) - return newcmp +from molpipeline.explainability.visualization.utils import ( + RGBAtuple, + get_color_map_from_input, + get_mol_lims, + pad, + plt_to_pil, + to_png, +) def _make_grid_from_mol( @@ -348,29 +268,6 @@ def make_sum_of_gaussians_grid( return value_grid -def get_color_map_from_input( - color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None -) -> Colormap: - # read user definer color scheme as ColorMap - if color is None: - coolwarm = ( - (0.017, 0.50, 0.850, 0.5), - (1.0, 1.0, 1.0, 0.5), - (1.0, 0.25, 0.0, 0.5), - ) - coolwarm = (coolwarm[2], coolwarm[1], coolwarm[0]) - color = coolwarm - if isinstance(color, Colormap): - color_map = color - elif isinstance(color, tuple): - color_map = color_tuple_to_colormap(color) # type: ignore - elif isinstance(color, str): - color_map = get_cmap(color) - else: - raise ValueError("Color must be a tuple, string or ColorMap.") - return color_map - - def _structure_heatmap( mol: RDKitMol, atom_weights: npt.NDArray[np.float64], @@ -378,8 +275,8 @@ def _structure_heatmap( width: int = 600, height: int = 600, color_limits: tuple[float, float] | None = None, -) -> Draw.MolDraw2D: - """Create a Gaussian plot on the molecular structure, highlight atoms with weighted Gaussians. +) -> tuple[Draw.MolDraw2D, ValueGrid, ValueGrid, colors.Normalize, Colormap]: + """Create a heatmap of the molecular structure, highlighting atoms with weighted Gaussian's. Parameters ---------- @@ -396,8 +293,9 @@ def _structure_heatmap( Returns ------- - Draw.MolDraw2D - The configured drawer. + Draw.MolDraw2D, ValueGrid, ColorGrid, colors.Normalize, Colormap + The configured drawer, the value grid, the color grid, the normalizer, and the + color map. """ drawer = Draw.MolDraw2DCairo(width, height) # Coloring atoms of element 0 to 100 black @@ -447,11 +345,33 @@ def structure_heatmap( width: int = 600, height: int = 600, color_limits: tuple[float, float] | None = None, -) -> Draw.MolDraw2D: +) -> Image.Image: + """Create a Gaussian plot on the molecular structure, highlight atoms with weighted Gaussians. + + Parameters + ---------- + mol: RDKitMol + The molecule. + atom_weights: npt.NDArray[np.float64] + The atom weights. + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None + The color map. + width: int + The width of the image in number of pixels. + height: int + The height of the image in number of pixels. + + Returns + ------- + Image + The image as PNG. + """ drawer, *_ = _structure_heatmap( mol, atom_weights, color, width, height, color_limits ) - return drawer + figure_bytes = drawer.GetDrawingText() + image = to_png(figure_bytes) + return image def structure_heatmap_shap_explanation( @@ -460,8 +380,27 @@ def structure_heatmap_shap_explanation( width: int = 600, height: int = 600, color_limits: tuple[float, float] | None = None, -) -> Draw.MolDraw2D: - # TODO this should only work if the feature vector is binary. Maybe raise an error otherwise? Or do something else? +) -> Image.Image: + """Create a heatmap of the molecular structure and display SHAP prediction composition. + + Parameters + ---------- + explanation: SHAPExplanation + The SHAP explanation. + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None + The color map. + width: int + The width of the image in number of pixels. + height: int + The height of the image in number of pixels. + color_limits: tuple[float, float] | None + The color limits. + + Returns + ------- + Image + The image as PNG. + """ present_shap = explanation.feature_weights[:, 1] * explanation.feature_vector absent_shap = explanation.feature_weights[:, 1] * (1 - explanation.feature_vector) sum_present_shap = sum(present_shap) @@ -476,7 +415,7 @@ def structure_heatmap_shap_explanation( color_limits=color_limits, ) figure_bytes = drawer.GetDrawingText() - image = show_png(figure_bytes) + image = to_png(figure_bytes) image_array = np.array(image) fig, ax = plt.subplots(figsize=(8, 8)) @@ -504,22 +443,11 @@ def structure_heatmap_shap_explanation( f"$features_{{absent}}={sum_absent_shap:.2f}$" ) fig.text(0.5, 0.18, text, ha="center") - return fig + image = plt_to_pil(fig) + # clear the figure and memory + plt.close() + plt.clf() + plt.cla() -def show_png(data: bytes) -> Image.Image: - """Show a PNG image from a byte stream. - - Parameters - ---------- - data: bytes - The image data. - - Returns - ------- - Image - The image. - """ - bio = io.BytesIO(data) - img = Image.open(bio) - return img + return image diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py index 939197bb..db9ef727 100644 --- a/tests/test_explainability/test_visualization/test_visualization.py +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -9,11 +9,13 @@ from molpipeline import Pipeline from molpipeline.any2mol import SmilesToMol -from molpipeline.explainability import SHAPTreeExplainer, Explanation +from molpipeline.explainability import Explanation, SHAPTreeExplainer +from molpipeline.explainability.explanation import SHAPExplanation from molpipeline.explainability.visualization.visualization import ( - structure_heatmap, - show_png, make_sum_of_gaussians_grid, + structure_heatmap, + structure_heatmap_shap_explanation, + to_png, ) from molpipeline.mol2any import MolToMorganFP @@ -23,7 +25,7 @@ _RANDOM_STATE = 67056 -def _get_test_explanations() -> list[Explanation]: +def _get_test_shap_explanations() -> list[SHAPExplanation]: """Get test explanations.""" pipeline = Pipeline( [ @@ -48,23 +50,33 @@ class TestExplainabilityVisualization(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Set up the tests.""" - cls.explanations = _get_test_explanations() - - def test_fingerprint_based_atom_coloring(self) -> None: - """Test fingerprint-based atom coloring.""" + cls.explanations = _get_test_shap_explanations() + def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None: + """Test structure heatmap fingerprint-based atom coloring.""" for explanation in self.explanations: self.assertTrue(explanation.is_valid()) self.assertIsInstance(explanation.atom_weights, np.ndarray) - drawer = structure_heatmap( + image = structure_heatmap( explanation.molecule, explanation.atom_weights, # type: ignore[arg-type] width=128, height=128, ) # type: ignore[union-attr] - self.assertIsNotNone(drawer) - figure_bytes = drawer.GetDrawingText() - image = show_png(figure_bytes) + self.assertIsNotNone(image) + self.assertEqual(image.format, "PNG") + + def test_structure_heatmap_shap_explanation(self) -> None: + """Test structure heatmap SHAP explanation.""" + for explanation in self.explanations: + self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation.atom_weights, np.ndarray) + image = structure_heatmap_shap_explanation( + explanation=explanation, + width=128, + height=128, + ) # type: ignore[union-attr] + self.assertIsNotNone(image) self.assertEqual(image.format, "PNG") @@ -74,11 +86,10 @@ class TestSumOfGaussiansGrid(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Set up the tests.""" - cls.explanations = _get_test_explanations() + cls.explanations = _get_test_shap_explanations() def test_grid_with_shap_atom_weights(self) -> None: """Test grid with SHAP atom weights.""" - for explanation in self.explanations: self.assertTrue(explanation.is_valid()) self.assertIsInstance(explanation.atom_weights, np.ndarray) @@ -98,21 +109,3 @@ def test_grid_with_shap_atom_weights(self) -> None: # test that the range of summed gaussian values is as expected for SHAP self.assertTrue(value_grid.values.min() >= -1) self.assertTrue(value_grid.values.max() <= 1) - - # def test_color_limits(self) -> None: - # """Test color limits.""" - # - # for explanation in self.explanations: - # self.assertTrue(explanation.is_valid()) - # self.assertIsInstance(explanation.atom_weights, np.ndarray) - # drawer = structure_heatmap( - # explanation.molecule, - # explanation.atom_weights, # type: ignore[arg-type] - # width=128, - # height=128, - # color_limits=(-1, 1), - # ) - # self.assertIsNotNone(drawer) - # figure_bytes = drawer.GetDrawingText() - # image = show_png(figure_bytes) - # self.assertEqual(image.format, "PNG")