-
Notifications
You must be signed in to change notification settings - Fork 512
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5cb2298
commit 9258b4d
Showing
221 changed files
with
36,989 additions
and
1 deletion.
There are no files selected for viewing
Submodule MobileSAMv2
deleted from
bddde3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import argparse | ||
import ast | ||
import torch | ||
from PIL import Image | ||
import cv2 | ||
import os | ||
import sys | ||
from mobilesamv2_segment_anything.promt_mobilesamv2 import PromptModel | ||
from mobilesamv2_segment_anything import sam_model_registry, SamPredictor | ||
from typing import Any, Dict, Generator,List | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--yolo_model_path", type=str, default="./PromtModel.pt", help="yolo_model_path") | ||
parser.add_argument("--model_path", type=str, default="./", help="model") | ||
parser.add_argument("--sam_path", type=str, default="sam_vit_h_4b8939.pth", help="sam_path") | ||
parser.add_argument("--img_path", type=str, default="./test_images/", help="path to image file") | ||
parser.add_argument("--imgsz", type=int, default=1024, help="image size") | ||
parser.add_argument("--iou",type=float,default=0.9,help="yolo iou") | ||
parser.add_argument("--conf", type=float, default=0.4, help="yolo object confidence threshold") | ||
parser.add_argument("--retina",type=bool,default=True,help="draw segmentation masks",) | ||
parser.add_argument("--output_dir", type=str, default="./", help="image save path") | ||
parser.add_argument("--model_type", choices=['tiny_vit','vit_h','mobile_sam','efficientvit_l2','efficientvit_l1','efficientvit_l0'], help="choose the model type") | ||
return parser.parse_args() | ||
|
||
def show_anns(anns): | ||
if len(anns) == 0: | ||
return | ||
ax = plt.gca() | ||
ax.set_autoscale_on(False) | ||
img = np.ones((anns.shape[1], anns.shape[2], 4)) | ||
img[:,:,3] = 0 | ||
for ann in range(anns.shape[0]): | ||
m = anns[ann].bool() | ||
m=m.cpu().numpy() | ||
color_mask = np.concatenate([np.random.random(3), [1]]) | ||
img[m] = color_mask | ||
ax.imshow(img) | ||
|
||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: | ||
assert len(args) > 0 and all( | ||
len(a) == len(args[0]) for a in args | ||
), "Batched iteration must have inputs of all the same size." | ||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) | ||
for b in range(n_batches): | ||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] | ||
|
||
model_path={'efficientvit_l0':'./weight/l0.pt', | ||
'efficientvit_l1':'./weight/l1.pt', | ||
'efficientvit_l2':'./weight/l2.pt', | ||
'tiny_vit':'./weight/mobile_sam.pt', | ||
'vit_h':'./weight/sam_vit_h_4b8939.pth' | ||
} | ||
|
||
def main(args): | ||
# import pdb;pdb.set_trace() | ||
promtmodel = PromptModel(args.yolo_model_path) | ||
output_dir=args.output_dir | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
sam = sam_model_registry[args.model_type](model_path[args.model_type]) | ||
sam.to(device=device) | ||
sam.eval() | ||
predictor = SamPredictor(sam) | ||
image_files= os.listdir(args.img_path) | ||
for image_name in image_files: | ||
print(image_name) | ||
image = cv2.imread(args.img_path + image_name) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
everything_results = promtmodel(image,device=device,retina_masks=args.retina,imgsz=args.imgsz,conf=args.conf,iou=args.iou) | ||
predictor.set_image(image) | ||
input_boxes1 = everything_results[0].boxes.xyxy | ||
input_boxes = input_boxes1.cpu().numpy() | ||
input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size) | ||
input_boxes = torch.from_numpy(input_boxes).cuda() | ||
sam_mask=[] | ||
image_embedding=predictor.features | ||
image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0) | ||
prompt_embedding=sam.prompt_encoder.get_dense_pe() | ||
prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0) | ||
for (boxes,) in batch_iterator(320, input_boxes): | ||
with torch.no_grad(): | ||
image_embedding=image_embedding[0:boxes.shape[0],:,:,:] | ||
prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:] | ||
sparse_embeddings, dense_embeddings = sam.prompt_encoder( | ||
points=None, | ||
boxes=boxes, | ||
masks=None,) | ||
low_res_masks, _ = sam.mask_decoder( | ||
image_embeddings=image_embedding, | ||
image_pe=prompt_embedding, | ||
sparse_prompt_embeddings=sparse_embeddings, | ||
dense_prompt_embeddings=dense_embeddings, | ||
multimask_output=False, | ||
simple_type=True, | ||
) | ||
low_res_masks=predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size) | ||
sam_mask_pre = (low_res_masks > sam.mask_threshold)*1.0 | ||
sam_mask.append(sam_mask_pre.squeeze(1)) | ||
sam_mask=torch.cat(sam_mask) | ||
everything_results[0].masks.data=sam_mask | ||
annotation = sam_mask | ||
areas = torch.sum(annotation, dim=(1, 2)) | ||
sorted_indices = torch.argsort(areas, descending=True) | ||
show_img = annotation[sorted_indices] | ||
plt.figure(figsize=(20,20)) | ||
background=np.ones_like(image)*255 | ||
plt.imshow(background) | ||
show_anns(show_img) | ||
plt.axis('off') | ||
plt.show() | ||
plt.savefig("{}".format(output_dir+image_name), bbox_inches='tight', pad_inches = 0.0) | ||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | ||
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | ||
# International Conference on Computer Vision (ICCV), 2023 | ||
|
||
from .augment import * | ||
from .base import * | ||
from .random_resolution import * |
6 changes: 6 additions & 0 deletions
6
MobileSAMv2/efficientvit/apps/data_provider/augment/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | ||
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | ||
# International Conference on Computer Vision (ICCV), 2023 | ||
|
||
from .bbox import * | ||
from .color_aug import * |
30 changes: 30 additions & 0 deletions
30
MobileSAMv2/efficientvit/apps/data_provider/augment/bbox.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | ||
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | ||
# International Conference on Computer Vision (ICCV), 2023 | ||
|
||
import numpy as np | ||
|
||
__all__ = ["rand_bbox"] | ||
|
||
|
||
def rand_bbox( | ||
h: int, | ||
w: int, | ||
lam: float, | ||
rand_func: callable = np.random.uniform, | ||
) -> tuple[int, int, int, int]: | ||
"""randomly sample bbox, used in cutmix""" | ||
cut_rat = np.sqrt(1.0 - lam) | ||
cut_w = w * cut_rat | ||
cut_h = h * cut_rat | ||
|
||
# uniform | ||
cx = rand_func(0, w) | ||
cy = rand_func(0, h) | ||
|
||
bbx1 = int(np.clip(cx - cut_w / 2, 0, w)) | ||
bby1 = int(np.clip(cy - cut_h / 2, 0, h)) | ||
bbx2 = int(np.clip(cx + cut_w / 2, 0, w)) | ||
bby2 = int(np.clip(cy + cut_h / 2, 0, h)) | ||
|
||
return bbx1, bby1, bbx2, bby2 |
78 changes: 78 additions & 0 deletions
78
MobileSAMv2/efficientvit/apps/data_provider/augment/color_aug.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | ||
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | ||
# International Conference on Computer Vision (ICCV), 2023 | ||
|
||
import numpy as np | ||
import torchvision.transforms as transforms | ||
from PIL import Image | ||
from timm.data.auto_augment import rand_augment_transform | ||
|
||
__all__ = ["ColorAug", "RandAug"] | ||
|
||
|
||
class ImageAug: | ||
def aug_image(self, image: Image.Image) -> Image.Image: | ||
raise NotImplementedError | ||
|
||
def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: | ||
if isinstance(feed_dict, dict): | ||
output_dict = feed_dict | ||
image = feed_dict[self.key] | ||
else: | ||
output_dict = None | ||
image = feed_dict | ||
is_ndarray = isinstance(image, np.ndarray) | ||
if is_ndarray: | ||
image = Image.fromarray(image) | ||
|
||
image = self.aug_image(image) | ||
|
||
if is_ndarray: | ||
image = np.array(image) | ||
|
||
if output_dict is None: | ||
return image | ||
else: | ||
output_dict[self.key] = image | ||
return output_dict | ||
|
||
|
||
class ColorAug(transforms.ColorJitter, ImageAug): | ||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"): | ||
super().__init__( | ||
brightness=brightness, | ||
contrast=contrast, | ||
saturation=saturation, | ||
hue=hue, | ||
) | ||
self.key = key | ||
|
||
def aug_image(self, image: Image.Image) -> Image.Image: | ||
return transforms.ColorJitter.forward(self, image) | ||
|
||
def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: | ||
return ImageAug.__call__(self, feed_dict) | ||
|
||
|
||
class RandAug(ImageAug): | ||
def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"): | ||
n = config.get("n", 2) | ||
m = config.get("m", 9) | ||
mstd = config.get("mstd", 1.0) | ||
inc = config.get("inc", 1) | ||
tpct = config.get("tpct", 0.45) | ||
config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}" | ||
|
||
aa_params = dict( | ||
translate_pct=tpct, | ||
img_mean=tuple([min(255, round(255 * x)) for x in mean]), | ||
interpolation=Image.BICUBIC, | ||
) | ||
self.aug_op = rand_augment_transform(config_str, aa_params) | ||
self.key = key | ||
|
||
def aug_image(self, image: Image.Image) -> Image.Image: | ||
return self.aug_op(image) | ||
|
||
def __repr__(self): | ||
return self.aug_op.__repr__() |
Oops, something went wrong.