Skip to content

Commit

Permalink
address comments and rebase over detailed design
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Apr 12, 2024
1 parent 0cc6bcb commit bf65f31
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 70 deletions.
128 changes: 58 additions & 70 deletions architecture_records/002-acceleration-framework.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Training and FineTuning Acceleration Framework
# Training Enhancements Framework

**Deciders(s)**: Sukriti Sharma ([email protected]), Raghu Ganti ([email protected]), Laura Wynter ([email protected]), Fabian Lim ([email protected]), Aaron Chew ([email protected])
**Date (YYYY-MM-DD)**: 2024-04-11
Expand Down Expand Up @@ -31,14 +31,18 @@ Currently `sft_trainer.py` only can access those tools already integrated in HF.
2. Prefix tuning from [PEFT](https://github.com/huggingface/peft).
3. FSDP training from [accelerate](https://github.com/huggingface/accelerate).

Below are various reasons for a framework to integrate custom training tools into [`sft_trainer.py`].
Below are various reasons for a framework to integrate custom training tools into `sft_trainer.py`.
* Enable quick integrations of open-source techniques that have yet to be integrated into Huggingface.
* Enable integrations of custom techniques developed by IBM researchers, that are not planned be integrated into Huggingface.

Recently, it has been observed that new training techniques are released with an incomplete "preview" version. These "preview" versions tend to be not be fully integrated into OSS. Therefore, using new techniques typically involve additional work. This framework aims to allow timely integrations of such techniques into `sft_trainer.py`. A short exampler list of powerful training techniques but are "preview"-only include:
- [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ).
* 4-bit quantization kernels to reduce memory storage of the base weights.
- [Unsloth](https://github.com/unslothai/unsloth).
* Fused operation kernels
* Kernels for common model architectures (e.g., cross-entropy losses, RoPE embeddings and RMS norms).
- [megablocks](https://github.com/databricks/megablocks).
- [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ).
* acceleration package for distributing mixture-of-experts that improves upon FSDP sharding.

<!--
Why this is a valuable problem to solve? What background information is needed to show how this design addresses the problem?
Expand All @@ -48,7 +52,6 @@ Which users are affected by the problem? Why is it a problem? What data supports

### User Benefit


Users will benefit from powerful training tools integrated into the platform, that are not readily accessible from huggingface. With these tools, users will be able to train models with less GPU resources and/or quicker, resulting in quicker turnaround and improved user experience.

<!--
Expand Down Expand Up @@ -80,7 +83,7 @@ The framework is designed to only modify them model at two integration points in
3. an *optional* `callback` method to install `TrainerCallbacks` (if needed, e.g. custom save logic).

```python
class FrameworkPlugin:
class TuningAccelerationPlugin:

# if specified, will restricted plugin to specified model archs
# - useful if method is restricted to certain model architectures, e.g., only used
Expand Down Expand Up @@ -111,26 +114,24 @@ Even though they are all optional, at least one out of the three should be imple
### Dependency Management

Take note:
- all plugin deps must be enforced to be optional deps in `pyproject.toml`, see [116](#116). If the dep is not installed, and the plugin is enabled, raise exception.
- all plugin deps must be enforced to be optional deps in `pyproject.toml`, see [116](https://github.com/foundation-model-stack/fms-hf-tuning/pull/116). If the dep is not installed, and the plugin is enabled, raise exception.
- any plugin that requires CUDA build tools (e.g. `triton` kernels) will need to be run in with [CUDA Toolkit dependencies (see this link for an example of a Debian installation)](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Debian&target_version=11&target_type=deb_local).
* in such cases, both the library (e.g. `triton`), and CUDA tools, need to be checked.


* whenever CUDA is needed, the framework will check for the CUDA_TOOLS dependency.

### Minimal and Controlled Changes to Training Script

All proposed code changes to [`sft_trainer.py`] contained in minimal lines of code:
All proposed code changes to `sft_trainer.py` contained in minimal lines of code:
- Plugins loaded by discovery; transparent to `sft_trainer.py`.
- Plugin configuration automatically parsed.
- Passthrough to original operation if `Framework` is disabled.

```python
from tuning.proposed_framework import Framework
from tuning.acceleration import AccelerationFramework

# Minor Change 1: creating the framework object
framework = None
if framework_args.config_file is not None:
framework = Framework(framework_args.config_file)
framework = AccelerationFramework(framework_args.config_file)

# Minor Change 2: custom loader (if necessary)
_model_loader = AutoModelForCausalLM.from_pretrained # default
Expand Down Expand Up @@ -162,18 +163,26 @@ trainer.add_callbacks(framework.callbacks())

# call train
trainer.train()

```

The picture below summarizes the above discussion.
The picture below summarizes the above discussion in more detail. It demonstrates how the design will not contradict internal workings of [`SFTTrainer`].
- Model is modified and then control passed to [`SFTTrainer`].
- [`SFTTrainer`] also performs model augmentation internally (e.g., it installs PEFT adapters if `peft_config` is passed in).
* However, [`SFTTrainer`]'s model augmentation should be passed through if configs are omitted (e.g., if `peft_config = None`).
- [`SFTTrainer`] will prepare model for distributed training (e.g. wrap with `FSDP`) internally.
* thus Plugin implementers need to be aware that `FrameworkPlugin.augmentation` should not interfere with any model preperation that [`SFTTrainer`] will perform.
* thus Plugin implementers need to be aware that `TuningAccelerationPlugin.augmentation` should not interfere with any model preperation that [`SFTTrainer`] will perform.

![Framework](imgs/002-framework.png)

### Acceleration Methods

A top priority is to incorporate methods that enchance PEFT. While PEFT is known to be memory efficient, it is known to be slower than full-finetuning if not *properly optimized*. Also, another topic of interest is to add support for 4D masks to enable packing while instruction tuning; this acceleration may require some adjustments to the data processing.
1. Add 4-bit `triton` kernels for PEFT base weights.
2. Add fused kernels for PEFT base models, as well as reusable kernels for other models (e.g. cross-entropy loss, RoPE).
3. Add support for 4D masking (may require `TuningAccelerationPlugin.augmentation` to also access the datasets).
4. Add support for distributed training (i.e., `megablocks`).


<!--
This is the meat of the document, where you explain the decision. If you have multiple alternatives, be sure to use sub-sections for better separation of the idea, and list pros/cons to each approach. If there are alternatives that you have eliminated, you should also list those here, and explain why you believe your chosen approach is superior.
Expand All @@ -182,21 +191,26 @@ Make sure you’ve thought through and addressed the following sections. If a se

### Alternatives Considered

[IN PROGRESS]
We have considered the following.

Consideration | Reason for Rejection
--|--
Only having `augmentation` and not employing any drop-in `loading` | Some methods like quantization has specialized checkpoints which require special loaders. Attempting to modify already instantiated models to load such specialized checkpoints can be error prone. Futhermore, for future extensions not having any drop-in `loading` can be a severe handicap.
Adding tuning enchancements directly to `SFT_Trainer` | The trainer is a very complex object, and manipulating it in unintended ways can have serious repurcussions. As such, we choose to allow `TuningAccelerationPlugin.augmentation` to modify only the `Accelerator` object which can already do quite a bit of things, like adjust the FSDP wrapping policy (for distributed training).

1. Alternative script to [`sft_trainer.py`].
2. Do not touch

<!--
- Make sure to discuss the relative merits of alternatives to your proposal.
-->

## Consequences

[IN PROGRESS]
We have considered the following

Drawbacks:
- cannot support any plugin design that requires a controlled call in places not supported by `TrainerCallbacks`.
Consideration | Background | Steps to Alleviate
--|--|--
Hosting custom packages | Sometimes the enhancements depend on OSS packages that have been improved on. | If the package
Managing depdencies | | Have `TuningAccelerationPlugin`

<!--
Describe the resulting context, after applying the decision. All consequences should be listed here, not just the "positive" ones. A particular decision may have positive, negative, and neutral consequences, but all of them affect the team and project in the future.
Expand All @@ -205,44 +219,35 @@ Describe the resulting context, after applying the decision. All consequences sh

## Detailed Design


<!--
This section is optional. Elaborate on details if they’re important to understanding the design, but would make it hard to read the proposal section above.
-->

`acceleration.yaml`
```
In this section we demonstrate how to implement an `AutoGPTQPlugin` that implements an accelerate PEFT training mode with 4 bit GPTQ base weights.

This is an `acceleration.yaml`
```yaml
quantization:
- requires_quantization: True
- quant_num_bits: 4
- quantize_cache: '/home/user/'
- quant_kernel: 'gptq-tritonv2'
- unsloth: False
```
```
```python
from functools import partial

class Framework:
class AutoGPTQPlugin(TuningAccelerationPlugin):
def __init__(self, acceleration_config_path:str) -> None:
self.acceleration_config = self.read_acceleration_config(acceleration_config_path)
self.num_bits = self.acceleration_config.quant_num_bits
self.requires_custom_loading = self.acceleration_config.requires_custom_loading
self.requires_quantization = self.acceleration_config.requires_quantization
self.quantize_cache = self.acceleration_config['quantize_cache']
self.kernel = self.acceleration_config['quant_kernel']
def read_acceleration_config(self, acceleration_config_path):
pass
# ... initialize config

def callbacks(self, *args, **kwargs):
pass
def model_loader(self):
def model_loader(self, model_path, **kwargs):

# ... maybe quantize if needed
quantize_config = QuantizeConfig(
bits=self.num_bits,
)
if self.requires_quantization:
return partial(AutoGPTQForCausalLM.from_pretrained, quantize_config = quantize_config)
else:
return partial(AutoGPTQForCausalLM.from_quantized, quantize_config = quantize_config)
return AutoGPTQForCausalLM.from_quantized(model_path, quantize_config = quantize_config)

def augmentation(
self,
Expand All @@ -251,29 +256,12 @@ class Framework:
train_args,
peft_config,
):
'''
This function is used for any augmentation of the model before trainings
e.g. quantization, unsloth/PEFT installation and also MegaBlocks patching
'''
if self.requires_quantization:
model.quantize()
model.save_quantized(save_dir = self.quantize_cache)
if peft_config:
# PEFT Installation
if 'gptq' in self.kernel:
from auto_gptq.utils.peft_utils import get_gptq_peft_model
model = get_gptq_peft_model(
model,
peft_config = peft_config,
)
else:
from peft import get_peft_model
model = get_peft_model(
model,
peft_config
)
return model
```
assert peft_config is not None, "need peft_config to install PEFT adapters"

# PEFT Installation
from auto_gptq.utils.peft_utils import get_gptq_peft_model
return get_gptq_peft_model(
model,
peft_config = peft_config,
)
```
Binary file modified architecture_records/imgs/002-framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit bf65f31

Please sign in to comment.