Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DataProcessor v1 #381

Merged
merged 17 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ callbacks=cb_,
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_

# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
ignored-argument-names=_.*|^ignored_|^unused_|kwargs

# Tells whether we should check for unused import in __init__ files.
init-import=no
Expand Down
30 changes: 30 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpful datasets for configuring individual unit tests.
"""
# Standard
import os

### Constants used for data
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__))
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataprocessor:
type: default
datasets:
- name: apply_custom_data_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: apply_custom_data_formatting_template
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
dataset_template: "dataset_template"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataprocessor:
type: default
datasets:
- name: pretokenized_dataset
data_paths:
- "FILE_PATH"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataprocessor:
type: default
datasets:
- name: text_dataset_input_output_masking
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field: "INPUT"
output_field: "OUTPUT"
110 changes: 110 additions & 0 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
from transformers import AutoTokenizer
import datasets
import pytest

# First Party
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL

# Local
from tuning.data.data_handlers import (
apply_custom_data_formatting_template,
combine_sequence,
)


def test_apply_custom_formatting_template():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

# a new dataset_text_field is created in Dataset
assert formatted_dataset_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
formatted_dataset_field = "formatted_data_field"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(KeyError):
json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
comb_seq = combine_sequence(input_element, output_element)
assert isinstance(comb_seq, str)
assert comb_seq == expected_res


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token)
expected_res += tokenizer.eos_token
assert isinstance(comb_seq, str)
assert comb_seq == expected_res
Loading
Loading