-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
63 include custom filter in experimental (#64)
* first version custom filter * linting * fix moltobool in pipelines * linting * formatting * christians review * make test more minimal * remove .vscode from ide settings * review * Jochen review
- Loading branch information
1 parent
5b52a3b
commit b18b78b
Showing
7 changed files
with
192 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Initialize module for experimental classes and functions.""" | ||
|
||
from molpipeline.experimental.custom_filter import CustomFilter | ||
|
||
__all__ = [ | ||
"CustomFilter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Module for custom filter functionality.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, Optional | ||
|
||
try: | ||
from typing import Self # type: ignore[attr-defined] | ||
except ImportError: | ||
from typing_extensions import Self | ||
|
||
from molpipeline.abstract_pipeline_elements.core import InvalidInstance | ||
from molpipeline.abstract_pipeline_elements.core import ( | ||
MolToMolPipelineElement as _MolToMolPipelineElement, | ||
) | ||
from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol | ||
|
||
|
||
class CustomFilter(_MolToMolPipelineElement): | ||
"""Filters molecules based on a custom boolean function. Elements not passing the filter will be set to InvalidInstances.""" | ||
|
||
def __init__( | ||
self, | ||
func: Callable[[RDKitMol], bool], | ||
name: str = "CustomFilter", | ||
n_jobs: int = 1, | ||
uuid: Optional[str] = None, | ||
) -> None: | ||
"""Initialize CustomFilter. | ||
Parameters | ||
---------- | ||
func : Callable[[RDKitMol], bool] | ||
custom function to filter molecules | ||
name : str, optional | ||
name of the element, by default "CustomFilter" | ||
n_jobs : int, optional | ||
number of jobs to use, by default 1 | ||
uuid : str, optional | ||
uuid of the element, by default None | ||
""" | ||
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) | ||
self.func = func | ||
|
||
def pretransform_single(self, value: RDKitMol) -> OptionalMol: | ||
"""Pretransform single value. | ||
Applies the custom boolean function to the molecule. | ||
Parameters | ||
---------- | ||
value : RDKitMol | ||
input value | ||
Returns | ||
------- | ||
OptionalMol | ||
output value | ||
""" | ||
if self.func(value): | ||
return value | ||
return InvalidInstance( | ||
self.uuid, | ||
f"Molecule does not match filter from {self.name}", | ||
self.name, | ||
) | ||
|
||
def get_params(self, deep: bool = True) -> dict[str, Any]: | ||
"""Get parameters of CustomFilter. | ||
Parameters | ||
---------- | ||
deep: bool, optional (default: True) | ||
If True, return the parameters of all subobjects that are PipelineElements. | ||
Returns | ||
------- | ||
dict[str, Any] | ||
Parameters of CustomFilter. | ||
""" | ||
params = super().get_params(deep=deep) | ||
if deep: | ||
params["func"] = self.func | ||
else: | ||
params["func"] = self.func | ||
return params | ||
|
||
def set_params(self, **parameters: dict[str, Any]) -> Self: | ||
"""Set parameters of CustomFilter. | ||
Parameters | ||
---------- | ||
parameters: dict[str, Any] | ||
Parameters to set. | ||
Returns | ||
------- | ||
Self | ||
Self. | ||
""" | ||
parameter_copy = dict(parameters) | ||
if "func" in parameter_copy: | ||
self.func = parameter_copy.pop("func") # type: ignore | ||
super().set_params(**parameter_copy) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Initialize the test module for experimental classes and functions.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Test the custom filter element.""" | ||
|
||
import unittest | ||
|
||
from molpipeline import Pipeline | ||
from molpipeline.any2mol import AutoToMol | ||
from molpipeline.experimental import CustomFilter | ||
from molpipeline.mol2any import MolToBool | ||
|
||
|
||
class TestCustomFilter(unittest.TestCase): | ||
"""Test the custom filter element.""" | ||
|
||
smiles_list = [ | ||
"CC", | ||
"CCC", | ||
"CCCC", | ||
"CO", | ||
] | ||
|
||
def test_transform(self) -> None: | ||
"""Test the custom filter.""" | ||
mol_list = AutoToMol().transform(self.smiles_list) | ||
res_filter = CustomFilter(lambda x: x.GetNumAtoms() == 2).transform(mol_list) | ||
res_bool = MolToBool().transform(res_filter) | ||
self.assertEqual(res_bool, [True, False, False, True]) | ||
|
||
def test_pipeline(self) -> None: | ||
"""Test the custom filter in pipeline.""" | ||
pipeline = Pipeline( | ||
[ | ||
("auto_to_mol", AutoToMol()), | ||
("custom_filter", CustomFilter(lambda x: x.GetNumAtoms() == 2)), | ||
("mol_to_bool", MolToBool()), | ||
] | ||
) | ||
self.assertEqual( | ||
pipeline.transform(self.smiles_list), [True, False, False, True] | ||
) |