Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use endpoint IDs for endpoint predict requests #78

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions launch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def read_endpoint_creation_logs(

def _sync_request(
self,
endpoint_name: str,
endpoint_id: str,
url: Optional[str] = None,
args: Optional[Dict] = None,
return_pickled: bool = False,
Expand All @@ -1102,7 +1102,7 @@ def _sync_request(
Endpoint at endpoint_id must be a SyncEndpoint, otherwise this request will fail.

Parameters:
endpoint_name: The name of the endpoint to make the request to
endpoint_id: The ID of the endpoint to make the request to

url: A url that points to a file containing model input.
Must be accessible by Scale Launch, hence it needs to either be public or a signedURL.
Expand All @@ -1129,8 +1129,6 @@ def _sync_request(
and the value is the output of the endpoint's ``predict`` function, serialized as json.
"""
validate_task_request(url=url, args=args)
endpoint = self.get_model_endpoint(endpoint_name)
endpoint_id = endpoint.model_endpoint.id # type: ignore
with ApiClient(self.configuration) as api_client:
api_instance = DefaultApi(api_client)
payload = dict_not_none(
Expand All @@ -1148,7 +1146,7 @@ def _sync_request(

def _async_request(
self,
endpoint_name: str,
endpoint_id: str,
*,
url: Optional[str] = None,
args: Optional[Dict] = None,
Expand All @@ -1160,7 +1158,7 @@ def _async_request(
the result of inference at a later time.

Parameters:
endpoint_name: The name of the endpoint to make the request to
endpoint_id: The ID of the endpoint to make the request to

url: A url that points to a file containing model input.
Must be accessible by Scale Launch, hence it needs to either be public or a signedURL.
Expand All @@ -1187,7 +1185,6 @@ def _async_request(
`abcabcab-cabc-abca-0123456789ab`
"""
validate_task_request(url=url, args=args)
endpoint = self.get_model_endpoint(endpoint_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just a postgres call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but also Redis to get some cached data about the infra state (and k8s api server if Redis cache doesn't contain the info)

with ApiClient(self.configuration) as api_client:
api_instance = DefaultApi(api_client)
payload = dict_not_none(
Expand All @@ -1197,8 +1194,7 @@ def _async_request(
callback_url=callback_url,
)
request = EndpointPredictRequest(**payload)
model_endpoint_id = endpoint.model_endpoint.id # type: ignore
query_params = frozendict({"model_endpoint_id": model_endpoint_id})
query_params = frozendict({"model_endpoint_id": endpoint_id})
response = api_instance.create_async_inference_task_v1_async_tasks_post( # type: ignore
body=request,
query_params=query_params, # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions launch/model_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
request: The ``EndpointRequest`` object that contains the payload.
"""
raw_response = self.client._sync_request( # pylint: disable=W0212
self.model_endpoint.name,
self.model_endpoint.id,
url=request.url,
args=request.args,
return_pickled=request.return_pickled,
Expand Down Expand Up @@ -367,7 +367,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseFuture:
result = f.get() # blocks on completion
"""
response = self.client._async_request( # pylint: disable=W0212
self.model_endpoint.name,
self.model_endpoint.id,
url=request.url,
args=request.args,
callback_url=request.callback_url,
Expand Down