forked from kingoflolz/mesh-transformer-jax
-
Notifications
You must be signed in to change notification settings - Fork 1
/
prepare_dataset_alpaca.py
63 lines (52 loc) · 1.92 KB
/
prepare_dataset_alpaca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from prepare_dataset import *
def generate_prompt(sample):
if "input" in sample and sample["input"]:
promtp = f"""Nedenfor er en instruksjon som beskriver en oppgave, sammen med et input som gir ytterligere kontekst. Skriv et svar som fullfører forespørselen på riktig måte.
### Instruksjon:
{sample["instruction"]}
### Input:
{sample["input"]}
### Respons:
{sample["output"]}"""
else:
promtp = f"""Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte.
### Instruksjon:
{sample["instruction"]}
### Respons:
{sample["output"]}"""
sample["prompt"] = promtp
return sample
def main_alpaca(args):
GPT2TokenizerFast.max_model_input_sizes[
"gpt2"
] = 1e20 # disables a misleading warning
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
epochs = args.n_repack_epochs
seq_length = args.sequence_length
ds = datasets.load_dataset(
args.dataset,
name=args.dataset_config or None,
split=args.dataset_split,
streaming=False,
use_auth_token=True,
)
if not args.preserve_data_order:
ds = ds.shuffle(seed=args.seed)
ds = ds.map(generate_prompt, desc="Generating prompts")
ds = ds.map(lambda x: tokenizer(x["prompt"]), batched=True, desc="Tokenizing", num_proc=16)
ds.set_epoch = ds.shuffle
seqs = tqdm(
split_every(
seq_length,
iter_tokens(
generate_sample(ds, epochs, "input_ids", args.preserve_data_order), tokenizer.eos_token_id
),
),
desc="Writing token ids as TF records",
)
filepath = args.output_dir / f"{args.name}.tfrecords"
seq_count = write_tfrecord(seqs, filepath.as_posix())
filepath_seq = args.output_dir / f"{args.name}_{seq_count}.tfrecords"
os.rename(filepath.as_posix(), filepath_seq.as_posix())
if __name__ == "__main__":
main_alpaca(parse_args())