Skip to content

Commit

Permalink
Support for Decompressing Models from HF Hub (#2212)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Apr 4, 2024
1 parent 3b813b6 commit 5ac1e15
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
9 changes: 4 additions & 5 deletions src/sparseml/transformers/compression/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,14 @@ def replace_layer(param_name: str, data: Tensor, model: Module):
model_device = operator.attrgetter(param_name)(model).device
set_layer(param_name, Parameter(data.to(model_device)), model)

def overwrite_weights(self, pretrained_model_name_or_path: str, model: Module):
def overwrite_weights(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from
pretrained_model_name_or_path
Overwrites the weights in model with weights decompressed from model_path
:param pretrained_model_name_or_path: path to compressed weights
:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
dense_gen = self.decompress(pretrained_model_name_or_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
ModelCompressor.replace_layer(name, data, model)
setattr(model, SPARSITY_CONFIG_NAME, self.config)
51 changes: 49 additions & 2 deletions src/sparseml/transformers/compression/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import os
import re
import struct
from typing import Dict, List
from typing import Dict, List, Optional

from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file


__all__ = [
"get_safetensors_folder",
"get_safetensors_header",
"match_param_name",
"merge_names",
Expand All @@ -30,6 +31,48 @@
]


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
) -> str:
"""
Given a Hugging Face stub or a local path, return the folder containing the
safetensors weight files
:param pretrained_model_name_or_path: local path to model or HF stub
:param cache_dir: optional cache dir to search through, if none is specified the
model will be searched for in the default TRANSFORMERS_CACHE
:return: local folder containing model data
"""
if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return pretrained_model_name_or_path

safetensors_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
index_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
if safetensors_path is not None:
# found a single cached safetensors file
return os.path.split(safetensors_path)[0]
if index_path is not None:
# found a cached safetensors weight index file
return os.path.split(index_path)[0]

# model weights could not be found locally or cached from HF Hub
raise ValueError(
"Could not locate safetensors weight or index file from "
f"{pretrained_model_name_or_path}."
)


def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
"""
Extracts the metadata from a safetensors file as JSON
Expand Down Expand Up @@ -105,6 +148,10 @@ def get_weight_mappings(model_path: str) -> Dict[str, str]:
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
header = index["weight_map"]
else:
raise ValueError(
f"Could not find a safetensors weight or index file at {model_path}"
)

# convert weight locations to full paths
for key, value in header.items():
Expand Down
10 changes: 8 additions & 2 deletions src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
log_model_load,
)
from sparseml.transformers.compression.utils import (
get_safetensors_folder,
infer_compressor_from_model_config,
modify_save_pretrained,
)
Expand Down Expand Up @@ -128,9 +129,14 @@ def skip(*args, **kwargs):

# If model is compressed on disk, decompress and load the weights
if compressor is not None:
compressor.overwrite_weights(
pretrained_model_name_or_path=pretrained_model_name_or_path, model=model
# if we loaded from a HF stub, find the cached model
model_path = get_safetensors_folder(
pretrained_model_name_or_path, cache_dir=kwargs.get("cache_dir", None)
)

# decompress weights
compressor.overwrite_weights(model_path=model_path, model=model)

recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path)
if recipe:
apply_recipe_structure_to_model(
Expand Down

0 comments on commit 5ac1e15

Please sign in to comment.