From 74dd7fd6e91ddf1ed6648d14aab18f85d7efbb29 Mon Sep 17 00:00:00 2001 From: Paul-B98 <115164840+Paul-B98@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:27:35 +0100 Subject: [PATCH] add support for the sklearn set_output api (SLEP018) --- photonai/base/photon_elements.py | 38 +++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/photonai/base/photon_elements.py b/photonai/base/photon_elements.py index 77803f4c..885b4715 100644 --- a/photonai/base/photon_elements.py +++ b/photonai/base/photon_elements.py @@ -1,19 +1,25 @@ +from __future__ import annotations + import importlib import importlib.util import inspect +import warnings from copy import deepcopy +from typing import List, Union + import dask -from dask.distributed import Client import numpy as np -import warnings +from dask.distributed import Client from sklearn.base import BaseEstimator from sklearn.model_selection._search import ParameterGrid -from typing import List, Union from photonai.base.photon_pipeline import PhotonPipeline from photonai.base.registry.registry import PhotonRegistry from photonai.helper.helper import PhotonDataHelper -from photonai.optimization.config_grid import create_global_config_grid, create_global_config_dict +from photonai.optimization.config_grid import ( + create_global_config_dict, + create_global_config_grid, +) from photonai.photonlogger.logger import logger @@ -99,7 +105,7 @@ def __init__(self, name: str, hyperparameters: dict = None, test_disabled: bool imported_module = importlib.import_module(desired_class_home) desired_class = getattr(imported_module, desired_class_name) self.base_element = desired_class(**kwargs) - except AttributeError as ae: + except AttributeError: logger.error('ValueError: Could not find according class:' + str(PhotonRegistry.ELEMENT_DICTIONARY[name])) raise ValueError('Could not find according class:', PhotonRegistry.ELEMENT_DICTIONARY[name]) @@ -398,6 +404,28 @@ def set_params(self, **kwargs): del kwargs['disabled'] self.base_element.set_params(**kwargs) return self + + def set_output(self, *, transform: None | str = None) -> PipelineElement: + """ + Calls set_output on the base element if it is implemented. + + For more information see the documentation. + - https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html + - https://scikit-learn.org/1.5/auto_examples/miscellaneous/plot_set_output.html + + Parameters: + transform: + The name of the output transformation. + If None, the base element is set to the default output transformation. + + Returns: + self + """ + if hasattr(self.base_element, 'set_output'): + self.base_element.set_output(transform=transform) + else: + logger.warning(f"set_output is not implemented for {self._name}") + return self def fit(self, X: np.ndarray, y: np.ndarray = None, **kwargs): """