-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
117 lines (100 loc) · 2.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
from datasets import Dataset
from peft import (
LoraConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
)
from trl import SFTTrainer
import os
import glob
MODEL_NAME = "model"
TRAIN_CHUNK_DIR = "/dataset/train"
VAL_CHUNK_DIR = "/dataset/val"
OUTPUT_DIR = "/output/trained_model"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# List files in the /dataset directory
dataset_files = os.listdir(TRAIN_CHUNK_DIR)
for file in dataset_files:
print(file)
# Load dataset chunks
def load_chunks(chunk_dir):
all_data = []
chunk_files = sorted(glob.glob(os.path.join(chunk_dir, "*.json")))
for chunk_file in chunk_files:
with open(chunk_file, "r") as f:
data = json.load(f)
all_data.extend(data)
if all_data:
keys = all_data[0].keys()
data_dict = {key: [d[key] for d in all_data] for key in keys}
return Dataset.from_dict(data_dict)
else:
return Dataset.from_dict({})
train_dataset = load_chunks(TRAIN_CHUNK_DIR)
valid_dataset = load_chunks(VAL_CHUNK_DIR)
print(len(train_dataset), len(valid_dataset))
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
# model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto"
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
peft_params = LoraConfig(
r=16,
lora_alpha=64,
target_modules=[
"q_proj",
"v_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_params)
model.print_trainable_parameters()
EPOCHS = 3
MICRO_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 3e-4
training_params = transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
learning_rate=LEARNING_RATE,
num_train_epochs=EPOCHS,
bf16=True,
logging_steps=2,
optim="paged_adamw_32bit",
lr_scheduler_type="cosine",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=0.2,
save_steps=0.2,
output_dir=OUTPUT_DIR,
save_total_limit=3,
load_best_model_at_end=True,
logging_dir=OUTPUT_DIR,
report_to="wandb"
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
peft_config=peft_params,
dataset_text_field="prompt",
max_seq_length=2000,
tokenizer=tokenizer,
args=training_params,
packing=False,
)
trainer.train()