Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 22, 2024
1 parent cf5f8b0 commit e70a631
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions comps/llms/local/llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import sys
from typing import List

import lm_eval.api.registry
import torch
from docarray import BaseDoc
from fastapi.responses import StreamingResponse
from GenAIEval.evaluation.lm_evaluation_harness.lm_eval.models.huggingface import HFLM, GaudiHFModelAdapter
from langchain_community.llms import HuggingFaceEndpoint

from comps import GeneratedDoc, ServiceType, opea_microservices, opea_telemetry, register_microservice

from docarray import BaseDoc
from typing import List

import torch
import lm_eval.api.registry
from GenAIEval.evaluation.lm_evaluation_harness.lm_eval.models.huggingface import HFLM, GaudiHFModelAdapter
lm_eval.api.registry.MODEL_REGISTRY["hf"] = HFLM
lm_eval.api.registry.MODEL_REGISTRY["gaudi-hf"] = GaudiHFModelAdapter

Expand All @@ -42,12 +41,12 @@ class LLMCompletionDoc(BaseDoc):
device = os.getenv("DEVICE", "")

llm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": 1, # dummy
"max_batch_size": None,
"device": device,
},
model_args,
{
"batch_size": 1, # dummy
"max_batch_size": None,
"device": device,
},
)


Expand All @@ -58,13 +57,10 @@ class LLMCompletionDoc(BaseDoc):
host="0.0.0.0",
port=9000,
)

@opea_telemetry
def llm_generate(input: LLMCompletionDoc):
global llm
batched_inputs = torch.tensor(input.batched_inputs,
dtype=torch.long,
device=llm.device)
batched_inputs = torch.tensor(input.batched_inputs, dtype=torch.long, device=llm.device)
with torch.no_grad():
# TODO, use model.generate.
logits = llm.model(batched_inputs).logits
Expand All @@ -74,9 +70,12 @@ def llm_generate(input: LLMCompletionDoc):
greedy_tokens = logits.argmax(dim=-1)
logprobs = torch.gather(logits, 2, batched_inputs[:, 1:].unsqueeze(-1)).squeeze(-1)

return {"greedy_tokens": greedy_tokens.detach().cpu().tolist(),
"logprobs": logprobs.detach().cpu().tolist(),
"batched_inputs": input.batched_inputs}
return {
"greedy_tokens": greedy_tokens.detach().cpu().tolist(),
"logprobs": logprobs.detach().cpu().tolist(),
"batched_inputs": input.batched_inputs,
}


if __name__ == "__main__":
opea_microservices["opea_service@llm_hf"].start()

0 comments on commit e70a631

Please sign in to comment.