From acfa30e519be9750caa8266e8dc443a160c238dd Mon Sep 17 00:00:00 2001 From: Lionel Teo <93119265+imda-lionelteo@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:57:39 +0800 Subject: [PATCH] push fix for MS-753 --- runners-modules/benchmarking.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/runners-modules/benchmarking.py b/runners-modules/benchmarking.py index 7c55e4e..124c842 100644 --- a/runners-modules/benchmarking.py +++ b/runners-modules/benchmarking.py @@ -40,7 +40,7 @@ class Benchmarking: """ sql_read_runner_cache_record = """ - SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=? + SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=? AND dataset_id=? AND prompt_template_id=? AND prompt=? """ BATCH_SIZE = 10 @@ -733,17 +733,24 @@ async def _get_dataset_prompts( # Retrieve dataset arguments ds_args = Dataset.read(ds_id) - # Generate a list of prompt indices based on prompt_selection_percentage and random_seed - self.num_of_prompts = int( - (self.prompt_selection_percentage / 100) * ds_args.num_of_dataset_prompts - ) - if self.num_of_prompts == ds_args.num_of_dataset_prompts: - prompt_indices = range(ds_args.num_of_dataset_prompts) + if ds_args.num_of_dataset_prompts == 0: + prompt_indices = [] else: - random.seed(self.random_seed) - prompt_indices = random.sample( - range(ds_args.num_of_dataset_prompts), self.num_of_prompts + # Generate a list of prompt indices based on prompt_selection_percentage and random_seed + self.num_of_prompts = max( + 1, + int( + (self.prompt_selection_percentage / 100) + * ds_args.num_of_dataset_prompts + ), ) + if self.num_of_prompts == ds_args.num_of_dataset_prompts: + prompt_indices = range(ds_args.num_of_dataset_prompts) + else: + random.seed(self.random_seed) + prompt_indices = random.sample( + range(ds_args.num_of_dataset_prompts), self.num_of_prompts + ) logger.debug( f"[Benchmarking] Dataset {ds_id}, using {len(prompt_indices)} of {ds_args.num_of_dataset_prompts} prompts." )