diff --git a/token_benchmark_ray.py b/token_benchmark_ray.py index a8c7754..40c8523 100644 --- a/token_benchmark_ray.py +++ b/token_benchmark_ray.py @@ -25,6 +25,7 @@ from transformers import LlamaTokenizerFast + def get_token_throughput_latencies( model: str, mean_input_tokens: int, @@ -63,7 +64,7 @@ def get_token_throughput_latencies( "hf-internal-testing/llama-tokenizer" ) get_token_length = lambda text: len(tokenizer.encode(text)) - + if not additional_sampling_params: additional_sampling_params = {} @@ -75,17 +76,19 @@ def get_token_throughput_latencies( num_output_tokens_list = [] prompts = [] for i in range(max_num_completed_requests): - num_output_tokens = (sample_random_positive_int( + num_output_tokens = sample_random_positive_int( mean_output_tokens, stddev_output_tokens - )) + ) num_output_tokens_list.append(num_output_tokens) - prompts.append(randomly_sample_sonnet_lines_prompt( - prompt_tokens_mean=mean_input_tokens, - prompt_tokens_stddev=stddev_input_tokens, - expect_output_tokens=num_output_tokens, - tokenizer=tokenizer - )) + prompts.append( + randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + expect_output_tokens=num_output_tokens, + tokenizer=tokenizer, + ) + ) start_time = time.monotonic() iter = 0 pbar = tqdm(total=max_num_completed_requests) @@ -113,13 +116,18 @@ def get_token_throughput_latencies( for out in outs: request_metrics, gen_text, _ = out num_output_tokens = get_token_length(gen_text) - if num_output_tokens: + if num_output_tokens: request_metrics[common_metrics.INTER_TOKEN_LAT] /= num_output_tokens else: request_metrics[common_metrics.INTER_TOKEN_LAT] = 0 request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens - request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens - request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT] + request_metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + if request_metrics[common_metrics.E2E_LAT]: + request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / request_metrics[common_metrics.E2E_LAT] + ) all_metrics.append(request_metrics) completed_requests.extend(all_metrics) pbar.update(len(completed_requests) - num_completed_requests) @@ -136,14 +144,18 @@ def get_token_throughput_latencies( for out in outs: request_metrics, gen_text, _ = out num_output_tokens = get_token_length(gen_text) - if num_output_tokens: + if num_output_tokens: request_metrics[common_metrics.INTER_TOKEN_LAT] /= num_output_tokens else: request_metrics[common_metrics.INTER_TOKEN_LAT] = 0 request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens - request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens - request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT] - + request_metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / request_metrics[common_metrics.E2E_LAT] + ) + all_metrics.append(request_metrics) completed_requests.extend(all_metrics) @@ -161,7 +173,7 @@ def get_token_throughput_latencies( } metadata["results"] = ret - + return metadata, completed_requests @@ -200,14 +212,14 @@ def flatten(item): df = pd.DataFrame(metrics) df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] - + for key in [ common_metrics.INTER_TOKEN_LAT, common_metrics.TTFT, common_metrics.E2E_LAT, common_metrics.REQ_OUTPUT_THROUGHPUT, common_metrics.NUM_INPUT_TOKENS, - common_metrics.NUM_OUTPUT_TOKENS + common_metrics.NUM_OUTPUT_TOKENS, ]: print(key) ret[key] = {} @@ -259,7 +271,7 @@ def flatten(item): ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min - + return ret