You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When quantizing SSDLiteMobilenetV3 using MCT, the following error occurred repeatedly.
"ERROR: Model Compression Toolkit: Found duplicate qco types!"
Specifically, it seems to occur in the following function:
"mct.ptq.pytorch_post_training_quantization"
However, despite the error, the quantization was successful, the mAP value was as expected, and it does not seem to affect the completed quantized model.
The script that generated the error is a Python script rewritten from the following Notebook:
The operating environment for the script is as follows:
・GPU used
・model-compression-toolkit 2.1.0
・pytorch 2.0.1
Expected behaviour
It doesn't seem to affect the model, but it would be better if no errors occurred.
Code to reproduce the issue
#Sorry, it's a bit long, but here is the reproducible script. The error log is output in the following "mct.ptq.pytorch_post_training_quantization".#-----#importtorchimporttorchvisionfromtorchvision.models.detection.ssdliteimportSSDLite320_MobileNet_V3_Large_Weightsfromtorchvision.models.detection.anchor_utilsimportImageListimportmodel_compression_toolkitasmctfrompycocotools.cocoimportCOCOfrompycocotools.cocoevalimportCOCOevaldevice='cuda'iftorch.cuda.is_available() else'cpu'image_size= (320, 320)
model=torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT)
# mAP=0.2131 (float)# mAP=0.2007 (quantized)model.eval()
model=model.to(device)
print('device : %s'% (device))
print('model loaded')
defformat_results(outputs, img_ids):
detections= []
# Process model outputs and convert to detection formatforidx, outputinenumerate(outputs):
image_id=img_ids[idx] # Adjust according to your batch size and indexingscores=output['scores'].cpu().numpy()
labels=output['labels'].cpu().numpy()
boxes=output['boxes'].cpu().numpy()
forscore, label, boxinzip(scores, labels, boxes):
detection= {
"image_id": image_id,
"category_id": label,
"bbox": [box[0], box[1], box[2] -box[0], box[3] -box[1]],
"score": score
}
detections.append(detection)
returndetectionsclassCocoEval:
def__init__(self, path2json):
# Load ground truth annotationsself.coco_gt=COCO(path2json)
# A list of reformatted model outputsself.all_detections= []
defadd_batch_detections(self, outputs, targets):
# Collect and format results from the batchimg_ids, _outs= [], []
fort, oinzip(targets, outputs):
iflen(t) >0:
img_ids.append(t[0]['image_id'])
_outs.append(o)
batch_detections=format_results(_outs, img_ids) # Implement this functionself.all_detections.extend(batch_detections)
defresult(self):
# Initialize COCO evaluation objectself.coco_dt=self.coco_gt.loadRes(self.all_detections)
coco_eval=COCOeval(self.coco_gt, self.coco_dt, 'bbox')
# Run evaluationcoco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
# Print mAP resultsprint("mAP: {:.4f}".format(coco_eval.stats[0]))
returncoco_eval.statsdefreset(self):
self.all_detections= []
EVAL_DATASET_FOLDER='./coco/val2017'EVAL_DATASET_ANNOTATION_FILE='./coco/annotations/instances_val2017.json'defcollate_fn(batch_input):
images= [b[0] forbinbatch_input]
targets= [b[1] forbinbatch_input]
returnimages, targets# Initialize the COCO evaluation DataLoadercoco_eval=torchvision.datasets.CocoDetection(root=EVAL_DATASET_FOLDER,
annFile=EVAL_DATASET_ANNOTATION_FILE,
transform=torchvision.transforms.ToTensor())
batch_size=50data_loader=torch.utils.data.DataLoader(coco_eval, batch_size=batch_size, shuffle=False,
num_workers=0, collate_fn=collate_fn)
# Initialize the evaluation metric objectcoco_metric=CocoEval(EVAL_DATASET_ANNOTATION_FILE)
classSDD4Quant(torch.nn.Module):
def__init__(self, in_sdd, *args, **kwargs):
super().__init__(*args, **kwargs)
# Save the float model under self.base as a module of the model. Later we'll only run "backbone" & "head"self.add_module("base", in_sdd)
# Forward pass of the model to be quantized. This code is copied from the float model forward function (removed the preprocess and postprocess code)defforward(self, x):
features=self.base.backbone(x)
features=list(features.values())
# compute the ssd heads outputs using the featureshead_outputs=self.base.head(features)
returnhead_outputsmodel4quant=SDD4Quant(model)
defpreprocess(image, targets):
# need to save the original image sizes before resize for the postprocess parttargets= {'gt': targets, 'img_size': list(image.size[::-1])}
image=model.transform([torchvision.transforms.ToTensor()(image)])[0].tensors[0, ...]
returnimage, targets# Define the postprocess, which is the code copied from the float model forward code. These layers will not be quantized.classPostProcess:
def__init__(self):
self.features= [torch.zeros((1, 1, s, s)) forsin [20, 10, 5, 3, 2, 1]]
def__call__(self, head_outputs, image_list, original_image_sizes):
anchors= [a.to(device) forainmodel.anchor_generator(image_list, self.features)]
# The MCT flattens the outputs of the head to a list, so need to change it to a dictionary as the psotprocess functions expect.ifnotisinstance(head_outputs, dict):
ifhead_outputs[0].shape[-1] ==4:
head_outputs= {"bbox_regression": head_outputs[0],
"cls_logits": head_outputs[1]}
else:
head_outputs= {"bbox_regression": head_outputs[1],
"cls_logits": head_outputs[0]}
# Float model postprocess functions that handle box regression and NMSdetections=model.postprocess_detections(head_outputs, anchors, image_list.image_sizes)
detections=model.transform.postprocess(detections, image_list.image_sizes, original_image_sizes)
returndetectionspostprocess=PostProcess()
deftrain_collate_fn(batch_input):
# collating images for the quantized model should return a single tensor: [B, C, H, W]images=torch.stack([b[0] forbinbatch_input])
targets= [b[1] forbinbatch_input]
returnimages, targetscoco_eval=torchvision.datasets.CocoDetection(root=EVAL_DATASET_FOLDER, annFile=EVAL_DATASET_ANNOTATION_FILE,
transforms=preprocess)
eval_loader=torch.utils.data.DataLoader(coco_eval, batch_size=50, shuffle=False, num_workers=0,
collate_fn=train_collate_fn)
defget_representative_dataset(n_iter):
defrepresentative_dataset():
ds_iter=iter(eval_loader)
for_inrange(n_iter):
yield [next(ds_iter)[0]]
returnrepresentative_dataset# Get representative dataset generatorrepresentative_dataset_gen=get_representative_dataset(20)
quant_model, _=mct.ptq.pytorch_post_training_quantization(model4quant,
representative_dataset_gen)
print('quant_model is Ready')
The text was updated successfully, but these errors were encountered:
kouki-ehara
changed the title
An error occurred during PTQ of SSDLiteMobilenetV3 (quantization was successful)
During PTQ, an error message "Found duplicate qco types!" occurs (quantization is successful)
Aug 6, 2024
Issue Type
Others
Source
pip (model-compression-toolkit)
MCT Version
2.1.0
OS Platform and Distribution
Ubuntu 22.04.4 LTS
Python version
3.9.19
Describe the issue
When quantizing SSDLiteMobilenetV3 using MCT, the following error occurred repeatedly.
"ERROR: Model Compression Toolkit: Found duplicate qco types!"
Specifically, it seems to occur in the following function:
"mct.ptq.pytorch_post_training_quantization"
However, despite the error, the quantization was successful, the mAP value was as expected, and it does not seem to affect the completed quantized model.
The script that generated the error is a Python script rewritten from the following Notebook:
https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_ssdlite_mobilenetv3_object_detection.ipynb
The operating environment for the script is as follows:
・GPU used
・model-compression-toolkit 2.1.0
・pytorch 2.0.1
Expected behaviour
It doesn't seem to affect the model, but it would be better if no errors occurred.
Code to reproduce the issue
Log output
ssd_mobilenetv3_mct_script_log.txt
The error log will be around line 50.
The text was updated successfully, but these errors were encountered: