Skip to content

Commit

Permalink
63 include custom filter in experimental (#64)
Browse files Browse the repository at this point in the history
* first version custom filter

* linting

* fix moltobool in pipelines

* linting

* formatting

* christians review

* make test more minimal

* remove .vscode from ide settings

* review

* Jochen review
  • Loading branch information
frederik-sandfort1 authored Aug 22, 2024
1 parent 5b52a3b commit b18b78b
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 3 deletions.
7 changes: 7 additions & 0 deletions molpipeline/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Initialize module for experimental classes and functions."""

from molpipeline.experimental.custom_filter import CustomFilter

__all__ = [
"CustomFilter",
]
105 changes: 105 additions & 0 deletions molpipeline/experimental/custom_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Module for custom filter functionality."""

from __future__ import annotations

from typing import Any, Callable, Optional

try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Self

from molpipeline.abstract_pipeline_elements.core import InvalidInstance
from molpipeline.abstract_pipeline_elements.core import (
MolToMolPipelineElement as _MolToMolPipelineElement,
)
from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol


class CustomFilter(_MolToMolPipelineElement):
"""Filters molecules based on a custom boolean function. Elements not passing the filter will be set to InvalidInstances."""

def __init__(
self,
func: Callable[[RDKitMol], bool],
name: str = "CustomFilter",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize CustomFilter.
Parameters
----------
func : Callable[[RDKitMol], bool]
custom function to filter molecules
name : str, optional
name of the element, by default "CustomFilter"
n_jobs : int, optional
number of jobs to use, by default 1
uuid : str, optional
uuid of the element, by default None
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)
self.func = func

def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""Pretransform single value.
Applies the custom boolean function to the molecule.
Parameters
----------
value : RDKitMol
input value
Returns
-------
OptionalMol
output value
"""
if self.func(value):
return value
return InvalidInstance(
self.uuid,
f"Molecule does not match filter from {self.name}",
self.name,
)

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get parameters of CustomFilter.
Parameters
----------
deep: bool, optional (default: True)
If True, return the parameters of all subobjects that are PipelineElements.
Returns
-------
dict[str, Any]
Parameters of CustomFilter.
"""
params = super().get_params(deep=deep)
if deep:
params["func"] = self.func
else:
params["func"] = self.func
return params

def set_params(self, **parameters: dict[str, Any]) -> Self:
"""Set parameters of CustomFilter.
Parameters
----------
parameters: dict[str, Any]
Parameters to set.
Returns
-------
Self
Self.
"""
parameter_copy = dict(parameters)
if "func" in parameter_copy:
self.func = parameter_copy.pop("func") # type: ignore
super().set_params(**parameter_copy)
return self
2 changes: 2 additions & 0 deletions molpipeline/mol2any/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Init the module for mol2any pipeline elements."""

from molpipeline.mol2any.mol2bin import MolToBinary
from molpipeline.mol2any.mol2bool import MolToBool
from molpipeline.mol2any.mol2concatinated_vector import MolToConcatenatedVector
from molpipeline.mol2any.mol2inchi import MolToInchi, MolToInchiKey
from molpipeline.mol2any.mol2maccs_key_fingerprint import MolToMACCSFP
Expand All @@ -21,6 +22,7 @@
"MolToInchi",
"MolToInchiKey",
"MolToRDKitPhysChem",
"MolToBool",
]

try:
Expand Down
25 changes: 24 additions & 1 deletion molpipeline/mol2any/mol2bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@


class MolToBool(MolToAnyPipelineElement):
"""Element to generate a bool array from input."""
"""
Element to generate a bool array from input.
Valid molecules are passed as True, InvalidInstances are passed as False.
"""

def pretransform_single(self, value: Any) -> bool:
"""Transform a value to a bool representation.
Expand All @@ -27,3 +31,22 @@ def pretransform_single(self, value: Any) -> bool:
if isinstance(value, InvalidInstance):
return False
return True

def transform_single(self, value: Any) -> Any:
"""Transform a single molecule to a bool representation.
Valid molecules are passed as True, InvalidInstances are passed as False.
RemovedMolecule objects are passed without change, as no transformations are applicable.
Parameters
----------
value: Any
Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...)
Returns
-------
Any
Bool representation of the molecule.
"""
pre_value = self.pretransform_single(value)
return self.finalize_single(pre_value)
16 changes: 14 additions & 2 deletions tests/test_elements/test_mol2any/test_mol2bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import unittest

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import InvalidInstance
from molpipeline.mol2any.mol2bool import MolToBool
from molpipeline.any2mol import AutoToMol
from molpipeline.mol2any import MolToBool


class TestMolToBool(unittest.TestCase):
"""Unittest for MolToBool."""

def test_bool_conversion(self) -> None:
"""Test if the invalid instances are converted to bool."""

mol2bool = MolToBool()
result = mol2bool.transform(
[
Expand All @@ -22,3 +23,14 @@ def test_bool_conversion(self) -> None:
]
)
self.assertEqual(result, [True, True, False, True])

def test_bool_conversion_pipeline(self) -> None:
"""Test if the invalid instances are converted to bool in pipeline."""
pipeline = Pipeline(
[
("auto_to_mol", AutoToMol()),
("mol2bool", MolToBool()),
]
)
result = pipeline.transform(["CC", "CCC", "no%valid~smiles"])
self.assertEqual(result, [True, True, False])
1 change: 1 addition & 0 deletions tests/test_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Initialize the test module for experimental classes and functions."""
39 changes: 39 additions & 0 deletions tests/test_experimental/test_custom_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test the custom filter element."""

import unittest

from molpipeline import Pipeline
from molpipeline.any2mol import AutoToMol
from molpipeline.experimental import CustomFilter
from molpipeline.mol2any import MolToBool


class TestCustomFilter(unittest.TestCase):
"""Test the custom filter element."""

smiles_list = [
"CC",
"CCC",
"CCCC",
"CO",
]

def test_transform(self) -> None:
"""Test the custom filter."""
mol_list = AutoToMol().transform(self.smiles_list)
res_filter = CustomFilter(lambda x: x.GetNumAtoms() == 2).transform(mol_list)
res_bool = MolToBool().transform(res_filter)
self.assertEqual(res_bool, [True, False, False, True])

def test_pipeline(self) -> None:
"""Test the custom filter in pipeline."""
pipeline = Pipeline(
[
("auto_to_mol", AutoToMol()),
("custom_filter", CustomFilter(lambda x: x.GetNumAtoms() == 2)),
("mol_to_bool", MolToBool()),
]
)
self.assertEqual(
pipeline.transform(self.smiles_list), [True, False, False, True]
)

0 comments on commit b18b78b

Please sign in to comment.