Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/localization webservice #10

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/deployment/localization/bentofile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
service: "service.py:svc"
include:
- "service.py"
- "yolov5s.pt"
python:
requirements_txt: "./requirements.txt"
docker:
system_packages:
- libsm6
- libxext6
12 changes: 12 additions & 0 deletions src/deployment/localization/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Pillow>=7.1.2
PyYAML>=5.3.1
bentoml>=1.0.0
matplotlib>=3.2.2
numpy>=1.18.5
pandas>=1.1.4
protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
ipython
torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395
torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395
fastapi
pydantic
286 changes: 286 additions & 0 deletions src/deployment/localization/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
import logging
import os
import pathlib
import tempfile
import urllib.request

import bentoml
import bentoml.io
import PIL.Image
import torch
import torchvision
import torchvision.datapoints
import torchvision.models.detection.anchor_utils
import torchvision.models.detection.backbone_utils
import torchvision.models.detection.faster_rcnn
import torchvision.models.mobilenetv3
import torchvision.transforms.v2 as T
from fastapi import FastAPI
from pydantic import BaseModel

tempdir = tempfile.TemporaryDirectory(prefix="mothml")
os.environ["LOCAL_WEIGHTS_PATH"] = tempdir.name

CHECKPOINT = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_mobilenet_v3_large_fpn_uqfh7u9w.pt"
SCORE_THRESHOLD = 0.0

logger = logging.getLogger(__name__)
BaseModel.model_config["protected_namespaces"] = ()


class LabelStudioTask(BaseModel):
data: dict
id: int


# Label studio task result
example_result = (
{
"from_name": "detected_object",
"to_name": "image",
"type": "rectangle",
"value": {
"x": 66.12516045570374,
"y": 74.68075222439236,
"width": 1.3237595558166504,
"height": 2.0854356553819446,
},
"score": 0.8321816325187683,
},
)


def get_or_download_file(path, destination_dir=None, prefix=None) -> pathlib.Path:
"""
>>> filename, headers = get_weights("https://drive.google.com/file/d/1KdQc56WtnMWX9PUapy6cS0CdjC8VSdVe/view?usp=sharing") # noqa: E501

Taken from https://github.com/RolnickLab/ami-data-companion/blob/main/trapdata/ml/utils.py
"""
if not path:
raise Exception("Specify a URL or path to fetch file from.")

# If path is a local path instead of a URL, urlretrieve will just return the path
destination_dir = destination_dir or os.environ.get("LOCAL_WEIGHTS_PATH")
fname = path.rsplit("/", 1)[-1]
if destination_dir:
destination_dir = pathlib.Path(destination_dir)
if prefix:
destination_dir = destination_dir / prefix
if not destination_dir.exists():
logger.info(f"Creating local directory {str(destination_dir)}")
destination_dir.mkdir(parents=True, exist_ok=True)
local_filepath = pathlib.Path(destination_dir) / fname
else:
raise Exception(
"No destination directory specified by LOCAL_WEIGHTS_PATH or app settings."
)

if local_filepath and local_filepath.exists():
logger.info(f"Using existing {local_filepath}")
return local_filepath

else:
logger.info(f"Downloading {path} to {destination_dir}")
resulting_filepath, headers = urllib.request.urlretrieve(
url=path, filename=local_filepath
)
resulting_filepath = pathlib.Path(resulting_filepath)
logger.info(f"Downloaded to {resulting_filepath}")
return resulting_filepath


class LabelStudioRequest(BaseModel):
tasks: list[LabelStudioTask]
model_version: str
project: str
label_config: str
params: dict


def load_model_scratch(
checkpoint_path: str | pathlib.Path,
trainable_backbone_layers: int = 6,
anchor_sizes: tuple = (64, 128, 256, 512),
num_classes: int = 2,
device: str | torch.device = "cpu",
):
norm_layer = torch.nn.BatchNorm2d
backbone = torchvision.models.mobilenetv3.mobilenet_v3_large(
weights=None, norm_layer=norm_layer
)
backbone = torchvision.models.detection.backbone_utils._mobilenet_extractor(
backbone, True, trainable_backbone_layers
)
anchor_sizes = (anchor_sizes,) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = torchvision.models.detection.faster_rcnn.FasterRCNN(
backbone,
num_classes,
rpn_anchor_generator=torchvision.models.detection.anchor_utils.AnchorGenerator(
anchor_sizes, aspect_ratios
),
rpn_score_thresh=0.05,
)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
return model


