-
Notifications
You must be signed in to change notification settings - Fork 0
/
lora.py
401 lines (339 loc) · 14.6 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import os
import sys
import time
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple
import lightning as L
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate.base import generate
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import (
CLI,
check_valid_checkpoint_dir,
chunked_cross_entropy,
get_default_supported_precision,
load_checkpoint,
num_parameters,
)
from scripts.prepare_alpaca import generate_prompt
#
eval_interval = 100
save_interval = 100
eval_iters = 100
eval_max_new_tokens = 100
log_interval = 1
devices = 1
# Hyperparameters
learning_rate = 3e-4
batch_size = 128
# To work in T4 SingleGPU
micro_batch_size = 2
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 # train dataset size
weight_decay = 0.01
lora_r = 16
lora_alpha = 32
lora_dropout = 0.05
lora_query = True
lora_key = True
lora_value = True
lora_projection = True
lora_mlp = True
lora_head = True
warmup_steps = 100
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
#
def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: int = 1,
seed: int = 1337,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_query: bool = True,
lora_key: bool = False,
lora_value: bool = True,
lora_projection: bool = False,
lora_mlp: bool = False,
lora_head: bool = False,
io: IOArgs = IOArgs(
train_data_dir=Path("data/alpaca"),
val_data_dir=Path("data/alpaca"),
checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir=Path("out/lora/alpaca"),
),
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=128,
micro_batch_size=4,
lr_warmup_steps=100,
epochs=5,
epoch_size=50000,
learning_rate=3e-4,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
) -> None:
print(locals())
precision = precision or get_default_supported_precision(training=True)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
" --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
if not any((lora_query, lora_key, lora_value, lora_projection, lora_mlp, lora_head)):
fabric.print("Warning: all LoRA layers are disabled!")
fabric.print(hparams)
fabric.launch(
main,
devices,
seed,
Config.from_name(
name=io.checkpoint_dir.name,
r=lora_r,
alpha=lora_alpha,
dropout=lora_dropout,
to_query=lora_query,
to_key=lora_key,
to_value=lora_value,
to_projection=lora_projection,
to_mlp=lora_mlp,
to_head=lora_head,
),
io,
train,
eval,
)
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(io, train, eval)
steps_per_epoch = train.epoch_size // devices // train.batch_size(devices)
lr_max_steps = train.epochs * steps_per_epoch
check_valid_checkpoint_dir(io.checkpoint_dir)
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(io.out_dir, exist_ok=True)
train_data = torch.load(io.train_data_dir / "train.pt")
val_data = torch.load(io.val_data_dir / "test.pt")
checkpoint_path = io.checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=(devices > 1)):
model = GPT(config)
mark_only_lora_as_trainable(model)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
model = fabric.setup_module(model)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
import bitsandbytes as bnb
optimizer_cls = bnb.optim.PagedAdamW
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
trainable_params, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2)
)
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
fabric.seed_everything(1337 + fabric.global_rank)
train_time = time.perf_counter()
fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
# Save the final LoRA checkpoint at the end of training
save_path = io.out_dir / "lit_model_lora_finetuned.pth"
save_lora_checkpoint(fabric, model, save_path)
def fit(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_data: List[Dict],
val_data: List[Dict],
devices: int,
io: IOArgs,
train: TrainArgs,
eval: EvalArgs,
) -> None:
tokenizer = Tokenizer(io.checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
# The existing code model.max_seq_length = longest_seq_length
# sets the maximum length based on the training data, which seem to less. Hence setting it to a hardcoded number.
model.max_seq_length = 500
# model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check
throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
total_t0 = time.perf_counter()
for iter_num in range(1, train.max_iters(devices) + 1):
iter_t0 = time.perf_counter()
input_ids, targets = get_batch(
fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None
)
is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
step_count += 1
total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss_item = loss.item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
fabric.print(
f"iter {iter_num} | step {step_count}: loss {loss_item:.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}"
)
if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, eval, train)
t1 = time.perf_counter() - t0
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.barrier()
if not is_accumulating and step_count % train.save_interval == 0:
checkpoint_path = io.out_dir / f"iter-{iter_num:06d}-ckpt.pth"
save_lora_checkpoint(fabric, model, checkpoint_path)
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
for k in range(eval.max_iters):
input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length)
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
output = tokenizer.decode(output)
fabric.print(output)
model.train()
return val_loss
def get_batch(
fabric: L.Fabric,
data: List[Dict],
micro_batch_size: int,
max_seq_length: Optional[int],
longest_seq_ix: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data), (micro_batch_size,))
if longest_seq_ix is not None:
# force the longest sample at the beginning so potential OOMs happen right away
ix[0] = longest_seq_ix
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
# this could be `longest_seq_length` to have a fixed size for all batches
max_len = max(len(s) for s in input_ids)
def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
# Truncate if needed
if max_seq_length:
x = x[:, :max_seq_length]
y = y[:, :max_seq_length]
if fabric.device.type == "cuda" and x.device.type == "cpu":
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
else:
x, y = fabric.to_device((x, y))
return x, y
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": lora_filter})
def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [
(io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]),
(train, ["epoch_size", "epochs"]),
(eval, ["max_new_tokens"]),
]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if issues:
raise ValueError("\n".join(issues))
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
CLI(setup)