Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Decompressing Models from HF Hub #2212

Merged
merged 5 commits into from
Apr 4, 2024
Merged

Conversation

Satrat
Copy link

@Satrat Satrat commented Apr 2, 2024

The original compression PR neglected to test loading a compressed model from the hub, and was only working with local compression loads. This PR adds support for loading a compressed model from the hub with SparseAutoModelForCausalLM.from_pretrained.

Test

Loads the compressed model from the hub, checks the decompressed weights match the original

from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer
import torch

ORIG_MODEL_PATH = "neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4"
COMPRESSED_MODEL_PATH = "mgoin/TinyLlama-1.1B-Chat-v1.0-pruned2.4-compressed"

# Compress and export the model
model = SparseAutoModelForCausalLM.from_pretrained(ORIG_MODEL_PATH, device_map="auto", torch_dtype="auto")
tokenizer = SparseAutoTokenizer.from_pretrained(ORIG_MODEL_PATH)

model_compressed = SparseAutoModelForCausalLM.from_pretrained(COMPRESSED_MODEL_PATH, device_map="auto", torch_dtype="auto")

og_state_dict = model.state_dict()
reconstructed_state_dict = model_compressed.state_dict()
assert(len(og_state_dict) == len(reconstructed_state_dict))
for key in og_state_dict.keys():
    dense_tensor = og_state_dict[key]
    reconstructed_tensor = reconstructed_state_dict[key]
    assert torch.equal(dense_tensor.cpu(), reconstructed_tensor.cpu())

dbogunowicz
dbogunowicz previously approved these changes Apr 3, 2024
src/sparseml/export/validators.py Outdated Show resolved Hide resolved
@Satrat Satrat requested review from mgoin and dbogunowicz April 3, 2024 14:11
@mgoin mgoin merged commit 5ac1e15 into main Apr 4, 2024
13 of 15 checks passed
@mgoin mgoin deleted the sa/remote_compress_load branch April 4, 2024 11:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants