Skip to content

Commit

Permalink
[FaqGen] Fix the metrics parse and statistics for benchmark (#215)
Browse files Browse the repository at this point in the history
* Use data param instead of json to send request for faqgen

Signed-off-by: Wang, Kai Lawrence <[email protected]>

* Fix the input statistics for faqgen benchmark

Signed-off-by: Wang, Kai Lawrence <[email protected]>

* Update the default prompt for faqgenfixed

Signed-off-by: Wang, Kai Lawrence <[email protected]>

* Implement the complete_response for the streaming output

Signed-off-by: Wang, Kai Lawrence <[email protected]>

* Set topK=1 for faqgenfixed

Signed-off-by: Wang, Kai Lawrence <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Wang, Kai Lawrence <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
wangkl2 and pre-commit-ci[bot] authored Dec 21, 2024
1 parent d2d2beb commit 5d717e8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
22 changes: 21 additions & 1 deletion evals/benchmark/stresscli/locust/aistress.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,16 @@ def bench_main(self):
"faqgenfixed",
"faqgenbench",
]
if self.environment.parsed_options.bench_target in ["faqgenfixed", "faqgenbench"]:
req_params = {"data": reqData}
else:
req_params = {"json": reqData}
test_start_time = time.time()
try:
start_ts = time.perf_counter()
with self.client.post(
url,
json=reqData,
**req_params,
stream=True if self.environment.parsed_options.bench_target in streaming_bench_target else False,
catch_response=True,
timeout=self.environment.parsed_options.http_timeout,
Expand Down Expand Up @@ -169,6 +173,22 @@ def bench_main(self):
complete_response += content
except json.JSONDecodeError:
continue
elif self.environment.parsed_options.bench_target in ["faqgenfixed", "faqgenbench"]:
client = sseclient.SSEClient(resp)
for event in client.events():
if first_token_ts is None:
first_token_ts = time.perf_counter()
try:
data = json.loads(event.data)
for op in data["ops"]:
if op["path"] == "/logs/HuggingFaceEndpoint/final_output":
generations = op["value"].get("generations", [])
for generation in generations:
for item in generation:
text = item.get("text", "")
complete_response += text
except json.JSONDecodeError:
continue
else:
client = sseclient.SSEClient(resp)
for event in client.events():
Expand Down
11 changes: 5 additions & 6 deletions evals/benchmark/stresscli/locust/faqgenfixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ def getUrl():


def getReqData():
# return {
# "inputs": "What is the revenue of Nike in last 10 years before 2023? Give me detail",
# "parameters": {"max_new_tokens": 128, "do_sample": True},
# }
# return {"query": "What is the revenue of Nike in last 10 years before 2023? Give me detail", "max_tokens": 128}
return {"messages": "What is the revenue of Nike in last 10 years before 2023? Give me detail", "max_tokens": 128}
return {
"messages": "Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E6.",
"max_tokens": 128,
"top_k": 1,
}


def respStatics(environment, reqData, respData):
Expand Down
2 changes: 1 addition & 1 deletion evals/benchmark/stresscli/locust/tokenresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def testFunc():

def respStatics(environment, req, resp):
tokenizer = transformers.AutoTokenizer.from_pretrained(environment.parsed_options.llm_model)
if environment.parsed_options.bench_target in ["chatqnafixed", "chatqnabench"]:
if environment.parsed_options.bench_target in ["chatqnafixed", "chatqnabench", "faqgenfixed", "faqgenbench"]:
num_token_input_prompt = len(tokenizer.encode(req["messages"]))
elif environment.parsed_options.bench_target in ["llmfixed"]:
num_token_input_prompt = len(tokenizer.encode(req["query"]))
Expand Down

0 comments on commit 5d717e8

Please sign in to comment.