diff --git a/aviary/backend/server/models.py b/aviary/backend/server/models.py index 3e5c4440..28d9e8a5 100644 --- a/aviary/backend/server/models.py +++ b/aviary/backend/server/models.py @@ -508,7 +508,7 @@ class AppArgs(BaseModel): class RouterArgs(BaseModel): - models: Union[str, LLMApp, List[Union[LLMApp, str]]] + models: Dict[str, Union[str, LLMApp]] class PlacementConfig(BaseModel): diff --git a/aviary/backend/server/run.py b/aviary/backend/server/run.py index 94c97991..e086bb97 100644 --- a/aviary/backend/server/run.py +++ b/aviary/backend/server/run.py @@ -7,7 +7,7 @@ from aviary.backend.llm.vllm.vllm_engine import VLLMEngine from aviary.backend.llm.vllm.vllm_models import VLLMApp from aviary.backend.server.app import RouterDeployment -from aviary.backend.server.models import LLMApp, RouterArgs, ScalingConfig +from aviary.backend.server.models import LLMApp, ScalingConfig from aviary.backend.server.plugins.deployment_base_client import DeploymentBaseClient from aviary.backend.server.plugins.execution_hooks import ( ExecutionHooks, @@ -111,8 +111,7 @@ def router_deployment( def router_application(args): - router_args = RouterArgs.parse_obj(args) - llm_apps = parse_args(router_args.models, llm_app_cls=VLLMApp) + llm_apps = parse_args(args, llm_app_cls=VLLMApp) return router_deployment(llm_apps, enable_duplicate_models=False)