Skip to content

Commit

Permalink
feat(tensor): share int32 unpacking code
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Sep 20, 2024
1 parent 3a3465e commit 45a9501
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
15 changes: 3 additions & 12 deletions optimum/quanto/tensor/weights/awq/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
from torch.utils import _pytree as pytree

from ..packing import unpack_int32_to_uint8


__all__ = ["AWQPackedTensor", "AWQPacking"]

Expand Down Expand Up @@ -89,20 +91,9 @@ def unpack(packed: torch.Tensor, reorder=False):
Returns:
An unpacked uint8 `torch.Tensor` expanded along the second dimension.
"""
bits = 4
shifts = torch.arange(0, 32, bits, device=packed.device)

# Unpack column-wise
unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
unpacked = unpacked.view(unpacked.shape[0], -1)
unpacked = unpack_int32_to_uint8(packed, bits=4)
if reorder:
unpacked = reverse_awq_order(unpacked)

# Convert to unsigned
unpacked = torch.bitwise_and(unpacked, (2**bits) - 1)

return unpacked


Expand Down
42 changes: 42 additions & 0 deletions optimum/quanto/tensor/weights/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


def unpack_int32_to_uint8(packed: torch.Tensor, bits: int):
"""Unpack a packed int32 tensor to a larger uint8 tensor
Args:
packed (`torch.Tensor`):
The packed integer tensor
bits: (`int`):
The number of bits of each packed value.
Returns:
An unpacked uint8 `torch.Tensor` expanded along the last dimension.
"""
total_bits = 32
shifts = torch.arange(0, total_bits, bits, device=packed.device)

# Unpack column-wise
unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
unpacked = unpacked.view(unpacked.shape[0], -1)

# Convert to unsigned
unpacked = torch.bitwise_and(unpacked, (2**bits) - 1)

return unpacked.to(torch.uint8)

0 comments on commit 45a9501

Please sign in to comment.