forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
samsum_dataset.py
52 lines (38 loc) · 1.81 KB
/
samsum_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://huggingface.co/datasets/samsum
import copy
import datasets
from unittest.mock import patch
@patch('builtins.input', return_value="N")
def load_samsum(split, _):
try:
ds = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
except ValueError as e:
if "trust_remote_code" in str(e):
raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
else:
raise e
return ds
def get_preprocessed_samsum(dataset_config, tokenizer, split):
dataset = load_samsum(split)
prompt = (
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
)
def apply_prompt_template(sample):
return {
"prompt": prompt.format(dialog=sample["dialogue"]),
"summary": sample["summary"],
}
dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
def tokenize_add_label(sample):
prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
sample = {
"input_ids": prompt + summary,
"attention_mask" : [1] * (len(prompt) + len(summary)),
"labels": [-100] * len(prompt) + summary,
}
return sample
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
return dataset