diff --git a/python/cuml/cuml/experimental/accel/__main__.py b/python/cuml/cuml/experimental/accel/__main__.py index e4c4af576b..86c6c0cb41 100644 --- a/python/cuml/cuml/experimental/accel/__main__.py +++ b/python/cuml/cuml/experimental/accel/__main__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ import click import code +import joblib +import pickle import os import runpy import sys @@ -31,14 +33,50 @@ default=False, help="Turn strict mode for hyperparameters on.", ) +@click.option( + "--convert-to-sklearn", + type=click.Path(exists=True), + required=False, + help="Path to a pickled accelerated estimator to convert to a sklearn estimator.", +) +@click.option( + "--format", + "format", + type=click.Choice(["pickle", "joblib"], case_sensitive=False), + default="pickle", + help="Format to save the converted sklearn estimator.", +) +@click.option( + "--output", + type=click.Path(writable=True), + default="converted_sklearn_model.pkl", + help="Output path for the converted sklearn estimator file.", +) @click.argument("args", nargs=-1) -def main(module, strict, args): +def main(module, strict, convert_to_sklearn, format, output, args): if strict: os.environ["CUML_ACCEL_STRICT_MODE"] = "ON" install() + # If the user requested a conversion, handle it and exit + if convert_to_sklearn: + + with open(convert_to_sklearn, "rb") as f: + if format == "pickle": + serializer = pickle + elif format == "joblib": + serializer = joblib + accelerated_estimator = serializer.load(f) + + sklearn_estimator = accelerated_estimator.as_sklearn() + + with open(output, "wb") as f: + serializer.dump(sklearn_estimator, f) + + sys.exit() + if module: (module,) = module # run the module passing the remaining arguments diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index a2a7374a1f..32b2cad908 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import copy import os import inspect import numbers @@ -24,7 +25,7 @@ from cuml.internals.device_support import GPU_ENABLED from cuml.internals.safe_imports import ( cpu_only_import, gpu_only_import_from, - null_decorator + null_decorator, ) np = cpu_only_import('numpy') nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator) @@ -910,3 +911,80 @@ class UniversalBase(Base): raise ex raise ex + + def as_sklearn(self, deepcopy=False): + """ + Convert the current GPU-accelerated estimator into a scikit-learn estimator. + + This method imports and builds an equivalent CPU-backed scikit-learn model, + transferring all necessary parameters from the GPU representation to the + CPU model. After this conversion, the returned object should be a fully + compatible scikit-learn estimator, allowing you to use it in standard + scikit-learn pipelines and workflows. + + Parameters + ---------- + deepcopy : boolean (default=False) + Whether to return a deepcopy of the internal scikit-learn estimator of + the cuML models. cuML models internally have CPU based estimators that + could be updated. If you intend to use both the cuML and the scikit-learn + estimators after using the method in parallel, it is recommended to set + this to True to avoid one overwriting data of the other. + + Returns + ------- + sklearn.base.BaseEstimator + A scikit-learn compatible estimator instance that mirrors the trained + state of the current GPU-accelerated estimator. + + """ + self.import_cpu_model() + self.build_cpu_model() + self.gpu_to_cpu() + if deepcopy: + return copy.deepcopy(self._cpu_model) + else: + return self._cpu_model + + @classmethod + def from_sklearn(cls, model): + """ + Create a GPU-accelerated estimator from a scikit-learn estimator. + + This class method takes an existing scikit-learn estimator and converts it + into the corresponding GPU-backed estimator. It imports any required CPU + model definitions, stores the given scikit-learn model internally, and then + transfers the model parameters and state onto the GPU. + + Parameters + ---------- + model : sklearn.base.BaseEstimator + A fitted scikit-learn estimator from which to create the GPU-accelerated + version. + + Returns + ------- + cls + A new instance of the GPU-accelerated estimator class that mirrors the + state of the input scikit-learn estimator. + + Notes + ----- + - `output_type` of the estimator is set to "numpy" + by default, as these cannot be inferred from training arguments. If + something different is required, then please use cuML's output_type + configuration utilities. + """ + estimator = cls() + estimator.import_cpu_model() + estimator._cpu_model = model + estimator.cpu_to_gpu() + + # we need to set an output type here since + # we cannot infer from training args. + # Setting to numpy seems like a reasonable default for matching the + # deserialized class by default. + estimator.output_type = "numpy" + estimator.output_mem_type = MemoryType.host + + return estimator diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index 7ff8702a2c..b984d47818 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -728,4 +728,4 @@ class TSNE(UniversalBase, def get_attr_names(self): return ["embedding", "kl_divergence_", "n_features_in_", "learning_rate_", - "n_iter_"] + "n_iter_", "embedding_"] diff --git a/python/cuml/cuml/tests/test_sklearn_import_export.py b/python/cuml/cuml/tests/test_sklearn_import_export.py new file mode 100644 index 0000000000..19e4fb2e5c --- /dev/null +++ b/python/cuml/cuml/tests/test_sklearn_import_export.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np + +from cuml.cluster import KMeans, DBSCAN +from cuml.decomposition import PCA, TruncatedSVD +from cuml.linear_model import ( + LinearRegression, + LogisticRegression, + ElasticNet, + Ridge, + Lasso, +) +from cuml.manifold import TSNE +from cuml.neighbors import NearestNeighbors + +from cuml.testing.utils import array_equal + +from numpy.testing import assert_allclose + +from sklearn.datasets import make_blobs, make_classification, make_regression +from sklearn.utils.validation import check_is_fitted +from sklearn.cluster import KMeans as SkKMeans, DBSCAN as SkDBSCAN +from sklearn.decomposition import PCA as SkPCA, TruncatedSVD as SkTruncatedSVD +from sklearn.linear_model import ( + LinearRegression as SkLinearRegression, + LogisticRegression as SkLogisticRegression, + ElasticNet as SkElasticNet, + Ridge as SkRidge, + Lasso as SkLasso, +) +from sklearn.manifold import TSNE as SkTSNE +from sklearn.neighbors import NearestNeighbors as SkNearestNeighbors + +############################################################################### +# Helper functions # +############################################################################### + + +@pytest.fixture +def random_state(): + return 42 + + +def assert_estimator_roundtrip( + cuml_model, sklearn_class, X, y=None, transform=False +): + """ + Generic assertion helper to test round-trip conversion: + fit original custom model + convert to sklearn + convert back to custom model + compare predictions or transform outputs + """ + # Fit original model + if y is not None: + cuml_model.fit(X, y) + else: + cuml_model.fit(X) + + # Convert to sklearn model + sklearn_model = cuml_model.as_sklearn() + check_is_fitted(sklearn_model) + + assert isinstance(sklearn_model, sklearn_class) + + # Convert back + roundtrip_model = type(cuml_model).from_sklearn(sklearn_model) + + # Ensure roundtrip model is fitted + check_is_fitted(roundtrip_model) + + # Compare predictions or transforms + if transform: + original_output = cuml_model.transform(X) + roundtrip_output = roundtrip_model.transform(X) + array_equal(original_output, roundtrip_output) + else: + # For predict methods + if hasattr(cuml_model, "predict"): + original_pred = cuml_model.predict(X) + roundtrip_pred = roundtrip_model.predict(X) + array_equal(original_pred, roundtrip_pred) + # For models that only produce labels_ or similar attributes (e.g., clustering) + elif hasattr(cuml_model, "labels_"): + array_equal(cuml_model.labels_, roundtrip_model.labels_) + else: + # If we get here, need a custom handling for that type + raise NotImplementedError( + "No known method to compare outputs of this model." + ) + + +############################################################################### +# Tests # +############################################################################### + + +def test_kmeans(random_state): + # Using sklearn directly for demonstration + X, _ = make_blobs( + n_samples=50, n_features=2, centers=3, random_state=random_state + ) + original = KMeans(n_clusters=3, random_state=random_state) + assert_estimator_roundtrip(original, SkKMeans, X) + + +def test_dbscan(random_state): + X, _ = make_blobs( + n_samples=50, n_features=2, centers=3, random_state=random_state + ) + original = DBSCAN(eps=0.5, min_samples=5) + # DBSCAN assigns labels_ after fit + original.fit(X) + sklearn_model = original.as_sklearn() + roundtrip_model = DBSCAN.from_sklearn(sklearn_model) + array_equal(original.labels_, roundtrip_model.labels_) + + +def test_pca(random_state): + X = np.random.RandomState(random_state).rand(50, 5) + original = PCA(n_components=2, random_state=random_state) + assert_estimator_roundtrip(original, SkPCA, X, transform=True) + + +def test_truncated_svd(random_state): + X = np.random.RandomState(random_state).rand(50, 5) + original = TruncatedSVD(n_components=2, random_state=random_state) + assert_estimator_roundtrip(original, SkTruncatedSVD, X, transform=True) + + +def test_linear_regression(random_state): + X, y = make_regression( + n_samples=50, n_features=5, noise=0.1, random_state=random_state + ) + original = LinearRegression() + assert_estimator_roundtrip(original, SkLinearRegression, X, y) + + +def test_logistic_regression(random_state): + X, y = make_classification( + n_samples=50, n_features=5, n_informative=3, random_state=random_state + ) + original = LogisticRegression(random_state=random_state, max_iter=500) + assert_estimator_roundtrip(original, SkLogisticRegression, X, y) + + +def test_elasticnet(random_state): + X, y = make_regression( + n_samples=50, n_features=5, noise=0.1, random_state=random_state + ) + original = ElasticNet(random_state=random_state) + assert_estimator_roundtrip(original, SkElasticNet, X, y) + + +def test_ridge(random_state): + X, y = make_regression( + n_samples=50, n_features=5, noise=0.1, random_state=random_state + ) + original = Ridge(alpha=1.0, random_state=random_state) + assert_estimator_roundtrip(original, SkRidge, X, y) + + +def test_lasso(random_state): + X, y = make_regression( + n_samples=50, n_features=5, noise=0.1, random_state=random_state + ) + original = Lasso(alpha=0.1, random_state=random_state) + assert_estimator_roundtrip(original, SkLasso, X, y) + + +def test_tsne(random_state): + # TSNE is a bit tricky as it is non-deterministic. For test simplicity: + X = np.random.RandomState(random_state).rand(50, 5) + original = TSNE(n_components=2, random_state=random_state) + original.fit(X) + sklearn_model = original.as_sklearn() + roundtrip_model = TSNE.from_sklearn(sklearn_model) + # Since TSNE is non-deterministic, exact match is unlikely. + # We can at least check output dimensions are the same. + original_embedding = original.embedding_ + sklearn_embedding = sklearn_model.embedding_ + roundtrip_embedding = roundtrip_model.embedding_ + + array_equal(original_embedding, sklearn_embedding) + array_equal(original_embedding, roundtrip_embedding) + + +def test_nearest_neighbors(random_state): + X = np.random.RandomState(random_state).rand(50, 5) + original = NearestNeighbors(n_neighbors=5) + original.fit(X) + sklearn_model = original.as_sklearn() + roundtrip_model = NearestNeighbors.from_sklearn(sklearn_model) + # Check that the kneighbors results are the same + dist_original, ind_original = original.kneighbors(X) + dist_roundtrip, ind_roundtrip = roundtrip_model.kneighbors(X) + assert_allclose(dist_original, dist_roundtrip) + assert_allclose(ind_original, ind_roundtrip)