Skip to content

Commit

Permalink
test: use QTensor.equal
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Aug 24, 2024
1 parent 634a980 commit 7e73673
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
3 changes: 1 addition & 2 deletions test/models/test_quantized_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def compare_models(a_model, b_model):
if isinstance(b_m, QModuleMixin):
assert isinstance(a_m, QModuleMixin)
if isinstance(a_m, QModuleMixin):
assert torch.equal(a_m.weight._data, b_m.weight._data)
assert torch.equal(a_m.weight._scale, b_m.weight._scale)
assert torch.equal(a_m.weight, b_m.weight)
for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):
assert a_p_name == b_p_name
assert isinstance(a_p, torch.Tensor)
Expand Down
3 changes: 1 addition & 2 deletions test/models/test_quantized_model_for_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def compare_models(a_model, b_model):
if isinstance(b_m, QModuleMixin):
assert isinstance(a_m, QModuleMixin)
if isinstance(a_m, QModuleMixin):
assert torch.equal(a_m.weight._data, b_m.weight._data)
assert torch.equal(a_m.weight._scale, b_m.weight._scale)
assert torch.equal(a_m.weight, b_m.weight)
for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):
assert a_p_name == b_p_name
assert isinstance(a_p, torch.Tensor)
Expand Down
4 changes: 1 addition & 3 deletions test/quantize/test_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,9 @@ def test_requantize_serialized_model(
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert module_reloaded.weight.qtype == module.weight.qtype
assert torch.equal(module_reloaded.weight, module.weight)
assert module_reloaded.weight_qtype == module.weight_qtype
assert module_reloaded.activation_qtype == module.activation_qtype
assert torch.equal(module_reloaded.weight._data, module.weight._data)
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)

Expand Down

0 comments on commit 7e73673

Please sign in to comment.