Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript / Pytorch Mobile Support #112

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ python scripts/export_onnx_model.py --checkpoint ./weights/mobile_sam.pt --model
Also check the [example notebook](https://github.com/ChaoningZhang/MobileSAM/blob/master/notebooks/onnx_model_example.ipynb) to follow detailed steps.
We recommend to use `onnx==1.12.0` and `onnxruntime==1.13.1` which is tested.

## Pytorch Mobile Export

**MobileSAM** can now be run on Pytorch Mobile. Export the model with

```
python ./scripts/convert_pytorch_mobile.py output_dir
```

The result can be loaded as described in https://pytorch.org/tutorials/prototype/ios_gpu_workflow.html

BUT: The current version only runs on CPU on Pytorch Mobile. The metal backend is missing strided convolution as it seems.

The caller still needs to provide input scaling and normalization, as it is done in
[set_image()](https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/predictor.py) in the predictor example.

## BibTex of our MobileSAM
If you use MobileSAM in your research, please use the following BibTeX entry. :mega: Thank you!
Expand Down
10 changes: 4 additions & 6 deletions mobile_sam/modeling/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ def forward(
)

# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
mask_slice = slice(1 if multimask_output else 0, None if multimask_output else 1)

masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]

Expand Down Expand Up @@ -137,8 +135,8 @@ def predict_masks(
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
for i, output_hypernetwork_mlp in enumerate(self.output_hypernetworks_mlps): # range(self.num_mask_tokens):
hyper_in_list.append(output_hypernetwork_mlp(mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
Expand Down
2 changes: 1 addition & 1 deletion mobile_sam/modeling/prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
device: torch.device = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
Expand Down
107 changes: 66 additions & 41 deletions mobile_sam/modeling/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from .prompt_encoder import PromptEncoder


MASK_THRESHOLD_DEFAULT: float = 0.0
IMAGE_FORMAT_DEFAULT: float = "RGB"

class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"

def __init__(
self,
Expand All @@ -27,6 +28,8 @@ def __init__(
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
mask_threshold=MASK_THRESHOLD_DEFAULT,
image_format=IMAGE_FORMAT_DEFAULT
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Expand All @@ -46,15 +49,16 @@ def __init__(
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
self.mask_threshold = mask_threshold
self.image_format = image_format

@property
def device(self) -> Any:
return self.pixel_mean.device

@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
batched_input: List[Dict[str, Union[torch.Tensor, Tuple[int, int]]]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Expand Down Expand Up @@ -95,47 +99,68 @@ def forward(
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)

outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs
with torch.no_grad():
input_images_list = []
for x in batched_input:
img = x["image"] # Needed for Torchscript support
assert isinstance(img, torch.Tensor)
processed_image = self.preprocess(torch.jit.annotate(torch.Tensor, img))
input_images_list.append(processed_image)

input_images = torch.stack(input_images_list, dim=0)
image_embeddings = self.image_encoder(input_images)

outputs: List[Dict[str, torch.Tensor]] = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
boxes = image_record["boxes"] if "boxes" in image_record else None
assert isinstance(boxes, Optional[torch.Tensor])
boxes = torch.jit.annotate(Optional[torch.Tensor], boxes)
masks = image_record["mask_inputs"] if "mask_inputs" in image_record else None
assert isinstance(masks, Optional[torch.Tensor])
if "point_coords" in image_record:
pc = image_record["point_coords"]
assert isinstance(pc, torch.Tensor)
pl = image_record["point_labels"]
assert isinstance(pl, torch.Tensor)
points = (pc, pl)
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=boxes,
masks=masks,
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
orig_size = image_record["original_size"]
assert isinstance(orig_size, Tuple[int, int])
img = image_record["image"]
assert isinstance(img, torch.Tensor)
masks = self.postprocess_masks(
low_res_masks,
input_size=img.shape[-2:],
original_size=orig_size,
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs

def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
input_size: List[int],
original_size: Tuple[int, int],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Expand Down
66 changes: 41 additions & 25 deletions mobile_sam/modeling/tiny_vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def __init__(self, in_chans, out_chans, expand_ratio,
activation, drop_path):
super().__init__()
self.in_chans = in_chans
assert self.in_chans > 0
self.hidden_chans = int(in_chans * expand_ratio)
assert self.hidden_chans > 0
self.out_chans = out_chans
assert self.out_chans > 0

self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
Expand Down Expand Up @@ -177,7 +180,7 @@ def __init__(self, dim, input_resolution, depth,

def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
if self.use_checkpoint and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
Expand Down Expand Up @@ -335,7 +338,9 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7,
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert L == H * W, f"input feature has wrong size: {L} != {H} * {W}"
assert H > 0, "height is 0"
assert W > 0, "width is 0"
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
Expand All @@ -346,9 +351,17 @@ def forward(self, x):
pad_r = (self.window_size - W %
self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0

if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

# Alternative to the above (pytorch lite doesn't come with F.pad on metal):
# if pad_b > 0:
# pad_tensor_b = torch.empty(size=(B, pad_b, W, C), dtype=x.dtype, device=x.device)
# x = torch.cat([x, pad_tensor_b], dim=1) # Concatenate it to the bottom of the height dimension
#
# if pad_r > 0:
# pad_tensor_r = torch.empty(size=(B, H + pad_b, pad_r, C), dtype=x.dtype, device=x.device)
# x = torch.cat([x, pad_tensor_r], dim=2)

pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
Expand Down Expand Up @@ -435,7 +448,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size,

def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
if self.use_checkpoint and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
Expand All @@ -446,6 +459,7 @@ def forward(self, x):
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"


class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
Expand All @@ -459,6 +473,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


class TinyViT(nn.Module):
def __init__(self, img_size=224, in_chans=3, num_classes=1000,
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
Expand Down Expand Up @@ -496,24 +512,18 @@ def __init__(self, img_size=224, in_chans=3, num_classes=1000,
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(
i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
else:
layer = BasicLayer(
Expand All @@ -522,7 +532,15 @@ def __init__(self, img_size=224, in_chans=3, num_classes=1000,
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs)
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint,
out_dim=embed_dims[min( i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
self.layers.append(layer)

# Classifier head
Expand Down Expand Up @@ -600,13 +618,11 @@ def no_weight_decay_keywords(self):
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)

x = self.layers[0](x)
start_i = 1

for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
for i, layer in enumerate(self.layers[1:]): # range(start_i, len(self.layers)):
x = layer.forward(x)
B,_,C=x.size()
x = x.view(B, 64, 64, C)
x=x.permute(0, 3, 1, 2)
Expand Down
Loading