Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fully functional FSDP one-shot process #2305

Merged
merged 7 commits into from
May 28, 2024
Merged

[Fix] Fully functional FSDP one-shot process #2305

merged 7 commits into from
May 28, 2024

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented May 28, 2024

Note: This PR should be landed in unison with: neuralmagic/compressed-tensors#58

Feature Description

A subtle set of fixes to enable FSDP one-shot. The fixes are mostly focused on correctly undoing the naming changes enforced by the wrapped FSDP module.

Testing

Note: The FSDP process was run with num_processes: 1, as well as num_processes: 2. Both runs yielded similar perplexities.

Model generation script

import torch

from sparseml.transformers import SparseAutoModelForCausalLM, oneshot

recipe = """
quant_stage:
    quant_modifiers:
        GPTQModifier:
            sequential_update: false
            ignore: ["lm_head"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 4
                        type: "int"
                        symmetric: true
                        strategy: "channel"
                    targets: ["Linear"]
"""
model_stub = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = SparseAutoModelForCausalLM.from_pretrained(
    model_stub, 
    torch_dtype=torch.bfloat16, 
)

dataset = "open-platypus"
output_dir = "./model"
splits = {"calibration": "train[:5%]"}
max_seq_length = 512
pad_to_max_length = False
num_calibration_samples = 512
oneshot(
    model=model,
    dataset=dataset,
    recipe=recipe,
    output_dir=output_dir,
    splits=splits,
    max_seq_length=max_seq_length,
    pad_to_max_length=pad_to_max_length,
    num_calibration_samples=num_calibration_samples,
    save_compressed=True,
)

To run FSDP training:

accelerate launch --config_file integrations/huggingface-transformers/finetuning/example_fsdp_config.yaml model_generation_script.py 

Model testing script

from sparseml import evaluate
from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer
import torch
from datasets import load_dataset
from sparseml.pytorch.utils.helpers import tensor_sparsity

print(evaluate("model", limit=100, integration="perplexity", datasets="garage-bAInd/Open-Platypus", text_column_name="instruction"))
print(evaluate("model_fsdp", limit=100, integration="perplexity", datasets="garage-bAInd/Open-Platypus", text_column_name="instruction"))

Result

The resulting post-FSDP one-shot model has the same perplexity and sparsity of its weights compared to the counterpart:

# eval for non-fsdp model (compressed=False or True yields the same perplexity)

formatted=[Evaluation(task='text-generation', dataset=Dataset(type='text-generation', name='garage-bAInd/Open-Platypus', config=None, split=None), metrics=[Metric(name='perplexity', value=17.98309205532074)], samples=None)] raw={'mean_perplexity': 17.98309205532074}

# eval for fsdp model (compressed=False)

formatted=[Evaluation(task='text-generation', dataset=Dataset(type='text-generation', name='garage-bAInd/Open-Platypus', config=None, split=None), metrics=[Metric(name='perplexity', value=17.03630661010742)], samples=None)] raw={'mean_perplexity': 17.03630661010742}

@dbogunowicz dbogunowicz requested review from Satrat and bfineran May 28, 2024 11:21
Copy link

@Satrat Satrat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work tracking these issues down!

@bfineran bfineran merged commit 451d838 into main May 28, 2024
16 of 17 checks passed
@bfineran bfineran deleted the damian/fix_fsdp branch May 28, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants