-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add formatting function alpaca (#161)
* utility functions to format datasets using template Signed-off-by: Sukriti-Sharma4 <[email protected]> * add tests and formatter as arg Signed-off-by: Sukriti-Sharma4 <[email protected]> * update tests to use template to avoid warnings Signed-off-by: Sukriti-Sharma4 <[email protected]> * update README and tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix:formatter Signed-off-by: Sukriti-Sharma4 <[email protected]> * Update README.md Signed-off-by: Sukriti Sharma <[email protected]> * fix imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix pylint Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * address review comments- function names Signed-off-by: Sukriti-Sharma4 <[email protected]> * formatting fix Signed-off-by: Sukriti-Sharma4 <[email protected]> * update error message Signed-off-by: Sukriti-Sharma4 <[email protected]> * restrict JSON fields templates Signed-off-by: Sukriti-Sharma4 <[email protected]> --------- Signed-off-by: Sukriti-Sharma4 <[email protected]> Signed-off-by: Sukriti Sharma <[email protected]>
- Loading branch information
Showing
8 changed files
with
350 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[ | ||
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}, | ||
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}, | ||
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}, | ||
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}, | ||
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}, | ||
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}, | ||
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}, | ||
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}, | ||
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}, | ||
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
# Third Party | ||
import datasets | ||
import pytest | ||
|
||
# First Party | ||
from tests.data import TWITTER_COMPLAINTS_DATA | ||
|
||
# Local | ||
from tuning.utils import data_utils | ||
|
||
|
||
def test_apply_custom_formatting_template(): | ||
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
# First response from the data file that is read. | ||
expected_response = ( | ||
"### Input: @HMRCcustomers No this is my first job" | ||
+ " \n\n ### Response: no complaint" | ||
) | ||
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( | ||
json_dataset, template | ||
) | ||
# a new dataset_text_field is created in Dataset | ||
assert dataset_text_field in formatted_dataset["train"][0] | ||
assert formatted_dataset["train"][0][dataset_text_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_template_adds_eos_token(): | ||
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
# First response from the data file that is read. | ||
expected_response = ( | ||
"### Input: @HMRCcustomers No this is my first job" | ||
+ " \n\n ### Response: no complaintEOS" | ||
) | ||
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( | ||
json_dataset, template, "EOS" | ||
) | ||
# a new dataset_text_field is created in Dataset | ||
assert dataset_text_field in formatted_dataset["train"][0] | ||
assert formatted_dataset["train"][0][dataset_text_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): | ||
"""Tests that the formatting function will throw error if wrong keys are passed to template""" | ||
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) | ||
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" | ||
with pytest.raises(KeyError): | ||
data_utils.apply_custom_formatting_template(json_dataset, template, "EOS") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.