Skip to content

Commit

Permalink
chore: fix sklearn 1.5 compatibility (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jan 14, 2025
1 parent 1fdae99 commit 06c92ff
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ filterwarnings = [
"ignore:read_text is deprecated.*:DeprecationWarning",
"ignore:open_binary is deprecated.*:DeprecationWarning",
"ignore:pkg_resources is deprecated as an API.*:DeprecationWarning",
"ignore:'multi_class' was deprecated in version 1.5 and will be removed in 1.7.*:FutureWarning",
]

[tool.semantic_release]
Expand Down
11 changes: 3 additions & 8 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,16 +1691,11 @@ def from_sklearn_model(
# Extract scikit-learn's initialization parameters
init_params = sklearn_model.get_params()

# Remove deprecated parameters
deprecated = "deprecated"
init_params = {k: v for k, v in init_params.items() if v != deprecated}

# Ensure compatibility for both sklearn 1.1 and >=1.4
# This parameter was removed in 1.4. If this package is installed
# Ensure compatibility for both sklearn 1.1 and >=1.5
# This parameter was removed in 1.5. If this package is installed
# with sklearn 1.1 which has it, then remove it when
# instantiating the 1.4 API compatible Concrete ML model
# instantiating the 1.5 API compatible Concrete ML model
init_params.pop("normalize", None)
init_params.pop("multi_class", None)

# Instantiate the Concrete ML model and update initialization parameters
# This update is necessary as we currently store scikit-learn attributes in Concrete ML
Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/sklearn/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_sklearn_params(self, deep: bool = True) -> dict:

# Remove the parameters added by Concrete ML
params.pop("n_bits", None)
# Remove sklearn 1.4 parameter when using sklearn 1.1
# Remove sklearn 1.5 parameter when using sklearn 1.1
if "1.1." in sklearn.__version__:
params.pop("solver", None) # pragma: no cover

Expand Down
4 changes: 4 additions & 0 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,7 @@ def __init__(
C=1.0,
fit_intercept=True,
intercept_scaling=1,
multi_class="auto",
class_weight=None,
random_state=None,
solver="lbfgs",
Expand All @@ -1655,6 +1656,7 @@ def __init__(
self.C = C
self.fit_intercept = fit_intercept
self.intercept_scaling = intercept_scaling
self.multi_class = multi_class
self.class_weight = class_weight
self.random_state = random_state
self.solver = solver
Expand Down Expand Up @@ -1689,6 +1691,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["C"] = self.C
metadata["fit_intercept"] = self.fit_intercept
metadata["intercept_scaling"] = self.intercept_scaling
metadata["multi_class"] = self.multi_class
metadata["class_weight"] = self.class_weight
metadata["random_state"] = self.random_state
metadata["solver"] = self.solver
Expand Down Expand Up @@ -1725,6 +1728,7 @@ def load_dict(cls, metadata: Dict):
obj.C = metadata["C"]
obj.fit_intercept = metadata["fit_intercept"]
obj.intercept_scaling = metadata["intercept_scaling"]
obj.multi_class = metadata["multi_class"]
obj.class_weight = metadata["class_weight"]
obj.random_state = metadata["random_state"]
obj.solver = metadata["solver"]
Expand Down
2 changes: 1 addition & 1 deletion tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,7 @@ def test_initialization_variables_and_defaults_match(
"""
if "1.1." in sklearn.__version__:
pytest.skip(
"Concrete ML currently implements sklearn 1.4 API"
"Concrete ML currently implements sklearn 1.5 API"
f" skipping this test on version {sklearn.__version__}"
)

Expand Down

0 comments on commit 06c92ff

Please sign in to comment.