diff --git a/pyproject.toml b/pyproject.toml index 5e7c1ffd..dbff7c63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ {include = "**/*.py", from = "src"}, ] readme = "README.md" -version = "0.10.0" +version = "0.10.1" [tool.poetry.dependencies] certifi = "^2021.10.8" diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 60d84eaa..463783e9 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -171,6 +171,7 @@ def submit_image_query( image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], wait: Optional[float] = None, human_review: Optional[bool] = True, + inspection_id: Optional[str] = None, ) -> ImageQuery: """Evaluates an image with Groundlight. :param detector: the Detector object, or string id of a detector like `det_12345` @@ -184,22 +185,49 @@ def submit_image_query( converted to JPEG at high quality before sending to service. :param wait: How long to wait (in seconds) for a confident answer. :param human_review: If set to False, do not escalate for human review + :param inspection_id: Most users will omit this. For accounts with Inspection Reports enabled, + this is the ID of the inspection to associate with the image query. """ if wait is None: wait = self.DEFAULT_WAIT - detector_id = detector.id if isinstance(detector, Detector) else detector + + # Convert from Detector to detector_id if necessary + if isinstance(detector, Detector): + detector_id = self.get_detector(detector) + else: + detector_id = detector image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) - raw_image_query = self.image_queries_api.submit_image_query( - detector_id=detector_id, patience_time=wait, human_review=human_review, body=image_bytesio - ) - image_query = ImageQuery.parse_obj(raw_image_query.to_dict()) + # Submit Image Query + # If no inspection_id is provided, we submit the image query using image_queries_api (autogenerated via OpenAPI) + # However, our autogenerated code does not support inspection_id, so if an inspection_id was provided, we use + # the private API client instead. + if inspection_id is None: + raw_image_query = self.image_queries_api.submit_image_query( + detector_id=detector_id, + patience_time=wait, + human_review=human_review, + body=image_bytesio + ) + image_query_dict = raw_image_query.to_dict() + image_query = ImageQuery.parse_obj(image_query_dict) + else: + print(f"Using private API client to submit image query with human_review: {human_review}") + iq_id = self.api_client.submit_image_query_with_inspection( + detector_id=detector_id, + patience_time=wait, + human_review=human_review, + image=image_bytesio, + inspection_id=inspection_id + ) + image_query = self.get_image_query(iq_id) + if wait: threshold = self.get_detector(detector).confidence_threshold image_query = self.wait_for_confident_result(image_query, confidence_threshold=threshold, timeout_sec=wait) return self._fixup_image_query(image_query) - + def wait_for_confident_result( self, image_query: ImageQuery, @@ -212,7 +240,10 @@ def wait_for_confident_result( :param confidence_threshold: The minimum confidence level required to return before the timeout. :param timeout_sec: The maximum number of seconds to wait. """ - # TODO: Add support for ImageQuery id instead of object. + # Convert from image_query_id to ImageQuery if needed. + if isinstance(image_query, str): + image_query = self.get_image_query(image_query) + start_time = time.time() next_delay = self.POLLING_INITIAL_DELAY target_delay = 0.0 @@ -253,3 +284,27 @@ def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str api_label = convert_display_label_to_internal(image_query_id, label) return self.api_client._add_label(image_query_id, api_label) # pylint: disable=protected-access + + def start_inspection(self) -> str: + """For users with Inspection Reports enabled only. + Starts an inspection report and returns the id of the inspection. + """ + return self.api_client.start_inspection() + + def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None: + """For users with Inspection Reports enabled only. + Add/update inspection metadata with the user_provided_key and user_provided_value. + """ + self.api_client.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + def stop_inspection(self, inspection_id: str) -> str: + """For users with Inspection Reports enabled only. + Stops an inspection and raises an exception if the response from the server + indicates that the inspection was not successfully stopped. + Returns a str with result of the inspection (either PASS or FAIL) + """ + return self.api_client.stop_inspection(inspection_id) + + def update_detector_confidence_threshold(self, detector_id: str, confidence_threshold: float) -> None: + """Updates the confidence threshold of a detector given a detector_id.""" + self.api_client.update_detector_confidence_threshold(detector_id, confidence_threshold) \ No newline at end of file diff --git a/src/groundlight/internalapi.py b/src/groundlight/internalapi.py index 74de8a0d..dff46b6b 100644 --- a/src/groundlight/internalapi.py +++ b/src/groundlight/internalapi.py @@ -17,19 +17,9 @@ logger = logging.getLogger("groundlight.sdk") - class NotFoundError(Exception): pass - -class InspectionError(Exception): - pass - - -class UpdateDetectorError(Exception): - pass - - def sanitize_endpoint_url(endpoint: Optional[str] = None) -> str: """Takes a URL for an endpoint, and returns a "sanitized" version of it. Currently the production API path must be exactly "/device-api". @@ -237,16 +227,28 @@ def _get_detector_by_name(self, name: str) -> Detector: return Detector.parse_obj(parsed["results"][0]) @RequestsRetryDecorator() - def submit_image_query_with_inspection(self, detector_id: str, image: ByteStreamWrapper, inspection_id: str) -> str: + def submit_image_query_with_inspection(self, + detector_id: str, + patience_time: float, + human_review: bool, + image: ByteStreamWrapper, + inspection_id: str) -> str: """Submits an image query to the API and returns the ID of the image query. The image query will be associated to the inspection_id provided. """ - url = f"{self.configuration.host}/posichecks?inspection_id={inspection_id}&predictor_id={detector_id}" + + # In the API, "send_notification" was used to escalate image queries to cloud labelers. + send_notification = human_review + + url = (f"{self.configuration.host}/posichecks" + f"?inspection_id={inspection_id}" + f"&predictor_id={detector_id}" + f"&send_notification={send_notification}") headers = self._headers() headers["Content-Type"] = "image/jpeg" - response = requests.request("POST", url, headers=headers, data=image.read()) + response = requests.request("POST", url, headers=headers, timeout=patience_time, data=image.read()) if not is_ok(response.status_code): logger.info(response) @@ -268,12 +270,16 @@ def start_inspection(self) -> str: response = requests.request("POST", url, headers=headers, json={}) if not is_ok(response.status_code): - raise InspectionError(f"Error starting inspection. Status code: {response.status_code}") + raise InternalApiError( + status=response.status_code, + reason="Error starting inspection.", + http_resp=response, + ) return response.json()["id"] @RequestsRetryDecorator() - def update_inspection_metadata(self, inspection_id: str, user_provided_key, user_provided_value) -> None: + def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None: """Add/update inspection metadata with the user_provided_key and user_provided_value. The API stores inspections metadata in two ways: @@ -293,15 +299,17 @@ def update_inspection_metadata(self, inspection_id: str, user_provided_key, user response = requests.request("GET", url, headers=headers) if not is_ok(response.status_code): - raise InspectionError( - f"Error getting inspection details for inspection {inspection_id}. Status code: {response.status_code}" + raise InternalApiError( + status=response.status_code, + reason=f"Error getting inspection details for inspection {inspection_id}.", + http_resp=response, ) payload = {} # Set the user_provided_id_key and user_provided_id_value if they were not previously set. response_json = response.json() - if not response_json["user_provided_id_key"]: + if not response_json.get("user_provided_id_key"): payload["user_provided_id_key"] = user_provided_key payload["user_provided_id_value"] = user_provided_value @@ -316,8 +324,10 @@ def update_inspection_metadata(self, inspection_id: str, user_provided_key, user response = requests.request("PATCH", url, headers=headers, json=payload) if not is_ok(response.status_code): - raise InspectionError( - f"Error updating inspection metadata on inspection {inspection_id}. Status code: {response.status_code}" + raise InternalApiError( + status=response.status_code, + reason=f"Error updating inspection metadata on inspection {inspection_id}.", + http_resp=response, ) @RequestsRetryDecorator() @@ -334,7 +344,11 @@ def stop_inspection(self, inspection_id: str) -> str: response = requests.request("PATCH", url, headers=headers, json=payload) if not is_ok(response.status_code): - raise InspectionError(f"Error stopping inspection {inspection_id}. Status code: {response.status_code}") + raise InternalApiError( + status=response.status_code, + reason=f"Error stopping inspection {inspection_id}.", + http_resp=response, + ) return response.json()["result"] @@ -356,4 +370,8 @@ def update_detector_confidence_threshold(self, detector_id: str, confidence_thre response = requests.request("PATCH", url, headers=headers, json=payload) if not is_ok(response.status_code): - raise UpdateDetectorError(f"Error updating detector {detector_id}. Status code: {response.status_code}") + raise InternalApiError( + status=response.status_code, + reason=f"Error updating detector {detector_id}.", + http_resp=response, + )