Skip to content

Commit

Permalink
added context key to MCQ
Browse files Browse the repository at this point in the history
  • Loading branch information
brucewlee committed Feb 7, 2024
1 parent 15fbd9b commit e018146
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 107 deletions.
5 changes: 5 additions & 0 deletions nutcracker/data/data_config/task/socialiqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
filename:
test: test.json
dev: dev.json
config: config.yaml
task_name: socialiqa
67 changes: 45 additions & 22 deletions nutcracker/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def __init__(
self.options = test_data["options"]
self.correct_options = test_data["correct_options"]

# Check if 'context' key exists in test_data
self.context_exists = 'context' in test_data and test_data['context']
if self.context_exists:
self.context = test_data["context"]

# below are derivational attributes that will be updated during code run
self.user_prompt = self._format_user_prompt()
self.model_response = None
Expand All @@ -74,17 +79,15 @@ def __init__(



def _format_user_prompt(
self,
) -> str:
def _format_user_prompt(self) -> str:
"""Format the user prompt.
Args:
None
Raises:
ValueError: If the user prompt template is invalid.
Returns:
str: Formatted user prompt.
"""
Expand All @@ -97,7 +100,7 @@ def _format_user_prompt(
value = value.replace('<wild*card>', '')

# Creating the complete function definition from YAML content
function_definition = f"def wildcard_formatter(centerpiece, options, correct_options):\n {value}"
function_definition = f"def wildcard_formatter(centerpiece, options, correct_options, context=None):\n {value}"

# Create the function dynamically
exec_globals = {}
Expand All @@ -107,35 +110,55 @@ def _format_user_prompt(
if key == 'example':
for example_data in self.example_data_list[:self.config['few_shot']]:
user_prompt += wildcard_formatter(
centerpiece = example_data['centerpiece'],
options = example_data['options'],
correct_options = example_data['correct_options']
context=example_data['context'],
centerpiece=example_data['centerpiece'],
options=example_data['options'],
correct_options=example_data['correct_options']
) if self.context_exists else wildcard_formatter(
centerpiece=example_data['centerpiece'],
options=example_data['options'],
correct_options=example_data['correct_options']
)
else:
user_prompt += wildcard_formatter(
centerpiece = self.centerpiece,
options = self.options,
correct_options = self.correct_options
context=self.context,
centerpiece=self.centerpiece,
options=self.options,
correct_options=self.correct_options
) if self.context_exists else wildcard_formatter(
centerpiece=self.centerpiece,
options=self.options,
correct_options=self.correct_options
)
# if no wildcard, format normally
else:
if key == 'example':
for example_data in self.example_data_list[:self.config['few_shot']]:
#print(self.example_data_list)
user_prompt += value.format(
centerpiece = example_data['centerpiece'],
options = example_data['options'],
correct_options = example_data['correct_options']
)
context=example_data['context'],
centerpiece=example_data['centerpiece'],
options=example_data['options'],
correct_options=example_data['correct_options']
) if self.context_exists else value.format(
centerpiece=example_data['centerpiece'],
options=example_data['options'],
correct_options=example_data['correct_options']
)
else:
user_prompt += value.format(
centerpiece = self.centerpiece,
options = self.options,
correct_options = self.correct_options
context=self.context,
centerpiece=self.centerpiece,
options=self.options,
correct_options=self.correct_options
) if self.context_exists else value.format(
centerpiece=self.centerpiece,
options=self.options,
correct_options=self.correct_options
)

# if user prompt is empty, raise error
if not user_prompt:
ValueError("Invalid user prompt template")
raise ValueError("Invalid user prompt template")

return user_prompt

66 changes: 66 additions & 0 deletions nutcracker/data/instance_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Optional
import random
#
class InstanceCollection:
def __init__(self) -> None:
"""Initialize an empty collection of instances."""
self.instances = []



def __len__(self) -> int:
"""Return the number of instances in the collection.
Returns:
int: The total number of instances in the collection.
"""
return len(self.instances)



def __getitem__(self, index):
"""Retrieve specific instance(s) by index or slice.
Args:
index (int or slice): The index of the instance to retrieve or a slice object to get a range of instances.
Raises:
IndexError: If the index is out of range.
TypeError: If the provided index is not an int or slice.
Returns:
Instance or list of Instances: A single instance or a list of instances based on the provided index.
"""
if isinstance(index, int):
if index >= len(self.instances) or index < 0:
raise IndexError("Index out of range")
return self.instances[index]
elif isinstance(index, slice):
return self.instances[index]
else:
raise TypeError("Invalid argument type")



def sample(
self,
n: int,
seed: Optional[int] = None
) -> list:
"""Randomly sample 'n' instances from the collection.
Args:
n (int): The number of instances to sample.
seed (Optional[int]): Optional random seed for reproducibility.
Raises:
ValueError: If 'n' is greater than the total number of instances.
Returns:
list: A list of 'n' randomly sampled instances.
"""
if n > len(self.instances):
raise ValueError("Sample size 'n' cannot be greater than the total number of instances.")
if seed is not None:
random.seed(seed)
return random.sample(self.instances, n)
44 changes: 2 additions & 42 deletions nutcracker/data/pile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
#
from nutcracker.data.instance import Instance
from nutcracker.data.task import Task
from nutcracker.data.instance_collection import InstanceCollection
#
#
class Pile:
class Pile (InstanceCollection):
def __init__(
self,
tasks: List[Task],
Expand Down Expand Up @@ -77,47 +78,6 @@ def add_task(



def __len__(self) -> int:
"""Return the number of Instances in the Pile.
Args:
None
Raises:
None
Returns:
int: Number of Instances in the Pile.
"""
return len(self.instances)



def __getitem__(self, index):
"""Return the Instance or a slice of Instances at the specified index.
Args:
index (int or slice): Index or slice of the Instance.
Raises:
IndexError: If the index is out of range.
Returns:
Instance or List[Instance]: The Instance or list of Instances at the given index.
"""
if isinstance(index, int):
# Handling single index
if index >= len(self.instances) or index < 0:
raise IndexError("Index out of range")
return self.instances[index]
elif isinstance(index, slice):
# Handling slice object
return self.instances[index]
else:
raise TypeError("Invalid argument type")



def _ensure_consistent_construction(self) -> None:
"""
Ensure all instances have the same construction type.
Expand Down
44 changes: 2 additions & 42 deletions nutcracker/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import logging
#
from nutcracker.data.instance import Instance
from nutcracker.data.instance_collection import InstanceCollection
#
#
from huggingface_hub import hf_hub_download
import yaml
#
#
#
class Task:
class Task (InstanceCollection):
def __init__(
self,
test_path: str,
Expand Down Expand Up @@ -155,47 +156,6 @@ def load_from_id(cls, task_id: str, user_given_directory: str):

# Initialize and return the Task object
return cls(test_path=test_path, example_path=example_path, config_path=config_path)



def __len__(self) -> int:
"""Return the number of instances in the task.
Args:
None
Raises:
None
Returns:
int: Number of test instances in the task.
"""
return len(self.instances)



def __getitem__(self, index):
"""Return the Instance or a slice of Instances at the specified index.
Args:
index (int or slice): Index or slice of the Instance.
Raises:
IndexError: If the index is out of range.
Returns:
Instance or List[Instance]: The Instance or list of Instances at the given index.
"""
if isinstance(index, int):
# Handling single index
if index >= len(self.instances) or index < 0:
raise IndexError("Index out of range")
return self.instances[index]
elif isinstance(index, slice):
# Handling slice object
return self.instances[index]
else:
raise TypeError("Invalid argument type")



Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# pip install -e .
setup(
name = 'nutcracker',
version='0.0.1a2',
version='0.0.1a3',
description = 'In Development',
author = 'Bruce W. Lee',
author_email = '[email protected]',
Expand Down

0 comments on commit e018146

Please sign in to comment.