Skip to content

Commit

Permalink
Convert input for quantized YOLOv8 (#1521) (#1628)
Browse files Browse the repository at this point in the history
* Added call to function that skips the quantization of the input if the model is quantized

* Removed bare exception

* Move input data type conversion upstream so it is properly saved and used as input to ort

* Style and quality fixes

---------

Co-authored-by: Alexandre Marques <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored Jun 15, 2023
1 parent 9f7606c commit 1995747
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
9 changes: 8 additions & 1 deletion src/sparseml/yolov8/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from sparseml.optim.helpers import load_recipe_yaml_str
from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize
from sparseml.pytorch.utils import ModuleExporter
from sparseml.pytorch.utils.helpers import download_framework_model_by_recipe_type
from sparseml.pytorch.utils.logger import LoggerManager, PythonLogger, WANDBLogger
Expand Down Expand Up @@ -729,7 +730,13 @@ def export(self, **kwargs):
else ["output0"],
)

onnx.checker.check_model(os.path.join(save_dir, name))
complete_path = os.path.join(save_dir, name)
try:
skip_onnx_input_quantize(complete_path, complete_path)
except Exception:
pass

onnx.checker.check_model(complete_path)
deployment_folder = exporter.create_deployment_folder(onnx_model_name=name)
if args["export_samples"]:
trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH)
Expand Down
15 changes: 8 additions & 7 deletions src/sparseml/yolov8/utils/export_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ def export_sample_inputs_outputs(
preprocessed_batch = preprocess(batch=batch, device=device)
image = preprocessed_batch["img"]

# Save inputs as numpy array
_export_inputs(image, sample_in_dir, file_idx, save_inputs_as_uint8)
# Save torch outputs as numpy array
_export_torch_outputs(image, model, sample_out_dir_torch, file_idx)

# Convert input data type if needed
if save_inputs_as_uint8:
image = (255 * image).to(dtype=torch.uint8)

# Save inputs as numpy array
_export_inputs(image, sample_in_dir, file_idx)
# Save onnxruntime outputs as numpy array
_export_ort_outputs(
image.cpu().numpy(), ort_session, sample_out_dir_ort, file_idx
Expand Down Expand Up @@ -166,13 +171,9 @@ def _export_ort_outputs(
numpy.savez(sample_output_filename, preds)


def _export_inputs(
image: torch.Tensor, sample_in_dir: str, file_idx: str, save_inputs_as_uint8: bool
):
def _export_inputs(image: torch.Tensor, sample_in_dir: str, file_idx: str):

sample_in = image.detach().to("cpu")
if save_inputs_as_uint8:
sample_in = (255 * sample_in).to(dtype=torch.uint8)

sample_input_filename = os.path.join(sample_in_dir, f"inp-{file_idx}.npz")
numpy.savez(sample_input_filename, sample_in)
Expand Down

0 comments on commit 1995747

Please sign in to comment.