Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 29, 2024
1 parent 0f1a839 commit bb625d4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/sparseml/transformers/compression/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TestQuantizationMatches(unittest.TestCase):
dataset = "ultrachat-200k"
output = "tiny_llama_out"
max_seq_length = 512
weight_dtype = torch.bfloat16
weight_dtype = torch.float16
num_eval = 64

@classmethod
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_quantization_reload(self):
n_scale, n_zp, n_weight = reloaded_weights[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == self.weight_dtype
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
assert torch.equal(o_zp, n_zp)

# we don't expect an exact match here because o_weight still has the
Expand All @@ -138,7 +138,7 @@ def test_quantization_reload(self):
n_scale, n_zp = reloaded_inputs[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == self.weight_dtype
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
assert torch.equal(o_zp, n_zp)

def _get_dataloader(self, data_args, tokenizer):
Expand Down

0 comments on commit bb625d4

Please sign in to comment.