Skip to content

Commit

Permalink
Add label_key parameter in prompt crafter (#1593)
Browse files Browse the repository at this point in the history
* Initial fix for label_key addition in prompt crafter

* Added label key capability in yaml

* Label key fix in prompt factory

* Renamed the label_key to ground_truth_column_name

* Adding tests for ground truth column

* Update the pc component version
  • Loading branch information
SamGos93 authored Nov 2, 2023
1 parent c82da79 commit 1f907ee
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 14 deletions.
10 changes: 9 additions & 1 deletion assets/aml-benchmark/components/prompt_crafter/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ display_name: Prompt Crafter
description: This component is used to create prompts from a given dataset. From a
given jinja prompt template, it will generate prompts. It can also create
few-shot prompts given a few-shot dataset and the number of shots.
version: 0.0.2.0
version: 0.0.2.2
is_deterministic: true

inputs:
Expand Down Expand Up @@ -93,6 +93,13 @@ inputs:
optional: true
default: 0
description: Random seed for sampling few-shots; if not specified, 0 is used.
ground_truth_column_name:
type: string
optional: true
default: ''
description: |
This will be used as the ground truth column if present in the input.
If not present, the output_pattern will be used as the ground truth.
outputs:
output_file:
type: uri_file
Expand All @@ -110,6 +117,7 @@ command: >-
--prompt_pattern '${{inputs.prompt_pattern}}'
--output_pattern '${{inputs.output_pattern}}'
$[[--system_message '${{inputs.system_message}}']]
$[[--ground_truth_column_name '${{inputs.ground_truth_column_name}}']]
$[[--prefix '${{inputs.prefix}}']]
$[[--few_shot_separator '${{inputs.few_shot_separator}}']]
--output_file ${{outputs.output_file}}
13 changes: 12 additions & 1 deletion assets/aml-benchmark/components/src/prompt_crafter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def parse_args() -> ArgumentParser:
type=str,
required=False,
help="The label map to be used for prompts.")
parser.add_argument(
"--ground_truth_column_name",
type=str,
required=False,
help="The ground truth key in the input data.")
parser.add_argument(
"--system_message",
type=str,
Expand All @@ -93,6 +98,7 @@ def main(
output_pattern: str,
prompt_pattern: str,
output_file: str,
ground_truth_column_name: Optional[str] = None,
few_shot_separator: Optional[str] = None,
prefix: Optional[str] = None,
system_message: Optional[str] = None,
Expand All @@ -106,7 +112,8 @@ def main(
:param n_shots: Number of shots to use for few-shot prompts.
:param output_pattern: Pattern to use for output prompts.
:param prompt_pattern: Pattern to use for prompts.
:param output_data: Path to jsonl with generated prompts.
:param output_file: Path to jsonl with generated prompts.
:param ground_truth_column_name: Ground truth column name.
:param few_shot_separator: Separator to use for few-shot prompts.
:param prefix: Prefix to use for prompts.
:param system_message: System message to use for prompts.
Expand All @@ -125,6 +132,7 @@ def main(
few_shot_separator=few_shot_separator,
prefix=prefix,
output_file=output_file,
ground_truth_column_name=ground_truth_column_name,
output_mltable=None,
metadata_keys=None,
label_map=None,
Expand All @@ -141,6 +149,8 @@ def main(
output_pattern=output_pattern,
system_message=system_message,
random_seed=random_seed,
ground_truth_column_name=ground_truth_column_name
if ground_truth_column_name else None,
test_dataset_checksum=resolve_io_path(test_data),
few_shot_dataset_checksum=resolve_io_path(few_shot_data)
if few_shot_data else None,
Expand All @@ -160,6 +170,7 @@ def main(
prompt_pattern=args.prompt_pattern,
few_shot_separator=args.few_shot_separator,
prefix=args.prefix,
ground_truth_column_name=args.ground_truth_column_name,
output_file=args.output_file,
system_message=args.system_message,
)
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def __init__(
system_message: Optional[str],
base_prompt_factory_cls: Optional[PromptFactory] = PromptFactory,
output_mltable: Optional[str] = None,
ground_truth_column_name: Optional[str] = None,
):
"""Initialize the prompt crafter."""
self.metadata_keys = metadata_keys
self.additional_payload = additional_payload
self.ground_truth_column_name = ground_truth_column_name
params = {k: v for k, v in locals().items() if k not in ["self", "base_prompt_factory_cls", "params"]}
self.mlflow_logger = _MLFlowLogger()
self.mlflow_logger.save_parameters(params=params, output_mltable=output_mltable)
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(
few_shot_pool=few_shot_pool,
few_shot_separator=few_shot_separator,
prefix=prefix,
ground_truth_column_name=ground_truth_column_name,
label_map_str=label_map,
output_pattern=output_pattern,
system_message=system_message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class PromptFactory(ABC):
few_shot_pattern: Optional[str] = None
few_shot_separator: Optional[str] = None
prefix: Optional[str] = None
ground_truth_column_name: Optional[str] = None
label_map_str: Optional[str] = None
output_pattern: Optional[str] = None
system_message: Optional[str] = None
Expand Down Expand Up @@ -190,6 +191,14 @@ def process_row(self, row: Dict, index: int) -> Dict:
if self.output_pattern is not None:
output_data['completion'] = self.get_label_from_output_pattern(row)

if self.ground_truth_column_name is not None and len(self.ground_truth_column_name) > 0:
if self.ground_truth_column_name in row:
output_data['ground_truth'] = row[self.ground_truth_column_name]
else:
mssg = "Ground truth column is not present in the data"
raise BenchmarkValidationException._with_error(
AzureMLError.create(BenchmarkValidationError, error_details=mssg))

if self.metadata_keys is not None:
def collect_metadata(metadata_keys, data, index):
metadata = data.get("metadata", {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ inputs:
output_pattern: " {{answerKey}} "
prompt_type: "completions"
system_message: "Answer truthfully. "
ground_truth_column_name: "answerKey"

outputs:
output_file:
Expand Down
42 changes: 30 additions & 12 deletions assets/aml-benchmark/tests/test_prompt_crafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
def _verify_and_get_output_records(
input_file: str,
expected_output_file: str,
is_ground_truth_col: bool = False,
) -> None:
"""Verify the output and get output records.
Expand All @@ -44,14 +45,17 @@ def _verify_and_get_output_records(
with open(expected_output_file, "r") as f:
expected_output_records = [json.loads(line) for line in f]
expected_output_row_count = len(expected_output_records)

if is_ground_truth_col:
for line in f:
assert 'ground_truth' in json.loads(line).keys()
assert input_row_count == expected_output_row_count
return


# test patterns
_prompt_pattern_test = "Question:{{question}} \nChoices:(1) {{choices.text[0]}}\n(2) {{choices.text[1]}}\n(3) {{choices.text[2]}}\n(4) {{choices.text[3]}}\nThe answer is: " # noqa: E501
_output_pattern_test = "{{answerKey}}"
_ground_truth_column = "answerKey"


class TestPromptCrafterComponent:
Expand Down Expand Up @@ -187,24 +191,24 @@ class TestPromptCrafterScript:

@pytest.mark.parametrize(
"test_data, prompt_type, n_shots, \
few_shot_data, prompt_pattern, output_pattern",
few_shot_data, prompt_pattern, output_pattern, ground_truth_column",
[
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "completions", 1,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test),
_output_pattern_test, _ground_truth_column),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "chat", 1,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test),
_output_pattern_test, _ground_truth_column),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "completions", 0,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test),
_output_pattern_test, _ground_truth_column),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "chat", 0,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test),
_output_pattern_test, _ground_truth_column),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "completions", 0,
None, _prompt_pattern_test, _output_pattern_test),
None, _prompt_pattern_test, _output_pattern_test, _ground_truth_column),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "chat", 0,
None, _prompt_pattern_test, _output_pattern_test),
None, _prompt_pattern_test, _output_pattern_test, _ground_truth_column),
]
)
def test_valid_prompt_crafter(
Expand All @@ -215,6 +219,7 @@ def test_valid_prompt_crafter(
few_shot_data: str,
prompt_pattern: str,
output_pattern: str,
ground_truth_column: str,
output_file: str = os.path.join(
os.path.dirname(__file__), 'data/prompt_crafter_output.jsonl'),
):
Expand All @@ -228,21 +233,29 @@ def test_valid_prompt_crafter(
output_file=output_file,
output_pattern=output_pattern,
few_shot_data=few_shot_data,
ground_truth_column=ground_truth_column,
)

# Verify the output file(s)
_verify_and_get_output_records(
test_data,
output_file,
is_ground_truth_col=True,
)

_error_mssg_few_shot_data_shortage = "Unable to find 10 few shots after 100 retries"
_error_mssg_ground_truth_column_not_found = "Ground truth column is not present in the data"

@pytest.mark.parametrize(
"test_data, prompt_type, n_shots, \
few_shot_data, prompt_pattern, output_pattern",
few_shot_data, prompt_pattern, output_pattern, ground_truth_column, expected_error_message",
[
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "completions", 10,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE,
_prompt_pattern_test, _output_pattern_test),
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test, _ground_truth_column, _error_mssg_few_shot_data_shortage),
(Constants.PROMPTCRAFTER_SAMPLE_INPUT_FILE, "completions", 1,
Constants.PROMPTCRAFTER_SAMPLE_FEWSHOT_FILE, _prompt_pattern_test,
_output_pattern_test, "invalid_column", _error_mssg_ground_truth_column_not_found),
]
)
def test_invalid_prompt_crafter(
Expand All @@ -253,11 +266,12 @@ def test_invalid_prompt_crafter(
few_shot_data: str,
prompt_pattern: str,
output_pattern: str,
ground_truth_column: str,
expected_error_message: str,
output_file: str = os.path.join(
os.path.dirname(__file__), 'data/prompt_crafter_output.jsonl'),
):
"""Test for valid input dataset."""
expected_error_message = "Unable to find 10 few shots after 100 retries"
try:
# Run the script
self._run_prompt_crafter_script(
Expand All @@ -268,6 +282,7 @@ def test_invalid_prompt_crafter(
output_file=output_file,
output_pattern=output_pattern,
few_shot_data=few_shot_data,
ground_truth_column=ground_truth_column,
)
except subprocess.CalledProcessError as e:
exception_message = e.output.strip()
Expand All @@ -280,6 +295,7 @@ def _run_prompt_crafter_script(
n_shots: int,
prompt_pattern: str,
output_file: str,
ground_truth_column: Optional[str] = None,
few_shot_separator: Optional[str] = None,
prefix: Optional[str] = None,
system_message: Optional[str] = None,
Expand Down Expand Up @@ -329,5 +345,7 @@ def _run_prompt_crafter_script(
args.extend(["--output_pattern", f'"{output_pattern}"'])
if few_shot_data is not None:
args.extend(["--few_shot_data", f'"{few_shot_data}"'])
if ground_truth_column is not None:
args.extend(["--ground_truth_column_name", f"{ground_truth_column}"])

run_command(str(" ".join(args)))

0 comments on commit 1f907ee

Please sign in to comment.