Skip to content

Commit

Permalink
test(tinygemm): add linear test
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Sep 26, 2024
1 parent 6c0c4f8 commit bd74833
Showing 1 changed file with 33 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
import torch
from helpers import assert_similar, device_eq, random_qweight
from helpers import assert_similar, device_eq, random_qweight, random_tensor
from packaging import version

from optimum.quanto import qint4
Expand Down Expand Up @@ -89,3 +89,35 @@ def test_tinygemm_weight_qbits_tensor_move(device):
assert tgqbt.qtype == tgqbt_cpu.qtype
assert tgqbt.shape == tgqbt_cpu.shape
assert torch.equal(tgqbt.dequantize().cpu(), tgqbt_cpu.dequantize())


@pytest.mark.skip_device("mps") # Only available with pytorch 2.4
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [256, 512])
@pytest.mark.parametrize("embeddings", [256, 512, 1024, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device):
if device.type == "cuda":
if version.parse(torch.version.cuda).release < (12, 1):
pytest.skip(reason="CUDA runtime must be at least 12.1")
if torch.cuda.get_device_capability()[0] < 8:
pytest.skip(reason="CUDA device >= sm80 not available")
qtype = qint4
group_size = 128
dtype = torch.bfloat16
inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device)
# Create a TinyGemmWeightQBitsTensor from a QBitsTensor
qbt = random_qweight((tokens, embeddings), qtype, dtype, group_size=group_size, device=device)
tinygemm_qweight = TinyGemmWeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale_shift=(qbt._scale, qbt._shift),
)
bias = random_tensor((tokens,), dtype=dtype).to(device) if use_bias else None
qout = torch.nn.functional.linear(inputs, tinygemm_qweight, bias)
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias)
assert_similar(out, qout)

0 comments on commit bd74833

Please sign in to comment.