diff --git a/launch/client.py b/launch/client.py index 5f230a3c..5fcc2358 100644 --- a/launch/client.py +++ b/launch/client.py @@ -1091,7 +1091,7 @@ def read_endpoint_creation_logs( def _sync_request( self, - endpoint_name: str, + endpoint_id: str, url: Optional[str] = None, args: Optional[Dict] = None, return_pickled: bool = False, @@ -1102,7 +1102,7 @@ def _sync_request( Endpoint at endpoint_id must be a SyncEndpoint, otherwise this request will fail. Parameters: - endpoint_name: The name of the endpoint to make the request to + endpoint_id: The ID of the endpoint to make the request to url: A url that points to a file containing model input. Must be accessible by Scale Launch, hence it needs to either be public or a signedURL. @@ -1129,8 +1129,6 @@ def _sync_request( and the value is the output of the endpoint's ``predict`` function, serialized as json. """ validate_task_request(url=url, args=args) - endpoint = self.get_model_endpoint(endpoint_name) - endpoint_id = endpoint.model_endpoint.id # type: ignore with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) payload = dict_not_none( @@ -1148,7 +1146,7 @@ def _sync_request( def _async_request( self, - endpoint_name: str, + endpoint_id: str, *, url: Optional[str] = None, args: Optional[Dict] = None, @@ -1160,7 +1158,7 @@ def _async_request( the result of inference at a later time. Parameters: - endpoint_name: The name of the endpoint to make the request to + endpoint_id: The ID of the endpoint to make the request to url: A url that points to a file containing model input. Must be accessible by Scale Launch, hence it needs to either be public or a signedURL. @@ -1187,7 +1185,6 @@ def _async_request( `abcabcab-cabc-abca-0123456789ab` """ validate_task_request(url=url, args=args) - endpoint = self.get_model_endpoint(endpoint_name) with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) payload = dict_not_none( @@ -1197,8 +1194,7 @@ def _async_request( callback_url=callback_url, ) request = EndpointPredictRequest(**payload) - model_endpoint_id = endpoint.model_endpoint.id # type: ignore - query_params = frozendict({"model_endpoint_id": model_endpoint_id}) + query_params = frozendict({"model_endpoint_id": endpoint_id}) response = api_instance.create_async_inference_task_v1_async_tasks_post( # type: ignore body=request, query_params=query_params, # type: ignore diff --git a/launch/model_endpoint.py b/launch/model_endpoint.py index 6fe4a6a8..0ff1c375 100644 --- a/launch/model_endpoint.py +++ b/launch/model_endpoint.py @@ -314,7 +314,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponse: request: The ``EndpointRequest`` object that contains the payload. """ raw_response = self.client._sync_request( # pylint: disable=W0212 - self.model_endpoint.name, + self.model_endpoint.id, url=request.url, args=request.args, return_pickled=request.return_pickled, @@ -367,7 +367,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseFuture: result = f.get() # blocks on completion """ response = self.client._async_request( # pylint: disable=W0212 - self.model_endpoint.name, + self.model_endpoint.id, url=request.url, args=request.args, callback_url=request.callback_url,