Skip to content

Commit

Permalink
review Christian
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-sandfort1 committed Oct 1, 2024
1 parent 235c8f8 commit 47d4d90
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 148 deletions.
30 changes: 12 additions & 18 deletions molpipeline/abstract_pipeline_elements/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
OptionalMol,
RDKitMol,
)
from molpipeline.utils.value_conversions import (
from molpipeline.utils.molpipeline_types import (
FloatCountRange,
IntCountRange,
IntOrIntCountRange,
count_value_to_tuple,
)
from molpipeline.utils.value_conversions import count_value_to_tuple

# possible mode types for a KeepMatchesFilter:
# - "any" means one match is enough
Expand All @@ -28,7 +28,7 @@


def _within_boundaries(
lower_bound: Optional[float], upper_bound: Optional[float], value: float
lower_bound: Optional[float], upper_bound: Optional[float], property: float
) -> bool:
"""Check if a value is within the specified boundaries.
Expand All @@ -40,17 +40,17 @@ def _within_boundaries(
Lower boundary.
upper_bound: Optional[float]
Upper boundary.
value: float
Value to check.
property: float
Property to check.
Returns
-------
bool
True if the value is within the boundaries, else False.
"""
if lower_bound is not None and value < lower_bound:
if lower_bound is not None and property < lower_bound:
return False
if upper_bound is not None and value > upper_bound:
if upper_bound is not None and property > upper_bound:
return False
return True

Expand Down Expand Up @@ -167,13 +167,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
params = super().get_params(deep=deep)
params["keep_matches"] = self.keep_matches
params["mode"] = self.mode
if deep:
params["filter_elements"] = {
element: (count_tuple[0], count_tuple[1])
for element, count_tuple in self.filter_elements.items()
}
else:
params["filter_elements"] = self.filter_elements
params["filter_elements"] = self.filter_elements
return params

def pretransform_single(self, value: RDKitMol) -> OptionalMol:
Expand All @@ -195,9 +189,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
OptionalMol
Molecule that matches defined filter elements, else InvalidInstance.
"""
for filter_element, (min_count, max_count) in self.filter_elements.items():
count = self._calculate_single_element_value(filter_element, value)
if _within_boundaries(min_count, max_count, count):
for filter_element, (lower_limit, upper_limit) in self.filter_elements.items():
property = self._calculate_single_element_value(filter_element, value)
if _within_boundaries(lower_limit, upper_limit, property):
# For "any" mode we can return early if a match is found
if self.mode == "any":
if not self.keep_matches:
Expand Down Expand Up @@ -265,7 +259,7 @@ def _calculate_single_element_value(
class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC):
"""Filter to keep or remove molecules based on patterns.
Parameters
Attributes
----------
filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]]
List of patterns to allow in molecules.
Expand Down
64 changes: 50 additions & 14 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from collections import Counter
from typing import Any, Mapping, Optional, Sequence, Union

from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries

try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Self

from loguru import logger
from rdkit import Chem
from rdkit.Chem import Descriptors

Expand All @@ -23,13 +26,14 @@
from molpipeline.abstract_pipeline_elements.mol2mol import (
BasePatternsFilter as _BasePatternsFilter,
)
from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol
from molpipeline.utils.value_conversions import (
from molpipeline.utils.molpipeline_types import (
FloatCountRange,
IntCountRange,
IntOrIntCountRange,
count_value_to_tuple,
OptionalMol,
RDKitMol,
)
from molpipeline.utils.value_conversions import count_value_to_tuple


class ElementFilter(_MolToMolPipelineElement):
Expand Down Expand Up @@ -60,6 +64,7 @@ def __init__(
allowed_element_numbers: Optional[
Union[list[int], dict[int, IntOrIntCountRange]]
] = None,
add_hydrogens: bool = True,
name: str = "ElementFilter",
n_jobs: int = 1,
uuid: Optional[str] = None,
Expand All @@ -72,6 +77,8 @@ def __init__(
List of atomic numbers of elements to allowed in molecules. Per default allowed elements are:
H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I.
Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum
add_hydrogens: bool, optional (default: True)
If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering.
name: str, optional (default: "ElementFilterPipe")
Name of the pipeline element.
n_jobs: int, optional (default: 1)
Expand All @@ -81,6 +88,32 @@ def __init__(
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)
self.allowed_element_numbers = allowed_element_numbers # type: ignore
self.add_hydrogens = add_hydrogens

@property
def add_hydrogens(self) -> bool:
"""Get add_hydrogens."""
return self._add_hydrogens

@add_hydrogens.setter
def add_hydrogens(self, add_hydrogens: bool) -> None:
"""Set add_hydrogens.
Parameters
----------
add_hydrogens: bool
If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering.
"""
self._add_hydrogens = add_hydrogens
if self.add_hydrogens and 1 in self.allowed_element_numbers:
self.process_hydrogens = True
else:
if 1 in self.allowed_element_numbers:
logger.warning(
"Hydrogens are included in allowed_element_numbers, but add_hydrogens is set to False. "
"Thus hydrogens are NOT added before filtering. You might receive unexpected results."
)
self.process_hydrogens = False

@property
def allowed_element_numbers(self) -> dict[int, IntCountRange]:
Expand Down Expand Up @@ -135,6 +168,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
}
else:
params["allowed_element_numbers"] = self.allowed_element_numbers
params["add_hydrogens"] = self.add_hydrogens
return params

def set_params(self, **parameters: Any) -> Self:
Expand All @@ -153,6 +187,8 @@ def set_params(self, **parameters: Any) -> Self:
parameter_copy = dict(parameters)
if "allowed_element_numbers" in parameter_copy:
self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers")
if "add_hydrogens" in parameter_copy:
self.add_hydrogens = parameter_copy.pop("add_hydrogens")
super().set_params(**parameter_copy)
return self

Expand All @@ -169,10 +205,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
OptionalMol
Molecule if it contains only allowed elements, else InvalidInstance.
"""
to_process_value = (
Chem.AddHs(value) if 1 in self.allowed_element_numbers else value
)

to_process_value = Chem.AddHs(value) if self.process_hydrogens else value
elements_list = [atom.GetAtomicNum() for atom in to_process_value.GetAtoms()]
elements_counter = Counter(elements_list)
if any(
Expand All @@ -181,11 +214,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
return InvalidInstance(
self.uuid, "Molecule contains forbidden chemical element.", self.name
)
for element, (min_count, max_count) in self.allowed_element_numbers.items():
for element, (lower_limit, upper_limit) in self.allowed_element_numbers.items():
count = elements_counter[element]
if (min_count is not None and count < min_count) or (
max_count is not None and count > max_count
):
if not _within_boundaries(lower_limit, upper_limit, count):
return InvalidInstance(
self.uuid,
f"Molecule contains forbidden number of element {element}.",
Expand Down Expand Up @@ -225,6 +256,11 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol:
class SmilesFilter(_BasePatternsFilter):
"""Filter to keep or remove molecules based on SMILES patterns.
In contrast to the SMARTSFilter, which also can match SMILES, the SmilesFilter
sanitizes the molecules and, e.g. checks kekulized bonds for aromaticity and
then sets it to aromatic while the SmartsFilter detects alternating single and
double bonds.
Notes
-----
There are four possible scenarios:
Expand Down Expand Up @@ -253,7 +289,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol:
class ComplexFilter(_BaseKeepMatchesFilter):
"""Filter to keep or remove molecules based on multiple filter elements.
Parameters
Attributes
----------
filter_elements: Sequence[_MolToMolPipelineElement]
MolToMol elements to use as filters.
Expand Down Expand Up @@ -317,7 +353,7 @@ def _calculate_single_element_value(
class RDKitDescriptorsFilter(_BaseKeepMatchesFilter):
"""Filter to keep or remove molecules based on RDKit descriptors.
Parameters
Attributes
----------
filter_elements: dict[str, FloatCountRange]
Dictionary of RDKit descriptors to filter by.
Expand Down Expand Up @@ -347,11 +383,11 @@ def filter_elements(self, descriptors: dict[str, FloatCountRange]) -> None:
descriptors: dict[str, FloatCountRange]
Dictionary of RDKit descriptors to filter by.
"""
self._filter_elements = descriptors
if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors):
raise ValueError(
"You are trying to use an invalid descriptor. Use RDKit Descriptors module."
)
self._filter_elements = descriptors

def _calculate_single_element_value(
self, filter_element: Any, value: RDKitMol
Expand Down
22 changes: 21 additions & 1 deletion molpipeline/utils/molpipeline_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
from __future__ import annotations

from numbers import Number
from typing import Any, List, Literal, Optional, Protocol, Tuple, TypeVar, Union
from typing import (
Any,
List,
Literal,
Optional,
Protocol,
Tuple,
TypeAlias,
TypeVar,
Union,
)

try:
from typing import Self # type: ignore[attr-defined]
Expand Down Expand Up @@ -47,6 +57,16 @@

TypeConserverdIterable = TypeVar("TypeConserverdIterable", List[_T], npt.NDArray[_T])

FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]]

IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]]

# IntOrIntCountRange for Typing of count ranges
# - a single int for an exact value match
# - a range given as a tuple with a lower and upper bound
# - both limits are optional
IntOrIntCountRange: TypeAlias = Union[int, IntCountRange]


class AnySklearnEstimator(Protocol):
"""Protocol for sklearn estimators."""
Expand Down
12 changes: 2 additions & 10 deletions molpipeline/utils/value_conversions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
"""Module for utilities converting values."""

from typing import Optional, Sequence, TypeAlias, Union
from typing import Sequence

FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]]

IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]]

# IntOrIntCountRange for Typing of count ranges
# - a single int for an exact value match
# - a range given as a tuple with a lower and upper bound
# - both limits are optional
IntOrIntCountRange: TypeAlias = Union[int, IntCountRange]
from molpipeline.utils.molpipeline_types import IntCountRange, IntOrIntCountRange


def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange:
Expand Down
Loading

0 comments on commit 47d4d90

Please sign in to comment.