def post_process_single(output: dict) -> tuple[list, list, list]:
scores = output["scores"].cpu().detach().numpy().tolist()
labels = output["labels"].cpu().detach().numpy().tolist()

# This model does not use the labels from the object detection model
assert all([label == 1 for label in labels])

# Filter out objects if their score is under score threshold
bboxes = output["boxes"][output["scores"] > SCORE_THRESHOLD]

print(
f"Keeping {len(bboxes)} out of {len(output['boxes'])} objects found "
f"(threshold: {SCORE_THRESHOLD})"
)

bboxes = bboxes.cpu().detach().numpy().tolist()
return bboxes, labels, scores


def format_predictions_single(image: PIL.Image.Image, bboxes, scores) -> list[dict]:
width, height = image.size
return [
{
"from_name": "detected_object",
"to_name": "image",
"type": "rectangle",
"value": {
"x": bbox[0] / width * 100,
"y": bbox[1] / height * 100,
"width": (bbox[2] - bbox[0]) / width * 100,
"height": (bbox[3] - bbox[1]) / height * 100,
"rotation": 0,
},
"score": score,
"original_width": width,
"original_height": height,
"image_rotation": 0,
}
for bbox, score in zip(bboxes, scores)
]


class MothDetectionRunner(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
self.model = load_model_scratch(
checkpoint_path=get_or_download_file(CHECKPOINT),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

def transform(self) -> T.Compose:
return T.Compose([T.ToImageTensor(), T.ConvertImageDtype()])

@bentoml.Runnable.method(batchable=True, batch_dim=0)
@torch.no_grad()
def inference(self, input_img_paths):
input_imgs = [
get_or_download_file(path) for path in input_img_paths if path is not None
]
input_imgs = [PIL.Image.open(path) for path in input_imgs]
input_imgs_t = self.transform()(input_imgs)
results = self.model(input_imgs_t)
results = [post_process_single(result) for result in results]
predictions = [
format_predictions_single(image, bboxes, scores)
for image, (bboxes, _, scores) in zip(input_imgs, results)
]
return predictions

@bentoml.Runnable.method(batchable=True, batch_dim=0)
def render(self, input_imgs):
# Return images with boxes and labels
to_tensor = T.Compose([T.ToImageTensor(), T.ToDtype(torch.uint8)])
input_imgs_t = self.transform()(input_imgs)
results = self.model(input_imgs_t)
draw = torchvision.utils.draw_bounding_boxes
to_image = torchvision.transforms.ToPILImage()
out_imgs = [
draw(to_tensor(image), output["boxes"])
for image, output in zip(input_imgs, results)
]
# overlay bounding boxes on original image
out_imgs = [
to_image(img * 0.5 + image * 0.5)
for img, image in zip(out_imgs, input_imgs_t)
]

return out_imgs


moth_detection_runner = bentoml.Runner(MothDetectionRunner, max_batch_size=30)

svc = bentoml.Service("moth_detector", runners=[moth_detection_runner])


@svc.on_startup
def download(_: bentoml.Context):
get_or_download_file(CHECKPOINT)


@svc.api(input=bentoml.io.Image(), output=bentoml.io.JSON())
async def invocation(input_img):
batch_ret = await moth_detection_runner.inference.async_run([input_img])
return batch_ret[0]


@svc.api(input=bentoml.io.Image(), output=bentoml.io.Image())
async def render(input_img):
batch_ret = await moth_detection_runner.render.async_run([input_img])
return batch_ret[0]


input_spec = bentoml.io.JSON(pydantic_model=LabelStudioRequest)
output_spec = bentoml.io.JSON()


@svc.api(
input=input_spec,
output=output_spec,
)
async def predict(input_data):
tasks = input_data.tasks
if len(tasks) > 1:
raise Exception("Only one task per request is supported")
image_paths = [task.data["image"] for task in tasks]
task_predictions = await moth_detection_runner.inference.async_run(image_paths)
resp = {
"results": [{"result": task_prediction} for task_prediction in task_predictions]
}
print(resp)
return resp


fastapi_app = FastAPI()
svc.mount_asgi_app(fastapi_app)


# Health check endpoint to match what Label Studio expects
@fastapi_app.get("/health")
def health():
return {
"status": "UP",
"model_class": "FasterRCNN_MobileNetV3_Large_FPN",
}


# Setup endpoint to match what Label Studio expects
# https://github.com/HumanSignal/label-studio-ml-backend/blob/master/label_studio_ml/api.py#L65
@fastapi_app.post("/setup")
def setup():
return {
"model_version": "uqfh7u9w",
}