Skip to content

Commit

Permalink
Move Session Management to Top Level (#2261)
Browse files Browse the repository at this point in the history
* top level import

* fix tests

* fix unit tests

* fix import

* change recommended import
  • Loading branch information
Sara Adkins authored May 8, 2024
1 parent a157dfd commit 214873b
Show file tree
Hide file tree
Showing 23 changed files with 90 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as session_manager
from sparseml import active_session
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
Expand All @@ -40,8 +40,7 @@ def main():
device = "cuda:0"

# set up SparseML session
session_manager.create_session()
session = session_manager.active_session()
session = active_session()

# download model
model = torchvision.models.mobilenet_v2(
Expand Down
8 changes: 2 additions & 6 deletions src/sparseml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
from .log import *
from .version import *

from .base import (
Framework,
check_version,
detect_framework,
execute_in_sparseml_framework,
)
from .core import *
from .base import check_version, detect_framework, execute_in_sparseml_framework
from .framework import (
FrameworkInferenceProviderInfo,
FrameworkInfo,
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
from .modifier import *
from .optimizer import *
from .recipe import *
from .session import *
from .state import *
9 changes: 9 additions & 0 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"SparseSession",
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
Expand Down Expand Up @@ -388,6 +389,14 @@ def active_session() -> SparseSession:
return getattr(_local_storage, "session", _global_session)


def reset_session():
"""
Reset the currently active session to its initial state
"""
session = active_session()
session._lifecycle.reset()


def pre_initialize_structure(**kwargs):
"""
A method to pre-initialize the structure of the model for the active session
Expand Down
15 changes: 0 additions & 15 deletions src/sparseml/core/utils/__init__.py

This file was deleted.

31 changes: 0 additions & 31 deletions src/sparseml/core/utils/session_helpers.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/sparseml/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Optional

from sparseml.core.utils import session_context_manager
from sparseml.core.session import create_session
from sparseml.evaluation.registry import SparseMLEvaluationRegistry
from sparsezoo.evaluation.results import Result

Expand Down Expand Up @@ -44,7 +44,7 @@ def evaluate(
:param batch_size: The batch size to use for evals, defaults to 1
:return: The evaluation result as a Result object
"""
with session_context_manager():
with create_session():
eval_integration = SparseMLEvaluationRegistry.resolve(
name=integration, datasets=datasets
)
Expand Down
7 changes: 3 additions & 4 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
import numpy

import click
import sparseml.core.session as session_manager
from sparseml.core.session import reset_session
from sparseml.export.helpers import (
AVAILABLE_DEPLOYMENT_TARGETS,
ONNX_MODEL_NAME,
Expand Down Expand Up @@ -192,8 +192,7 @@ def export(
opset = opset or TORCH_DEFAULT_ONNX_OPSET

# start a new SparseSession for potential recipe application
session_manager.create_session()
session_manager.active_session().reset()
reset_session()

if source_path is not None and model is not None:
raise ValueError(
Expand Down Expand Up @@ -269,7 +268,7 @@ def export(

# once model is loaded we can clear the SparseSession, it was only needed for
# adding structural changes (ie quantization) to the model
session_manager.active_session().reset()
reset_session()

_LOGGER.info("Creating data loader for the export...")
if tokenizer is not None:
Expand Down
14 changes: 7 additions & 7 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from torch.nn import Module

import sparseml.core.session as session_manager
import sparseml
from safetensors import safe_open
from sparseml.core.framework import Framework
from sparseml.pytorch.sparsification.quantization.helpers import (
Expand Down Expand Up @@ -101,17 +101,17 @@ def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path:
orig_state_dict = model.state_dict()

# apply structural changes to the model
if not session_manager.active_session():
session_manager.create_session()
session_manager.pre_initialize_structure(
if not sparseml.active_session():
sparseml.create_session()
sparseml.pre_initialize_structure(
model=model, recipe=recipe_path, framework=Framework.pytorch
)

# no need to reload if no recipe was applied
if recipe_path is None:
return

session = session_manager.active_session()
session = sparseml.active_session()
num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages)
msg = (
"an unstaged recipe"
Expand Down Expand Up @@ -260,7 +260,7 @@ def save_model_and_recipe(
_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = session_manager.active_session()
session = sparseml.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_session_model() -> Module:
:return: pytorch module stored by the active SparseSession, or None if no session
is active
"""
session = session_manager.active_session()
session = sparseml.active_session()
if not session:
return None

Expand Down
3 changes: 1 addition & 2 deletions src/sparseml/sparsification/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import logging
from typing import Any, Dict, List, Optional, Type

from sparseml import Framework, execute_in_sparseml_framework
from sparseml.base import detect_frameworks
from sparseml.base import Framework, detect_frameworks, execute_in_sparseml_framework
from sparseml.sparsification.analyzer import Analyzer
from sparseml.sparsification.recipe_builder import PruningRecipeBuilder
from sparseml.sparsification.recipe_editor import run_avaialble_recipe_editors
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor
from torch.nn import Module

import sparseml.core.session as session_manager
import sparseml
from compressed_tensors import CompressionConfig
from sparseml.pytorch.utils import ModuleSparsificationInfo

Expand Down Expand Up @@ -53,7 +53,7 @@ def infer_sparsity_structure() -> str:
:return: sparsity structure as a string
"""
current_session = session_manager.active_session()
current_session = sparseml.active_session()
stage_modifiers = current_session.lifecycle.modifiers
sparsity_structure = "unstructured"

Expand Down
5 changes: 3 additions & 2 deletions src/sparseml/transformers/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers import TrainerCallback, TrainerControl, TrainingArguments
from transformers.trainer_callback import TrainerState

import sparseml
import sparseml.core.session as session_manager


Expand Down Expand Up @@ -57,7 +58,7 @@ def on_train_begin(
model, as it will have changed to a wrapper if FSDP is enabled
"""
super().on_train_begin(args, state, control, **kwargs)
session = session_manager.active_session()
session = sparseml.active_session()
session.state.model.model = self.trainer.model

def on_step_end(
Expand Down Expand Up @@ -113,7 +114,7 @@ def qat_active(self) -> bool:
"""
:return: True if a quantization modifier is active in the current session
"""
session = session_manager.active_session()
session = sparseml.active_session()
return session.state.model.qat_active()

def on_epoch_begin(
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import Dataset
from transformers import AutoTokenizer

import sparseml.core.session as session_manager
import sparseml
from sparseml.core.recipe import Recipe, StageRunType
from sparseml.pytorch.model_load.helpers import (
get_completed_stages,
Expand Down Expand Up @@ -302,7 +302,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
save_completed_stages(self._output_dir, completed_stages)

# setup for next stage
session = session_manager.active_session()
session = sparseml.active_session()
session.reset_stage()

# synchronize and clean up memory
Expand Down
26 changes: 13 additions & 13 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import get_last_checkpoint

import sparseml.core.session as session_manager
import sparseml
from sparseml.core.framework import Framework
from sparseml.core.session import callbacks
from sparseml.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(

# setup logger and session
self.logger_manager = LoggerManager(log_python=False)
session_manager.create_session()
sparseml.create_session()

# call Trainer initialization
super().__init__(**kwargs)
Expand Down Expand Up @@ -131,15 +131,15 @@ def initialize_session(
:param checkpoint: Optional checkpoint to initialize from to continue training
:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = session_manager.active_session()
session = sparseml.active_session()
if session.lifecycle.initialized_ or session.lifecycle.finalized:
return False

train_data = self.get_train_dataloader()

self.accelerator.wait_for_everyone()
with summon_full_params_context(self.model, offload_to_cpu=True):
session_manager.initialize(
sparseml.initialize(
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
recipe=self.recipe,
Expand Down Expand Up @@ -172,11 +172,11 @@ def initialize_structure(self, stage: Optional[str] = None):
:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = session_manager.active_session()
session = sparseml.active_session()
if session.lifecycle.initialized_:
return False

session_manager.pre_initialize_structure(
sparseml.pre_initialize_structure(
model=self.model,
recipe=self.recipe,
recipe_stage=stage,
Expand All @@ -190,13 +190,13 @@ def finalize_session(self):
"""
Wrap up training by finalizing all modifiers initialized in the current session
"""
session = session_manager.active_session()
session = sparseml.active_session()
if not session.lifecycle.initialized_ or session.lifecycle.finalized:
return False

with summon_full_params_context(self.model, offload_to_cpu=True):
# in order to update each layer we need to gathers all its parameters
session_manager.finalize()
sparseml.finalize()
_LOGGER.info("Finalized SparseML session")
model = get_session_model()
self.model = model
Expand Down Expand Up @@ -232,7 +232,7 @@ def create_optimizer(self):
len(self.train_dataset) / total_batch_size
)

session_manager.initialize(
sparseml.initialize(
optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
)

Expand Down Expand Up @@ -304,7 +304,7 @@ def compute_loss(
log["step_loss"] = loss.item()
log["perplexity"] = torch.exp(loss).item()

if session_manager.active_session().lifecycle.initialized_:
if sparseml.active_session().lifecycle.initialized_:
state = callbacks.loss_calculated(loss=loss)
if state and state.loss is not None:
loss = state.loss
Expand Down Expand Up @@ -406,7 +406,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
:param stage: which stage of the recipe to run, or None to run whole recipe
:param calib_data: dataloader of calibration data
"""
session_manager.apply(
sparseml.apply(
framework=Framework.pytorch,
recipe=self.recipe,
recipe_stage=stage,
Expand All @@ -432,7 +432,7 @@ def save_model(
:param output_dir: the path to save the recipes into
"""
if session_manager.active_session() is None:
if sparseml.active_session() is None:
return # nothing to save

if output_dir is None:
Expand Down Expand Up @@ -464,7 +464,7 @@ def save_model(
# save recipe, will contain modifiers from the model's original recipe as
# well as those added from self.recipe
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
session = session_manager.active_session()
session = sparseml.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)
Expand Down
Loading

0 comments on commit 214873b

Please sign in to comment.