Skip to content

Commit

Permalink
responding to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Huff authored and Tim Huff committed Aug 5, 2023
1 parent 2ef2230 commit 351c7c1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packages = [
{include = "**/*.py", from = "src"},
]
readme = "README.md"
version = "0.10.0"
version = "0.10.1"

[tool.poetry.dependencies]
certifi = "^2021.10.8"
Expand Down
69 changes: 62 additions & 7 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def submit_image_query(
image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray],
wait: Optional[float] = None,
human_review: Optional[bool] = True,
inspection_id: Optional[str] = None,
) -> ImageQuery:
"""Evaluates an image with Groundlight.
:param detector: the Detector object, or string id of a detector like `det_12345`
Expand All @@ -184,22 +185,49 @@ def submit_image_query(
converted to JPEG at high quality before sending to service.
:param wait: How long to wait (in seconds) for a confident answer.
:param human_review: If set to False, do not escalate for human review
:param inspection_id: Most users will omit this. For accounts with Inspection Reports enabled,
this is the ID of the inspection to associate with the image query.
"""
if wait is None:
wait = self.DEFAULT_WAIT
detector_id = detector.id if isinstance(detector, Detector) else detector

# Convert from Detector to detector_id if necessary
if isinstance(detector, Detector):
detector_id = self.get_detector(detector)
else:
detector_id = detector

image_bytesio: ByteStreamWrapper = parse_supported_image_types(image)

raw_image_query = self.image_queries_api.submit_image_query(
detector_id=detector_id, patience_time=wait, human_review=human_review, body=image_bytesio
)
image_query = ImageQuery.parse_obj(raw_image_query.to_dict())
# Submit Image Query
# If no inspection_id is provided, we submit the image query using image_queries_api (autogenerated via OpenAPI)
# However, our autogenerated code does not support inspection_id, so if an inspection_id was provided, we use
# the private API client instead.
if inspection_id is None:
raw_image_query = self.image_queries_api.submit_image_query(
detector_id=detector_id,
patience_time=wait,
human_review=human_review,
body=image_bytesio
)
image_query_dict = raw_image_query.to_dict()
image_query = ImageQuery.parse_obj(image_query_dict)
else:
print(f"Using private API client to submit image query with human_review: {human_review}")
iq_id = self.api_client.submit_image_query_with_inspection(
detector_id=detector_id,
patience_time=wait,
human_review=human_review,
image=image_bytesio,
inspection_id=inspection_id
)
image_query = self.get_image_query(iq_id)

if wait:
threshold = self.get_detector(detector).confidence_threshold
image_query = self.wait_for_confident_result(image_query, confidence_threshold=threshold, timeout_sec=wait)
return self._fixup_image_query(image_query)

