From 2bcab78024d71b0fcfb071414fd5b76ba2ca29f9 Mon Sep 17 00:00:00 2001 From: dsxailab Date: Tue, 28 Feb 2023 19:52:56 +0800 Subject: [PATCH 1/2] Performance improvement by caching the model --- scripts/ddetailer.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/scripts/ddetailer.py b/scripts/ddetailer.py index 7841d8e..dddd1de 100644 --- a/scripts/ddetailer.py +++ b/scripts/ddetailer.py @@ -14,6 +14,8 @@ from basicsr.utils.download_util import load_file_from_url dd_models_path = os.path.join(models_path, "mmdet") +segm_model = None +bbox_model = None def list_models(model_path): model_list = modelloader.load_models(model_path=model_path, ext_filter=[".pth"]) @@ -476,10 +478,15 @@ def inference(image, modelname, conf_thres, label): return results def inference_mmdet_segm(image, modelname, conf_thres, label): + global segm_model model_checkpoint = modelpath(modelname) model_config = os.path.splitext(model_checkpoint)[0] + ".py" model_device = get_device() - model = init_detector(model_config, model_checkpoint, device=model_device) + if segm_model is not None: + model = segm_model + else: + model = init_detector(model_config, model_checkpoint, device=model_device) + segm_model = model mmdet_results = inference_detector(model, np.array(image)) bbox_results, segm_results = mmdet_results dataset = modeldataset(modelname) @@ -504,10 +511,15 @@ def inference_mmdet_segm(image, modelname, conf_thres, label): return results def inference_mmdet_bbox(image, modelname, conf_thres, label): + global bbox_model model_checkpoint = modelpath(modelname) model_config = os.path.splitext(model_checkpoint)[0] + ".py" model_device = get_device() - model = init_detector(model_config, model_checkpoint, device=model_device) + if bbox_model is not None: + model = bbox_model + else: + model = init_detector(model_config, model_checkpoint, device=model_device) + bbox_model = model results = inference_detector(model, np.array(image)) cv2_image = np.array(image) cv2_image = cv2_image[:, :, ::-1].copy() From 39dddf4043e8621eeb1681754cddb66c7501cef8 Mon Sep 17 00:00:00 2001 From: dsxailab Date: Sun, 9 Apr 2023 01:28:31 +0800 Subject: [PATCH 2/2] Fix: mmcv version issue --- scripts/ddetailer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/ddetailer.py b/scripts/ddetailer.py index dddd1de..a4dd271 100644 --- a/scripts/ddetailer.py +++ b/scripts/ddetailer.py @@ -47,9 +47,9 @@ def startup(): from launch import is_installed, run if not is_installed("mmdet"): python = sys.executable - run(f'"{python}" -m pip install -U openmim', desc="Installing openmim", errdesc="Couldn't install openmim") - run(f'"{python}" -m mim install mmcv-full', desc=f"Installing mmcv-full", errdesc=f"Couldn't install mmcv-full") - run(f'"{python}" -m pip install mmdet', desc=f"Installing mmdet", errdesc=f"Couldn't install mmdet") + run(f'"{python}" -m pip install --upgrade -U openmim==0.3.7', desc="Installing openmim", errdesc="Couldn't install openmim") + run(f'"{python}" -m mim install --upgrade mmcv-full==1.7.1', desc=f"Installing mmcv-full", errdesc=f"Couldn't install mmcv-full") + run(f'"{python}" -m pip install --upgrade mmdet==2.28.2', desc=f"Installing mmdet", errdesc=f"Couldn't install mmdet") if (len(list_models(dd_models_path)) == 0): print("No detection models found, downloading...")