Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Apr 9, 2024
2 parents ae7c2a3 + c9d1cdc commit db46325
Show file tree
Hide file tree
Showing 29 changed files with 166 additions and 106 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/build_linux_wheels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Build Linux Wheels

on:
pull_request:
push:
branches:
- nightly
- main
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
workflow_dispatch:

permissions:
id-token: write
contents: read

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: linux
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-cuda: disable
with-rocm: disable
build:
needs: generate-matrix
name: ${{ matrix.repository }}
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
strategy:
fail-fast: false
with:
repository: pytorch/torchtune
ref: ""
pre-script: ""
post-script: ""
smoke-test-script: ""
package-name: torchtune
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
trigger-event: ${{ github.event_name }}
10 changes: 1 addition & 9 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
# Add necessary recipe files (not as importable)
recursive-include recipes *.py *.yaml

# Add requirements
include requirements.txt
include dev-requirements.txt

# Remove tests from packaging
prune tests
prune tests # Remove all testing files from final dist/
9 changes: 0 additions & 9 deletions dev-requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion docs/source/examples/finetune_llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ To run the recipe without any changes on 4 GPUs, launch a training run using Tun

.. code-block:: bash
tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config full_finetune_distributed
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config full_finetune_distributed
Dataset
-------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/lora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ You can then run the following command to perform a LoRA finetune of Llama2-7B u

.. code-block:: bash
tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed
.. note::
Make sure to point to the location of your Llama2 weights and tokenizer. This can be done
Expand Down Expand Up @@ -288,7 +288,7 @@ Let's run this experiment. We can also increase alpha (in general it is good pra

.. code-block:: bash
tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed \
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed \
lora_attn_modules='[q_proj, k_proj, v_proj, output_proj]' \
lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1
Expand Down
66 changes: 66 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,69 @@
# ---- All project specifications ---- #
[project]
name = "torchtune"
description = "A native-PyTorch library for LLM fine-tuning"
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
authors = [
{ name = "PyTorch Team", email = "[email protected]" },
]
keywords = ["pytorch", "finetuning", "llm"]
dependencies = [
# Hugging Face integrations
"datasets",
"huggingface_hub",
"safetensors",

# Miscellaneous
"sentencepiece",
"tqdm",
"omegaconf",

# Quantization
"torchao==0.1",
]
dynamic = ["version"]

[project.urls]
GitHub = "https://github.com/pytorch/torchtune"
Documentation = "https://pytorch.org/torchtune/main/index.html"
Issues = "https://github.com/pytorch/torchtune/issues"

[project.scripts]
tune = "torchtune._cli.tune:main"

[project.optional-dependencies]
dev = [
"bitsandbytes",
"pre-commit",
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-integration",
"tensorboard",
"transformers",
"wandb",
]

[tool.setuptools.dynamic]
version = {attr = "torchtune.__version__"}


# ---- Explicit project build information ---- #
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = [""]
include = ["torchtune*", "recipes*"]

[tool.setuptools.package-data]
recipes = ["configs/*.yaml", "configs/*/*.yaml"]


# ---- Tooling specifications ---- #
[tool.usort]
first_party_detection = false

Expand Down
2 changes: 1 addition & 1 deletion recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ checkpointer:
# make sure to change the checkpointer component
checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_files: [meta_model_0.4w.pt]
checkpoint_files: [meta_model_0-4w.pt]
# Quantization Arguments
quantizer:
Expand Down
5 changes: 2 additions & 3 deletions recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# TODO: Add proper link to pytorch.org/torchtune/... when the docs are live.
raise ModuleNotFoundError(
"The TorchTune recipes folder isn't a package and you should not import anything from there. "
"Refer to our docs for detailed instructions on how to use the recipes!"
"This file only exists for testing reasons."
"The TorchTune recipes directory isn't a package and you should not import anything from here. "
"Refer to our docs for detailed instructions on how to use recipes!"
)
6 changes: 3 additions & 3 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
#
# This config assumes that you've run the following command before launching
# this run:
# tune download --repo-id google/gemma-2b \
# tune download google/gemma-2b \
# --hf-token <HF_TOKEN> \
# --output-dir /tmp/gemma2
#
# To launch on 4 devices, run the following command from root:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config gemma/2B_full \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config gemma/2B_full \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config llama2/13B_full \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \
# --config llama2/13B_lora \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config llama2/7B_full \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/llama2/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
# --config llama2/7B_full_single_device \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
# --config llama2/7B_full_single_device \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \
# --config llama2/7B_lora \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \
# tune run --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \
# --config 7B_lora_single_device \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \
# tune run --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \
# --config 7B_qlora_single_device \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/mistral/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device \
# --config mistral/7B_full_single_device \
# --config llama2/7B_full_single_device \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.
Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.total_training_steps = 0

def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg,
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.total_training_steps = 0

def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg,
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
Expand Down
4 changes: 2 additions & 2 deletions recipes/gemma_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.total_training_steps = 0

def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg,
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def __init__(self, cfg: DictConfig) -> None:
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. This includes the
base model weights. If resume_from_checkpoint is True, this also includes
the adapter weights and recipe state
"""
self._checkpointer = config.instantiate(
cfg,
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
Expand Down Expand Up @@ -147,7 +147,7 @@ def setup(self, cfg: DictConfig) -> None:
"""
self._metric_logger = config.instantiate(cfg.metric_logger)

checkpoint_dict = self.load_checkpoint(cfg=cfg.checkpointer)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down
22 changes: 14 additions & 8 deletions recipes/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict

import torch
Expand Down Expand Up @@ -69,15 +71,19 @@ def quantize(self, cfg: DictConfig):
def save_checkpoint(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]
quantized_file_name = (
cfg.checkpointer.output_dir
+ file_name
+ "."
+ self._quantization_mode
+ ".pt"

output_dir = Path(cfg.checkpointer.output_dir)
output_dir.mkdir(exist_ok=True)
checkpoint_file = Path.joinpath(
output_dir, f"{file_name}-{self._quantization_mode}"
).with_suffix(".pt")

torch.save(ckpt_dict, checkpoint_file)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB "
f"saved to {checkpoint_file}"
)
torch.save(ckpt_dict, quantized_file_name)
logger.info(f"Saved quantized model to {quantized_file_name}")


@config.parse
Expand Down
Loading

0 comments on commit db46325

Please sign in to comment.