diff --git a/src/sparseml/core/utils/__init__.py b/src/sparseml/core/utils/__init__.py new file mode 100644 index 00000000000..57a02bd65d6 --- /dev/null +++ b/src/sparseml/core/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .session_helpers import * diff --git a/src/sparseml/core/utils/session_helpers.py b/src/sparseml/core/utils/session_helpers.py new file mode 100644 index 00000000000..f614b21de6c --- /dev/null +++ b/src/sparseml/core/utils/session_helpers.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +import sparseml.core.session as session_manager + + +@contextmanager +def session_context_manager(): + """ + A context manager to setup a fresh session and reset it after the context + is exited. + """ + + active_session = session_manager.active_session() + active_session.reset() + yield + # reset the session after each context + active_session.reset() diff --git a/src/sparseml/evaluation/evaluator.py b/src/sparseml/evaluation/evaluator.py index 9eb72cff6df..88c90f33650 100644 --- a/src/sparseml/evaluation/evaluator.py +++ b/src/sparseml/evaluation/evaluator.py @@ -15,6 +15,7 @@ from typing import Optional +from sparseml.core.utils import session_context_manager from sparseml.evaluation.registry import SparseMLEvaluationRegistry from sparsezoo.evaluation.results import Result @@ -43,15 +44,17 @@ def evaluate( :param batch_size: The batch size to use for evals, defaults to 1 :return: The evaluation result as a Result object """ - - eval_integration = SparseMLEvaluationRegistry.resolve( - name=integration, datasets=datasets - ) - - if datasets is None: - # let the integration handle the default dataset - return eval_integration(model_path=model_path, batch_size=batch_size, **kwargs) - - return eval_integration( - model_path=model_path, datasets=datasets, batch_size=batch_size, **kwargs - ) + with session_context_manager(): + eval_integration = SparseMLEvaluationRegistry.resolve( + name=integration, datasets=datasets + ) + + if datasets is None: + # let the integration handle the default dataset + return eval_integration( + model_path=model_path, batch_size=batch_size, **kwargs + ) + + return eval_integration( + model_path=model_path, datasets=datasets, batch_size=batch_size, **kwargs + )