From e01814633fa6c5ee14892528339403bd93aaf5ea Mon Sep 17 00:00:00 2001 From: Bruce Lee Date: Tue, 6 Feb 2024 22:32:10 -0500 Subject: [PATCH] added context key to MCQ --- .../data/data_config/task/socialiqa.yaml | 5 ++ nutcracker/data/instance.py | 67 +++++++++++++------ nutcracker/data/instance_collection.py | 66 ++++++++++++++++++ nutcracker/data/pile.py | 44 +----------- nutcracker/data/task.py | 44 +----------- setup.py | 2 +- 6 files changed, 121 insertions(+), 107 deletions(-) create mode 100644 nutcracker/data/data_config/task/socialiqa.yaml create mode 100644 nutcracker/data/instance_collection.py diff --git a/nutcracker/data/data_config/task/socialiqa.yaml b/nutcracker/data/data_config/task/socialiqa.yaml new file mode 100644 index 0000000..ee8aff9 --- /dev/null +++ b/nutcracker/data/data_config/task/socialiqa.yaml @@ -0,0 +1,5 @@ +filename: + test: test.json + dev: dev.json + config: config.yaml +task_name: socialiqa diff --git a/nutcracker/data/instance.py b/nutcracker/data/instance.py index c2c1479..3fd5abf 100644 --- a/nutcracker/data/instance.py +++ b/nutcracker/data/instance.py @@ -66,6 +66,11 @@ def __init__( self.options = test_data["options"] self.correct_options = test_data["correct_options"] + # Check if 'context' key exists in test_data + self.context_exists = 'context' in test_data and test_data['context'] + if self.context_exists: + self.context = test_data["context"] + # below are derivational attributes that will be updated during code run self.user_prompt = self._format_user_prompt() self.model_response = None @@ -74,17 +79,15 @@ def __init__( - def _format_user_prompt( - self, - ) -> str: + def _format_user_prompt(self) -> str: """Format the user prompt. - + Args: None - + Raises: ValueError: If the user prompt template is invalid. - + Returns: str: Formatted user prompt. """ @@ -97,7 +100,7 @@ def _format_user_prompt( value = value.replace('', '') # Creating the complete function definition from YAML content - function_definition = f"def wildcard_formatter(centerpiece, options, correct_options):\n {value}" + function_definition = f"def wildcard_formatter(centerpiece, options, correct_options, context=None):\n {value}" # Create the function dynamically exec_globals = {} @@ -107,35 +110,55 @@ def _format_user_prompt( if key == 'example': for example_data in self.example_data_list[:self.config['few_shot']]: user_prompt += wildcard_formatter( - centerpiece = example_data['centerpiece'], - options = example_data['options'], - correct_options = example_data['correct_options'] + context=example_data['context'], + centerpiece=example_data['centerpiece'], + options=example_data['options'], + correct_options=example_data['correct_options'] + ) if self.context_exists else wildcard_formatter( + centerpiece=example_data['centerpiece'], + options=example_data['options'], + correct_options=example_data['correct_options'] ) else: user_prompt += wildcard_formatter( - centerpiece = self.centerpiece, - options = self.options, - correct_options = self.correct_options + context=self.context, + centerpiece=self.centerpiece, + options=self.options, + correct_options=self.correct_options + ) if self.context_exists else wildcard_formatter( + centerpiece=self.centerpiece, + options=self.options, + correct_options=self.correct_options ) # if no wildcard, format normally else: if key == 'example': for example_data in self.example_data_list[:self.config['few_shot']]: - #print(self.example_data_list) user_prompt += value.format( - centerpiece = example_data['centerpiece'], - options = example_data['options'], - correct_options = example_data['correct_options'] - ) + context=example_data['context'], + centerpiece=example_data['centerpiece'], + options=example_data['options'], + correct_options=example_data['correct_options'] + ) if self.context_exists else value.format( + centerpiece=example_data['centerpiece'], + options=example_data['options'], + correct_options=example_data['correct_options'] + ) else: user_prompt += value.format( - centerpiece = self.centerpiece, - options = self.options, - correct_options = self.correct_options + context=self.context, + centerpiece=self.centerpiece, + options=self.options, + correct_options=self.correct_options + ) if self.context_exists else value.format( + centerpiece=self.centerpiece, + options=self.options, + correct_options=self.correct_options ) # if user prompt is empty, raise error if not user_prompt: - ValueError("Invalid user prompt template") + raise ValueError("Invalid user prompt template") return user_prompt + diff --git a/nutcracker/data/instance_collection.py b/nutcracker/data/instance_collection.py new file mode 100644 index 0000000..44f59a6 --- /dev/null +++ b/nutcracker/data/instance_collection.py @@ -0,0 +1,66 @@ +from typing import Optional +import random +# +class InstanceCollection: + def __init__(self) -> None: + """Initialize an empty collection of instances.""" + self.instances = [] + + + + def __len__(self) -> int: + """Return the number of instances in the collection. + + Returns: + int: The total number of instances in the collection. + """ + return len(self.instances) + + + + def __getitem__(self, index): + """Retrieve specific instance(s) by index or slice. + + Args: + index (int or slice): The index of the instance to retrieve or a slice object to get a range of instances. + + Raises: + IndexError: If the index is out of range. + TypeError: If the provided index is not an int or slice. + + Returns: + Instance or list of Instances: A single instance or a list of instances based on the provided index. + """ + if isinstance(index, int): + if index >= len(self.instances) or index < 0: + raise IndexError("Index out of range") + return self.instances[index] + elif isinstance(index, slice): + return self.instances[index] + else: + raise TypeError("Invalid argument type") + + + + def sample( + self, + n: int, + seed: Optional[int] = None + ) -> list: + """Randomly sample 'n' instances from the collection. + + Args: + n (int): The number of instances to sample. + seed (Optional[int]): Optional random seed for reproducibility. + + Raises: + ValueError: If 'n' is greater than the total number of instances. + + Returns: + list: A list of 'n' randomly sampled instances. + """ + if n > len(self.instances): + raise ValueError("Sample size 'n' cannot be greater than the total number of instances.") + if seed is not None: + random.seed(seed) + return random.sample(self.instances, n) diff --git a/nutcracker/data/pile.py b/nutcracker/data/pile.py index 2d71ecf..a8661f0 100644 --- a/nutcracker/data/pile.py +++ b/nutcracker/data/pile.py @@ -3,9 +3,10 @@ # from nutcracker.data.instance import Instance from nutcracker.data.task import Task +from nutcracker.data.instance_collection import InstanceCollection # # -class Pile: +class Pile (InstanceCollection): def __init__( self, tasks: List[Task], @@ -77,47 +78,6 @@ def add_task( - def __len__(self) -> int: - """Return the number of Instances in the Pile. - - Args: - None - - Raises: - None - - Returns: - int: Number of Instances in the Pile. - """ - return len(self.instances) - - - - def __getitem__(self, index): - """Return the Instance or a slice of Instances at the specified index. - - Args: - index (int or slice): Index or slice of the Instance. - - Raises: - IndexError: If the index is out of range. - - Returns: - Instance or List[Instance]: The Instance or list of Instances at the given index. - """ - if isinstance(index, int): - # Handling single index - if index >= len(self.instances) or index < 0: - raise IndexError("Index out of range") - return self.instances[index] - elif isinstance(index, slice): - # Handling slice object - return self.instances[index] - else: - raise TypeError("Invalid argument type") - - - def _ensure_consistent_construction(self) -> None: """ Ensure all instances have the same construction type. diff --git a/nutcracker/data/task.py b/nutcracker/data/task.py index b8eb85e..5607242 100644 --- a/nutcracker/data/task.py +++ b/nutcracker/data/task.py @@ -4,6 +4,7 @@ import logging # from nutcracker.data.instance import Instance +from nutcracker.data.instance_collection import InstanceCollection # # from huggingface_hub import hf_hub_download @@ -11,7 +12,7 @@ # # # -class Task: +class Task (InstanceCollection): def __init__( self, test_path: str, @@ -155,47 +156,6 @@ def load_from_id(cls, task_id: str, user_given_directory: str): # Initialize and return the Task object return cls(test_path=test_path, example_path=example_path, config_path=config_path) - - - - def __len__(self) -> int: - """Return the number of instances in the task. - - Args: - None - - Raises: - None - - Returns: - int: Number of test instances in the task. - """ - return len(self.instances) - - - - def __getitem__(self, index): - """Return the Instance or a slice of Instances at the specified index. - - Args: - index (int or slice): Index or slice of the Instance. - - Raises: - IndexError: If the index is out of range. - - Returns: - Instance or List[Instance]: The Instance or list of Instances at the given index. - """ - if isinstance(index, int): - # Handling single index - if index >= len(self.instances) or index < 0: - raise IndexError("Index out of range") - return self.instances[index] - elif isinstance(index, slice): - # Handling slice object - return self.instances[index] - else: - raise TypeError("Invalid argument type") diff --git a/setup.py b/setup.py index f0d9621..cc08dae 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ # pip install -e . setup( name = 'nutcracker', - version='0.0.1a2', + version='0.0.1a3', description = 'In Development', author = 'Bruce W. Lee', author_email = 'bruce@walnutresearch.com',