diff --git a/tests/sparseml/transformers/compression/test_fp8.py b/tests/sparseml/transformers/compression/test_fp8.py index 8125afbabfe..aa98b2e39f5 100644 --- a/tests/sparseml/transformers/compression/test_fp8.py +++ b/tests/sparseml/transformers/compression/test_fp8.py @@ -49,7 +49,7 @@ class TestQuantizationMatches(unittest.TestCase): dataset = "ultrachat-200k" output = "tiny_llama_out" max_seq_length = 512 - weight_dtype = torch.bfloat16 + weight_dtype = torch.float16 num_eval = 64 @classmethod @@ -127,7 +127,7 @@ def test_quantization_reload(self): n_scale, n_zp, n_weight = reloaded_weights[name] assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype == self.weight_dtype + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn assert torch.equal(o_zp, n_zp) # we don't expect an exact match here because o_weight still has the @@ -138,7 +138,7 @@ def test_quantization_reload(self): n_scale, n_zp = reloaded_inputs[name] assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype == self.weight_dtype + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn assert torch.equal(o_zp, n_zp) def _get_dataloader(self, data_args, tokenizer):