Skip to content

Commit

Permalink
Update API type based on merged scoring url (#3369)
Browse files Browse the repository at this point in the history
* Update API type based on merged scoring url

* address comment
  • Loading branch information
novaturient95 authored Sep 21, 2024
1 parent fc1f599 commit 3def597
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aml_benchmark.utils.exceptions import swallow_all_exceptions
from aml_benchmark.utils.aml_run_utils import str2bool
from aml_benchmark.utils.exceptions import BenchmarkUserException
from aml_benchmark.utils.constants import AuthenticationType, get_endpoint_type
from aml_benchmark.utils.constants import AuthenticationType, get_api_type, get_endpoint_type
from aml_benchmark.utils.error_definitions import BenchmarkUserError
from azureml._common._error_definition.azureml_error import AzureMLError

Expand Down Expand Up @@ -307,7 +307,7 @@ def main(

config_dict = {
"api": {
"type": "completion",
"type": get_api_type(merged_scoring_url),
"response_segment_size": response_segment_size
},
"authentication": authentication_dict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,36 @@ class IntermediateNames:
_DEFAULT_URL_TYPE = "azureml_online_endpoint"


class ApiType():
"""Api Type."""

Unknown = 'unknown'
Completion = 'completion'
ChatCompletion = 'chat_completion'


COMPLETION_API_SUFFIX_LIST = ["v1/completions"]
CHAT_COMPLETION_API_SUFFIX_LIST = ["v1/chat/completions"]
DEFAULT_API_TYPE = ApiType.Completion
API_TYPE_MAPPING = {
ApiType.Completion: COMPLETION_API_SUFFIX_LIST,
ApiType.ChatCompletion: CHAT_COMPLETION_API_SUFFIX_LIST
}


def get_api_type(url: str) -> str:
"""Get the api type for a given endpoint URL.
:param url: The URL of the endpoint.
:return: API type of the endpoint.
"""
return next((
api_type for api_type, suffixes in API_TYPE_MAPPING.items()
if any(suffix in url for suffix in suffixes)),
DEFAULT_API_TYPE
)


def get_endpoint_type(url: str) -> str:
"""
Get the endpoint type for a given endpoint URL.
Expand Down

0 comments on commit 3def597

Please sign in to comment.