diff --git a/test/quantize/test_requantize.py b/test/quantize/test_requantize.py index e1cd1709..3f0f20b3 100644 --- a/test/quantize/test_requantize.py +++ b/test/quantize/test_requantize.py @@ -44,7 +44,7 @@ def save_and_reload_state_dict(state_dict, serialization): ids=["small", "large"], ) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) @pytest.mark.parametrize("activations", [None, qint8], ids=["a-none", "a-qint8"]) def test_requantize_serialized_model(