def wait_for_confident_result(
self,
image_query: ImageQuery,
Expand All @@ -212,7 +240,10 @@ def wait_for_confident_result(
:param confidence_threshold: The minimum confidence level required to return before the timeout.
:param timeout_sec: The maximum number of seconds to wait.
"""
# TODO: Add support for ImageQuery id instead of object.
# Convert from image_query_id to ImageQuery if needed.
if isinstance(image_query, str):
image_query = self.get_image_query(image_query)

start_time = time.time()
next_delay = self.POLLING_INITIAL_DELAY
target_delay = 0.0
Expand Down Expand Up @@ -253,3 +284,27 @@ def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str
api_label = convert_display_label_to_internal(image_query_id, label)

return self.api_client._add_label(image_query_id, api_label) # pylint: disable=protected-access

def start_inspection(self) -> str:
"""For users with Inspection Reports enabled only.
Starts an inspection report and returns the id of the inspection.
"""
return self.api_client.start_inspection()

def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None:
"""For users with Inspection Reports enabled only.
Add/update inspection metadata with the user_provided_key and user_provided_value.
"""
self.api_client.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value)

def stop_inspection(self, inspection_id: str) -> str:
"""For users with Inspection Reports enabled only.
Stops an inspection and raises an exception if the response from the server
indicates that the inspection was not successfully stopped.
Returns a str with result of the inspection (either PASS or FAIL)
"""
return self.api_client.stop_inspection(inspection_id)

def update_detector_confidence_threshold(self, detector_id: str, confidence_threshold: float) -> None:
"""Updates the confidence threshold of a detector given a detector_id."""
self.api_client.update_detector_confidence_threshold(detector_id, confidence_threshold)
62 changes: 40 additions & 22 deletions src/groundlight/internalapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,9 @@

logger = logging.getLogger("groundlight.sdk")


class NotFoundError(Exception):
pass


class InspectionError(Exception):
pass


class UpdateDetectorError(Exception):
pass


def sanitize_endpoint_url(endpoint: Optional[str] = None) -> str:
"""Takes a URL for an endpoint, and returns a "sanitized" version of it.
Currently the production API path must be exactly "/device-api".
Expand Down Expand Up @@ -237,16 +227,28 @@ def _get_detector_by_name(self, name: str) -> Detector:
return Detector.parse_obj(parsed["results"][0])

@RequestsRetryDecorator()
def submit_image_query_with_inspection(self, detector_id: str, image: ByteStreamWrapper, inspection_id: str) -> str:
def submit_image_query_with_inspection(self,
detector_id: str,
patience_time: float,
human_review: bool,
image: ByteStreamWrapper,
inspection_id: str) -> str:
"""Submits an image query to the API and returns the ID of the image query.
The image query will be associated to the inspection_id provided.
"""
url = f"{self.configuration.host}/posichecks?inspection_id={inspection_id}&predictor_id={detector_id}"

# In the API, "send_notification" was used to escalate image queries to cloud labelers.
send_notification = human_review

url = (f"{self.configuration.host}/posichecks"
f"?inspection_id={inspection_id}"
f"&predictor_id={detector_id}"
f"&send_notification={send_notification}")

headers = self._headers()
headers["Content-Type"] = "image/jpeg"

response = requests.request("POST", url, headers=headers, data=image.read())
response = requests.request("POST", url, headers=headers, timeout=patience_time, data=image.read())

if not is_ok(response.status_code):
logger.info(response)
Expand All @@ -268,12 +270,16 @@ def start_inspection(self) -> str:
response = requests.request("POST", url, headers=headers, json={})

if not is_ok(response.status_code):
raise InspectionError(f"Error starting inspection. Status code: {response.status_code}")
raise InternalApiError(
status=response.status_code,
reason="Error starting inspection.",
http_resp=response,
)

return response.json()["id"]

@RequestsRetryDecorator()
def update_inspection_metadata(self, inspection_id: str, user_provided_key, user_provided_value) -> None:
def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None:
"""Add/update inspection metadata with the user_provided_key and user_provided_value.
The API stores inspections metadata in two ways:
Expand All @@ -293,15 +299,17 @@ def update_inspection_metadata(self, inspection_id: str, user_provided_key, user
response = requests.request("GET", url, headers=headers)

if not is_ok(response.status_code):
raise InspectionError(
f"Error getting inspection details for inspection {inspection_id}. Status code: {response.status_code}"
raise InternalApiError(
status=response.status_code,
reason=f"Error getting inspection details for inspection {inspection_id}.",
http_resp=response,
)

payload = {}

# Set the user_provided_id_key and user_provided_id_value if they were not previously set.
response_json = response.json()
if not response_json["user_provided_id_key"]:
if not response_json.get("user_provided_id_key"):
payload["user_provided_id_key"] = user_provided_key
payload["user_provided_id_value"] = user_provided_value

Expand All @@ -316,8 +324,10 @@ def update_inspection_metadata(self, inspection_id: str, user_provided_key, user
response = requests.request("PATCH", url, headers=headers, json=payload)

if not is_ok(response.status_code):
raise InspectionError(
f"Error updating inspection metadata on inspection {inspection_id}. Status code: {response.status_code}"
raise InternalApiError(
status=response.status_code,
reason=f"Error updating inspection metadata on inspection {inspection_id}.",
http_resp=response,
)

@RequestsRetryDecorator()
Expand All @@ -334,7 +344,11 @@ def stop_inspection(self, inspection_id: str) -> str:
response = requests.request("PATCH", url, headers=headers, json=payload)

if not is_ok(response.status_code):
raise InspectionError(f"Error stopping inspection {inspection_id}. Status code: {response.status_code}")
raise InternalApiError(
status=response.status_code,
reason=f"Error stopping inspection {inspection_id}.",
http_resp=response,
)

return response.json()["result"]

Expand All @@ -356,4 +370,8 @@ def update_detector_confidence_threshold(self, detector_id: str, confidence_thre
response = requests.request("PATCH", url, headers=headers, json=payload)

if not is_ok(response.status_code):
raise UpdateDetectorError(f"Error updating detector {detector_id}. Status code: {response.status_code}")
raise InternalApiError(
status=response.status_code,
reason=f"Error updating detector {detector_id}.",
http_resp=response,
)

0 comments on commit 351c7c1

Please sign in to comment.