diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 7a287cd0..b0c02214 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -1,6 +1,6 @@ import os from copy import deepcopy -from typing import Dict, List, Tuple +from typing import List import cv2 import numpy as np @@ -105,9 +105,10 @@ def __init__( valid_providers = [] available_providers = ort.get_available_providers() - for provider in providers or []: - if provider in available_providers: - valid_providers.append(provider) + if providers: + for provider in providers or []: + if provider in available_providers: + valid_providers.append(provider) else: valid_providers.extend(available_providers) @@ -142,7 +143,16 @@ def predict( Returns: List[PILImage]: A list of masks generated by the decoder. """ - prompt = kwargs.get("sam_prompt", "{}") + prompt = kwargs.get( + "sam_prompt", + [ + { + "type": "point", + "label": 1, + "data": [int(img.width / 2), int(img.height / 2)], + } + ], + ) schema = { "type": "array", "items": {