Skip to content

Commit

Permalink
Merge branch 'main' into lit-gpt_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
visanth-techconative authored May 16, 2024
2 parents 5a9755a + 1ae3892 commit 5ae279e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 1 deletion.
5 changes: 5 additions & 0 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from scripts.prepare_alpaca import generate_prompt


def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
Expand Down Expand Up @@ -96,6 +97,7 @@ def setup(

if not any((lora_query, lora_key, lora_value, lora_projection, lora_mlp, lora_head)):
fabric.print("Warning: all LoRA layers are disabled!")

fabric.launch(
main,
devices,
Expand All @@ -118,6 +120,7 @@ def setup(
)



def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(io, train, eval)

Expand Down Expand Up @@ -188,7 +191,9 @@ def fit(
) -> None:
tokenizer = Tokenizer(io.checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)

model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))

fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
Expand Down
1 change: 1 addition & 0 deletions generate/lora_ui_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def main(

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()

with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
jsonargparse[signatures] # CLI
jinja2
torch>=2.2.0
lightning @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af
1 change: 0 additions & 1 deletion scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights
from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load


def merge_lora(
lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
Expand Down

0 comments on commit 5ae279e

Please sign in to comment.