diff --git a/finetune/lora.py b/finetune/lora.py index 42c339b..80cbe51 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -31,6 +31,7 @@ ) from scripts.prepare_alpaca import generate_prompt + def setup( precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, @@ -96,6 +97,7 @@ def setup( if not any((lora_query, lora_key, lora_value, lora_projection, lora_mlp, lora_head)): fabric.print("Warning: all LoRA layers are disabled!") + fabric.launch( main, devices, @@ -118,6 +120,7 @@ def setup( ) + def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: validate_args(io, train, eval) @@ -188,7 +191,9 @@ def fit( ) -> None: tokenizer = Tokenizer(io.checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" diff --git a/generate/lora_ui_gen.py b/generate/lora_ui_gen.py index dbbf44d..b199fb7 100644 --- a/generate/lora_ui_gen.py +++ b/generate/lora_ui_gen.py @@ -120,6 +120,7 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() + with fabric.init_module(empty_init=True): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/requirements.txt b/requirements.txt index 4263d54..a6ecd90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ +jsonargparse[signatures] # CLI +jinja2 torch>=2.2.0 lightning @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af diff --git a/scripts/merge_lora.py b/scripts/merge_lora.py index c25f87f..212461b 100644 --- a/scripts/merge_lora.py +++ b/scripts/merge_lora.py @@ -16,7 +16,6 @@ from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load - def merge_lora( lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"), checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),