From 62b889edc73c949bf0e6d2633a1cf28ab39bdb37 Mon Sep 17 00:00:00 2001 From: Kannan R Date: Mon, 23 Oct 2023 22:42:28 +0530 Subject: [PATCH] WIP script to fine-tune data for ui-gen --- requirements.txt | 1 + scripts/prepare_ui_gen_data.py | 142 +++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 scripts/prepare_ui_gen_data.py diff --git a/requirements.txt b/requirements.txt index f76378a..493eb51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch>=2.1.0 lightning @ git+https://github.com/Lightning-AI/lightning@71aed751f7f0ca8422ddca256e602099070f490b jsonargparse[signatures] # CLI +jinja2 diff --git a/scripts/prepare_ui_gen_data.py b/scripts/prepare_ui_gen_data.py new file mode 100644 index 0000000..0ac0f51 --- /dev/null +++ b/scripts/prepare_ui_gen_data.py @@ -0,0 +1,142 @@ +import json +import logging +import sys +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +logger = logging.getLogger(__name__) +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from jinja2 import Template + +COLUMNS = ("instruction", "input", "output") + +# The template is specific for ui-gen data, be conscious when using for other data. + +TEMPLATE = Template(''' +<|SYSTEM|>The response MUST be a valid JSON. Generate UI-DSL for the below input and context.<|END_SYSTEM|> +{%- if context -%}<|CONTEXT|>{{ context }}<|END_CONTEXT|>{%- endif -%} +<|INPUT|>{{ prompt }}<|END_INPUT|> +<|OUTPUT|>```{{ response }}```<|END_OUTPUT|> +''') + +def prepare( + csv_path: Path, + destination_path: Path = Path("data/csv"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + test_split_fraction: float = 0.1, + seed: int = 42, + mask_inputs: bool = False, + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare a CSV dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open(checkpoint_dir / "lit_config.json", "r") as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + logger.info("Loading data file ...") + import pandas as pd + + df = pd.read_csv(csv_path, dtype=str).fillna("") + if not (df.columns.values == COLUMNS).all(): + raise ValueError(f"CSV columns must be {COLUMNS}, found {df.columns.values}") + data = json.loads(df.to_json(orient="records", indent=4)) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed) + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "input_ids_no_response": encoded_full_prompt, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + return TEMPLATE.render(prompt=example['instruction'], response=example['output'], context=example.get('input', '')) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(prepare, as_positional=False)