Skip to content

Commit

Permalink
modify add_lora and remove_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
CNTRYROA committed Sep 18, 2024
1 parent c30c189 commit 66dc08b
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:

@asynccontextmanager
async def lifespan(app: FastAPI):

async def _force_log():
while True:
await asyncio.sleep(10)
Expand Down Expand Up @@ -265,27 +264,45 @@ def generate_unique_int(exclude_list, start=0, end=100):
@router.post("/v1/add_lora")
async def add_lora(request: LoRARequestPOJO):
async with lock:
filtered = filter(lambda x: x.lora_name == request.lora_name, openai_serving_chat.lora_requests)
if len(list(filtered)) > 0:
return JSONResponse(content={"error": "duplicated name: %s" % request.lora_name},
status_code=409)

exclude_list = [x.lora_int_id for x in openai_serving_chat.lora_requests]
unique_int = generate_unique_int(exclude_list)

lora_request = LoRARequest(lora_int_id=unique_int, **request.dict())
openai_serving_completion.lora_requests.append(lora_request)
openai_serving_chat.lora_requests.append(lora_request)
try:
filtered = filter(lambda x: x.lora_name == request.lora_name, openai_serving_chat.lora_requests)
if len(list(filtered)) > 0:
return JSONResponse(content={"error": "duplicated name: %s" % request.lora_name},
status_code=409)

exclude_list = [x.lora_int_id for x in openai_serving_chat.lora_requests]
unique_int = generate_unique_int(exclude_list)

lora_request = LoRARequest(lora_int_id=unique_int, **request.dict())
if hasattr(openai_serving_completion.async_engine_client, "engine"):
openai_serving_completion.async_engine_client.engine.add_lora(lora_request)
openai_serving_completion.lora_requests.append(lora_request)
openai_serving_chat.lora_requests.append(lora_request)
except Exception:
import traceback
import sys
return JSONResponse(
content={"stacktrace": '\n'.join(str(x) for x in traceback.format_exception(*sys.exc_info()))},
status_code=409)
return JSONResponse(content={})


@router.get("/v1/remove_lora")
async def remove_lora(lora_name: str):
async with lock:
filtered = filter(lambda x: x.lora_name == lora_name, openai_serving_chat.lora_requests)
for item in list(filtered):
openai_serving_completion.lora_requests.remove(item)
openai_serving_chat.lora_requests.remove(item)
try:
filtered = filter(lambda x: x.lora_name == lora_name, openai_serving_chat.lora_requests)
for item in list(filtered):
openai_serving_completion.lora_requests.remove(item)
openai_serving_chat.lora_requests.remove(item)
if hasattr(openai_serving_completion.async_engine_client, "engine"):
openai_serving_completion.async_engine_client.engine.remove_lora(item.lora_int_id)
except Exception:
import traceback
import sys
return JSONResponse(
content={"stacktrace": '\n'.join(str(x) for x in traceback.format_exception(*sys.exc_info()))},
status_code=409)
return JSONResponse(content={})


Expand Down Expand Up @@ -318,6 +335,7 @@ async def sample_data0(llm_api_type: str | None = None):
}
})


def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
Expand Down Expand Up @@ -368,8 +386,8 @@ async def authentication(request: Request, call_next):


async def init_app(
async_engine_client: AsyncEngineClient,
args: Namespace,
async_engine_client: AsyncEngineClient,
args: Namespace,
) -> FastAPI:
app = build_app(args)

Expand Down

0 comments on commit 66dc08b

Please sign in to comment.