Skip to content

Commit

Permalink
WIP script to fine-tune data for ui-gen
Browse files Browse the repository at this point in the history
  • Loading branch information
kannangce committed Oct 23, 2023
1 parent cf91ad2 commit 62b889e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch>=2.1.0
lightning @ git+https://github.com/Lightning-AI/lightning@71aed751f7f0ca8422ddca256e602099070f490b
jsonargparse[signatures] # CLI
jinja2
142 changes: 142 additions & 0 deletions scripts/prepare_ui_gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import logging
import sys
from pathlib import Path
from typing import Optional

import torch
from torch.utils.data import random_split
from tqdm import tqdm

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
logger = logging.getLogger(__name__)
sys.path.append(str(wd))

from lit_gpt.tokenizer import Tokenizer
from jinja2 import Template

COLUMNS = ("instruction", "input", "output")

# The template is specific for ui-gen data, be conscious when using for other data.

TEMPLATE = Template('''<s>
<|SYSTEM|>The response MUST be a valid JSON. Generate UI-DSL for the below input and context.<|END_SYSTEM|>
{%- if context -%}<|CONTEXT|>{{ context }}<|END_CONTEXT|>{%- endif -%}
<|INPUT|>{{ prompt }}<|END_INPUT|>
<|OUTPUT|>```{{ response }}```<|END_OUTPUT|>
</s>''')

def prepare(
csv_path: Path,
destination_path: Path = Path("data/csv"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
test_split_fraction: float = 0.1,
seed: int = 42,
mask_inputs: bool = False,
ignore_index: int = -1,
max_seq_length: Optional[int] = None,
) -> None:
"""Prepare a CSV dataset for instruction tuning.
The output is a training and test dataset saved as `train.pt` and `test.pt`,
which stores the preprocessed and tokenized prompts and labels.
"""
if max_seq_length is None:
with open(checkpoint_dir / "lit_config.json", "r") as file:
config = json.load(file)
max_seq_length = config["block_size"]

destination_path.mkdir(parents=True, exist_ok=True)
logger.info("Loading data file ...")
import pandas as pd

df = pd.read_csv(csv_path, dtype=str).fillna("")
if not (df.columns.values == COLUMNS).all():
raise ValueError(f"CSV columns must be {COLUMNS}, found {df.columns.values}")
data = json.loads(df.to_json(orient="records", indent=4))

print("Loading tokenizer...")
tokenizer = Tokenizer(checkpoint_dir)

# Partition the dataset into train and test
train_set, test_set = random_split(
data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed)
)
train_set, test_set = list(train_set), list(test_set)

print(f"train has {len(train_set):,} samples")
print(f"test has {len(test_set):,} samples")

print("Processing train split ...")
train_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(train_set)
]
torch.save(train_set, destination_path / "train.pt")

print("Processing test split ...")
test_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(test_set)
]
torch.save(test_set, destination_path / "test.pt")


def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict:
"""Processes a single sample.
Each sample in the dataset consists of:
- instruction: A string describing the task
- input: A string holding a special input value for the instruction.
This only applies to some samples, and in others this is empty.
- output: The response string
This function processes this data to produce a prompt text and a label for
supervised training. The prompt text is formed as a single message including both
the instruction and the input. The label/target is the same message but with the
response attached.
Finally, both the prompt and the label get tokenized. If desired, all tokens
in the label that correspond to the original input prompt get masked out (default).
"""
full_prompt = generate_prompt(example)
full_prompt_and_response = full_prompt + example["output"]
encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)
encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)

# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_full_prompt_and_response.clone()
if mask_inputs:
labels[: len(encoded_full_prompt)] = ignore_index

return {
**example,
"input_ids": encoded_full_prompt_and_response,
"input_ids_no_response": encoded_full_prompt,
"labels": labels,
}


def generate_prompt(example: dict) -> str:
"""Generates a standardized message to prompt the model with an instruction, optional input and a
'response' field."""
return TEMPLATE.render(prompt=example['instruction'], response=example['output'], context=example.get('input', ''))


if __name__ == "__main__":
from jsonargparse import CLI

CLI(prepare, as_positional=False)

0 comments on commit 62b889e

Please sign in to comment.