Skip to content

Commit

Permalink
Fix: mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 15, 2025
1 parent 4e65cde commit ab892d2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity_structure = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value)
if sparsity_scheme is None:
return False

is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)

valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors import CompressionFormat, ModelCompressor
Expand Down Expand Up @@ -268,17 +268,17 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape,
)
return sparsity_compressor.decompress_weight(weight_data)

split_weights = None
split_bitmask = None
split_shape = None
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]

if split_weights is not None:
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
Expand Down

0 comments on commit ab892d2

Please sign in to comment.