-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add: Support for Sparse24Bitmask Compressed Models #12097
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Rahul Tuli <[email protected]>
Signed-off-by: Rahul Tuli <[email protected]>
ab892d2
to
02ff821
Compare
Add a test file with an 8B 2of4 compressed model for lm_eval_harness in buildkite
|
@@ -481,6 +495,19 @@ def supports_cutlass_24( | |||
|
|||
return weight_quant.num_bits == input_quant.num_bits == 8 | |||
|
|||
def _get_model_compression_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like an unnecessary function break out
assert all( | ||
partition_size % 8 == 0 | ||
for partition_size in output_partition_sizes | ||
), "All partitions must be divisible by 8 for 2:4 compressed models" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe "for a 2:4 sparse compressed model"?
shape = BitMaskShapeParameter(data=torch.empty( | ||
2 * len(output_partition_sizes), 1, dtype=torch.uint64), | ||
weight_loader=weight_loader) | ||
compressed = ModelWeightParameter(data=torch.empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: parameter name
new_tensor = tensor.view(-1, 4) | ||
zero_counts = (new_tensor == 0).sum(dim=1) | ||
return (zero_counts >= 2).all().item() | ||
def _decompress_bitmask_compressed_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
for partition_size in output_partition_sizes | ||
), "All partitions must be divisible by 8 for 2:4 compressed models" | ||
|
||
shape = BitMaskShapeParameter(data=torch.empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need to shard the shape
This PR adds support for models compressed using
Sparse24BitMaskCompressor
to use cutlass 2:4 KernelsBitmaskShapeParameter
This diff was manually tested on:
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM
Also added unit tests for the compressed cases!!
Needs the following compressed-tensors PR to land:
Notion Doc: https://www.notion.so/SparseBitMask-24-work-15e863ebf65c80dcbc70e6317d552987