Skip to content

Commit

Permalink
Merge pull request #100 from Paul-B98/main
Browse files Browse the repository at this point in the history
Add support for the sklearn set_output API (SLEP018)
  • Loading branch information
jernsting authored Dec 20, 2024
2 parents c774b70 + 74dd7fd commit 563b7c9
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions photonai/base/photon_elements.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from __future__ import annotations

import importlib
import importlib.util
import inspect
import warnings
from copy import deepcopy
from typing import List, Union

import dask
from dask.distributed import Client
import numpy as np
import warnings
from dask.distributed import Client
from sklearn.base import BaseEstimator
from sklearn.model_selection._search import ParameterGrid
from typing import List, Union

from photonai.base.photon_pipeline import PhotonPipeline
from photonai.base.registry.registry import PhotonRegistry
from photonai.helper.helper import PhotonDataHelper
from photonai.optimization.config_grid import create_global_config_grid, create_global_config_dict
from photonai.optimization.config_grid import (
create_global_config_dict,
create_global_config_grid,
)
from photonai.photonlogger.logger import logger


Expand Down Expand Up @@ -99,7 +105,7 @@ def __init__(self, name: str, hyperparameters: dict = None, test_disabled: bool
imported_module = importlib.import_module(desired_class_home)
desired_class = getattr(imported_module, desired_class_name)
self.base_element = desired_class(**kwargs)
except AttributeError as ae:
except AttributeError:
logger.error('ValueError: Could not find according class:'
+ str(PhotonRegistry.ELEMENT_DICTIONARY[name]))
raise ValueError('Could not find according class:', PhotonRegistry.ELEMENT_DICTIONARY[name])
Expand Down Expand Up @@ -398,6 +404,28 @@ def set_params(self, **kwargs):
del kwargs['disabled']
self.base_element.set_params(**kwargs)
return self

def set_output(self, *, transform: None | str = None) -> PipelineElement:
"""
Calls set_output on the base element if it is implemented.
For more information see the documentation.
- https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html
- https://scikit-learn.org/1.5/auto_examples/miscellaneous/plot_set_output.html
Parameters:
transform:
The name of the output transformation.
If None, the base element is set to the default output transformation.
Returns:
self
"""
if hasattr(self.base_element, 'set_output'):
self.base_element.set_output(transform=transform)
else:
logger.warning(f"set_output is not implemented for {self._name}")
return self

def fit(self, X: np.ndarray, y: np.ndarray = None, **kwargs):
"""
Expand Down

0 comments on commit 563b7c9

Please sign in to comment.