Skip to content

Commit

Permalink
update mobilesamv2
Browse files Browse the repository at this point in the history
  • Loading branch information
qiaoyu1002 committed Dec 19, 2023
1 parent 5cb2298 commit 9258b4d
Show file tree
Hide file tree
Showing 221 changed files with 36,989 additions and 1 deletion.
1 change: 0 additions & 1 deletion MobileSAMv2
Submodule MobileSAMv2 deleted from bddde3
115 changes: 115 additions & 0 deletions MobileSAMv2/Inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
import ast
import torch
from PIL import Image
import cv2
import os
import sys
from mobilesamv2_segment_anything.promt_mobilesamv2 import PromptModel
from mobilesamv2_segment_anything import sam_model_registry, SamPredictor
from typing import Any, Dict, Generator,List
import matplotlib.pyplot as plt
import numpy as np
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--yolo_model_path", type=str, default="./PromtModel.pt", help="yolo_model_path")
parser.add_argument("--model_path", type=str, default="./", help="model")
parser.add_argument("--sam_path", type=str, default="sam_vit_h_4b8939.pth", help="sam_path")
parser.add_argument("--img_path", type=str, default="./test_images/", help="path to image file")
parser.add_argument("--imgsz", type=int, default=1024, help="image size")
parser.add_argument("--iou",type=float,default=0.9,help="yolo iou")
parser.add_argument("--conf", type=float, default=0.4, help="yolo object confidence threshold")
parser.add_argument("--retina",type=bool,default=True,help="draw segmentation masks",)
parser.add_argument("--output_dir", type=str, default="./", help="image save path")
parser.add_argument("--model_type", choices=['tiny_vit','vit_h','mobile_sam','efficientvit_l2','efficientvit_l1','efficientvit_l0'], help="choose the model type")
return parser.parse_args()

def show_anns(anns):
if len(anns) == 0:
return
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((anns.shape[1], anns.shape[2], 4))
img[:,:,3] = 0
for ann in range(anns.shape[0]):
m = anns[ann].bool()
m=m.cpu().numpy()
color_mask = np.concatenate([np.random.random(3), [1]])
img[m] = color_mask
ax.imshow(img)

def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]

model_path={'efficientvit_l0':'./weight/l0.pt',
'efficientvit_l1':'./weight/l1.pt',
'efficientvit_l2':'./weight/l2.pt',
'tiny_vit':'./weight/mobile_sam.pt',
'vit_h':'./weight/sam_vit_h_4b8939.pth'
}

def main(args):
# import pdb;pdb.set_trace()
promtmodel = PromptModel(args.yolo_model_path)
output_dir=args.output_dir
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[args.model_type](model_path[args.model_type])
sam.to(device=device)
sam.eval()
predictor = SamPredictor(sam)
image_files= os.listdir(args.img_path)
for image_name in image_files:
print(image_name)
image = cv2.imread(args.img_path + image_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
everything_results = promtmodel(image,device=device,retina_masks=args.retina,imgsz=args.imgsz,conf=args.conf,iou=args.iou)
predictor.set_image(image)
input_boxes1 = everything_results[0].boxes.xyxy
input_boxes = input_boxes1.cpu().numpy()
input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size)
input_boxes = torch.from_numpy(input_boxes).cuda()
sam_mask=[]
image_embedding=predictor.features
image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
prompt_embedding=sam.prompt_encoder.get_dense_pe()
prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
for (boxes,) in batch_iterator(320, input_boxes):
with torch.no_grad():
image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
points=None,
boxes=boxes,
masks=None,)
low_res_masks, _ = sam.mask_decoder(
image_embeddings=image_embedding,
image_pe=prompt_embedding,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
simple_type=True,
)
low_res_masks=predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size)
sam_mask_pre = (low_res_masks > sam.mask_threshold)*1.0
sam_mask.append(sam_mask_pre.squeeze(1))
sam_mask=torch.cat(sam_mask)
everything_results[0].masks.data=sam_mask
annotation = sam_mask
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=True)
show_img = annotation[sorted_indices]
plt.figure(figsize=(20,20))
background=np.ones_like(image)*255
plt.imshow(background)
show_anns(show_img)
plt.axis('off')
plt.show()
plt.savefig("{}".format(output_dir+image_name), bbox_inches='tight', pad_inches = 0.0)
if __name__ == "__main__":
args = parse_args()
main(args)
Empty file.
Empty file.
7 changes: 7 additions & 0 deletions MobileSAMv2/efficientvit/apps/data_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

from .augment import *
from .base import *
from .random_resolution import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

from .bbox import *
from .color_aug import *
30 changes: 30 additions & 0 deletions MobileSAMv2/efficientvit/apps/data_provider/augment/bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import numpy as np

__all__ = ["rand_bbox"]


def rand_bbox(
h: int,
w: int,
lam: float,
rand_func: callable = np.random.uniform,
) -> tuple[int, int, int, int]:
"""randomly sample bbox, used in cutmix"""
cut_rat = np.sqrt(1.0 - lam)
cut_w = w * cut_rat
cut_h = h * cut_rat

# uniform
cx = rand_func(0, w)
cy = rand_func(0, h)

bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
bby1 = int(np.clip(cy - cut_h / 2, 0, h))
bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
bby2 = int(np.clip(cy + cut_h / 2, 0, h))

return bbx1, bby1, bbx2, bby2
78 changes: 78 additions & 0 deletions MobileSAMv2/efficientvit/apps/data_provider/augment/color_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from timm.data.auto_augment import rand_augment_transform

__all__ = ["ColorAug", "RandAug"]


class ImageAug:
def aug_image(self, image: Image.Image) -> Image.Image:
raise NotImplementedError

def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
if isinstance(feed_dict, dict):
output_dict = feed_dict
image = feed_dict[self.key]
else:
output_dict = None
image = feed_dict
is_ndarray = isinstance(image, np.ndarray)
if is_ndarray:
image = Image.fromarray(image)

image = self.aug_image(image)

if is_ndarray:
image = np.array(image)

if output_dict is None:
return image
else:
output_dict[self.key] = image
return output_dict


class ColorAug(transforms.ColorJitter, ImageAug):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
super().__init__(
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,
)
self.key = key

def aug_image(self, image: Image.Image) -> Image.Image:
return transforms.ColorJitter.forward(self, image)

def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
return ImageAug.__call__(self, feed_dict)


class RandAug(ImageAug):
def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"):
n = config.get("n", 2)
m = config.get("m", 9)
mstd = config.get("mstd", 1.0)
inc = config.get("inc", 1)
tpct = config.get("tpct", 0.45)
config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"

aa_params = dict(
translate_pct=tpct,
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
interpolation=Image.BICUBIC,
)
self.aug_op = rand_augment_transform(config_str, aa_params)
self.key = key

def aug_image(self, image: Image.Image) -> Image.Image:
return self.aug_op(image)

def __repr__(self):
return self.aug_op.__repr__()
Loading

0 comments on commit 9258b4d

Please sign in to comment.