Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qiaoyu1002 committed Dec 19, 2023
1 parent a78c310 commit 0e44119
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 200 deletions.
55 changes: 31 additions & 24 deletions MobileSAMv2/Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,34 @@
import cv2
import os
import sys
from mobilesamv2.promt_mobilesamv2 import PromptModel
from mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from mobilesamv2 import sam_model_registry, SamPredictor
from typing import Any, Dict, Generator,List
import matplotlib.pyplot as plt
import numpy as np
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--yolo_model_path", type=str, default="./PromtModel.pt", help="yolo_model_path")
parser.add_argument("--model_path", type=str, default="./", help="model")
parser.add_argument("--sam_path", type=str, default="sam_vit_h_4b8939.pth", help="sam_path")
parser.add_argument("--ObjectAwareModel_path", type=str, default='./PromptGuidedDecoder/ObjectAwareModel.pt', help="ObjectAwareModel path")
parser.add_argument("--Prompt_guided_Mask_Decoder_path", type=str, default='./PromptGuidedDecoder/Prompt_guided_Mask_Decoder.pt', help="Prompt_guided_Mask_Decoder path")
parser.add_argument("--encoder_path", type=str, default="./", help="select your own path")
parser.add_argument("--img_path", type=str, default="./test_images/", help="path to image file")
parser.add_argument("--imgsz", type=int, default=1024, help="image size")
parser.add_argument("--iou",type=float,default=0.9,help="yolo iou")
parser.add_argument("--conf", type=float, default=0.4, help="yolo object confidence threshold")
parser.add_argument("--retina",type=bool,default=True,help="draw segmentation masks",)
parser.add_argument("--output_dir", type=str, default="./", help="image save path")
parser.add_argument("--model_type", choices=['tiny_vit','vit_h','mobile_sam','efficientvit_l2','efficientvit_l1','efficientvit_l0'], help="choose the model type")
parser.add_argument("--encoder_type", choices=['tiny_vit','sam_vit_h','mobile_sam','efficientvit_l2','efficientvit_l1','efficientvit_l0'], help="choose the model type")
return parser.parse_args()

def create_model():
Prompt_guided_path='./PromptGuidedDecoder/Prompt_guided_Mask_Decoder.pt'
obj_model_path='./weight/ObjectAwareModel.pt'
ObjAwareModel = ObjectAwareModel(obj_model_path)
PromptGuidedDecoder=sam_model_registry['PromptGuidedDecoder'](Prompt_guided_path)
mobilesamv2 = sam_model_registry['vit_h']()
mobilesamv2.prompt_encoder=PromptGuidedDecoder['PromtEncoder']
mobilesamv2.mask_decoder=PromptGuidedDecoder['MaskDecoder']
return mobilesamv2,ObjAwareModel

def show_anns(anns):
if len(anns) == 0:
return
Expand All @@ -46,47 +55,45 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]

model_path={'efficientvit_l0':'./weight/l0.pt',
'efficientvit_l1':'./weight/l1.pt',
'efficientvit_l2':'./weight/l2.pt',
encoder_path={'efficientvit_l2':'./weight/l2.pt',
'tiny_vit':'./weight/mobile_sam.pt',
'vit_h':'./weight/sam_vit_h_4b8939.pth'
}
'sam_vit_h':'./weight/sam_vit_h.pt',}

def main(args):
# import pdb;pdb.set_trace()
promtmodel = PromptModel(args.yolo_model_path)
output_dir=args.output_dir
output_dir=args.output_dir
mobilesamv2, ObjAwareModel=create_model()
image_encoder=sam_model_registry[args.encoder_type](encoder_path[args.encoder_type])
mobilesamv2.image_encoder=image_encoder
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[args.model_type](model_path[args.model_type])
sam.to(device=device)
sam.eval()
predictor = SamPredictor(sam)
mobilesamv2.to(device=device)
mobilesamv2.eval()
predictor = SamPredictor(mobilesamv2)
image_files= os.listdir(args.img_path)
for image_name in image_files:
print(image_name)
image = cv2.imread(args.img_path + image_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
everything_results = promtmodel(image,device=device,retina_masks=args.retina,imgsz=args.imgsz,conf=args.conf,iou=args.iou)
obj_results = ObjAwareModel(image,device=device,retina_masks=args.retina,imgsz=args.imgsz,conf=args.conf,iou=args.iou)
predictor.set_image(image)
input_boxes1 = everything_results[0].boxes.xyxy
input_boxes1 = obj_results[0].boxes.xyxy
input_boxes = input_boxes1.cpu().numpy()
input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size)
input_boxes = torch.from_numpy(input_boxes).cuda()
sam_mask=[]
image_embedding=predictor.features
image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
prompt_embedding=sam.prompt_encoder.get_dense_pe()
prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
for (boxes,) in batch_iterator(320, input_boxes):
with torch.no_grad():
image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
points=None,
boxes=boxes,
masks=None,)
low_res_masks, _ = sam.mask_decoder(
low_res_masks, _ = mobilesamv2.mask_decoder(
image_embeddings=image_embedding,
image_pe=prompt_embedding,
sparse_prompt_embeddings=sparse_embeddings,
Expand All @@ -95,10 +102,9 @@ def main(args):
simple_type=True,
)
low_res_masks=predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size)
sam_mask_pre = (low_res_masks > sam.mask_threshold)*1.0
sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)*1.0
sam_mask.append(sam_mask_pre.squeeze(1))
sam_mask=torch.cat(sam_mask)
everything_results[0].masks.data=sam_mask
annotation = sam_mask
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=True)
Expand All @@ -110,6 +116,7 @@ def main(args):
plt.axis('off')
plt.show()
plt.savefig("{}".format(output_dir+image_name), bbox_inches='tight', pad_inches = 0.0)

if __name__ == "__main__":
args = parse_args()
main(args)
Binary file not shown.
3 changes: 1 addition & 2 deletions MobileSAMv2/experiments/mobilesamv2.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python Inference.py \
--img_path './test_images/' \
--yolo_model_path './weight/PromtModel.pt' \
--output_dir './' \
--model_type 'efficientvit_l2' \
--encoder_type 'efficientvit_l2' \
Loading

0 comments on commit 0e44119

Please sign in to comment.