diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..7eeb55620 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,34 @@ +--- +name: Bug report +about: Create a report to help us improve +title: "" +labels: "" +assignees: "" +--- + +## Describe the bug + +A clear and concise description of what the bug is. + +## Platform + +Please provide details about the environment you are using, including the following: + +- Interpreter version: +- Library version: + +## Sample Code + +Please include a minimal sample of the code that will (if possible) reproduce the bug in isolation + +## Expected behavior + +A clear and concise description of what you expected to happen. + +## Observed behavior + +What you see happening (error messages, stack traces, etc...) + +## Additional context + +Add any other context about the problem here. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..96b857215 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "" +labels: "" +assignees: "" +--- + +## Is your feature request related to a problem? Please describe. + +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +## Describe the solution you'd like + +A clear and concise description of what you want to happen. + +## Describe alternatives you've considered + +A clear and concise description of any alternative solutions or features you've considered. + +## Additional context + +Add any other context about the feature request here. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/user_story.md b/.github/ISSUE_TEMPLATE/user_story.md new file mode 100644 index 000000000..4b62a291d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/user_story.md @@ -0,0 +1,23 @@ +--- +name: User story +about: A user-oriented story describing a piece of work to do +title: "" +labels: "" +assignees: "" +--- + +## Description + +As a , I want to , so that I can + +## Discussion + +Provide detailed discussion here + +## Acceptance Criteria + + + +- [ ] Unit tests cover new/changed code +- [ ] Examples build against new/changed code +- [ ] READMEs are updated \ No newline at end of file diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..467b55498 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,19 @@ + + +### Description of the change + + + +### Related issue number + + + +### How to verify the PR + + + +### Was the PR tested + + +- [ ] I have added >=1 unit test(s) for every new method I have added. +- [ ] I have ensured all unit tests pass \ No newline at end of file diff --git a/.github/workflows/build-and-publish.yaml b/.github/workflows/build-and-publish.yaml new file mode 100644 index 000000000..7bb377df8 --- /dev/null +++ b/.github/workflows/build-and-publish.yaml @@ -0,0 +1,52 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Build and Publish FMS-hf-tuning Library + +on: + release: + types: [published] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - setup: "3.11" + tox: "py311" + + environment: + name: pypi + url: https://pypi.org/p/fms-hf-tuning + permissions: + id-token: write # IMPORTANT: this permission is mandatory for trusted publishing + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version.setup }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version.setup }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install tox + - name: Build and test with tox + run: tox -e ${{ matrix.python-version.tox }} + - name: Build and check wheel package + run: + tox -e build,twinecheck + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index d926b1220..574aac18a 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -21,10 +21,10 @@ on: branches: [ "main" ] jobs: - build: + lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 uses: actions/setup-python@v4 with: @@ -32,7 +32,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r setup_requirements.txt - - name: Check Formatting + python -m pip install tox + - name: Check formatting run: tox -e fmt - + - name: Run pylint + run: tox -e lint diff --git a/.github/workflows/image.yaml b/.github/workflows/image.yaml new file mode 100644 index 000000000..73bbe5982 --- /dev/null +++ b/.github/workflows/image.yaml @@ -0,0 +1,15 @@ +name: Image +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build image + run: | + docker build -t fms-hf-tuning:dev . -f build/Dockerfile \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 000000000..f8e24265c --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,27 @@ +name: Test +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - setup: "3.9" + tox: "py39" + - setup: "3.10" + tox: "py310" + - setup: "3.11" + tox: "py311" + steps: + - uses: actions/checkout@v4 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install tox + - name: Run unit tests + run: tox -e py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 61f39d73a..af2f64704 100644 --- a/.gitignore +++ b/.gitignore @@ -7,13 +7,15 @@ durations/* coverage*.xml dist htmlcov -build test # IDEs .vscode/ .idea/ +# AIM files +.aim + # Env files .env @@ -26,3 +28,17 @@ venv/ # Tox envs .tox + +# Aim +.aim + +# Backup files and folders +*.bkp +*.bkp.* +*bkp* + +# Build output +/build/lib/ + +# Auto-generated file +/tuning/_version.py diff --git a/.pylintrc b/.pylintrc index 5c7b4676e..e94869511 100644 --- a/.pylintrc +++ b/.pylintrc @@ -443,7 +443,9 @@ disable=raw-checker-failed, attribute-defined-outside-init, abstract-method, pointless-statement, - wrong-import-order + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..a28fcff97 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,11 @@ +##################################################### +# +# List of approvers for fms-hf-tuning repository +# +##################################################### +# +# Learn about CODEOWNERS file format: +# https://help.github.com/en/articles/about-code-owners +# + +* @anhuong @Ssukriti @alex-jw-brooks diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..6d20b395c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,160 @@ +# Contributing + +👍🎉 First off, thank you for taking the time to contribute! 🎉👍 + +The following is a set of guidelines for contributing. These are just guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. + +## What Should I Know Before I Get Started? + +### Code of Conduct + +This project adheres to the [Contributor Covenant](./code-of-conduct.md). By participating, you are expected to uphold this code. + +Please report unacceptable behavior to one of the [Code Owners](./CODEOWNERS). + +### How Do I Start Contributing? + +The below workflow is designed to help you begin your first contribution journey. It will guide you through creating and picking up issues, working through them, having your work reviewed, and then merging. + +Help on open source projects is always welcome and there is always something that can be improved. For example, documentation (like the text you are reading now) can always use improvement, code can always be clarified, variables or functions can always be renamed or commented on, and there is always a need for more test coverage. If you see something that you think should be fixed, take ownership! Here is how you get started: + +## How Can I Contribute? + +NOTE: Before making any contribution, please ensure the content does not include any IBM proprietary information or any specific information about IBM products. + +For any contributions that need design changes/API changes, reach out to maintainers to check if an Architectural Design Record would be beneficial. Reason for ADR: teams agree on the design, to avoid back and forth after writing code. An ADR gives context on the code being written. If requested for an ADR, make a contribution [using the template](./architecture_records/template.md). + +When contributing, it's useful to start by looking at [issues](https://github.com/foundation-model-stack/fms-hf-tuning/issues). After picking up an issue, writing code, or updating a document, make a pull request and your work will be reviewed and merged. If you're adding a new feature or find a bug, it's best to [write an issue](https://github.com/foundation-model-stack/fms-hf-tuning/issues/new) first to discuss it with maintainers. + +To contribute to this repo, you'll use the Fork and Pull model common in many open source repositories. For details on this process, check out [The GitHub Workflow +Guide](https://github.com/kubernetes/community/blob/master/contributors/guide/github-workflow.md) +from Kubernetes. + +When your contribution is ready, you can create a pull request. Pull requests are often referred to as "PR". In general, we follow the standard [GitHub pull request](https://help.github.com/en/articles/about-pull-requests) process. Follow the template to provide details about your pull request to the maintainers. It's best to break your contribution into smaller PRs with incremental changes, and include a good description of the changes. +We require new unit tests to be contributed with any new functionality added. + +Before sending pull requests, make sure your changes pass formatting, linting and unit tests. These checks will run with the pull request builds. Alternatively, you can run the checks manually on your local machine [as specified below](#development). + +#### Dependencies +If additional new Python module dependencies are required, think about where to put them: + +- If they're required for fms-hf-tuning, then append them to the [dependencies](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/pyproject.toml#L28) in the pyproject.toml. +- If they're optional dependencies for additional functionality, then put them in the pyproject.toml file like were done for [flash-attn](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/pyproject.toml#L44) or [aim](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/pyproject.toml#L45). +- If it's an additional dependency for development, then add it to the [dev](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/pyproject.toml#L43) dependencies. + +#### Code Review + +Once you've [created a pull request](#how-can-i-contribute), maintainers will review your code and may make suggestions to fix before merging. It will be easier for your pull request to receive reviews if you consider the criteria the reviewers follow while working. Remember to: + +- Run tests locally and ensure they pass +- Follow the project coding conventions +- Write detailed commit messages +- Break large changes into a logical series of smaller patches, which are easy to understand individually and combine to solve a broader issue + +Maintainers will perform "squash and merge" actions on PRs in this repo, so it doesn't matter how many commits your PR has, as they will end up being a single commit after merging. + +### Reporting Bugs + +This section guides you through submitting a bug report. Following these guidelines helps maintainers and the community understand your report ✏️, reproduce the behavior 💻, and find related reports 🔎. + +#### How Do I Submit A (Good) Bug Report? + +Bugs are tracked as [GitHub issues using the Bug Report template](https://github.com/foundation-model-stack/fms-hf-tuning/issues/new?template=bug_report.md). Create an issue on that and provide the information suggested in the bug report issue template. + +### Suggesting Enhancements + +This section guides you through submitting an enhancement suggestion, including completely new features, tools, and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion ✏️ and find related suggestions 🔎 + +#### How Do I Submit A (Good) Enhancement Suggestion? + +Enhancement suggestions are tracked as [GitHub issues using the Feature Request template](https://github.com/foundation-model-stack/fms-hf-tuning/issues/new?template=feature_request.md). Create an issue and provide the information suggested in the feature requests or user story issue template. + +#### How Do I Submit A (Good) Improvement Item? + +Improvements to existing functionality are tracked as [GitHub issues using the User Story template](https://github.com/foundation-model-stack/fms-hf-tuning/issues/new?template=user_story.md). Create an issue and provide the information suggested in the feature requests or user story issue template. + +## Development + +### Set up your dev environment + +The following tools are required: + +- [git](https://git-scm.com) +- [python](https://www.python.org) (v3.8+) +- [pip](https://pypi.org/project/pip/) (v23.0+) + +Installation: +``` +pip install -U datasets +pip install -e . +``` +
+Linting + +To lint your code: +``` + make lint +``` + +We use Pylint to checks your Python code for errors, coding standards, code convention and refactoring suggestions. + +Pylint emits [messages](https://pylint.pycqa.org/en/latest/user_guide/messages/index.html) that provides explanations of the failed checks. + +You should fix all message in the following order: +1. Fix each message provided. Select a message [description](https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html#messages-overview) to fix a message. +2. Disable a message (i.e: unbalanced-tuple-unpacking) caused by a particular line of code: + ```python + a, b = ... # pylint: disable=unbalanced-tuple-unpacking + ``` + Please see [here](https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html#block-disables) for the progma syntax. + +3. Disable a checker globally. Please extend the `disable=` list in the [pylintrc](.pylintrc) file. + > Note: Disable checkers only if there is good reason. +
+ +
+Formatting + +To format your code: +``` + make fmt +``` +We use [black](https://github.com/psf/black) formatter to format the code. + +You could optionally install the git pre-commit hooks if you would like to format the code automatically for each commit: +``` +brew install pre-commit +pre-commit install +``` +
+ +
+Unit tests + +To run unit tests: +``` + make test +``` +Running unit tests ensures your contributions do not break exiting code. +We use [pytest](https://docs.pytest.org/) framework to run unit tests. The framework is setup to run all run all test_*.py or *_test.py in the [tests](./tests) directory. + +> Optionally, run `make all` command to do formatting, linting, and testing at once. +
+ +
+Build wheel + +To build a wheel file: +```shell +tox -e build +``` +Running the command will create a single ZIP-format archive containing the library source code with the .whl extension in the `dist/` directory. + +
+ +## Your First Code Contribution + +Unsure where to begin contributing? You can start by looking through these issues: + +- Issues with the [`good first issue` label](https://github.com/foundation-model-stack/fms-hf-tuning/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) - these should only require a few lines of code and are good targets if you're just starting contributing. +- Issues with the [`help wanted` label](https://github.com/foundation-model-stack/fms-hf-tuning/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) - these range from simple to more complex, but are generally things we want but can't get to in a short time frame. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..b34c922c6 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +# Run all +.PHONY: all +all: fmt lint test + +# Run unit tests +.PHONY: test +test: + tox -e py + +# Format python code +.PHONY: fmt +fmt: + tox -e fmt + +# Run pylint to check code +.PHONY: lint +lint: + tox -e lint diff --git a/README.md b/README.md index b15b69815..887d61502 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,20 @@ This repo provides basic tuning scripts with support for specific models. The re ## Installation ``` -pip install -r requirements.txt -pip install -U datasets pip install -e . ``` -> Note: If you wish to use [FlashAttention](https://github.com/Dao-AILab/flash-attention), then you need to install these requirements: `pip install -r flashattn_requirements`. [FlashAttention](https://github.com/Dao-AILab/flash-attention) requires the [CUDA Toolit](https://developer.nvidia.com/cuda-toolkit) to be pre-installed. +> Note: After installing, if you wish to use [FlashAttention](https://github.com/Dao-AILab/flash-attention), then you need to install these requirements: +``` +pip install -e ".[dev]" +pip install -e ".[flash-attn]" +``` +[FlashAttention](https://github.com/Dao-AILab/flash-attention) requires the [CUDA Toolit](https://developer.nvidia.com/cuda-toolkit) to be pre-installed. + +If you wish to use [aim](https://github.com/aimhubio/aim), then you need to install it: +``` +pip install -e ".[aim]" +``` ## Data format The data format expectation is a single column text. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. @@ -60,9 +68,13 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have # if you want to use one GPU on multi-gpu machine export CUDA_VISIBLE_DEVICES=0 +# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint +# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset +# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved + python tuning/sft_trainer.py \ --model_name_or_path $MODEL_PATH \ ---data_path $DATA_PATH \ +--training_data_path $TRAIN_DATA_PATH \ --output_dir $OUTPUT_PATH \ --num_train_epochs 5 \ --per_device_train_batch_size 4 \ @@ -83,15 +95,31 @@ python tuning/sft_trainer.py \ ``` ### Multiple GPUs with FSDP + +The recommendation is to use [huggingface accelerate](https://huggingface.co/docs/accelerate/en/index) to launch multi-gpu jobs, in particular when using FSDP: +- `accelerate` is written on top of [`torch.distributed.run`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py). +- `accelerate launch` CLI highly similar to `torchrun`, spawns multiple jobs (one for each gpu). +- tightly integrated with [huggingface Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py). + +`accelerate launch` CLI to be run with specific command line arguments, see example below. Default arguments handled by passing in a +`--config_file` argument; see [reference docs](https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch) and [fixtures/accelerate_fsdp_defaults.yaml](./fixtures/accelerate_fsdp_defaults.yaml) for sample defaults. + ```bash -torchrun \ ---nnodes=1 \ ---nproc_per_node=8 \ ---master_port=1234 \ +# Please set the environment variables: +# MASTER_PORT=1234 # The port at which the process with rank 0 listens to and should be set to an unused port +# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint +# TRAIN_DATA_PATH=twitter_complaints.json # Path to the training dataset +# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved + +accelerate launch \ +--main_process_port $MASTER_PORT \ +--config_file fixtures/accelerate_fsdp_defaults.yaml \ +--num_processes=8 \ +--main_process_port=$MASTER_PORT \ tuning/sft_trainer.py \ --model_name_or_path $MODEL_PATH \ ---data_path $DATA_PATH \ ---bf16 True \ +--training_data_path $TRAIN_DATA_PATH \ +--torch_dtype bfloat16 \ --output_dir $OUTPUT_PATH \ --num_train_epochs 5 \ --per_device_train_batch_size 4 \ @@ -104,16 +132,199 @@ tuning/sft_trainer.py \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ ---fsdp "full_shard auto_wrap" \ ---fsdp_config tuning/config/fsdp_config.json \ --include_tokens_per_second \ --packing False \ --response_template "\n### Response:" \ --dataset_text_field "output" ``` +To summarize you can pick either python for singleGPU jobs or use accelerate launch for multiGPU jobs. The following tuning techniques can be applied: + +## Tuning Techniques : + +### LoRA Tuning Example + +Set peft_method = "lora". You can additionally pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21). +```bash +# Args you can pass +r: int =8 +lora_alpha: int = 32 +target_modules: List[str] = field( + default_factory=lambda: ["q_proj", "v_proj"], + metadata={ + "help": "The names of the modules to apply LORA to. LORA selects modules which either \ + completely match or " + 'end with one of the strings. If the value is ["all-linear"], \ + then LORA selects all linear and Conv1D ' + "modules except for the output layer." + }, + ) + bias = "none" + lora_dropout: float = 0.05 + +``` +Example command to run: + +```bash +python tuning/sft_trainer.py \ +--model_name_or_path $MODEL_PATH \ +--training_data_path $TRAIN_DATA_PATH \ +--output_dir $OUTPUT_PATH \ +--num_train_epochs 40 \ +--per_device_train_batch_size 4 \ +--per_device_eval_batch_size 4 \ +--gradient_accumulation_steps 4 \ +--save_strategy "epoch" \ +--learning_rate 1e-4 \ +--weight_decay 0. \ +--warmup_ratio 0.03 \ +--lr_scheduler_type "cosine" \ +--logging_steps 1 \ +--include_tokens_per_second \ +--packing False \ +--response_template "\n### Label:" \ +--dataset_text_field "output" \ +--use_flash_attn False \ +--tokenizer_name_or_path $MODEL_PATH \ +--torch_dtype float32 \ +--peft_method "lora" \ +--logging_strategy "epoch" \ +--r 8 \ +--lora_dropout 0.05 \ +--lora_alpha 16 +``` + +Notice the `target_modules` that are set are the default values. `target_modules` are the names of the modules to apply the adapter to. If this is specified, only the modules with the specified names will be replaced. When passing a list of strings, either an exact match will be performed or it is checked if the name of the module ends with any of the passed strings. If this is specified as `all-linear`, then all linear/Conv1D modules are chosen, excluding the output layer. If this is not specified, modules will be chosen according to the model architecture. If the architecture is not known, an error will be raised — in this case, you should specify the target modules manually. See [HuggingFace docs](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for more details. + +For each model, the `target_modules` will depend on the type of model architecture. You can specify linear or attention layers to `target_modules`. To obtain list of `target_modules` for a model: + +```py +from transformers import AutoModelForCausalLM +# load the model +model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) +# see the module list +model.modules + +# to get just linear layers +import re +model_modules = str(model.modules) +pattern = r'\((\w+)\): Linear' +linear_layer_names = re.findall(pattern, model_modules) + +names = [] +for name in linear_layer_names: + names.append(name) +target_modules = list(set(names)) +``` + +For example for LLaMA model the modules look like: +``` + +``` + +You can specify attention or linear layers. With the CLI, you can specify layers with `--target_modules "q_proj" "v_proj" "k_proj" "o_proj"` or `--target_modules "all-linear"`. -For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `tuning/config/fsdp_config.json` for proper sharding of the model. +### Prompt Tuning : + +Specify peft_method to 'pt' . You can additionally pass any arguments from [PromptTuningConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L39). +```bash + # prompt_tuning_init can be either "TEXT" or "RANDOM" + prompt_tuning_init: str = "TEXT" + num_virtual_tokens: int = 8 + # prompt_tuning_init_text only applicable if prompt_tuning_init= "TEXT" + prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:" + tokenizer_name_or_path: str = "llama-7b-hf" +``` + +Example command you can run: + +```bash + +accelerate launch \ +--main_process_port $MASTER_PORT \ +--config_file fixtures/accelerate_fsdp_defaults.yaml \ +tuning/sft_trainer.py \ +--model_name_or_path $MODEL_PATH \ +--training_data_path $TRAIN_DATA_PATH \ +--output_dir $OUTPUT_PATH \ +--peft_method pt \ +--torch_dtype bfloat16 \ +--tokenizer_name_or_path $MODEL_PATH \ +--num_train_epochs 5 \ +--per_device_train_batch_size 1 \ +--per_device_eval_batch_size 1 \ +--gradient_accumulation_steps 1 \ +--evaluation_strategy "no" \ +--save_strategy "epoch" \ +--learning_rate 1e-5 \ +--weight_decay 0. \ +--warmup_ratio 0.03 \ +--lr_scheduler_type "cosine" \ +--logging_steps 1 \ +--include_tokens_per_second \ +--packing False \ +--response_template "\n### Label:" \ +--dataset_text_field "output" +``` + +### Fine Tuning : + +Set peft_method = 'None' + +Full fine tuning needs more compute resources, so it is advised to use the MultiGPU method +```bash + +accelerate launch \ +--main_process_port $MASTER_PORT \ +--config_file fixtures/accelerate_fsdp_defaults.yaml \ +tuning/sft_trainer.py \ +--model_name_or_path $MODEL_PATH \ +--training_data_path $TRAIN_DATA_PATH \ +--output_dir $OUTPUT_PATH \ +--peft_method "None" \ +--torch_dtype bfloat16 \ +--tokenizer_name_or_path $MODEL_PATH \ +--num_train_epochs 5 \ +--per_device_train_batch_size 1 \ +--per_device_eval_batch_size 1 \ +--gradient_accumulation_steps 1 \ +--evaluation_strategy "no" \ +--save_strategy "epoch" \ +--learning_rate 1e-5 \ +--weight_decay 0. \ +--warmup_ratio 0.03 \ +--lr_scheduler_type "cosine" \ +--logging_steps 1 \ +--include_tokens_per_second \ +--packing False \ +--response_template "\n### Label:" \ +--dataset_text_field "output" +``` ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. diff --git a/architecture_records/001-trainer-controller-framework.md b/architecture_records/001-trainer-controller-framework.md new file mode 100644 index 000000000..1bf79d67f --- /dev/null +++ b/architecture_records/001-trainer-controller-framework.md @@ -0,0 +1,171 @@ +# Trainer Controller Framework + +**Deciders(s)**: Alexander Brooks (alex.brooks@ibm.com), Sukriti Sharma (sukriti.sharma4@ibm.com), Raghu Ganti (rganti@us.ibm.com), Padmanabha Venkatagiri Seshadri (seshapad@in.ibm.com), Dushyant Behl (dushyantbehl@in.ibm.com) +**Date (YYYY-MM-DD)**: 2024-03-05 +**Obsoletes ADRs**: NA +**Modified By ADRs**: NA +**Relevant Issues**: [537](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/537), [323](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/323) + +- [Summary and Objective](#summary-and-objective) + - [Motivation](#motivation) + - [User Benefit](#user-benefit) +- [Decision](#decision) + - [Alternatives Considered](#alternatives-considered) +- [Consequences](#consequences) +- [Detailed Design](#detailed-design) + +## Summary and Objective + +To create a framework for controlling the trainer loop using user-defined rules and metrics. + +### Motivation + +- The issue [537](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/537), had raised the need for stopping an ongoing training if some stopping criteria is satisfied (E.g loss validation reaching a certain target, loss increasing with epoch, loss values for last 100 steps increasing etc). +- There is a [EarlyStoppingCallback](https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/trainer_callback.py#L543) in HF, but the granularity of stopping is only on `evaluate` events, and handles only compares instantaneous metric value to a threshold. +- Therefore, there is a need for a mechanism to capture the user-defined custom stopping criteria which could involve multiple metrics. +- In addition to user-defined stopping criteria, there could other types of control operations with respect to training (for instance, should the trainer perform saving, logging or evaluation operations or not, should we scale resources dynamically so that training could run faster and so on). Therefore, there is a need for general need to capture all these use-cases in a single framework. This PR attempts to provide such a framework. + +### User Benefit + +Users could control the training loop by defining custom rules. This will benefit the user in terms of conserving resources by killing training jobs with run-away loss, help in collecting debugging data (log-on-demand), checkpoint-on-demand, and similar scenarios where intervention is required with respect to training loop. + +## Decision + +### Alternatives Considered + +We considered some of the following alternatives: +- Defining the metrics as functions instead of classes. This was dropped because it was less expressive and did not preserve state of the computation (as in the case of windowing mechanisms), which could be updated more efficiently using the evolving logs. +- Making modification to the trainer loop directly instead of using callbacks. This was dropped because it will require custom huggingface stack with the said modifications. Further, it will make the framework tightly coupled. + +## Consequences + +Following are the advantages and limitations of our design approach: + +### Advantages: +- We have used the trainer callback approach which can used in `plug-and-play` form with the trainer. In addition, we have designed this framework to be a independent packaged +- The rules and metrics are flexible and can be defined by user easily with limited coding effort (only in the case of custom metrics), and no coding effort if the user is using existing metrics. + +### Impact on performance: +Could add to the overhead of the trainer loop as the callback is invoked at various events of the trainer loop and what is computed within the callback could affect the performance of the loop iteration. + + +## Detailed Design + +### High-level architecture +Following is a high-level design diagram. Following are the touch-points to the framework through which user of this framework could interact with it: +- **Registration**: The registration mechanism de-couples the metrics and operators from the trainer framework. A user could implement a custom metric or operator and register it through the registration methods mentioned previously. This makes the framework highly extensible. + +- **Configuration**: The trainer controller configuration supplies the definition for triggers, rule, operations and metrics to orchestrate the enactment of a particular control policy. These details are split up and passed off to the respective modules by the trainer controller as shown in the figure. + +- **Events**: Events supply the state and arguments required for the metric handlers to perform metric computation at the events they are registered for. The framework callback lists out all event handlers with prefix `"on_"` and loads then as event handlers. Every metric declares one or more events from this list of valid handlers. These computed metric variables are stored in a global state of the trainer controller and independently picked up the operations which could potentially be triggered on an entirely different set of events. This decouples the control loop for metrics and operations. I.e. the metric could be computed on event A, while operation could be triggered on event B. The controller rules which use the metric variables from the trainer controller state are evaluated and based on the outcomes specified actions are performed. +![High-Level Design Diagram: Trainer Controller Framework](imgs/001-arch.png) + +### Usage and customization +We have implemented a trainer callback (see [here](https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/callback)) which accepts a `training control definition` file (in YAML format) which facilitates the definition of: +1. Rules to control training loop +2. Trigger points that evaluate the above rules +3. Control operation and action that needs to be performed if rule is evaluated to true. + +The trainer controller configuration is structured as shown below. There are list of metric definitions under `controller-metrics`, a list of operations and their actions under `operations` and a list of controllers, each of which define the rules, triggers and control operations. +```yaml +controller-metrics: + - name: + class: + arguments: + : + : + ... +operations: + - name: + class: + arguments: + : + : + ... +controllers: + - name: + triggers: + - + ... + rule: + operations: + - + ... +``` +The `controller-metrics` and `operations` are optional. We provide a set of built-in `controller-metrics` and `operations` which could be referred to without actually defining them as. For example, the below configuration defines a `controller-metric` called `loss` which refers to a built-in `Loss` controller-metric class with custom arguments (in this case, no arguments. If arguments are required, then they could be listed under a `arguments` section as shown above), but does not define any `operations`. It only refers to a built-in operation. +```yaml +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller + triggers: + - on_log + rule: "loss < 1.0" + operations: + - hfcontrols.should_training_stop +``` + +We follow the below naming convention for the above trainer controller configuration: +1. `-` could be used in the case of key names, and name of the metric, operation or controller. This is usually to break multiple words of a name phrase. +1. Python convention for [class name](https://visualgit.readthedocs.io/en/latest/pages/naming_convention.html#classes). +1. `_` are used for events and control actions. + +For defining custom handler classes, we have an interface defined as an abstract class as shown below, with two abstract methods, namely: `validate()` to define the validation conditions, and `compute()` to compute the metric. The `compute()` returns an `Any` type. While it could be any value, developers should keep in mind that it should be only key-value pairs that are used in the rule(s) defined in the configuration. + +Further, the `init` method of the class should accept variable arguments in the form of key-value pairs. `Important point to note is that keys used in the arguments of the above config should not conflict with any keys used by Hugging face trainer callback. Please try to use unique keys are arguments name`. + ```python + class MetricHandler(metaclass=abc.ABCMeta): + @abc.abstractmethod + def validate(self) -> bool: + pass + + @abc.abstractmethod + def compute(self, event_name: str, **kwargs) -> Any: + pass + ``` +These classes can be user-defined. To add a new metric class, simply implement the above structure and register it with the trainer controller framework using the `register_metric_handlers()` method. To use the metric handler class, add the class name, arguments to the above configuration file. + +Similarly, there is an operator abstract class `Operation` which could be inherited and custom operations could be defined as illustrated below: +```python +class CustomOperation(Operation): + def should_perform_action_xyz(args): + pass +``` +Every action defined in the custom operation should be represented as a function with `"should_"` prefixed in the function name. The controller will automatically pickup these functions and invoke them if they are referred to in the configuration. Custom operations could be registered using `register_operation_handlers()` method. + +`rule` is a Python expression which could express a condition to evaluate on a metric variable. For example, in the above configuration, `loss` is the variable, and the rule is applying a threshold on it. The details of what rules are supported is given in the section below. + +`operations` lists the operation-actions to be performed when the rule evaluates to True. The convention followed to refer to an operation is `.`. In this example, the `` is referring to built-in operation `hfcontrols` and one of its corresponding action `action-name` i.e `should_training_stop`. + +### Controller Rule Evaluation +1. We use the Python AST library https://docs.python.org/3/library/ast.html to parse the code. +2. We use the SimpleEval library https://github.com/danthedeckie/simpleeval for traversing and evaluating the generated AST. +3. The output of evaluation should be a boolean value which we return indicating if the rule succeeded. + +Example supported expressions: https://github.com/danthedeckie/simpleeval?tab=readme-ov-file#operators + +Support for more complex types (list, dict, etc.) are also implemented: +- https://github.com/danthedeckie/simpleeval?tab=readme-ov-file#compound-types +- https://github.com/danthedeckie/simpleeval/blob/2a12b5856d6f70b78dc1ac38840c80c1be6c6c4e/simpleeval.py#L638-L642 + +Notes: +- PyPi https://pypi.org/project/simpleeval/ (MIT Licence) +- Array Multiplication `["hello"]*10` is allowed but limited to smaller numbers to avoid Denial of Service (DOS). +- List comprehension is allowed but limited to small number of iterations to avoid DOS. +- Exponentiation `9**9` is allowed but limited to smaller numbers to avoid DOS. +- Access to double underscore attributes like `__class__` is disallowed to avoid access to arbitrary classes like `Quitter`. +- Lambda expressions are disallowed for simplicity and to reduce the attack surface. Can be added back in if necessary. +- Access to builtin functions and globals are disallowed expect for `abs`, `float`, `int`, `len`, `rand`, `randint`, `str` and `sqrt`. +- Has support for accessing keys in dicts using the short syntax:`foo.bar` . Example dict: `{"foo": {"bar": 42}})` +- Access is disallowed to private members/attributes like `aaa._foo` and to those with the prefix `func_`. + +**Big Numbers:** +Certain operations like `**` and `<<` and `>>` can easily result in very large numbers, especially when used repeatedly. +Example: Something like `9**9**9**9` leads to the Python runtime getting stuck for several minutes. +To mitigate DOS, the inputs to these operations are limited to a certain threshold: +- When such a rule is executed, we raise a `NumberTooHigh` exception if the (absolute value of the) number exceeds the limit: +https://github.com/danthedeckie/simpleeval/blob/2a12b5856d6f70b78dc1ac38840c80c1be6c6c4e/simpleeval.py#L204C7-L208 +https://github.com/danthedeckie/simpleeval/blob/2a12b5856d6f70b78dc1ac38840c80c1be6c6c4e/simpleeval.py#L242-L243 +- The limit is currently `4000000` but can be easily set using the evaluator class. +https://github.com/danthedeckie/simpleeval/blob/2a12b5856d6f70b78dc1ac38840c80c1be6c6c4e/simpleeval.py#L114 diff --git a/architecture_records/002-acceleration-framework.md b/architecture_records/002-acceleration-framework.md new file mode 100644 index 000000000..1d086dfa1 --- /dev/null +++ b/architecture_records/002-acceleration-framework.md @@ -0,0 +1,484 @@ +# Training Enhancements Framework + +**Deciders(s)**: Sukriti Sharma (sukriti.sharma4@ibm.com), Raghu Ganti (rganti@us.ibm.com), Laura Wynter (lwynter@sg.ibm.com), Fabian Lim (flim@sg.ibm.com), Aaron Chew (aaron.chew1@ibm.com) +**Date (YYYY-MM-DD)**: 2024-04-11 +**Obsoletes ADRs**: NA +**Modified By ADRs**: NA +**Relevant Issues**: [116](https://github.com/foundation-model-stack/fms-hf-tuning/pull/116) + +- [Summary and Objective](#summary-and-objective) + - [Motivation](#motivation) + - [User Benefit](#user-benefit) +- [Decision](#decision) + - [Alternatives Considered](#alternatives-considered) +- [Consequences](#consequences) +- [Detailed Design](#detailed-design) + +## Summary and Objective + +Design and implement a framework to include custom acceleration tools into `sft_trainer.py`, that improve training metrics such as GPU memory consumption, training speed, etc. + + + +### Motivation + +Recently, it has been observed that new training techniques are released with an incomplete "preview" version. These "preview" versions tend to be not be fully integrated into OSS. Therefore, using new techniques typically involve additional work. This framework aims to allow timely integrations of such techniques into `sft_trainer.py`, to enable: +* developers to integrate open-source training improvements into `sft_trainer.py`. +* researchers to implement custom training improvements into `sft_trainer.py`. +* users to easily pick and choose which said training improvements to enable when using `sft_trainer.py`. + +The framework that we propose must be extensible. We propose 3 strong candidates for to be implemented in the near future, but over time new improvements will be available (from open source or from internal development). When a new improvement is added, it should be done in a manner that is minimally intrusive to `sft_trainer.py`. + +#### Training Improvements planned for Integration into Framework + +The following 3 techniques are currently under strong consideration to be added. In what follows, we explain clearly why these techniques are currently available out-of-the-box from huggingface `SFTTrainer`, to motivate why they need to be added as improvements: +- [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ). + * AutoGPTQ is available only via basic integration through [huggingface optimum](https://github.com/huggingface/optimum). + * AutoGPTQ provides state-of-the-art, 4-bit quantized PEFT-LoRA that greatly reduces memory requirements base weights. + * Unfortunately, huggingface integrated [GPTQ kernels that do not work in training](https://github.com/AutoGPTQ/AutoGPTQ/issues/633). + * Therefore, a training improvement is planned to properly integrate the [latest triton V2 kernels](https://github.com/AutoGPTQ/AutoGPTQ/pull/596) that can be used for sped-up PEFT training. +- [Unsloth](https://github.com/unslothai/unsloth). + * Unsloth is a collection of kernels and fused operations that improve PEFT-LoRA. + * Unfortunately, unsloth's codebase contains also a lot of reatures (e.g., fused attention, gradient checkpointing) that we do not want integrated at this moment. + * Thefore, a training improvement is planned to incorporate a clean integration of unsloth, that extracts out only the critical code pieces. +- [megablocks](https://github.com/databricks/megablocks). + * Megablocks is a collection of distributed training methods to speed up mixture-of-experts training. + * Megablocks is procured by [databricks](https://github.com/databricks/megablocks), and there is no indication it will be integrated into `SFTTrainer`. + * Therefore, a training improvement is under strong consideration, to be added to allow a model with a mixture-of-experts layer, to be sped-up using megablocks techniques. + + + +### User Benefit + +Users will benefit from powerful training tools integrated into the platform, that are not readily accessible from huggingface. With these tools, users will be able to train models with less GPU resources and/or quicker, resulting in quicker turnaround and improved user experience. + + + +## Decision + +Terminology | Description +--|-- +Maintainer | Developers of `sft_trainer.py`. +Framework | an extensible `Framework` managing all implemented methods. +Framework Plugin | Self-contained implementation of a framework method. + +The proposal satisfies the following desiredata: +- Unified configuration YAML for all plugins. Fine configuration details abstracted away from Maintainers and other plugin developers. +- Modular design allows new methods plugins to be added / removed / deactivated seemlessly. +- Modular design enforces that plugins interact with `sft_trainer.py` at controlled points, and throw appropriate exceptions. +- Generic enough for most use cases of interest (e.g., quantization, distributed training, etc). +- Unobstrusive design that only *modifies the model*, and leaves `SFTTrainer` unmodified. Minimal inversion-of-control maintained through `TrainerCallbacks`. + +### Only the Model is Modified + +The `Trainer` is designed to work with generic pytorch models; `trl.SFTTrainer` inherits from `Trainer` and has sligthly more constraints (such as throwing errors if `tokenizer=None`), but are still bare minimum. With this, we claim that modifying the model is much less intrusive to the training pipeline, then say, modifying `SFTTrainer` itself. The hope is then if we constrain ourselves to modify only the model, that we can implement all the method plugins (e.g., quantization, distributed training, etc) that we hope for. + +The framework is designed to only modify them model at two integration points in `sft_trainer.py`. The primary motivation for this is easy code maintenance: +1. an *optional* `model_loader` method that acts as a drop-in replacement for `AutoModel.from_pretrained`. +2. an *optional* `agumentation` method that provides a way to perform *minor* adjustments to an already instantiated model +3. an *optional* `callback` method to install `TrainerCallbacks` (if needed, e.g. custom save logic). + +In what follows, we provide: +- a description of an abstract base class that all plugins must inherit and conform to. +- a brief description of the framework class that is responsible for managing plugins (e.g., loading, executing). + +NOTE: We want to note that the implementation of frameworks and plugins may be moved to a separate open source repository . fms-hf-tuning will then include the framework library as an optional dependency and call respective loaders and augmentation techniques as specified. + +#### AccelerationPlugin Base Class + +Implement concrete plugins that inherit below abstract `AccelerationPlugin` class. +* See also [concrete plugin implementation that loads quantized model (using AutoGPTQ) for LoRA training](#detailed-design). +* Even though all 3 methods are optional, at least one should be implemented. + +```python + +# data class to hold data and pointer to registered plugins +@dataclass +class PluginRegistration: + plugin: "AccelerationPlugin" + configuration_paths: List[str] # path + +# global object to store all registered plugins +PLUGIN_REGISTRATIONS: List[PluginRegistration] = list() + +# this is a base class from which concrete implementations will inherit from +class AccelerationPlugin: + + @staticmethod + def register_plugin( + plugin: "AccelerationPlugin", configuration_paths: List[str], + **kwargs, + ): + global PLUGIN_REGISTRATIONS + PLUGIN_REGISTRATIONS.append( + PluginRegistration(plugin, configuration_paths) + ) + + # if specified, will restricted plugin to specified model archs + # - useful if method is restricted to certain model architectures, e.g., only used for MoEs + restricted_model_archs: Set = None + + # if specified, will check if the package/s is/are installed + required_packages: Set = None + + @property + def requires_custom_loading(self): + return False # to return True if plugin requires custom model + + @property + def requires_agumentation(self): + return False # to return True if plugin requires model augmentation + + def model_loader(model_path: str, **kwargs): + pass # to be replaced with concrete + + # augment model or accelerator object + def augmentation(model: nn.Module, **kwargs): + pass + + def callbacks(model: nn.Module, **kwargs): + return [] +``` + + +#### Implmentation of Acceleration Framework Class + +The role of `AccelerationFramework` is to manage implemented plugins. In particular: +- parse `configuration_file`, see below, and based on contents, decide the `AccelerationPlugin`'s to promote to `active_plugins`. See [implementaiton of AutoGPTQ LoRA Plugin](#detailed-design) for more description on configuration logic. +- handle *plugin stacking*, i.e., when we have more than one `active_plugins`, apply their `model_loader` / `augmentation` logic in appropriate succession. + * This is very useful, e.g., loading a LoRA-trainable AutoGPTQ model first, then applying additional optimized fused kernels to further improve training speeds. +- enforce that `AccelerationFramework.model_loader` witll call `AccelerationPlugin.model_loader` of some `active_plugins` *at most once*. + * Prevents potential complication as multiple `active_plugins` could load models in conflicting ways. +- enforce that `AccelerationFramework.augmentation` would apply `AccelerationPlugin.augmentation` in appropriate succession. + * c.f. previous example, where first a LoRA-trainable AutoGPTQ model is loaded, then fused_kernels on applied on top. + +```python +class AccelerationFramework: + + active_plugins: Dict[str, AccelerationPlugin] = dict() + plugins_require_custom_loading: List = list() + + def __init__(self, configuration_file: Optional[str]=None): + + with open(configuration_file, "r") as f: + contents = yaml.safe_load(f) + + # pepare the plugin configurations + plugin_configs = { k:v for k,v in contents[KEY_PLUGINS].items() } + + for selected_configs, cls in get_relevant_configuration_sections(plugin_configs): + + # then the model is to be installed + # get the plugin + plugin_name = str(cls.__name__) + plugin = cls(selected_configs) + + # check plugin (this is a function that checks if the package requirements of plugin are met) + check_plugin_packages(plugin) + + # install plugin + self.active_plugins[plugin_name] = plugin + if plugin.requires_custom_loading: + self.plugins_require_custom_loading.append(plugin_name) + + if len(self.active_plugins) == 0: + raise ValueError( + "No plugins could be configured. Please check the acceleration " + "framework configuration file." + ) + + assert len(self.plugins_require_custom_loading) <= 1, \ + f"can load at most 1 plugin with custom model loading, but tried to \'{self.plugins_require_custom_loading}\'." + + def model_loader(self, model_name: str, **kwargs): + + if len(self.plugins_require_custom_loading) == 0: + raise NotImplementedError( + f"Attempted modeling loading, but none of activated plugins \'{list(self.active_plugins.keys())}\' " + "require custom loading." + ) + + # otherwise there should be exactly 1 + plugin_name = self.plugins_require_custom_loading[0] + return self.active_plugins[plugin_name].model_loader(model_name, **kwargs) + + def augmentation( + self, + model: PreTrainedModel, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + model_archs = set(model.config.architectures) # get the config + + # NOTE: this assumes that augmentation order does not matter + for plugin_name, plugin in self.active_plugins.items(): + + # check the model arcs at augmentation + if ( + plugin.restricted_model_archs and + not any([x in model_archs for x in plugin.restricted_model_archs]) + ): + raise ValueError( + f'Model architectures in \'{model_archs}\' are supported for \'{plugin_name}\'.' + ) + + if plugin.requires_agumentation: + model, modifiable_args = plugin.augmentation( + model, train_args, modifiable_args=modifiable_args + ) + + return model, modifiable_args + + @property + def requires_custom_loading(self): + return len(self.plugins_require_custom_loading) > 0 + + @property + def requires_agumentation(self): + return any([x.requires_agumentation for x in self.active_plugins.values()]) +``` + +### Dependency Management + +Take note: +- all plugin deps must be enforced to be optional deps in `pyproject.toml`, see [116](https://github.com/foundation-model-stack/fms-hf-tuning/pull/116). If the dep is not installed, and the plugin is enabled, raise exception. +- any plugin that requires CUDA build tools (e.g. `triton` kernels) will need to be run in with [CUDA Toolkit dependencies (see this link for an example of a Debian installation)](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Debian&target_version=11&target_type=deb_local). + * whenever CUDA is needed, the framework will check for the CUDA_TOOLS dependency. + +### Minimal and Controlled Changes to Training Script + +Next, we demonstrate how `AccelerationFramework` would be integrated into `sft_trainer.py` with minimal changes: +- `sft_trainer.py` would take in arguments meant for framework, say via `AccelerationFrameworkArguments`. +- `AccelerationFramework` constructed only if `AccelerationFrameworkArguments.acceleration_framework_config_file` is specified. Null pattern otherwise. +- Plugins loading handled inside `AccelerationFramework`, see [above](#implmentation-of-acceleration-framework-class); transparent to `sft_trainer.py`. +- Fallback to standard logic logic if `AccelerationFrameworkArguments.acceleration_framework_config_file` is `None`. + +```python +from tuning.acceleration import AccelerationFramework + +def train( + ..., acceleration_framework_args: Optional[configs.AccelerationFrameworkArguments] = None, +): + + # Minor Change 1: creating the framework object + framework = None + if acceleration_framework_args.acceleration_framework_config_file is not None: + framework = AccelerationFramework(acceleration_framework_args.acceleration_framework_config_file) + + # Minor Change 2: custom loader (if necessary) + _model_loader = AutoModelForCausalLM.from_pretrained # default + if framework is not None and framework.requires_custom_loading: + _model_loader = framework.model_loader # drop in replacement + + # will passthrough the default loader if framework is disabled + model = _model_loader( + model_args.model_name_or_path, + cache_dir=train_args.cache_dir, + torch_dtype=get_torch_dtype(model_args.torch_dtype), + attn_implementation="flash_attention_2" if model_args.use_flash_attn else None, + ) + + # Minor Change 3: + if framework is not None and framework.requires_agumentation: + # will also take in some other configs that may affect augmentation + # some of these args may be modified due to the augmentation + # e.g., peft_config will be consumed in augmentation, and returned as None + # to prevent SFTTrainer from doing extraneous PEFT logic + model, (peft_config,) = framework.augmentation( + model, + train_args, modifiable_args=(peft_config,), + ) + + # instantiate trainer. Pass in model (with training enchancements) + trainer = Trainer(model, ...) + + # Minor Change 4: add trainer callbacsk + for x in framework.callbacks(): + trainer.add_callback(x) + + # call train + trainer.train() +``` + +The picture below summarizes the above discussion in more detail. It demonstrates how the design will not contradict internal workings of `SFTTrainer`. +- Model is modified and then control passed to `SFTTrainer`. +- `SFTTrainer` also performs model augmentation internally (e.g., it installs PEFT adapters if `peft_config` is passed in). + * However, `SFTTrainer`'s model augmentation should be passed through if configs are omitted (e.g., if `peft_config = None`). +- `SFTTrainer` will prepare model for distributed training (e.g. wrap with `FSDP`) internally. + * thus Plugin implementers need to be aware that `TuningAccelerationPlugin.augmentation` should not interfere with any model preperation that `SFTTrainer` will perform. + +![Framework](imgs/002-framework.png) + +### Acceleration Methods + +A top priority is to incorporate methods that enchance PEFT. While PEFT is known to be memory efficient, but certain scenarios it is has been shown to converge more slowly than full finetuning, e.g., see this [ICLR paper, Fig. 1](https://arxiv.org/pdf/2304.14999.pdf). +Also, another topic of interest is to add support for 4D masks to enable packing while instruction tuning; this acceleration may require some adjustments to the data processing. +1. Add 4-bit `triton` kernels for PEFT base weights. +2. Add fused kernels for PEFT base models, as well as reusable kernels for other models (e.g. cross-entropy loss, RoPE). +3. Add support for 4D masking (may require `TuningAccelerationPlugin.augmentation` to also access the datasets). +4. Add support for distributed training (i.e., `megablocks`). + + + + +### Alternatives Considered + +We considered the following **alternatives**. + +Consideration | Why it was decided agianst +--|-- +Restrict to only performing `augmentation` and not having custom model `loading` | Some methods (e.g., quantization that has special checkpoints) require special loaders. Furthmore any attempt to modify and instantiated models in unintended manners will be error-prone. Finally for extensibility reasons, we decided that preventing drop-in `loading` replacements will be a severe handicap. +Adding tuning enchancements directly to `SFT_Trainer` | The Huggingface trainer is a very complex, and is not recommended to manipulate it directly. + + + +## Consequences + +We considered the following **concerns**. + +Concern| Reason for concern | Possible Solution/s | Recommendation +--|--|--|-- +Managing python deps not found on PyPI | Enhancement plugins may depend on OSS packages that require custom improvements (e.g., extending an OSS PEFT package to support the latest kernel, etc). | 1. Package can be installed directly from GH, public or private (the latter requires some CI changes to manage deployment keys), 2. Specially host custom wheels for CI/CD purposes. | 2 +Managing CUDA compilations | Deploying certain enchancements may require additional CUDA Toolkit deps for kernel compilation. | 1. Extend GH workflow to have a [GH cuda-toolkit action](https://github.com/marketplace/actions/cuda-toolkit) to build the kernels during CI/DC. 2. If kernels are limited to custom deps that are slow-changing, then pre-build custom deps and store as specially hosted wheels. | 2 +Licences for OSS Packages | Copyright concerns | At best all packages under consideration to be used in enhancements should have permissive licences (i.e. Apache 2.0 / MIT). Special considerations required if not possible. +Testing | Do we need to test enchancements? | Request for comment | N/A + +Both concerns can be addresed with an artifactory and centralized location to host custom OSS packages. +- Hosting the OSS packages in a single GH org for accountability. Can be private hosting if this is something we do not want to release. +- Regular users who want to use the enhancements may not be familar with installing cuda-toolkits and compilation. Preparing compiled wheels +for them will be helpful. +- Compiled kernels are sensitive to python and CUDA versions. Can consult existing packages (e.g., flash-attention) to see how this is managed. + +### On OSS packages requiring custom wheels + +Package | Reason for hosting custom wheel | Urgency +--|--|-- +AutoGPTQ | Required changes in `main` (v > 0.7.1) yet to be released. | Low. Can wait for new wheel release (v > 0.7.1) and replace accordingly (last release 1 Mar 2024). +UnSloth | Limited model support. | High. Unclear if new realeases will address the limited model support. +MegaBlocks | Limited model support | High. Unclear if new realeases will address the limited model support. + +### On Licenses + +Plugins will depend on various OSS packages, have to be be careful about licenses. Our plan: +- keep integration lightweight; extract out key parts if possible (of course with the required credits). +- maintain them ourselves, best in a [monorepo](https://www.tweag.io/blog/2023-04-04-python-monorepo-1/) so that we can manage each OSS as an independent dependency. + + +Package | License | Notes +--|--|-- +AutoGPTQ | [Link to repo's MIT License](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/LICENSE) | +Unsloth | [Link to repo's Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE) | Authors also additionally claim that their Apache 2.0 only supports up to 4 GPUs, where these calims are found [only in the code, outside of the license](https://github.com/unslothai/unsloth/blob/ec18e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1236-L1240). Take note that Unsloth has been [integrated into SFTTrainer](https://huggingface.co/blog/unsloth-trl), so its unclear how they insert such clauses into OSS. +Megablocks | [Link to repo's Apache 2.0](https://github.com/databricks/megablocks/blob/main/LICENSE) | Used to be under Stanford, now under databricks. + + + + + +## Detailed Design + + + + +### Plugin For Loading LoRA-Traininable AutoGPTQ Model + +In this section we demonstrate an `AutoGPTQAccelerationPlugin` that implements accelerated PEFT training using 4 bit GPTQ base weights with `triton_v2` kernels. +* inherits `AccelerationPlugin` as described [in the above description](#accelerationplugin-base-class). +* registers to `peft.quantization.auto_gptq` in configuration file pointed to by `AccelerationFrameworkArguments.acceleration_framework_config_file`. See below [example of acceleration framework configuration file loading `AutoGPTQAccelerationPlugin`](#configuration-to-load-autogptq-lora-plugin) + + +```python +from transformers import TrainingArguments +from peft import LoraConfig, prepare_model_for_kbit_training +from .framework_plugin import AccelerationPlugin # this is the one + +# Acceleration Plugin for AutoGPTQ acceleration with kernels +class AutoGPTQAccelerationPlugin(AccelerationPlugin): + def __init__(self, configurations: Dict[str, Dict]): + # ... perform any initializations from configurations + + def model_loader(self, model_path: str, **kwargs): + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig + + # assume model_name points to a quantized checkpoint. Thus we load the quantization + # config directly from the checkpoint. + quantize_config = BaseQuantizeConfig.from_pretrained(model_name) + + # .. some code + model = AutoGPTQForCausalLM.from_quantized( + model_name, quantize_config=quantize_config, + torch_dtype=torch_dtype, ... + ) + # .. more code and then return the model + return model + + def augmentation( + self, model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + assert peft_config is not None, "need peft_config to install PEFT adapters" + peft_config, = modifiable_args # unpack modifiable args + + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=train_args.gradient_checkpointing, + gradient_checkpointing_kwargs=train_args.gradient_checkpointing_kwargs, + ) + modifiable_args = (None, ) # return a None for peft_config + + # .. some more code ... and also install the PEFT + from auto_gptq.utils.peft_utils import get_gptq_peft_model + model = get_gptq_peft_model(model, peft_config=peft_config, ...) + + # ... some more code ... then return the model and args + return model, modifiable_args + +# plugin registration +AccelerationPlugin.register_plugin( + AutoGPTQAccelerationPlugin, + configuration_paths=["peft.quantization.auto_gptq"], +) +``` + +### Configuration To Load AutoGPTQ LoRA Plugin + +This file pointed to by `AccelerationFrameworkArguments.acceleration_framework_config_file` would looking like the below samle YAML: +- All contents under `plugins` be parsed by `AccelerationFramework.__init__`. + * For any registered plugin, recall [above](#plugin-for-loading-lora-traininable-autogptq-model) that we check `PluginRegistration.configuration_paths` against the contents of the configuration file. + * In this case the path `peft.quantization.auto_gptq` exists, and `AccelerationFramework` instantiates the plugin and stores `active_plugin` + * contents under `peft.quantization.auto_gptq` passed to plugin constructor. + +```yaml +plugins: + + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # AutoGPTQ quantized base weights. + auto_gptq: + kernel: triton_v2 + from_quantized: True +``` diff --git a/architecture_records/003-generic-tracker-framework.md b/architecture_records/003-generic-tracker-framework.md new file mode 100644 index 000000000..f70b2df15 --- /dev/null +++ b/architecture_records/003-generic-tracker-framework.md @@ -0,0 +1,102 @@ +# Resource Scanner + +**Deciders(s)**: Sukriti Sharma (sukriti.sharma4@ibm.com), Alexander Brooks (alex.brooks@ibm.com), Raghu Ganti (rganti@us.ibm.com), Dushyant Behl (dushyantbehl@in.ibm.com), Ashok Pon Kumar (ashokponkumar@in.ibm.com) + +**Date (YYYY-MM-DD)**: 2024-03-06 + +**Obsoletes ADRs**: NA + +**Modified By ADRs**: NA + +**Relevant Issues**: [1](https://github.com/foundation-model-stack/fms-hf-tuning/issues/34), [2](https://github.com/foundation-model-stack/fms-hf-tuning/issues/33) + +- [Summary and Objective](#summary-and-objective) + - [Motivation](#motivation) + - [User Benefit](#user-benefit) +- [Decision](#decision) + - [Alternatives Considered](#alternatives-considered) +- [Consequences](#consequences) +- [Detailed Design](#detailed-design) + +## Summary and Objective + +This PR introduces a generic interface `class Tracker` which implements basic functionality needed to be satisfied by any tracker to be implemented inside `fms-hf-tuning`. +Tracker here means an agent which can track AI and system metrics like [Aimstack](https://aimstack.io/) or [WandB](https://wandb.ai/site). + +### Motivation + +The current code in `fms-hf-tuning` has [Aimstack](https://aimstack.io/) as the sole integration point in the file [aim_loader.py](https://github.com/foundation-model-stack/fms-hf-tuning/blob/74caf85140a112cd9289502b0777baac636adf1d/tuning/aim_loader.py) (taken from the latest head of the tree at the point of proposing this change). + +Users of `fms-hf-tuning` in the current state are forced to used Aimstack and cannot interface with any other tracker like quite popular WandB. If a user would want to add any new tracker to the current code they would need to implement all functionality and change the code in a heavy manner. +To interface with other tracker we need a more Modular interface in our tuning script, further in the current code it is not possible to disable Aimstack and users of this repo have raised concern regarding the same (https://github.com/foundation-model-stack/fms-hf-tuning/pull/20). + +With the new modular inteface we also add support for tracking custom experiment metadata and custom metrics with just one line of change in the code. + +So, due to the limitations of current code and lack of modular structure this ADR introduces modular design to tracking in fms-hf-tuning which enables future support for any tracker. + +### User Benefit + +Users of this updated design will be able to, + +1. Implement and interface with any tracker they want by just implementing 4 functions of an interface. +1. Run the code without any tracker if they do not want to use it. +1. Track any custom metrics or associate any experiment metadata they need with the training runs. + +## Decision + +### Alternatives Considered + +Alternatives to this design is already implemented in the code, we considered expanding that but then for every tracker we would need to introduce a new loader and with no well defined structure or interface to tracking the core tuning script would need to have too much boiler plate and conditional checks implemented to use one or the other tracker. + +## Consequences + +### Advantages + +- Modular design to keep the core training loop clean while working with trackers +- Support for tracking any metrics and metatada +- Simple interface to expand and attach any tracker + +### Impact on performance + +None. The interface does not add any extra overhead to the tracking. If a tracker is not implemented the tracker interface functions are NOOP modules. + +## Detailed Design + +The changes to the code are these, + +``` +class Tracker: + def __init__(self, name=None, tracker_config=None) -> None: + if tracker_config is not None: + self.config = tracker_config + if name is None: + self._name = "None" + else: + self._name = name + + # we use args here to denote any argument. + def get_hf_callback(self): + return None + + def track(self, metric, name, stage): + pass + + # Object passed here is supposed to be a KV object + # for the parameters to be associated with a run + def set_params(self, params, name): + pass +``` + +This interface expects any tracker to implement just 4 basic functions. + +1. `init` to initialise the tracker using the config passed from command line +1. `get_hf_callback` to be called to get the hugging face callback for the specific tracker. +1. `track` to track any custom metrics +1. `set_params` to set any experiment metadata as additional parameters + +In addition, we also. + +1. We also introduce a tracker factory which initializes the available tracker. +1. We remove the file `aim_loader.py` and implement the same code as a `Tracker` in the folder `trackers/aimstack_tracker.py` +1. We also implement the `track` and `set_params` functions for `Aimstack` +1. We change the main tuning script to use `Tracker` inteface instead of directly calling `Aimstack` functions. diff --git a/architecture_records/imgs/001-arch.png b/architecture_records/imgs/001-arch.png new file mode 100644 index 000000000..7fdae0aa9 Binary files /dev/null and b/architecture_records/imgs/001-arch.png differ diff --git a/architecture_records/imgs/002-framework.png b/architecture_records/imgs/002-framework.png new file mode 100644 index 000000000..52ec3071c Binary files /dev/null and b/architecture_records/imgs/002-framework.png differ diff --git a/architecture_records/template.md b/architecture_records/template.md new file mode 100644 index 000000000..802ab82e8 --- /dev/null +++ b/architecture_records/template.md @@ -0,0 +1,51 @@ +# Title of ADR, keep it concise + +**Deciders(s)**: +**Date (YYYY-MM-DD)**: +**Obsoletes ADRs**: +**Modified By ADRs**: +**Relevant Issues**: + +- [Summary and Objective](#summary-and-objective) + - [Motivation](#motivation) + - [User Benefit](#user-benefit) +- [Decision](#decision) + - [Alternatives Considered](#alternatives-considered) +- [Consequences](#consequences) +- [Detailed Design](#detailed-design) + +## Summary and Objective + +Context goes here. + +Describe the forces at play, including technological, political, social, and project local. These forces are likely in tension, and should be called out as such. The language in this section is value-neutral. It is simply describing facts. + +### Motivation + +Why this is a valuable problem to solve? What background information is needed to show how this design addresses the problem? + +Which users are affected by the problem? Why is it a problem? What data supports this? What related work exists? + +### User Benefit + +How will users (or other contributors) benefit from this work? What would be the headline in the release notes or blog post? + +## Decision + +This is the meat of the document, where you explain the decision. If you have multiple alternatives, be sure to use sub-sections for better separation of the idea, and list pros/cons to each approach. If there are alternatives that you have eliminated, you should also list those here, and explain why you believe your chosen approach is superior. + +Make sure you’ve thought through and addressed the following sections. If a section is not relevant to your specific proposal, please explain why, e.g. your ADR addresses a convention or process, not an API. + +### Alternatives Considered + +- Make sure to discuss the relative merits of alternatives to your proposal. + +## Consequences + +Describe the resulting context, after applying the decision. All consequences should be listed here, not just the "positive" ones. A particular decision may have positive, negative, and neutral consequences, but all of them affect the team and project in the future. + + +## Detailed Design + +This section is optional. Elaborate on details if they’re important to understanding the design, but would make it hard to read the proposal section above. + diff --git a/build/Dockerfile b/build/Dockerfile new file mode 100644 index 000000000..c3c4836b4 --- /dev/null +++ b/build/Dockerfile @@ -0,0 +1,159 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM registry.access.redhat.com/ubi9/python-311 as wheel + +ARG WHEEL_VERSION="" +USER root +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip && \ + python -m pip install build +COPY tuning tuning +COPY .git .git +COPY pyproject.toml pyproject.toml +# build wheel if wheel version is empty else download the wheel from PyPi +RUN if [[ -z "${WHEEL_VERSION}" ]]; \ + then python -m build --wheel --outdir /tmp; \ + else pip download fms-hf-tuning==${WHEEL_VERSION} --dest /tmp --only-binary=:all: --no-deps; \ + fi && \ + ls /tmp/*.whl >/tmp/bdist_name + + +FROM registry.access.redhat.com/ubi9/ubi AS release + +ARG CUDA_VERSION=11.8.0 +ARG USER=tuning +ARG USER_UID=1000 +ARG SET_NUM_PROCESSES_TO_NUM_GPUS=True + +USER root + +RUN dnf update -y \ + && dnf remove -y --disableplugin=subscription-manager \ + subscription-manager \ + # we install newer version of requests via pip + python3.11-requests \ + && dnf install -y make \ + # to help with debugging + procps \ + && dnf clean all + +ENV LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + +ENV CUDA_VERSION=$CUDA_VERSION \ + NV_CUDA_LIB_VERSION=11.8.0-1 \ + NVIDIA_VISIBLE_DEVICES=all \ + NVIDIA_DRIVER_CAPABILITIES=compute,utility \ + NV_CUDA_CUDART_VERSION=11.8.89-1 \ + NV_CUDA_COMPAT_VERSION=520.61.05-1 + +RUN dnf config-manager \ + --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ + && dnf install -y \ + cuda-cudart-11-8-${NV_CUDA_CUDART_VERSION} \ + cuda-compat-11-8-${NV_CUDA_COMPAT_VERSION} \ + && echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf \ + && echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf \ + && dnf clean all + +ENV CUDA_HOME="/usr/local/cuda" \ + PATH="/usr/local/nvidia/bin:${CUDA_HOME}/bin:${PATH}" \ + LD_LIBRARY_PATH="/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" + + +ENV NV_NVTX_VERSION=11.8.86-1 \ + NV_LIBNPP_VERSION=11.8.0.86-1 \ + NV_LIBCUBLAS_VERSION=11.11.3.6-1 \ + NV_LIBNCCL_PACKAGE_VERSION=2.15.5-1+cuda11.8 + +RUN dnf config-manager \ + --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ + && dnf install -y \ + cuda-libraries-11-8-${NV_CUDA_LIB_VERSION} \ + cuda-nvtx-11-8-${NV_NVTX_VERSION} \ + libnpp-11-8-${NV_LIBNPP_VERSION} \ + libcublas-11-8-${NV_LIBCUBLAS_VERSION} \ + libnccl-${NV_LIBNCCL_PACKAGE_VERSION} \ + && dnf clean all + +ENV NV_CUDA_CUDART_DEV_VERSION=11.8.89-1 \ + NV_NVML_DEV_VERSION=11.8.86-1 \ + NV_LIBCUBLAS_DEV_VERSION=11.11.3.6-1 \ + NV_LIBNPP_DEV_VERSION=11.8.0.86-1 \ + NV_LIBNCCL_DEV_PACKAGE_VERSION=2.15.5-1+cuda11.8 + +RUN dnf config-manager \ + --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ + && dnf install -y \ + cuda-command-line-tools-11-8-${NV_CUDA_LIB_VERSION} \ + cuda-libraries-devel-11-8-${NV_CUDA_LIB_VERSION} \ + cuda-minimal-build-11-8-${NV_CUDA_LIB_VERSION} \ + cuda-cudart-devel-11-8-${NV_CUDA_CUDART_DEV_VERSION} \ + cuda-nvml-devel-11-8-${NV_NVML_DEV_VERSION} \ + libcublas-devel-11-8-${NV_LIBCUBLAS_DEV_VERSION} \ + libnpp-devel-11-8-${NV_LIBNPP_DEV_VERSION} \ + libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \ + && dnf clean all + +ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs" + +RUN dnf install -y python3.11 git \ + && ln -s /usr/bin/python3.11 /bin/python \ + && python -m ensurepip --upgrade \ + && dnf clean all + +# Removes the example private key to avoid high severity vulnerability warning +RUN rm -f /usr/share/doc/perl-Net-SSLeay/examples/server_key.pem + +WORKDIR /tmp +COPY --from=wheel /tmp/*.whl /tmp/bdist_name /tmp/ +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip && \ + python -m pip install wheel && \ + python -m pip install "$(head bdist_name)" && \ + # Due to FIPS tolerance issues, removing aim at this time + #python -m pip install "$(head bdist_name)[aim]" && \ + python -m pip install "$(head bdist_name)[flash-attn]" && \ + # Clean up the wheel module. It's only needed by flash-attn install + python -m pip uninstall wheel -y && \ + # Cleanup the bdist whl file + rm $(head bdist_name) /tmp/bdist_name + +RUN mkdir -p /licenses +COPY LICENSE /licenses/ + +RUN mkdir /app +# Copy scripts and default configs +COPY build/launch_training.py build/accelerate_launch.py fixtures/accelerate_fsdp_defaults.yaml /app/ +COPY build/utils.py /app/build/ +RUN chmod +x /app/launch_training.py /app/accelerate_launch.py + +ENV FSDP_DEFAULTS_FILE_PATH="/app/accelerate_fsdp_defaults.yaml" + +# Need a better way to address this hack +RUN touch /.aim_profile && \ + chmod -R 777 /.aim_profile && \ + mkdir /.cache && \ + chmod -R 777 /.cache + +# create tuning user and give ownership to dirs +RUN useradd -u $USER_UID tuning -m -g 0 --system && \ + chown -R $USER:0 /app /tmp && \ + chmod -R g+rwX /app /tmp + +WORKDIR /app +USER ${USER} + +CMD [ "python", "/app/accelerate_launch.py" ] diff --git a/build/README.md b/build/README.md new file mode 100644 index 000000000..656406ad9 --- /dev/null +++ b/build/README.md @@ -0,0 +1,171 @@ +# Building fms-hf-tuning as an Image + +The Dockerfile provides a way of running fms-hf-tuning SFT Trainer. It installs the dependencies needed and adds two additional scripts that helps to parse arguments to pass to SFT Trainer. The `accelerate_launch.py` script is run by default when running the image to trigger SFT trainer for single or multi GPU by parsing arguments and running `accelerate launch launch_training.py`. + +## Configuration + +The scripts accept a JSON formatted config which are set by environment variables. `SFT_TRAINER_CONFIG_JSON_PATH` can be set to the mounted path of the JSON config. Alternatively, `SFT_TRAINER_CONFIG_JSON_ENV_VAR` can be set to the encoded JSON config using the below function: + +```py +import base64 + +def encode_json(my_json_string): + base64_bytes = base64.b64encode(my_json_string.encode("ascii")) + txt = base64_bytes.decode("ascii") + return txt + +with open("test_config.json") as f: + contents = f.read() + +encode_json(contents) +``` + +The keys for the JSON config are all of the flags available to use with [SFT Trainer](https://huggingface.co/docs/trl/sft_trainer#trl.SFTTrainer). + +For configuring `accelerate launch`, use key `accelerate_launch_args` and pass the set of flags accepted by [accelerate launch](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch). Since these flags are passed via the JSON config, the key matches the long formed flag name. For example, to enable flag `--quiet`, use JSON key `"quiet"`, using the short formed `"q"` will fail. + +For example, the below config is used for running with two GPUs and FSDP for fine tuning: + +```json +{ + "accelerate_launch_args": { + "num_machines": 1, + "main_process_port": 1234, + "num_processes": 2, + "use_fsdp": true, + "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_sync_module_states": true + }, + "model_name_or_path": "/llama/13B", + "training_data_path": "/data/twitter_complaints.json", + "output_dir": "/output/llama-7b-pt-multigpu", + "num_train_epochs": 5.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "save_strategy": "epoch", + "learning_rate": 0.03, + "weight_decay": 0.0, + "lr_scheduler_type": "cosine", + "logging_steps": 1.0, + "packing": false, + "include_tokens_per_second": true, + "response_template": "\n### Label:", + "dataset_text_field": "output", + "use_flash_attn": true, + "torch_dtype": "bfloat16", + "tokenizer_name_or_path": "/llama/13B" +} +``` + +Users should always set `num_processes` to be explicit about the number of processes to run tuning on. When `num_processes` is greater than 1, the [FSDP config](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/fixtures/accelerate_fsdp_defaults.yaml) is used by default. Thus in the above example, you don't need to pass in the FSDP flags since they match the ones used in the default FSDP config. You can also set your own default values by specifying your own config file using key `config_file`. Any of these values in configs can be overwritten by passing in flags via `accelerate_launch_args` in the JSON config. + +Note that `num_processes` which is the total number of processes to be launched in parallel, should match the number of GPUs to run on. The number of GPUs used can also be set by setting environment variable `CUDA_VISIBLE_DEVICES`. If ``num_processes=1`, the script will assume single-GPU. + + +## Building the Image + +With docker, build the image at the top level with: + +```sh +docker build . -t sft-trainer:mytag -f build/Dockerfile +``` + +## Running the Image + +Run sft-trainer-image with the JSON env var and mounts set up. + +```sh +docker run -v config.json:/app/config.json -v $MODEL_PATH:/model -v $TRAINING_DATA_PATH:/data/twitter_complaints.json --env SFT_TRAINER_CONFIG_JSON_PATH=/app/config.json sft-trainer:mytag +``` + +This will run `accelerate_launch.py` with the JSON config passed. + +An example Kubernetes Pod for deploying sft-trainer which requires creating PVCs with the model and input dataset and any mounts needed for the outputted tuned model: + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: +name: sft-trainer-config +data: +config.json: | + { + "accelerate_launch_args": { + "num_machines": 1, + "main_process_port": 1234, + "num_processes": 2, + "use_fsdp": true, + "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_sync_module_states": true + }, + "model_name_or_path": "/llama/13B", + "training_data_path": "/data/twitter_complaints.json", + "output_dir": "/output/llama-7b-pt-multigpu", + "num_train_epochs": 5.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "save_strategy": "epoch", + "learning_rate": 0.03, + "weight_decay": 0.0, + "lr_scheduler_type": "cosine", + "logging_steps": 1.0, + "packing": false, + "include_tokens_per_second": true, + "response_template": "\n### Label:", + "dataset_text_field": "output", + "use_flash_attn": true, + "torch_dtype": "bfloat16", + "tokenizer_name_or_path": "/llama/13B" + } +--- +apiVersion: v1 +kind: Pod +metadata: +name: sft-trainer-test +spec: +containers: + env: + - name: SFT_TRAINER_CONFIG_JSON_PATH + value: /config/config.json + image: sft-trainer:mytag + imagePullPolicy: IfNotPresent + name: tuning-test + resources: + limits: + nvidia.com/gpu: "2" + memory: 200Gi + cpu: "10" + ephemeral-storage: 2Ti + requests: + memory: 80Gi + cpu: "5" + volumeMounts: + - mountPath: /data/input + name: input-data + - mountPath: /data/output + name: output-data + - mountPath: /config + name: sft-trainer-config +restartPolicy: Never +terminationGracePeriodSeconds: 30 +volumes: + - name: input-data + persistentVolumeClaim: + claimName: input-pvc + - name: output-data + persistentVolumeClaim: + claimName: output-pvc + - name: sft-trainer-config + configMap: + name: sft-trainer-config +``` + +The above kube resource values are not hard-defined. However, they are useful when running some models (such as LLaMa-13b model). If ephemeral storage is not defined, you will likely hit into error `The node was low on resource: ephemeral-storage. Container was using 1498072868Ki, which exceeds its request of 0.` where the pod runs low on storage while tuning the model. \ No newline at end of file diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py new file mode 100644 index 000000000..e18812572 --- /dev/null +++ b/build/accelerate_launch.py @@ -0,0 +1,43 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Script wraps launch_training to run with accelerate for multi and single GPU cases. +Read accelerate_launch_args configuration via environment variable `SFT_TRAINER_CONFIG_JSON_PATH` +for the path to the JSON config file with parameters or `SFT_TRAINER_CONFIG_JSON_ENV_VAR` +for the encoded config string to parse. +""" + +# Standard +import os +import logging + +# Third Party +from accelerate.commands.launch import launch_command + +# Local +from build.utils import process_accelerate_launch_args, get_job_config + + +def main(): + LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() + logging.basicConfig(level=LOGLEVEL) + + job_config = get_job_config() + + args = process_accelerate_launch_args(job_config) + logging.debug("accelerate launch parsed args: %s", args) + launch_command(args) + + +if __name__ == "__main__": + main() diff --git a/build/launch_training.py b/build/launch_training.py new file mode 100644 index 000000000..af02575bf --- /dev/null +++ b/build/launch_training.py @@ -0,0 +1,131 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Script wraps SFT Trainer to run for Train Conductor. +Read SFTTrainer configuration via environment variable `SFT_TRAINER_CONFIG_JSON_PATH` +for the path to the JSON config file with parameters or `SFT_TRAINER_CONFIG_JSON_ENV_VAR` +for the encoded config string to parse. +""" + +# Standard +import os +import tempfile +import shutil + +# First Party +import logging + +# Local +from tuning import sft_trainer +from tuning.utils.merge_model_utils import create_merged_model +from tuning.config.tracker_configs import TrackerConfigFactory +from build.utils import process_launch_training_args, get_job_config + + +def get_highest_checkpoint(dir_path): + checkpoint_dir = "" + for curr_dir in os.listdir(dir_path): + if curr_dir.startswith("checkpoint"): + if checkpoint_dir: + curr_dir_num = int(checkpoint_dir.rsplit("-", maxsplit=1)[-1]) + new_dir_num = int(curr_dir.split("-")[-1]) + if new_dir_num > curr_dir_num: + checkpoint_dir = curr_dir + else: + checkpoint_dir = curr_dir + + return checkpoint_dir + + +def main(): + LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() + logging.basicConfig(level=LOGLEVEL) + + logging.info("Initializing launch training script") + + job_config = get_job_config() + + logging.debug("Input params parsed: %s", job_config) + + ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + file_logger_config, + aim_config, + ) = process_launch_training_args(job_config) + + original_output_dir = training_args.output_dir + with tempfile.TemporaryDirectory() as tempdir: + training_args.output_dir = tempdir + tracker_config_args = TrackerConfigFactory( + file_logger_config=file_logger_config, aim_config=aim_config + ) + sft_trainer.train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + tracker_configs=tracker_config_args, + ) + + if merge_model: + export_path = os.getenv( + "LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir + ) + + # get the highest checkpoint dir (last checkpoint) + lora_checkpoint_dir = get_highest_checkpoint(training_args.output_dir) + full_checkpoint_dir = os.path.join( + training_args.output_dir, lora_checkpoint_dir + ) + + logging.info( + "Merging lora tuned checkpoint %s with base model into output path: %s", + lora_checkpoint_dir, + export_path, + ) + + create_merged_model( + checkpoint_models=full_checkpoint_dir, + export_path=export_path, + base_model=model_args.model_name_or_path, + save_tokenizer=True, + ) + else: + # copy last checkpoint into mounted output dir + pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir) + logging.info( + "Copying last checkpoint %s into output dir %s", + pt_checkpoint_dir, + original_output_dir, + ) + shutil.copytree( + os.path.join(training_args.output_dir, pt_checkpoint_dir), + original_output_dir, + dirs_exist_ok=True, + ) + + # copy over any loss logs + train_logs_filepath = os.path.join( + training_args.output_dir, + tracker_config_args.file_logger_config.training_logs_filename, + ) + if os.path.exists(train_logs_filepath): + shutil.copy(train_logs_filepath, original_output_dir) + + +if __name__ == "__main__": + main() diff --git a/build/utils.py b/build/utils.py new file mode 100644 index 000000000..7025d0978 --- /dev/null +++ b/build/utils.py @@ -0,0 +1,210 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os +import json +import logging +import base64 +import pickle + +# Third Party +import torch +import transformers +from accelerate.commands.launch import launch_command_parser + +# Local +from tuning.config import configs, peft_config, tracker_configs + + +def txt_to_obj(txt): + base64_bytes = txt.encode("ascii") + message_bytes = base64.b64decode(base64_bytes) + try: + # If the bytes represent JSON string + return json.loads(message_bytes) + except UnicodeDecodeError: + # Otherwise the bytes are a pickled python dictionary + return pickle.loads(message_bytes) + + +def get_job_config(): + json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH") + json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR") + + # accepts either path to JSON file or encoded string config + if json_path: + with open(json_path, "r", encoding="utf-8") as f: + job_config_dict = json.load(f) + elif json_env_var: + job_config_dict = txt_to_obj(json_env_var) + else: + raise ValueError( + "Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \ + or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'." + ) + return job_config_dict + + +def process_launch_training_args(job_config_dict): + """Return parsed config for tuning to pass to SFT Trainer + Args: + job_config_dict: dict + Return: + model_args: configs.ModelArguments + data_args: configs.DataArguments + training_args: configs.TrainingArguments + tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig + merge_model: bool + file_logger_config: tracker_configs.FileLoggingTrackerConfig + aim_config: tracker_configs.AimConfig + """ + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + tracker_configs.FileLoggingTrackerConfig, + tracker_configs.AimConfig, + ) + ) + + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + file_logger_config, + aim_config, + ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) + + peft_method_parsed = job_config_dict.get("peft_method") + + tune_config = None + merge_model = False + if peft_method_parsed == "lora": + tune_config = lora_config + merge_model = True + elif peft_method_parsed == "pt": + tune_config = prompt_tuning_config + + logging.info( + "Parameters used to launch training: \ + model_args %s, data_args %s, training_args %s, tune_config %s \ + file_logger_config %s aim_config %s", + model_args, + data_args, + training_args, + tune_config, + file_logger_config, + aim_config, + ) + + return ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + file_logger_config, + aim_config, + ) + + +def process_accelerate_launch_args(job_config_dict): + """Return parsed config for tuning to pass to SFT Trainer + Args: + job_config_dict: dict + Return: + args to pass to `accelerate launch` + """ + parser = launch_command_parser() + # Map to determine which flags don't require a value to be set + actions_type_map = { + action.dest: type(action).__name__ for action in parser._actions + } + + # Parse accelerate_launch_args + accelerate_launch_args = [] + accelerate_config = job_config_dict.get("accelerate_launch_args", {}) + if accelerate_config: + logging.info("Using accelerate_launch_args configs: %s", accelerate_config) + for key, val in accelerate_config.items(): + # skip num_processes to assign below based on SET_NUM_PROCESSES_TO_NUM_GPUS + if key == "num_processes": + continue + + if actions_type_map.get(key) == "_AppendAction": + for param_val in val: + accelerate_launch_args.extend([f"--{key}", str(param_val)]) + elif (actions_type_map.get(key) == "_StoreTrueAction" and val) or ( + actions_type_map.get(key) == "_StoreFalseAction" and not val + ): + accelerate_launch_args.append(f"--{key}") + else: + accelerate_launch_args.append(f"--{key}") + # Only need to add key for params that aren't flags ie. --quiet + if actions_type_map.get(key) == "_StoreAction": + accelerate_launch_args.append(str(val)) + + # accept setting SET_NUM_PROCESSES_TO_NUM_GPUS=True in Shell interpreted as string + set_num_processes_to_num_gpus = os.getenv( + "SET_NUM_PROCESSES_TO_NUM_GPUS", "True" + ).lower() + user_arg_num_processes = accelerate_config.get("num_processes") + num_processes = 0 + if set_num_processes_to_num_gpus == "true": + num_processes = torch.cuda.device_count() + + if user_arg_num_processes: + logging.warning( + "SET_NUM_PROCESSES_TO_NUM_GPUS=True, overwriting user set num_processes %s\ + to all GPUs available, %s.", + user_arg_num_processes, + num_processes, + ) + elif user_arg_num_processes: + num_processes = int(user_arg_num_processes) + + if num_processes: + accelerate_launch_args.extend(["--num_processes", str(num_processes)]) + # if multi GPU setting and accelerate config_file not passed by user, + # use the default config for default set of parameters + if num_processes > 1 and not accelerate_config.get("config_file"): + # Add default FSDP config + fsdp_filepath = os.getenv( + "FSDP_DEFAULTS_FILE_PATH", "/app/accelerate_fsdp_defaults.yaml" + ) + if os.path.exists(fsdp_filepath): + logging.info("Using accelerate config file: %s", fsdp_filepath) + accelerate_launch_args.extend(["--config_file", fsdp_filepath]) + + elif num_processes == 1: + logging.info("num_processes=1 so setting env var CUDA_VISIBLE_DEVICES=0") + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + else: + logging.warning( + "num_processes param was not passed in. Value from config file (if available) will \ + be used or accelerate launch will determine number of processes automatically" + ) + + # Add training_script + accelerate_launch_args.append("/app/launch_training.py") + + logging.debug("accelerate_launch_args: %s", accelerate_launch_args) + args = parser.parse_args(args=accelerate_launch_args) + return args diff --git a/code-of-conduct.md b/code-of-conduct.md new file mode 100644 index 000000000..b2e6e4b6d --- /dev/null +++ b/code-of-conduct.md @@ -0,0 +1,76 @@ +# Community Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +- The use of sexualized language or imagery and unwelcome sexual attention or + advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic + address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the [project team](./CODEOWNERS). All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/examples/prompt_tuning_twitter_complaints/README.md b/examples/prompt_tuning_twitter_complaints/README.md index f3c663b35..c8383cd57 100644 --- a/examples/prompt_tuning_twitter_complaints/README.md +++ b/examples/prompt_tuning_twitter_complaints/README.md @@ -30,20 +30,22 @@ dataset.to_json("twitter_complaints.json") ### Prompt Tuning We will switch our PEFT method from LORA to Prompt Tuning (pt) ```bash -# replace these with your values -MODEL_PATH=llama-7b-hf -DATA_PATH=twitter_complaints.json -OUTPUT_PATH=out +# Please set the environment variables: +# MASTER_PORT=1234 # The port at which the process with rank 0 listens to and should be set to an unused port +# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint +# TRAIN_DATA_PATH=twitter_complaints.json # Path to the training dataset +# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved -torchrun \ ---nnodes=1 \ ---nproc_per_node=8 \ ---master_port=1234 \ + +accelerate launch \ +--main_process_port $MASTER_PORT \ +--config_file fixtures/accelerate_fsdp_defaults.yaml \ tuning/sft_trainer.py \ --model_name_or_path $MODEL_PATH \ ---data_path $DATA_PATH \ +--training_data_path $TRAIN_DATA_PATH \ --output_dir $OUTPUT_PATH \ --peft_method pt \ +--torch_dtype bfloat16 \ --tokenizer_name_or_path $MODEL_PATH \ --num_train_epochs 5 \ --per_device_train_batch_size 1 \ @@ -56,8 +58,6 @@ tuning/sft_trainer.py \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ ---fsdp "full_shard auto_wrap" \ ---fsdp_config tuning/config/fsdp_config.json \ --include_tokens_per_second \ --packing False \ --response_template "\n### Label:" \ diff --git a/examples/trainercontroller_configs/Readme.md b/examples/trainercontroller_configs/Readme.md new file mode 100644 index 000000000..39d8271c8 --- /dev/null +++ b/examples/trainercontroller_configs/Readme.md @@ -0,0 +1,5 @@ +# How-To +To use one of these files with the trainer, execute the `sft_trainer.py` with the following option: +``` +--trainer_controller_config_file "examples/trainercontroller_configs/" +``` diff --git a/examples/trainercontroller_configs/loss.yaml b/examples/trainercontroller_configs/loss.yaml new file mode 100644 index 000000000..dd272d21c --- /dev/null +++ b/examples/trainercontroller_configs/loss.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller + triggers: + - on_log + rule: loss < 1.0 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/fixtures/accelerate_fsdp_defaults.yaml b/fixtures/accelerate_fsdp_defaults.yaml new file mode 100644 index 000000000..f70d74faa --- /dev/null +++ b/fixtures/accelerate_fsdp_defaults.yaml @@ -0,0 +1,60 @@ +# options that can be used with accelerate config are neatly documented here - +# https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/docs/source/package_reference/cli.md + +# type of compute environment, no need to change +compute_environment: LOCAL_MACHINE # AMAZON_SAGEMAKER + +# use FSDP distributed compute +distributed_type: FSDP + +# FSDP specific configurations +fsdp_config: + + # use this for training transformers + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + + # this controls the FSDP pipelining + fsdp_backward_prefetch_policy: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline + # but requires the most memory. BACKWARD_POST is the less + # memory intensive option + + # setting this to true will increase forward memory by prefetching the next FSDP all-gather, while performing + # the current forward pass. + fsdp_forward_prefetch: false + + # setting this will offload model and optimizer parameters to the CPU, to save GPU memory at a significant + # increase of CPU time. + fsdp_offload_params: false + + fsdp_sharding_strategy: 1 # set to FULL_SHARD (1), SHARD_GRAD_OP (2), + # 3 is NO_SHARD, effectively disabling FSDP + # 4, 5 are HYBRID_ modes for multi-node training only. + + fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) + # 2 is LOCAL_STATE_DICT where parameters are still flattened + # 3 is efficient, but requires know-how to use the shared checkpoint. + + fsdp_cpu_ram_efficient_loading: true # for large models set to true, model loaded on single process + fsdp_sync_module_states: true # for large models set to true, model loaded on single process + + # not needed for HF models that have . _no_split_modules + # the example below is for GPTBigCode + # fsdp_transformer_layer_cls_to_wrap: "GPTBigCodeBlock” + +# for "autocast" mixed precision training, where the weights of the model are kept at higher precision, but the +# learning products (e.g., gradients, model parameters) are kept at a lower precision. Default is 'no'. Other options +# would be fp16, bf16, etc. +mixed_precision: 'no' + +machine_rank: 0 # rank of the machine where accelerate is launched +num_machines: 1 +num_processes: 1 # default, override with --num_processes + +# the rendezvous method to use in distributed training. Other option is c10d +rdzv_backend: static +same_network: true + +# below arguments are required when training in multi-node setup +# for multi-gpu single node, the below values default to +# main_process_ip: 127.0.0.1 # override with --main_process_ip +# main_process_port: 29500 # override with --main_process_port diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..3bff676b5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=60", + "setuptools-scm>=8.0" + ] + +[project] +name = "fms-hf-tuning" +dynamic = ["version"] +description = "FMS HF Tuning" +authors = [ + {name = "Sukriti Sharma", email = "sukriti.sharma4@ibm.com"}, + {name = "Anh Uong", email = "anh.uong@ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.9" +keywords = ['fms-hf-tuning', 'python', 'tuning'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ +"numpy", +"accelerate>=0.20.3", +"transformers", +"torch", +"sentencepiece", +"tokenizers>=0.13.3", +"tqdm", +"trl", +"peft>=0.8.0", +"datasets>=2.15.0", +"fire", +"simpleeval", +] + +[project.optional-dependencies] +dev = ["wheel", "packaging", "ninja", "scikit-learn>=1.0, <2.0"] +flash-attn = ["flash-attn"] +aim = ["aim==3.19.0"] + +[tool.setuptools.packages.find] +exclude = ["tests", "tests.*"] +namespaces = false + +[tool.setuptools_scm] +version_file = "tuning/_version.py" + +[project.urls] +Homepage = "https://github.com/foundation-model-stack/fms-hf-tuning" +Repository = "https://github.com/foundation-model-stack/fms-hf-tuning" +Issues = "https://github.com/foundation-model-stack/fms-hf-tuning/issues" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..f082af727 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +# Register tests from `build` dir, removing `build` from norecursedirs default list, +# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs +norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch} diff --git a/requirements.txt b/requirements.txt index 04d84f2b3..747570fc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,5 +11,4 @@ ninja>=1.11.1.1,<2.0 peft>=0.8.0,<1.0 datasets>=2.15.0,<3.0 fire>=0.5.0,<1.0 -packaging>=23.2,<24 - +packaging>=23.2,<24 \ No newline at end of file diff --git a/scripts/run_evaluation.py b/scripts/run_evaluation.py new file mode 100644 index 000000000..dde162bb8 --- /dev/null +++ b/scripts/run_evaluation.py @@ -0,0 +1,480 @@ +"""Runs evaluation on Alpaca formatted data. + +Metrics used: Accuracy / Micro & Macro F1. +""" +# Standard +from shutil import rmtree +from typing import Any, Optional +import argparse +import json +import os + +# Third Party +from run_inference import TunedCausalLM +from sklearn import preprocessing +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from tqdm import tqdm +import datasets +import numpy as np + + +def parse_and_validate_args(): + """Parse the arguments and ensure everything is valid.""" + parser = argparse.ArgumentParser( + description="Runs evaluation on a Alpaca style dataset" + ) + parser.add_argument( + "--model", help="Path to tuned model / merged model to be loaded", required=True + ) + parser.add_argument( + "--data_path", help="Path to the dataset to be loaded", required=True + ) + parser.add_argument( + "--split", help="Split to be used for the data", default="train" + ) + parser.add_argument( + "--max_new_tokens", + help="Max new tokens to use in generation", + type=int, + ) + parser.add_argument( + "--output_dir", + help="Directory path to export results to", + default="eval_results", + ) + parser.add_argument( + "--delimiter", + help="Delimiter to be used for multilabel multiclass evaluation", + default=None, + ) + parser.add_argument( + "--eos_token", + help="EOS token emitted by the model; will recursively remove the token if present", + ) + parser.add_argument( + "--use_instruction", + help="Indicates whether or not the instruction field should be used in formatting", + action="store_true", + ) + parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction) + parser.add_argument( + "--use_flash_attn", + help="Whether to load the model using Flash Attention 2", + action="store_true", + ) + parsed_args = parser.parse_args() + + print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}") + # If we have a collision on the outdir, only remove the existing file if we explicitly say to + if os.path.exists(parsed_args.output_dir): + if parsed_args.purge_results: + print( + f"Existing output file/directory: [{parsed_args.output_dir}] will be deleted..." + ) + rmtree(parsed_args.output_dir) + else: + raise FileExistsError( + f"Output dir [{parsed_args.output_dir}] exists; use --purge_results to clobber it" + ) + return parsed_args + + +### Alpaca dataset formatting utilities +PROMPT_DICT = { + "prompt_input": ( + # pylint: disable=line-too-long + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} + + +def get_formatted_example( + example: dict[str, str], use_instruction: bool +) -> dict[str, str]: + """Given a single example, format it based on whether or not we have an input provided. + + Args: + example: dict[str, str] + Dictionary containing the keys for instruction / input / output, i.e., Alpaca formatted + data. + use_instruction: bool + Indicates whether or not the instruction field will be used. + + Returns: + dict[str, str] + Dictionary containing the following: + "input" - the formatted text to run the prediction on. + "output" - the target text we aim to generate. + """ + # NOTE: Currently we ignore the instruction field due to the type of tasks we're tuning against + if use_instruction: + prompt_input, prompt_no_input = ( + PROMPT_DICT["prompt_input"], + PROMPT_DICT["prompt_no_input"], + ) + formatted_input = ( + prompt_input.format_map(example) + if example.get("input", "") != "" + else prompt_no_input.format_map(example) + ) + else: + formatted_input = f"Input: \n{example.get('input')}\n\n### Response:" + return { + # Text to run the prediction on + "input": formatted_input, + # Text to be generated (does not include the input str) + "output": example["output"], + } + + +### Model evaluation +def get_prediction_results( + model: TunedCausalLM, + data: datasets.arrow_dataset.Dataset, + max_new_tokens: int, + use_instruction: bool, + delimiter: Optional[str], + eos_token: Optional[str], +) -> tuple[list]: + """Runs the model over the alpaca formatted data to get the predictions / references to be used + when computing the metrics of interest. + + Args: + model: TunedCausalLM + Model to be used for evaliuation. + data: datasets.arrow_dataset.Dataset + HF dataset to be processed for evaluation. + max_new_tokens: int + Max number of tokens to be used for generation. + use_instruction: bool + Indicates whether or not the instruction field should be used. + delimiter: Optional[str] + Delimiter to be used for splitting apart multioutput instances. + eos_token: Optional[str] + EOS token emitted by the model, which will be recursively removed from predictions. + Returns: + tuple[list] + Tuple containing: + predictions [list of strings] + references [list of strings] + model_pred_info [list of dicts containing formatted data to be dumped later] + """ + preds = [] + refs = [] + model_pred_info = [] + for datum in tqdm(data): + # Format the alpaca example + formatted_datum = get_formatted_example(datum, use_instruction) + # Run the formatted text through the model, and only save the newly generated text strings + prediction = model.run( + formatted_datum["input"], + max_new_tokens=max_new_tokens, + ret_gen_text_only=True, + ) + # Save the raw output / predicted texts + processed_pred = postprocess_output(prediction, delimiter, eos_token) + # The reference text should not have an EOS to strip + processed_ref = postprocess_output(formatted_datum["output"], delimiter, None) + preds.append(processed_pred) + refs.append(processed_ref) + model_pred_info.append( + { + "formatted input": formatted_datum["input"], + "predicted target": processed_pred, + "ref target": processed_ref, + } + ) + return preds, refs, model_pred_info + + +def postprocess_output( + output_text: str, delimiter: Optional[str], eos_token: Optional[str] +) -> list[str]: + """NOTE: We are returning a list here, since that is what the one hot encoder module expects. + Args: + output_text: str + Raw text to be split into one or more (potentially) delimited instances. + delimiter: Optional[str] + Delimiter to be used for splitting apart multioutput instances. + delimiter: Optional[str] + Delimiter to be used for splitting apart multioutput instances. + + Returns + list[str] + List of one or more labels. + """ + if eos_token is not None: + while output_text.removesuffix(eos_token) != output_text: + output_text = output_text.removesuffix(eos_token) + if delimiter is not None: + return [text_substr.strip() for text_substr in output_text.split(delimiter)] + return [output_text.strip()] + + +### Metric computation/display & utils for mapping labels to numerics for sklearn +def map_predictions_and_references_to_encoded_vectors( + predictions_lists: list[list[str]], references_lists: list[list[str]] +) -> tuple[Any]: + """Maps the delimited text lists to lists of encoded vectors. + + Args: + predictions_lists: list[list[str]] + Delimited text lists for model predictions to be encoded. + references_lists: list[list[str]] + Ground truth delimited text lists to be encoded. + + Returns: + tuple[Any] + tuple containing: + pred_vectors [list of encoded 1D numpy arrays] + reference_vectors [list of encoded 1D numpy arrays] + label_map dict[str, str] - maps class indices to labels + """ + if not predictions_lists or not references_lists: + raise ValueError("Predictions and/or references should not be empty!") + ohe = preprocessing.OneHotEncoder() + # Extract the unique (potentially delimited labels) to fit the one hot encoder. We need to do + # this directly in case it's a multiclass/multilabel scenario, because the 2D arr consumed + # by the OHE expected consistent axis shapes, i.e., columns are treated as different features, + # and cannot have a variable number of values. + unk_label = "" + unique_labels = extract_unique_labels( + predictions_lists, references_lists, unk_label + ) + ohe.fit(unique_labels) + + # Now get the encoded vectors for our references and our predictions by one hot encoding + # theunique sublabels and collapsing them into one vector along the row dimension. + reference_vectors = [ + get_encoded_vector(ohe, refs, unk_label) for refs in references_lists + ] + pred_vectors = [ + get_encoded_vector(ohe, preds, unk_label) for preds in predictions_lists + ] + + # For debugging purposes - map the indices in our none hot encoded entries. + # NOTE: the categories_ attr is a 2D array of features, and we only care about [0] + # since the uniquely extracted labels are only single dim features when fitting + # the transform itself. + label_map = dict(enumerate(ohe.categories_[0])) + return pred_vectors, reference_vectors, label_map + + +def get_encoded_vector( + ohe: preprocessing.OneHotEncoder, texts: list[str], unk_label: str +) -> np.typing.NDArray: + """Get the encoded vector representing one or more generated texts by one hot encoding each + individual text and collapsing the result. + + Args: + ohe: preprocessing.OneHotEncoder + Sklearn one hot encoder to be used for one hot encoding all texts + (including the garbage class if we have one). + texts: list[str] + List of texts to be encoded and collapsed into one vector. + unk_label: str + Label to be used for garbage generations. + + Returns: + np.typing.NDArray + Binary vector encoding the list of texts as labels. + """ + # Since our encoded vector is built on collapsing one hot encoded vectors, + # we need to explicitly handle the empty case since it is not one hot encodable. + # raise ValueError(np.zeros(len(ohe.categories_[0])).dtype ) + if not texts: + return np.zeros(len(ohe.categories_[0])) + # Clean the generated text list; anything that is in the list that is not known to the + # one hot encoder gets replaced by the unk_label. It is okay if we have multiple unk_labels + # in the vector, since all of these just map to one positive entry in the encoded vector. + cleaned_texts = list( + {text if text in ohe.categories_[0] else unk_label for text in texts} + ) + + # Encode the cleaned text as a 2D feature array of one hot encoded vectors + vec_stack = ohe.transform([[text] for text in cleaned_texts]).toarray() + + # Then collapse the one hot encoded vectors along the column dimension to get + # get the encoded binary vector for the multilabel / multiclass prediction. + return vec_stack.sum(axis=0) + + +def extract_unique_labels( + preds: list[list[str]], refs: list[list[str]], unk_label: str +) -> list[list[str]]: + """Grab all of the unique labels and return them as a list of single feature lists. + Args: + preds: list[list[str]] + List of lists, where each sublist contains the stripped delimited substrings of a + single model prediction. + refs: list[list[str]] + List of lists, where each sublist contains the stripped delimited substrings of a + single ground truth reference. + unk_label: str + Label to be used for Unknown - this class is only created in evaluation if the + generative model predicts something that is not present in the ground truth refs. + + Returns: + list[list[str]] + List of single value lists, each of which contains a single label. + """ + unique_ref_labels = set() + for ref in refs: + for sub_label in ref: + # This is pretty unlikely to happen (class named ""), but for now, raise + # if we see it happen, since that will currently mess up the results a little bit. + if sub_label == unk_label: + raise ValueError( + f"Unk label {unk_label} is being used as a ground truth label!" + ) + unique_ref_labels.add(sub_label) + + ref_label_list = [[label] for label in unique_ref_labels] + # HACK - traverse the predictions and see if any unk predictions were made; if so, make a + # garbage class, which we will mark as false positives here. + for pred in preds: + for sub_pred in pred: + # One of our delimited predictions is unknown! + if sub_pred not in unique_ref_labels: + # Add the unk label once we know that it isn't a field in our eval data + print("Adding label to handle garbage label generation") + ref_label_list.append([unk_label]) + return ref_label_list + return ref_label_list + + +def compute_metrics_dict_multi( + enc_preds: list[np.typing.NDArray], enc_refs: list[np.typing.NDArray] +) -> dict[str, Any]: + """Calculate the metrics based on the encoded prediction and reference vector lists. + Current metrics: precision, recall f1, accuracy + + Args: + enc_preds: list[np.typing.NDArray] + List of encoded binary vectors for predictions from the model. + enc_refs: list[np.typing.NDArray] + List of encoded binary vectors for ground truth references. + + Returns: + dict[str, Any] + Dictionary of metrics. + """ + micro_f1 = f1_score(enc_refs, enc_preds, average="micro", zero_division=np.nan) + macro_f1 = f1_score(enc_refs, enc_preds, average="macro", zero_division=np.nan) + # For recall - the UNK class containing only false positives does NOT affect score. + micro_recall = recall_score( + enc_refs, enc_preds, average="micro", zero_division=np.nan + ) + macro_recall = recall_score( + enc_refs, enc_preds, average="macro", zero_division=np.nan + ) + micro_prec = precision_score( + enc_refs, enc_preds, average="micro", zero_division=np.nan + ) + macro_prec = precision_score( + enc_refs, enc_preds, average="macro", zero_division=np.nan + ) + # NOTE: For the multiclass / multilabel scenario, sklearn accuracy does NOT assign partial + # credit, i.e., instances are only considered correct if they match the ground truth + # encoded vectors exactly. + accuracy = accuracy_score(enc_refs, enc_preds) + return { + "f1": { + "micro": micro_f1, + "macro": macro_f1, + }, + "recall": { + "micro": micro_recall, + "macro": macro_recall, + }, + "precision": { + "micro": micro_prec, + "macro": macro_prec, + }, + "accuracy": accuracy, + } + + +def export_experiment_info( + metrics: dict[str, Any], + label_map: dict[str, str], + model_pred_info: list[dict[str, Any]], + metadata: dict[str, Any], + output_dir: str, +): + """Creates an exports all experiments info / metadata. + + Args: + metrics: dict[str, Any], + Dictionary containing metrics of interest (i.e., F1 / accuracy). + label_map: dict[str, str] + Mapping of class integers / labels. + model_pred_info: list[dict[str, Any]] + List of serializable dicts containing formatted data to be processed. + metadata: dict[str, Any] + Other experiment metadata of interest, e.g., model name, max new tokens, etc. + output_dir: str + Directory name to be created to hold the experiment files. + """ + os.mkdir(output_dir) + with open( + os.path.join(output_dir, "eval_metrics.json"), "w", encoding="utf-8" + ) as metrics_fp: + json.dump(metrics, metrics_fp, indent=4, sort_keys=True) + # Dump the label map to a file for debugging purposes + with open( + os.path.join(output_dir, "label_map.json"), "w", encoding="utf-8" + ) as map_fp: + json.dump(label_map, map_fp, indent=4, sort_keys=True) + # Also, dump the predictions / references info to a file for debugging purposes + with open( + os.path.join(output_dir, "preds_and_references.json"), "w", encoding="utf-8" + ) as preds_fp: + json.dump(model_pred_info, preds_fp, indent=4, sort_keys=True) + with open( + os.path.join(output_dir, "experiment_metadata.json"), "w", encoding="utf-8" + ) as exp_md_fp: + json.dump(metadata, exp_md_fp, indent=4, sort_keys=True) + + +if __name__ == "__main__": + args = parse_and_validate_args() + tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn) + eval_data = datasets.load_dataset( + "json", data_files=args.data_path, split=args.split + ) + predictions, references, model_pred_file_info = get_prediction_results( + tuned_model, + eval_data, + args.max_new_tokens, + args.use_instruction, + args.delimiter, + args.eos_token, + ) + + ( + pred_vecs, + ref_vecs, + eval_label_map, + ) = map_predictions_and_references_to_encoded_vectors(predictions, references) + metrics_dict = compute_metrics_dict_multi(pred_vecs, ref_vecs) + experiment_metadata = { + "model": args.model, + "max_new_tokens": args.max_new_tokens, + "data_path": args.data_path, + } + export_experiment_info( + metrics_dict, + eval_label_map, + model_pred_file_info, + experiment_metadata, + args.output_dir, + ) + print(f"Exported results to: {args.output_dir}") diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 989aaa8ca..70820049e 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -1,3 +1,16 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """CLI for running loading a tuned model and running one or more inference calls on it. NOTE: For the moment, this script is intentionally written to contain all dependencies for two @@ -8,6 +21,7 @@ If these things change in the future, we should consider breaking it up. """ + # Standard import argparse import json @@ -77,11 +91,11 @@ def _apply_config_changes(self, overrides: dict) -> dict: # If we have no overrides, this context manager is a noop; no need to do anything if not overrides: return {} - with open(self.config_path, "r") as config_file: + with open(self.config_path, "r", encoding="utf-8") as config_file: adapter_config = json.load(config_file) overridden_values = self._get_old_config_values(adapter_config, overrides) adapter_config = {**adapter_config, **overrides} - with open(self.config_path, "w") as config_file: + with open(self.config_path, "w", encoding="utf-8") as config_file: json.dump(adapter_config, config_file, indent=4) return overridden_values @@ -128,7 +142,10 @@ def __init__(self, model, tokenizer, device): @classmethod def load( - cls, checkpoint_path: str, base_model_name_or_path: str = None + cls, + checkpoint_path: str, + base_model_name_or_path: str = None, + use_flash_attn: bool = False, ) -> "TunedCausalLM": """Loads an instance of this model. @@ -138,6 +155,8 @@ def load( adapter_config.json. base_model_name_or_path: str [Default: None] Override for the base model to be used. + use_flash_attn: bool [Default: False] + Whether to load the model using flash attention. By default, the paths for the base model and tokenizer are contained within the adapter config of the tuned model. Note that in this context, a path may refer to a model to be @@ -159,21 +178,33 @@ def load( try: with AdapterConfigPatcher(checkpoint_path, overrides): try: - model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path) + model = AutoPeftModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) except OSError as e: print("Failed to initialize checkpoint model!") raise e except FileNotFoundError: print("No adapter config found! Loading as a merged model...") # Unable to find the adapter config; fall back to loading as a merged model - model = AutoModelForCausalLM.from_pretrained(checkpoint_path) + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) device = "cuda" if torch.cuda.is_available() else None print(f"Inferred device: {device}") model.to(device) return cls(model, tokenizer, device) - def run(self, text: str, *, max_new_tokens: int) -> str: + def run( + self, text: str, *, max_new_tokens: int, ret_gen_text_only: bool = False + ) -> str: """Runs inference on an instance of this model. Args: @@ -181,6 +212,9 @@ def run(self, text: str, *, max_new_tokens: int) -> str: Text on which we want to run inference. max_new_tokens: int Max new tokens to use for inference. + ret_gen_text_only: bool + Indicates whether or not we should return the full text (i.e., input + new tokens) + or just the newly generated tokens. Returns: str @@ -192,8 +226,12 @@ def run(self, text: str, *, max_new_tokens: int) -> str: peft_outputs = self.peft_model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens ) + if ret_gen_text_only: + tok_to_decode = peft_outputs[:, input_ids.shape[1] :] + else: + tok_to_decode = peft_outputs decoded_result = self.tokenizer.batch_decode( - peft_outputs, skip_special_tokens=False + tok_to_decode, skip_special_tokens=False )[0] return decoded_result @@ -213,7 +251,8 @@ def main(): ) parser.add_argument( "--base_model_name_or_path", - help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]", + help="Override for base model to be used for non-merged models \ + [default: value in model adapter_config.json]", default=None, ) parser.add_argument( @@ -222,6 +261,11 @@ def main(): type=int, default=20, ) + parser.add_argument( + "--use_flash_attn", + help="Whether to load the model using Flash Attention 2", + action="store_true", + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--text", help="Text to run inference on") group.add_argument( @@ -237,13 +281,14 @@ def main(): loaded_model = TunedCausalLM.load( checkpoint_path=args.model, base_model_name_or_path=args.base_model_name_or_path, + use_flash_attn=args.use_flash_attn, ) # Run inference on the text; if multiple were provided, process them all if args.text: texts = [args.text] else: - with open(args.text_file, "r") as text_file: + with open(args.text_file, "r", encoding="utf-8") as text_file: texts = [line.strip() for line in text_file.readlines()] # TODO: we should add batch inference support @@ -256,7 +301,7 @@ def main(): ] # Export the results to a file - with open(args.out_file, "w") as out_file: + with open(args.out_file, "w", encoding="utf-8") as out_file: json.dump(results, out_file, sort_keys=True, indent=4) print(f"Exported results to: {args.out_file}") diff --git a/setup.py b/setup.py deleted file mode 100644 index ae71369c0..000000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -# Third Party -from setuptools import find_packages, setup - -setup(name="tuning", version="0.0.1", packages=find_packages()) diff --git a/setup_requirements.txt b/setup_requirements.txt deleted file mode 100644 index 519d362a1..000000000 --- a/setup_requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -pre-commit>=3.0.4,<4.0 -pylint>=2.16.2,<4.0 -pydeps>=1.12.12,<2 -tox>=4.4.2,<5 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/build/__init__.py b/tests/build/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/build/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/build/dummy_job_config.json b/tests/build/dummy_job_config.json new file mode 100644 index 000000000..315a5b527 --- /dev/null +++ b/tests/build/dummy_job_config.json @@ -0,0 +1,32 @@ +{ + "accelerate_launch_args": { + "use_fsdp": true, + "env": ["env1", "env2"], + "dynamo_use_dynamic": true, + "num_machines": 1, + "main_process_port": 1234, + "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_sync_module_states": true, + "config_file": "fixtures/accelerate_fsdp_defaults.yaml" + }, + "model_name_or_path": "bigscience/bloom-560m", + "training_data_path": "data/twitter_complaints_small.json", + "output_dir": "bloom-twitter", + "num_train_epochs": 5.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "learning_rate": 0.03, + "weight_decay": 0.000001, + "lr_scheduler_type": "cosine", + "logging_steps": 1.0, + "packing": false, + "include_tokens_per_second": true, + "response_template": "### Label:", + "dataset_text_field": "output", + "use_flash_attn": false, + "tokenizer_name_or_path": "bigscience/bloom-560m" + } \ No newline at end of file diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py new file mode 100644 index 000000000..1bfaabba4 --- /dev/null +++ b/tests/build/test_utils.py @@ -0,0 +1,160 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import copy +import json +import os +from unittest.mock import patch + +# Third Party +import pytest + +# Local +from tuning.config.peft_config import LoraConfig, PromptTuningConfig +from build.utils import process_launch_training_args, process_accelerate_launch_args + +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "dummy_job_config.json" +) + + +# Note: job_config dict gets modified during process_launch_training_args +@pytest.fixture(name="job_config", scope="session") +def fixture_job_config(): + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: + dummy_job_config_dict = json.load(f) + return dummy_job_config_dict + + +def test_process_launch_training_args(job_config): + job_config_copy = copy.deepcopy(job_config) + ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + _, + _, + ) = process_launch_training_args(job_config_copy) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert data_args.dataset_text_field == "output" + assert training_args.output_dir == "bloom-twitter" + assert tune_config is None + assert merge_model is False + + +def test_process_launch_training_args_defaults(job_config): + job_config_defaults = copy.deepcopy(job_config) + assert "torch_dtype" not in job_config_defaults + assert job_config_defaults["use_flash_attn"] is False + assert "save_strategy" not in job_config_defaults + model_args, _, training_args, _, _, _, _ = process_launch_training_args( + job_config_defaults + ) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert model_args.use_flash_attn is False + assert training_args.save_strategy.value == "epoch" + + +def test_process_launch_training_args_peft_method(job_config): + job_config_pt = copy.deepcopy(job_config) + job_config_pt["peft_method"] = "pt" + _, _, _, tune_config, merge_model, _, _ = process_launch_training_args( + job_config_pt + ) + assert isinstance(tune_config, PromptTuningConfig) + assert merge_model is False + + job_config_lora = copy.deepcopy(job_config) + job_config_lora["peft_method"] = "lora" + _, _, _, tune_config, merge_model, _, _ = process_launch_training_args( + job_config_lora + ) + assert isinstance(tune_config, LoraConfig) + assert merge_model is True + + +def test_process_accelerate_launch_args(job_config): + args = process_accelerate_launch_args(job_config) + # json config values used + assert args.use_fsdp is True + assert args.fsdp_backward_prefetch_policy == "TRANSFORMER_BASED_WRAP" + assert args.env == ["env1", "env2"] + assert args.training_script == "/app/launch_training.py" + assert args.config_file == "fixtures/accelerate_fsdp_defaults.yaml" + + # default values + assert args.tpu_use_cluster is False + assert args.mixed_precision is None + + +@patch("torch.cuda.device_count", return_value=1) +def test_accelerate_launch_args_user_set_num_processes_ignored(job_config): + job_config_copy = copy.deepcopy(job_config) + job_config_copy["accelerate_launch_args"]["num_processes"] = "3" + args = process_accelerate_launch_args(job_config_copy) + # determine number of processes by number of GPUs available + assert args.num_processes == 1 + + # if single-gpu, CUDA_VISIBLE_DEVICES set + assert os.getenv("CUDA_VISIBLE_DEVICES") == "0" + + +@patch.dict(os.environ, {"SET_NUM_PROCESSES_TO_NUM_GPUS": "False"}) +def test_accelerate_launch_args_user_set_num_processes(job_config): + job_config_copy = copy.deepcopy(job_config) + job_config_copy["accelerate_launch_args"]["num_processes"] = "3" + + args = process_accelerate_launch_args(job_config_copy) + # json config values used + assert args.num_processes == 3 + assert args.config_file == "fixtures/accelerate_fsdp_defaults.yaml" + + +def test_accelerate_launch_args_default_fsdp_config_multigpu(job_config): + with patch("torch.cuda.device_count", return_value=2): + with patch("os.path.exists", return_value=True): + job_config_copy = copy.deepcopy(job_config) + job_config_copy["accelerate_launch_args"].pop("config_file") + + assert "config_file" not in job_config_copy["accelerate_launch_args"] + + args = process_accelerate_launch_args(job_config_copy) + + # use default config file + assert args.config_file == "/app/accelerate_fsdp_defaults.yaml" + # determine number of processes by number of GPUs available + assert args.num_processes == 2 + + +@patch("os.path.exists") +def test_process_accelerate_launch_custom_config_file(patch_path_exists): + patch_path_exists.return_value = True + + dummy_config_path = "dummy_fsdp_config.yaml" + + # When user passes custom fsdp config file, use custom config and accelerate + # launch will use `num_processes` from config + temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}} + args = process_accelerate_launch_args(temp_job_config) + assert args.config_file == dummy_config_path + assert args.num_processes is None + + # When user passes custom fsdp config file and also `num_processes` as a param, + # use custom config and overwrite num_processes from config with param + temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}} + args = process_accelerate_launch_args(temp_job_config) + assert args.config_file == dummy_config_path diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 000000000..6df7802cd --- /dev/null +++ b/tests/data/__init__.py @@ -0,0 +1,24 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpful datasets for configuring individual unit tests. +""" +# Standard +import os + +### Constants used for data +DATA_DIR = os.path.join(os.path.dirname(__file__)) +TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") +EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") diff --git a/tests/data/empty_data.json b/tests/data/empty_data.json new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/malformatted_data.json b/tests/data/malformatted_data.json new file mode 100644 index 000000000..437763095 --- /dev/null +++ b/tests/data/malformatted_data.json @@ -0,0 +1 @@ +This data is bad! We can't use it to tune. diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py new file mode 100644 index 000000000..a18d746d2 --- /dev/null +++ b/tests/data/trainercontroller/__init__.py @@ -0,0 +1,61 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpful datasets for configuring individual unit tests. +""" +# Standard +import os + +### Constants used for data +_DATA_DIR = os.path.join(os.path.dirname(__file__)) +TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_YAML = os.path.join( + _DATA_DIR, "loss_on_threshold.yaml" +) +TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_WITH_TRAINER_STATE_YAML = os.path.join( + _DATA_DIR, "loss_on_threshold_with_trainer_state.yaml" +) +TRAINER_CONFIG_EXPOSED_METRICS_YAML = os.path.join(_DATA_DIR, "exposed_metrics.yaml") +TRAINER_CONFIG_INCORRECT_SOURCE_EVENT_EXPOSED_METRICS_YAML = os.path.join( + _DATA_DIR, "incorrect_source_event_exposed_metrics.yaml" +) +TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML = os.path.join( + _DATA_DIR, "loss_with_invalid_type_rule.yaml" +) +TRAINER_CONFIG_TEST_MALICIOUS_OS_RULE_YAML = os.path.join( + _DATA_DIR, "loss_with_malicious_os_rule.yaml" +) +TRAINER_CONFIG_TEST_MALICIOUS_INPUT_RULE_YAML = os.path.join( + _DATA_DIR, "loss_with_malicious_input_rule.yaml" +) +TRAINER_CONFIG_TEST_INVALID_TRIGGER_YAML = os.path.join( + _DATA_DIR, "loss_invalid_trigger.yaml" +) +TRAINER_CONFIG_TEST_INVALID_OPERATION_YAML = os.path.join( + _DATA_DIR, "loss_invalid_operation.yaml" +) +TRAINER_CONFIG_TEST_INVALID_OPERATION_ACTION_YAML = os.path.join( + _DATA_DIR, "loss_invalid_operation_action.yaml" +) +TRAINER_CONFIG_TEST_INVALID_METRIC_YAML = os.path.join( + _DATA_DIR, "loss_invalid_metric.yaml" +) +TRAINER_CONFIG_TEST_CUSTOM_METRIC_YAML = os.path.join( + _DATA_DIR, "loss_custom_metric.yaml" +) +TRAINER_CONFIG_TEST_CUSTOM_OPERATION_YAML = os.path.join( + _DATA_DIR, "loss_custom_operation.yaml" +) +TRAINER_CONFIG_TEST_CUSTOM_OPERATION_INVALID_ACTION_YAML = os.path.join( + _DATA_DIR, "loss_custom_operation_invalid_action.yaml" +) diff --git a/tests/data/trainercontroller/exposed_metrics.yaml b/tests/data/trainercontroller/exposed_metrics.yaml new file mode 100644 index 000000000..6fef43d68 --- /dev/null +++ b/tests/data/trainercontroller/exposed_metrics.yaml @@ -0,0 +1,12 @@ +controller-metrics: + - name: evalmetric + class: EvalMetrics + arguments: + source-event: on_evaluate +controllers: + - name: loss-controller + triggers: + - on_evaluate + rule: evalmetric['eval_loss'] < 2.5 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml b/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml new file mode 100644 index 000000000..b507150d1 --- /dev/null +++ b/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml @@ -0,0 +1,12 @@ +controller-metrics: + - name: evalmetric + class: EvalMetrics + arguments: + source-event: on_incorrect_event +controllers: + - name: loss-controller + triggers: + - on_evaluate + rule: evalmetric['eval_loss'] < 2.5 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_metric.yaml b/tests/data/trainercontroller/loss_custom_metric.yaml new file mode 100644 index 000000000..fece59d9a --- /dev/null +++ b/tests/data/trainercontroller/loss_custom_metric.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: testflag + class: CustomMetric +controllers: + - name: loss-controller-custom-metric + triggers: + - on_log + rule: testflag == True + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/data/trainercontroller/loss_custom_operation.yaml new file mode 100644 index 000000000..73737f8fb --- /dev/null +++ b/tests/data/trainercontroller/loss_custom_operation.yaml @@ -0,0 +1,13 @@ +controller-metrics: + - name: loss + class: Loss +operations: + - name: customoperation + class: CustomOperation +controllers: + - name: loss-controller-custom-operation + triggers: + - on_log + rule: loss < 1.0 + operations: + - customoperation.should_perform_action_xyz \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml new file mode 100644 index 000000000..80c07f296 --- /dev/null +++ b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml @@ -0,0 +1,13 @@ +controller-metrics: + - name: loss + class: Loss +operations: + - name: customoperation + class: CustomOperationInvalidAction +controllers: + - name: loss-controller-custom-operation-invalid-action + triggers: + - on_log + rule: loss < 1.0 + operations: + - customoperation.should_ \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/data/trainercontroller/loss_invalid_metric.yaml new file mode 100644 index 000000000..f86de8f57 --- /dev/null +++ b/tests/data/trainercontroller/loss_invalid_metric.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: MissingMetricClass +controllers: + - name: loss-controller-invalid-metric + triggers: + - on_log + rule: loss < 1.0 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/data/trainercontroller/loss_invalid_operation.yaml new file mode 100644 index 000000000..65aaff263 --- /dev/null +++ b/tests/data/trainercontroller/loss_invalid_operation.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-invalid-operation + triggers: + - on_log + rule: loss < 1.0 + operations: + - missingop.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/data/trainercontroller/loss_invalid_operation_action.yaml new file mode 100644 index 000000000..6f72b65ea --- /dev/null +++ b/tests/data/trainercontroller/loss_invalid_operation_action.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-invalid-operation-action + triggers: + - on_log + rule: loss < 1.0 + operations: + - hfcontrols.missingaction \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/data/trainercontroller/loss_invalid_trigger.yaml new file mode 100644 index 000000000..5e509cbb9 --- /dev/null +++ b/tests/data/trainercontroller/loss_invalid_trigger.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-invalid-trigger + triggers: + - log_it_all_incorrect_trigger_name + rule: loss < 1.0 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/data/trainercontroller/loss_on_threshold.yaml new file mode 100644 index 000000000..dd272d21c --- /dev/null +++ b/tests/data/trainercontroller/loss_on_threshold.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller + triggers: + - on_log + rule: loss < 1.0 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml new file mode 100644 index 000000000..c40bb58b2 --- /dev/null +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -0,0 +1,12 @@ +controller-metrics: + - name: state + class: TrainingState + - name: loss + class: Loss +controllers: + - name: loss-controller + triggers: + - on_log + rule: loss < 2 and state["epoch"] >= 0.5 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml new file mode 100644 index 000000000..a2bd9e303 --- /dev/null +++ b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-wrong-os-rule + triggers: + - on_log + rule: "2+2" + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml new file mode 100644 index 000000000..a466675f6 --- /dev/null +++ b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-wrong-input-rule + triggers: + - on_log + rule: input('Please enter your password:') + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml new file mode 100644 index 000000000..3c32e61df --- /dev/null +++ b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-wrong-os-rule + triggers: + - on_log + rule: __import__('os').system('clear') + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/twitter_complaints_small.json b/tests/data/twitter_complaints_small.json new file mode 100644 index 000000000..eb203d10d --- /dev/null +++ b/tests/data/twitter_complaints_small.json @@ -0,0 +1,10 @@ +{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"} +{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"} +{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"} +{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"} +{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"} +{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"} +{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"} +{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"} +{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"} +{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"} diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 000000000..a88ae3ef8 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,47 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +import transformers + +# Local +from tuning.config import configs, peft_config + + +def causal_lm_train_kwargs(train_kwargs): + """Parse the kwargs for a valid train call to a Causal LM.""" + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + ) + ) + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + ) = parser.parse_dict(train_kwargs, allow_extra_keys=True) + return ( + model_args, + data_args, + training_args, + lora_config + if train_kwargs.get("peft_method") == "lora" + else prompt_tuning_config, + ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py new file mode 100644 index 000000000..c23c7e2c5 --- /dev/null +++ b/tests/test_sft_trainer.py @@ -0,0 +1,509 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit Tests for SFT Trainer. +""" + +# Standard +import copy +import json +import os +import tempfile + +# Third Party +from datasets.exceptions import DatasetGenerationError +import pytest +import torch +import transformers + +# First Party +from scripts.run_inference import TunedCausalLM +from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA +from tests.helpers import causal_lm_train_kwargs + +# Local +from tuning import sft_trainer + +MODEL_NAME = "Maykeye/TinyLLama-v0" +BASE_PEFT_KWARGS = { + "model_name_or_path": MODEL_NAME, + "training_data_path": TWITTER_COMPLAINTS_DATA, + "num_train_epochs": 5, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "learning_rate": 0.00001, + "weight_decay": 0, + "warmup_ratio": 0.03, + "lr_scheduler_type": "cosine", + "logging_steps": 1, + "include_tokens_per_second": True, + "packing": False, + "response_template": "\n### Label:", + "dataset_text_field": "output", + "use_flash_attn": False, + "torch_dtype": "float32", + "max_seq_length": 4096, + "peft_method": "pt", + "prompt_tuning_init": "RANDOM", + "num_virtual_tokens": 8, + "prompt_tuning_init_text": "hello", + "tokenizer_name_or_path": MODEL_NAME, + "save_strategy": "epoch", + "output_dir": "tmp", +} + +BASE_LORA_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) +BASE_LORA_KWARGS["peft_method"] = "lora" + +BASE_FT_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) +BASE_FT_KWARGS["peft_method"] = "" +BASE_FT_KWARGS["prompt_tuning_init"] = "" +BASE_FT_KWARGS["prompt_tuning_init_text"] = "" + + +def test_helper_causal_lm_train_kwargs(): + """Check happy path kwargs passed and parsed properly.""" + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + BASE_PEFT_KWARGS + ) + + assert model_args.model_name_or_path == MODEL_NAME + assert model_args.use_flash_attn is False + assert model_args.torch_dtype == "float32" + + assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA + assert data_args.response_template == "\n### Label:" + assert data_args.dataset_text_field == "output" + + assert training_args.num_train_epochs == 5 + assert training_args.max_seq_length == 4096 + assert training_args.save_strategy == "epoch" + + assert tune_config.prompt_tuning_init == "RANDOM" + assert tune_config.prompt_tuning_init_text == "hello" + assert tune_config.tokenizer_name_or_path == MODEL_NAME + assert tune_config.num_virtual_tokens == 8 + + +def test_run_train_requires_output_dir(): + """Check fails when output dir not provided.""" + updated_output_dir = copy.deepcopy(BASE_PEFT_KWARGS) + updated_output_dir["output_dir"] = None + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + updated_output_dir + ) + with pytest.raises(TypeError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_run_train_fails_training_data_path_not_exist(): + """Check fails when data path not found.""" + updated_output_path = copy.deepcopy(BASE_PEFT_KWARGS) + updated_output_path["training_data_path"] = "fake/path" + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + updated_output_path + ) + with pytest.raises(FileNotFoundError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +############################# Prompt Tuning Tests ############################# + + +def test_run_causallm_pt_and_inference(): + """Check if we can bootstrap and peft tune causallm models""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}} + + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_PEFT_KWARGS) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + +def test_run_causallm_pt_init_text(): + """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"output_dir": tempdir, "prompt_tuning_init": "TEXT"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS) + + +invalid_params_map = [ + ("num_train_epochs", 0, "num_train_epochs has to be an integer/float >= 1"), + ( + "gradient_accumulation_steps", + 0, + "gradient_accumulation_steps has to be an integer >= 1", + ), +] + + +@pytest.mark.parametrize( + "param_name,param_val,exc_msg", + invalid_params_map, + ids=["num_train_epochs", "grad_acc_steps"], +) +def test_run_causallm_pt_invalid_params(param_name, param_val, exc_msg): + """Check if error is raised when invalid params are used to peft tune causallm models""" + with tempfile.TemporaryDirectory() as tempdir: + invalid_params = copy.deepcopy(BASE_PEFT_KWARGS) + invalid_params["output_dir"] = tempdir + invalid_params[param_name] = param_val + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + invalid_params + ) + + with pytest.raises(ValueError, match=exc_msg): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_run_causallm_pt_with_validation(): + """Check if we can bootstrap and peft tune causallm models with validation dataset""" + with tempfile.TemporaryDirectory() as tempdir: + validation_peft = copy.deepcopy(BASE_PEFT_KWARGS) + validation_peft["output_dir"] = tempdir + validation_peft["validation_data_path"] = TWITTER_COMPLAINTS_DATA + validation_peft["evaluation_strategy"] = "epoch" + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + validation_peft + ) + + assert data_args.validation_data_path == TWITTER_COMPLAINTS_DATA + + sft_trainer.train(model_args, data_args, training_args, tune_config) + _validate_training(tempdir, check_eval=True) + + +############################# Lora Tests ############################# + +target_modules_val_map = [ + (None, ["q_proj", "v_proj"]), + ( + ["q_proj", "k_proj", "v_proj", "o_proj"], + ["q_proj", "k_proj", "v_proj", "o_proj"], + ), + ( + ["all-linear"], + ["o_proj", "q_proj", "gate_proj", "down_proj", "k_proj", "up_proj", "v_proj"], + ), +] + + +@pytest.mark.parametrize( + "target_modules,expected", + target_modules_val_map, + ids=["default", "custom_target_modules", "all_linear_target_modules"], +) +def test_run_causallm_lora_and_inference(request, target_modules, expected): + """Check if we can bootstrap and lora tune causallm models""" + with tempfile.TemporaryDirectory() as tempdir: + base_lora_kwargs = copy.deepcopy(BASE_LORA_KWARGS) + base_lora_kwargs["output_dir"] = tempdir + if "default" not in request._pyfuncitem.callspec.id: + base_lora_kwargs["target_modules"] = target_modules + + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + base_lora_kwargs + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + + # validate lora tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA", base_lora_kwargs) + + for module in expected: + assert module in adapter_config.get("target_modules") + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "Simply put, the theory of relativity states that ", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "Simply put, the theory of relativity states that" in output_inference + + +############################# Finetuning Tests ############################# + + +def test_run_causallm_ft_and_inference(): + """Check if we can bootstrap and finetune tune causallm models""" + with tempfile.TemporaryDirectory() as tempdir: + BASE_FT_KWARGS["output_dir"] = tempdir + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + BASE_FT_KWARGS + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + + # validate ft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_FT_KWARGS) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + +def _validate_training(tempdir, check_eval=False): + assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) + train_logs_file_path = "{}/training_logs.jsonl".format(tempdir) + train_log_contents = "" + with open(train_logs_file_path, encoding="utf-8") as f: + train_log_contents = f.read() + + assert os.path.exists(train_logs_file_path) is True + assert os.path.getsize(train_logs_file_path) > 0 + assert "training_loss" in train_log_contents + + if check_eval: + assert "validation_loss" in train_log_contents + + +def _get_checkpoint_path(dir_path): + return os.path.join(dir_path, "checkpoint-5") + + +def _get_adapter_config(dir_path): + with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f: + return json.load(f) + + +def _validate_adapter_config(adapter_config, peft_type, base_kwargs): + assert adapter_config.get("task_type") == "CAUSAL_LM" + assert adapter_config.get("peft_type") == peft_type + assert ( + ( + adapter_config.get("tokenizer_name_or_path") + == base_kwargs["tokenizer_name_or_path"] + ) + if peft_type == "PROMPT_TUNING" + else True + ) + + +### Tests for a variety of edge cases and potentially problematic cases; +# some of these test directly test validation within external dependencies +# and validate errors that we expect to get from them which might be unintuitive. +# In such cases, it would probably be best for us to handle these things directly +# for better error messages, etc. + +### Tests related to tokenizer configuration +def test_tokenizer_has_no_eos_token(): + """Ensure that if the model has no EOS token, it sets the default before formatting.""" + # This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir + # that we can then reload out of for the train call, and clean up afterwards. + tokenizer = transformers.AutoTokenizer.from_pretrained( + BASE_PEFT_KWARGS["model_name_or_path"] + ) + model = transformers.AutoModelForCausalLM.from_pretrained( + BASE_PEFT_KWARGS["model_name_or_path"] + ) + tokenizer.eos_token = None + with tempfile.TemporaryDirectory() as tempdir: + tokenizer.save_pretrained(tempdir) + model.save_pretrained(tempdir) + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"model_name_or_path": tempdir, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # If we handled this badly, we would probably get something like a + # TypeError: can only concatenate str (not "NoneType") to str error + # when we go to apply the data formatter. + sft_trainer.train(model_args, data_args, training_args, tune_config) + _validate_training(tempdir) + + +### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't +def test_invalid_dataset_text_field(): + """Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(KeyError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) +def test_malformatted_data(): + """Ensure that malformatted data explodes due to failure to generate the dataset.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(DatasetGenerationError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_empty_data(): + """Ensure that malformatted data explodes due to failure to generate the dataset.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(DatasetGenerationError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_data_path_is_a_directory(): + """Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": tempdir, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # Confusingly, if we pass a directory for our data path, it will throw a + # FileNotFoundError saying "unable to find ''", since it can't + # find a matchable file in the path. + with pytest.raises(FileNotFoundError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for bad tuning module configurations +def test_run_causallm_lora_with_invalid_modules(): + """Check that we throw a value error if the target modules for lora don't exist.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"peft_method": "lora", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them + tune_config.target_modules = ["foo", "bar"] + # Peft should throw a value error about modules not matching the base module + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Direct validation tests based on whether or not packing is enabled +def test_no_packing_needs_dataset_text_field(): + """Ensure we need to set the dataset text field if packing is False""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"dataset_text_field": None, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +# TODO: Fix this case +@pytest.mark.skip(reason="currently crashes before validation is done") +def test_no_packing_needs_reponse_template(): + """Ensure we need to set the response template if packing is False""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"response_template": None, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for model dtype edge cases +@pytest.mark.skipif( + not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), + reason="Only runs if bf16 is unsupported", +) +def test_bf16_still_tunes_if_unsupported(): + """Ensure that even if bf16 is not supported, tuning still works without problems.""" + assert not torch.cuda.is_bf16_supported() + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"torch_dtype": "bfloat16", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + _validate_training(tempdir) + + +def test_bad_torch_dtype(): + """Ensure that specifying an invalid torch dtype yields a ValueError.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"torch_dtype": "not a type", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) diff --git a/tests/trainercontroller/custom_metric.py b/tests/trainercontroller/custom_metric.py new file mode 100644 index 000000000..83b6acc53 --- /dev/null +++ b/tests/trainercontroller/custom_metric.py @@ -0,0 +1,59 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from dataclasses import dataclass +from typing import Any + +# Third Party +from transformers import TrainerState +import pytest + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + + +class CustomMetric(MetricHandler): + """Implements a custom metric for testing""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__(events=["on_log"], **kwargs) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, state: TrainerState = None, **kwargs) -> Any: + """Just returns True (for testing purposes only). + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + Any. The exposed variables are returned here. + """ + return True diff --git a/tests/trainercontroller/custom_operation.py b/tests/trainercontroller/custom_operation.py new file mode 100644 index 000000000..b09ff91de --- /dev/null +++ b/tests/trainercontroller/custom_operation.py @@ -0,0 +1,47 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from dataclasses import dataclass +from typing import Any + +# Third Party +from transformers import TrainerControl, TrainerState +import pytest + +# Local +from tuning.trainercontroller.operations import Operation + + +class CustomOperation(Operation): + """Implements a custom operation for testing""" + + def __init__(self, **kwargs): + """Initializes the custom operation class. + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__() + + def should_perform_action_xyz(self, control: TrainerControl, **kwargs): + """This method performs a set training stop flag action. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + control.should_training_stop = True diff --git a/tests/trainercontroller/custom_operation_invalid_action.py b/tests/trainercontroller/custom_operation_invalid_action.py new file mode 100644 index 000000000..29b447bef --- /dev/null +++ b/tests/trainercontroller/custom_operation_invalid_action.py @@ -0,0 +1,47 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from dataclasses import dataclass +from typing import Any + +# Third Party +from transformers import TrainerControl, TrainerState +import pytest + +# Local +from tuning.trainercontroller.operations import Operation + + +class CustomOperationInvalidAction(Operation): + """Implements a custom operation for testing""" + + def __init__(self, **kwargs): + """Initializes the custom operation class. + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__() + + def should_(self, control: TrainerControl, **kwargs): + """This method defines an action within an invalid name. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + control.should_training_stop = True diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py new file mode 100644 index 000000000..c572a9c3f --- /dev/null +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -0,0 +1,354 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from dataclasses import dataclass + +# Third Party +from simpleeval import FunctionNotDefined +from transformers import IntervalStrategy, TrainerControl, TrainerState +import pytest + +# First Party +from tests.trainercontroller.custom_metric import CustomMetric +from tests.trainercontroller.custom_operation import CustomOperation +from tests.trainercontroller.custom_operation_invalid_action import ( + CustomOperationInvalidAction, +) +import tests.data.trainercontroller as td + +# Local +import tuning.config.configs as config +import tuning.trainercontroller as tc + + +@dataclass +class InputData: + """Stores the operation handler instance and corresponding action""" + + args: config.TrainingArguments + state: TrainerState + + +def _setup_data() -> InputData: + """ + Sets up the test data for the test cases. This includes the logs, arguments for training and state + of the training. + + Returns: + InputData. + """ + # Test data to mimic the fields of trainer loop log-lines + # trainer arguments and the initial state + return InputData( + args=config.TrainingArguments( + output_dir="", + logging_strategy=IntervalStrategy.STEPS, + logging_steps=1, + ), + state=TrainerState( + log_history=[ + {"loss": 2.0, "epoch": 0.1}, + {"loss": 2.1, "epoch": 0.25}, + {"loss": 1.3, "epoch": 0.5}, + {"loss": 0.9, "epoch": 0.6}, + ], + epoch=0.6, + ), + ) + + +def test_loss_on_threshold(): + """Tests the loss threshold example in + `examples/trainer-controller-configs/loss_on_threshold.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert control.should_training_stop == True + + +def test_loss_on_threshold_with_trainer_state(): + """Tests the loss threshold with trainer state example in + `examples/trainer-controller-configs/loss_on_threshold_with_trainer_state.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_WITH_TRAINER_STATE_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + + +def test_exposed_metrics(): + """Tests the expose metric scenario example in + `examples/trainer-controller-configs/exposed_metrics.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_EXPOSED_METRICS_YAML) + control = TrainerControl(should_training_stop=False) + metrics = {"eval_loss": 2.2} + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_evaluate( + args=test_data.args, state=test_data.state, control=control, metrics=metrics + ) + assert control.should_training_stop == True + + +def test_incorrect_source_event_exposed_metrics(): + """Tests the expose metric scenario example in + `examples/trainer-controller-configs/incorrect_source_event_exposed_metrics.yaml` + """ + with pytest.raises(ValueError) as exception_handler: + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_INCORRECT_SOURCE_EVENT_EXPOSED_METRICS_YAML + ) + control = TrainerControl(should_training_stop=False) + metrics = {"eval_loss": 2.2} + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_evaluate( + args=test_data.args, state=test_data.state, control=control, metrics=metrics + ) + assert ( + str(exception_handler.value).strip("'") + == "Specified source event [on_incorrect_event] is invalid for EvalMetrics" + ) + assert control.should_training_stop == True + + +def test_custom_metric_handler(): + """Tests the custom metric registration + `examples/trainer-controller-configs/loss_custom_metric.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_CUSTOM_METRIC_YAML + ) + tc_callback.register_metric_handlers([CustomMetric]) + control = TrainerControl() + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert control.should_training_stop == True + + +def test_custom_operation_handler(): + """Tests the custom operation registration + `examples/trainer-controller-configs/loss_custom_operation.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_CUSTOM_OPERATION_YAML + ) + tc_callback.register_operation_handlers([CustomOperation]) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert control.should_training_stop == True + + +def test_custom_operation_invalid_action_handler(): + """Tests the registration of custom operation with an invalid action. Uses: + `examples/trainer-controller-configs/loss_custom_operation_invalid_action.yaml` + """ + test_data = _setup_data() + with pytest.raises(KeyError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_CUSTOM_OPERATION_INVALID_ACTION_YAML + ) + tc_callback.register_operation_handlers([CustomOperationInvalidAction]) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value).strip("'") + == "Invalid operation customoperation.should_ for control loss-controller-custom-operation-invalid-action" + ) + + +def test_invalid_type_rule(): + """Tests the invalid type rule using configuration + `examples/trainer-controller-configs/loss_with_invalid_type_rule.yaml` + """ + test_data = _setup_data() + with pytest.raises(TypeError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert str(exception_handler.value) == "Rule failed due to incorrect type usage" + + +def test_malicious_os_rule(): + """Tests the malicious rule using configuration + `examples/trainer-controller-configs/loss_with_malicious_os_rule.yaml` + """ + test_data = _setup_data() + with pytest.raises(ValueError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_MALICIOUS_OS_RULE_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value) + == "Rule for control loss-controller-wrong-os-rule is invalid" + ) + + +def test_malicious_input_rule(): + """Tests the malicious rule using configuration + `examples/trainer-controller-configs/loss_with_malicious_input_rule.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_MALICIOUS_INPUT_RULE_YAML + ) + control = TrainerControl(should_training_stop=False) + with pytest.raises(FunctionNotDefined) as exception_handler: + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value) + == "Function 'input' not defined, for expression 'input('Please enter your password:')'." + ) + + +def test_invalid_trigger(): + """Tests the invalid trigger scenario in the controller. Uses: + `examples/trainer-controller-configs/loss_invalid_trigger.yaml` + """ + test_data = _setup_data() + with pytest.raises(KeyError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_INVALID_TRIGGER_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value).strip("'") + == "Controller loss-controller-invalid-trigger has an invalid event (log_it_all_incorrect_trigger_name)" + ) + + +def test_invalid_operation(): + """Tests the invalid operation scenario in the controller. Uses: + `examples/trainer-controller-configs/loss_invalid_operation.yaml` + """ + test_data = _setup_data() + with pytest.raises(KeyError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_INVALID_OPERATION_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value).strip("'") + == "Invalid operation missingop.should_training_stop for control loss-controller-invalid-operation" + ) + + +def test_invalid_operation_action(): + """Tests the invalid operation action scenario in the controller. Uses: + `examples/trainer-controller-configs/loss_invalid_operation_action.yaml` + """ + test_data = _setup_data() + with pytest.raises(KeyError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_INVALID_OPERATION_ACTION_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value).strip("'") + == "Invalid operation hfcontrols.missingaction for control loss-controller-invalid-operation-action" + ) + + +def test_invalid_metric(): + """Tests the invalid metric scenario in the controller. Uses: + `examples/trainer-controller-configs/loss_invalid_metric.yaml` + """ + test_data = _setup_data() + with pytest.raises(KeyError) as exception_handler: + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_INVALID_METRIC_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.state, control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + assert ( + str(exception_handler.value).strip("'") + == "Undefined metric handler MissingMetricClass" + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/utils/test_data_type_utils.py b/tests/utils/test_data_type_utils.py new file mode 100644 index 000000000..51bb35b3e --- /dev/null +++ b/tests/utils/test_data_type_utils.py @@ -0,0 +1,49 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +import pytest +import torch + +# Local +from tuning.utils import data_type_utils + +dtype_dict = { + "bool": torch.bool, + "double": torch.double, + "float32": torch.float32, + "int64": torch.int64, + "long": torch.long, +} + + +def test_str_to_torch_dtype(): + for t in dtype_dict: + assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t) + + +def test_str_to_torch_dtype_exit(): + with pytest.raises(ValueError): + data_type_utils.str_to_torch_dtype("foo") + + +def test_get_torch_dtype(): + for t in dtype_dict: + # When passed a string, it gets converted to torch.dtype + assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t) + # When passed a torch.dtype, we get the same torch.dtype returned + assert data_type_utils.get_torch_dtype(dtype_dict.get(t)) == dtype_dict.get(t) diff --git a/tests/utils/test_evaluator.py b/tests/utils/test_evaluator.py new file mode 100644 index 000000000..9bb2e4fad --- /dev/null +++ b/tests/utils/test_evaluator.py @@ -0,0 +1,166 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Tuple + +# Third Party +import numpy as np +import pytest + +# Local +from tuning.utils.evaluator import get_evaluator + + +def test_mailicious_inputs_to_eval(): + """Tests the malicious rules""" + rules: list[Tuple[str, bool, str]] = [ + # Valid rules + ("", False, "flags['is_training'] == False"), + ("", False, "not flags['is_training']"), + ("", True, "-10 < loss"), + ("", True, "+1000 > loss"), + ("", True, "~1000 < loss"), + ("", True, "(10 + 10) < loss"), + ("", True, "(20 - 10) < loss"), + ("", True, "(20/10) < loss"), + ("", True, "(20 % 10) < loss"), + ("", False, "loss < 1.0"), + ("", False, "(loss < 1.0)"), + ("", False, "loss*loss < 1.0"), + ("", False, "loss*loss*loss < 1.0"), + ("", False, "(loss*loss)*loss < 1.0"), + ("", True, "int(''.join(['3', '4'])) < loss"), + ("", True, "loss < 9**9"), + ("", False, "loss < sqrt(xs[0]*xs[0] + xs[1]*xs[1])"), + ("", True, "len(xs) > 2"), + ("", True, "loss < abs(-100)"), + ("", True, "loss == flags.aaa.bbb[0].ccc"), + ("", True, "array3d[0][1][1] == 4"), + ("", True, "numpyarray[0][1][1] == 4"), + ( + "", + True, + "len(xs) == 4 and xs[0] == 1 and (xs[1] == 0 or xs[2] == 0) and xs[3] == 2", + ), + # Invalid rules + ( + "'aaa' is not defined for expression 'loss == aaa.bbb[0].ccc'", + False, + "loss == aaa.bbb[0].ccc", + ), + ("0", False, "loss == flags[0].ccc"), # KeyError + ( + "Attribute 'ddd' does not exist in expression 'loss == flags.ddd[0].ccc'", + False, + "loss == flags.ddd[0].ccc", + ), + ( + "Sorry, access to __attributes or func_ attributes is not available. (__class__)", + False, + "'x'.__class__", + ), + ( + "Lambda Functions not implemented", + False, + # Try to instantiate and call Quitter + "().__class__.__base__.__subclasses__()[141]('', '')()", + ), + ( + "Lambda Functions not implemented", + False, + # pylint: disable=line-too-long + "[x for x in ().__class__.__base__.__subclasses__() if x.__name__ == 'Quitter'][0]('', '')()", + ), + ( + "Function 'getattr' not defined, for expression 'getattr((), '__class__')'.", + False, + "getattr((), '__class__')", + ), + ( + "Function 'getattr' not defined, for expression 'getattr((), '_' '_class_' '_')'.", + False, + "getattr((), '_' '_class_' '_')", + ), + ( + "Sorry, I will not evalute something that long.", + False, + '["hello"]*10000000000', + ), + ( + "Sorry, I will not evalute something that long.", + False, + "'i want to break free'.split() * 9999999999", + ), + ( + "Lambda Functions not implemented", + False, + "(lambda x='i want to break free'.split(): x * 9999999999)()", + ), + ( + "Sorry, NamedExpr is not available in this evaluator", + False, + "(x := 'i want to break free'.split()) and (x * 9999999999)", + ), + ("Sorry! I don't want to evaluate 9 ** 387420489", False, "9**9**9**9"), + ( + "Function 'mymetric1' not defined, for expression 'mymetric1() > loss'.", + True, + "mymetric1() > loss", + ), + ( + "Function 'mymetric2' not defined, for expression 'mymetric2(loss) > loss'.", + True, + "mymetric2(loss) > loss", + ), + ] + metrics = { + "loss": 42.0, + "flags": {"is_training": True, "aaa": {"bbb": [{"ccc": 42.0}]}}, + "xs": [1, 0, 0, 2], + "array3d": [ + [ + [1, 2], + [3, 4], + ], + [ + [5, 6], + [7, 8], + ], + ], + "numpyarray": (np.arange(8).reshape((2, 2, 2)) + 1), + } + + evaluator = get_evaluator(metrics=metrics) + + for validation_error, expected_rule_is_true, rule in rules: + rule_parsed = evaluator.parse(expr=rule) + if validation_error == "": + actual_rule_is_true = evaluator.eval( + expr=rule, + previously_parsed=rule_parsed, + ) + assert ( + actual_rule_is_true == expected_rule_is_true + ), "failed to execute the rule" + else: + with pytest.raises(Exception) as exception_handler: + evaluator.eval( + expr=rule, + previously_parsed=rule_parsed, + ) + assert str(exception_handler.value) == validation_error diff --git a/tox.ini b/tox.ini index bbcbba9b0..14dd1d715 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,39 @@ [tox] -envlist = lint, fmt +envlist = py, lint, fmt + +[testenv] +description = run unit tests +deps = + pytest>=7 +commands = + pytest {posargs:tests} [testenv:fmt] description = format with pre-commit +deps = + pre-commit commands = ./scripts/fmt.sh allowlist_externals = ./scripts/fmt.sh [testenv:lint] description = lint with pylint -commands = pylint tuning scripts/*.py +deps = + pylint>=2.16.2,<=3.1.0 + pytest + .[dev] +commands = pylint tuning scripts/*.py build/*.py tests allowlist_externals = pylint + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True diff --git a/tuning/__init__.py b/tuning/__init__.py index e69de29bb..38a9531ef 100644 --- a/tuning/__init__.py +++ b/tuning/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py deleted file mode 100644 index 6ee617a42..000000000 --- a/tuning/aim_loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Standard -import os - -# Third Party -from aim.hugging_face import AimCallback - - -def get_aimstack_callback(): - # Initialize a new run - aim_server = os.environ.get("AIMSTACK_SERVER") - aim_db = os.environ.get("AIMSTACK_DB") - aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") - if aim_experiment is None: - aim_experiment = "" - - if aim_server: - aim_callback = AimCallback( - repo="aim://" + aim_server + "/", experiment=aim_experiment - ) - if aim_db: - aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) - else: - aim_callback = AimCallback(experiment=aim_experiment) - - return aim_callback diff --git a/tuning/config/__init__.py b/tuning/config/__init__.py index e69de29bb..38a9531ef 100644 --- a/tuning/config/__init__.py +++ b/tuning/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 0b0a8fb67..247652b7c 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -1,6 +1,20 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from dataclasses import dataclass, field -from typing import Dict, Optional, Union +from typing import List, Optional, Union # Third Party import torch @@ -28,7 +42,7 @@ class ModelArguments: @dataclass class DataArguments: - data_path: str = field( + training_data_path: str = field( default=None, metadata={"help": "Path to the training data in JSONL format."} ) response_template: str = field( @@ -47,13 +61,53 @@ class DataArguments: class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) # optim: str = field(default=DEFAULT_OPTIMIZER) - model_max_length: int = field( + max_seq_length: int = field( default=DEFAULT_CONTEXT_LENGTH, metadata={ - "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "help": "Maximum sequence length. Sequences will be right padded \ + (and possibly truncated)." }, ) packing: bool = field( default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) + save_strategy: str = field( + default="epoch", + metadata={ + "help": "The checkpoint save strategy to adopt during training. \ + Possible values are 'no'(no save is done during training), \ + 'epoch' (save is done at the end of each epoch), \ + 'steps' (save is done every `save_steps`)" + }, + ) + logging_strategy: str = field( + default="epoch", + metadata={ + "help": "The logging strategy to adopt during training. \ + Possible values are 'no'(no logging is done during training), \ + 'epoch' (logging is done at the end of each epoch), \ + 'steps' (logging is done every `logging_steps`)" + }, + ) + trackers: Optional[List[str.lower]] = field( + default_factory=lambda: ["file_logger"], + metadata={ + "help": "Experiment trackers to use.\n" + + "Available trackers are - file_logger(default), aim, none\n" + + "Requires additional configs, see tuning.configs/tracker_configs.py" + }, + ) + + +@dataclass +class TrainerControllerArguments: + trainer_controller_config_file: str = field( + default=None, + metadata={ + "help": ( + "Trainer controller configuration file (e.g trainercontroller_config.yaml) \ + in YAML format." + ) + }, + ) diff --git a/tuning/config/fsdp_config.json b/tuning/config/fsdp_config.json deleted file mode 100644 index cb96df45d..000000000 --- a/tuning/config/fsdp_config.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", - "fsdp_backward_prefetch_policy": "BACKWARD_PRE", - "fsdp_cpu_ram_efficient_loading": "False", - "fsdp_forward_prefetch": "True", - "fsdp_offload_params": "False", - "fsdp_state_dict_type": "SHARDED_STATE_DICT", - "fsdp_sync_module_states": "False", - "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", - "fsdp_use_orig_params": "True", - "activation_checkpointing": "True" -} \ No newline at end of file diff --git a/tuning/config/peft_config.py b/tuning/config/peft_config.py index a3d30c763..bbb48e608 100644 --- a/tuning/config/peft_config.py +++ b/tuning/config/peft_config.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from dataclasses import dataclass, field from typing import List @@ -5,13 +19,39 @@ @dataclass class LoraConfig: + """ + This is the configuration class to store the configuration of a [`LoraModel`]. + + Args: + r (`int`): + Lora attention dimension (the "rank"). + target_modules (List[str]]): + The names of the modules to apply the adapter to. \ + If this is specified, only the modules with the specified \ + names will be replaced. Please specify modules as per model architecture. \ + If the value is ["all-linear"], \ + then LORA selects all linear and Conv1D modules as per model architecture, \ + except for the output layer. + lora_alpha (`int`): + The alpha parameter for Lora scaling. + lora_dropout (`float`): + The dropout probability for Lora layers. + bias (`str`): + Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. \ + If 'all' or 'lora_only', the corresponding biases will be updated during training. \ + Be aware that this means that, even when disabling the adapters, the model \ + will not produce the same output as the base model would have without adaptation. + """ + r: int = 8 lora_alpha: int = 32 target_modules: List[str] = field( default_factory=lambda: ["q_proj", "v_proj"], metadata={ - "help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or " - 'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D ' + "help": "The names of the modules to apply LORA to. LORA selects modules which either \ + completely match or " + 'end with one of the strings. If the value is ["all-linear"], \ + then LORA selects all linear and Conv1D ' "modules except for the output layer." }, ) @@ -21,6 +61,21 @@ class LoraConfig: @dataclass class PromptTuningConfig: + """ + This is the configuration class for Prompt Tuning. + + Args: + prompt_tuning_init : str: The initialization of the prompt embedding. \ + Allowed values "TEXT" or "RANDOM". + prompt_tuning_init_text (`str`, *optional*): + The text to initialize the prompt embedding. \ + Only used if `prompt_tuning_init` is `TEXT`. + tokenizer_name_or_path (`str`, *optional*): + The name or path of the tokenizer. \ + Only used if `prompt_tuning_init` is `TEXT`. + num_virtual_tokens (`int`): The number of virtual tokens to use. + """ + prompt_tuning_init: str = "TEXT" num_virtual_tokens: int = 8 prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:" diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py new file mode 100644 index 000000000..e0b52cb30 --- /dev/null +++ b/tuning/config/tracker_configs.py @@ -0,0 +1,61 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass + + +@dataclass +class FileLoggingTrackerConfig: + training_logs_filename: str = "training_logs.jsonl" + + +@dataclass +class AimConfig: + # Name of the experiment + experiment: str = None + # aim_repo can point to a locally accessible directory + # or a remote repository hosted on a server. + # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, + # it designates a remote aim repo. + # Otherwise, 'repo' specifies the directory, with default of None meaning '.aim'. + # + # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html + # for documentation on Aim remote server tracking. + aim_repo: str = None + aim_remote_server_ip: str = None + aim_remote_server_port: int = None + aim_url: str = None + + def __post_init__(self): + if self.experiment is None: + self.experiment = "fms-hf-tuning" + + if ( + self.aim_remote_server_ip is not None + and self.aim_remote_server_port is not None + ): + self.aim_url = ( + "aim://" + + self.aim_remote_server_ip + + ":" + + self.aim_remote_server_port + + "/" + ) + + +@dataclass +class TrackerConfigFactory: + file_logger_config: FileLoggingTrackerConfig = None + aim_config: AimConfig = None diff --git a/tuning/data/__init__.py b/tuning/data/__init__.py index e69de29bb..38a9531ef 100644 --- a/tuning/data/__init__.py +++ b/tuning/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 3a8a288f3..7c314a187 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -1,17 +1,23 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard -from typing import Dict, Sequence -import copy -import json -import logging +from typing import Dict # Third Party -from torch.utils.data import Dataset -import torch import transformers -# Local -from tuning.config import configs - def tokenizer_and_embedding_resize( special_tokens_dict: Dict, diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index ce9e323e0..b307505c0 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,8 +1,21 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard -from datetime import datetime -from typing import Optional, Union +from typing import Dict, List, Optional, Union import json -import os +import time # Third Party from peft.utils.other import fsdp_auto_wrap_policy @@ -22,75 +35,32 @@ import transformers # Local -from tuning.aim_loader import get_aimstack_callback from tuning.config import configs, peft_config +from tuning.config.tracker_configs import ( + AimConfig, + FileLoggingTrackerConfig, + TrackerConfigFactory, +) from tuning.data import tokenizer_data_utils +from tuning.trackers.tracker_factory import get_tracker +from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -class PeftSavingCallback(TrainerCallback): - def on_save(self, args, state, control, **kwargs): - checkpoint_path = os.path.join( - args.output_dir, f"checkpoint-{state.global_step}" - ) - kwargs["model"].save_pretrained(checkpoint_path) - - if "pytorch_model.bin" in os.listdir(checkpoint_path): - os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - - -class FileLoggingCallback(TrainerCallback): - """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" - - def __init__(self, logger): - self.logger = logger - - def on_log(self, args, state, control, logs=None, **kwargs): - """Checks if this log contains keys of interest, e.g., loss, and if so, creates - train_loss.jsonl in the model output dir (if it doesn't already exist), - appends the subdict of the log & dumps the file. - """ - # All processes get the logs from this node; only update from process 0. - if not state.is_world_process_zero: - return - - # separate evaluation loss with train loss - log_file_path = os.path.join(args.output_dir, "train_loss.jsonl") - eval_log_file_path = os.path.join(args.output_dir, "eval_loss.jsonl") - if logs is not None and "loss" in logs and "epoch" in logs: - self._track_loss("loss", log_file_path, logs, state) - elif logs is not None and "eval_loss" in logs and "epoch" in logs: - self._track_loss("eval_loss", eval_log_file_path, logs, state) - - def _track_loss(self, loss_key, log_file, logs, state): - try: - # Take the subdict of the last log line; if any log_keys aren't part of this log - # object, assume this line is something else, e.g., train completion, and skip. - log_obj = { - "name": loss_key, - "data": { - "epoch": round(logs["epoch"], 2), - "step": state.global_step, - "value": logs[loss_key], - "timestamp": datetime.isoformat(datetime.now()), - }, - } - except KeyError: - return - - # append the current log to the jsonl file - with open(log_file, "a") as f: - f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") - - def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, train_args: configs.TrainingArguments, - peft_config: Optional[ + peft_config: Optional[ # pylint: disable=redefined-outer-name Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, + trainer_controller_args: configs.TrainerControllerArguments = None, + tracker_configs: Optional[TrackerConfigFactory] = TrackerConfigFactory( + file_logger_config=FileLoggingTrackerConfig() + ), + additional_callbacks: Optional[List[TrainerCallback]] = None, + exp_metadata: Optional[Dict] = None, ): """Call the SFTTrainer @@ -102,13 +72,23 @@ def train( peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer + trainer_control_args: configs.TrainerControllerArguments \ + for controlling the training loop using policy rules + tracker_configs: An instance of tuning.config.tracker_configs.TrackerConfigFactory \ + which represents the configuration for various trackers\ + Note, trackers need to be enabled to use this \ + for e.g. --tracker(s) aim \ + additional_callbacks: List of callbacks to attach with SFTtrainer,\ + besides those associated with experiment trackers \ + or TrainerControllers. Callbacks associated with \ + tracker with automatically be added. + exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. """ - run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 logger = logging.get_logger("sft_trainer") # Validate parameters - if (not isinstance(train_args.num_train_epochs, float)) or ( + if (not isinstance(train_args.num_train_epochs, (float, int))) or ( train_args.num_train_epochs <= 0 ): raise ValueError("num_train_epochs has to be an integer/float >= 1") @@ -117,32 +97,59 @@ def train( ): raise ValueError("gradient_accumulation_steps has to be an integer >= 1") - # make sure to unset FSDP args when running on single gpu - if not run_distributed: - train_args.fsdp = "" - train_args.fsdp_config = {"xla": False} - task_type = "CAUSAL_LM" + additional_metrics = {} + + # Initialize Trackers And Callbacks + trackers = [] + trainer_callbacks = [] + + if train_args.trackers is not None: + requested_trackers = set(train_args.trackers) + else: + requested_trackers = set() + + # Now initialize trackers one by one + for name in requested_trackers: + t = get_tracker(name, tracker_configs) + cb = t.get_hf_callback() + if cb is not None: + trainer_callbacks.append(cb) + trackers.append(t) + + # Now add trainer controller callbacks if requested + if (trainer_controller_args is not None) and ( + trainer_controller_args.trainer_controller_config_file is not None + ): + tc_callback = TrainerControllerCallback( + trainer_controller_args.trainer_controller_config_file + ) + trainer_callbacks.append(tc_callback) + + # Add any extra callback if passed by users + if additional_callbacks is not None: + trainer_callbacks.append(additional_callbacks) + + model_load_time = time.time() model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), - use_flash_attention_2=model_args.use_flash_attn, + attn_implementation="flash_attention_2" if model_args.use_flash_attn else None, ) - peft_config = get_hf_peft_config(task_type, peft_config) - - model.gradient_checkpointing_enable() - # TODO: Move these to a config as well tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True ) + # Calculate and save additional metrics to track later. + additional_metrics["model_load_time"] = time.time() - model_load_time + + peft_config = get_hf_peft_config(task_type, peft_config) + # TODO: understand if we need to hardcode these here or just use defaults in model - if isinstance(tokenizer, LlamaTokenizer) or isinstance( - tokenizer, LlamaTokenizerFast - ): + if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): tokenizer.add_special_tokens( { "bos_token": "", @@ -151,33 +158,34 @@ def train( "pad_token": "", } ) - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance( - tokenizer, GPT2Tokenizer - ): + elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): tokenizer.add_special_tokens( { "pad_token": "", } ) - """TODO: near term - how response template ids are parsed out needs to be cleaned. - The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found. - We will create issue to clean this out after we discuss data formats and collators we will support - """ + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. response_template_ids = tokenizer.encode( data_args.response_template, add_special_tokens=False )[2:] - # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length - # as in current main. We need to change name of this parameter we expose to users. - model_max_length = min(train_args.model_max_length, tokenizer.model_max_length) - logger.info(f"Model max length {model_max_length}") - if train_args.model_max_length > tokenizer.model_max_length: + + max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) + logger.info("Max sequence length is %s", max_seq_length) + if train_args.max_seq_length > tokenizer.model_max_length: logger.warning( - f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}" + "max_seq_length %s exceeds tokenizer.model_max_length \ + %s, using tokenizer.model_max_length %s", + train_args.max_seq_length, + tokenizer.model_max_length, + tokenizer.model_max_length, ) # TODO: we need to change this, perhaps follow what open instruct does? - special_tokens_dict = dict() + special_tokens_dict = {} if tokenizer.pad_token is None: logger.warning("PAD token set to default, missing in tokenizer") special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN @@ -199,55 +207,46 @@ def train( model=model, ) + # Configure the collator and validate args related to packing prior to formatting the dataset + if train_args.packing: + logger.info("Packing is set to True") + data_collator = None + packing = True + else: + logger.info("Packing is set to False") + if data_args.response_template is None: + # TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization + # We should do this validation up front, then do the encoding, then handle the collator + raise ValueError("Response template is None, needs to be set for training") + if data_args.dataset_text_field is None: + raise ValueError("Dataset_text_field is None, needs to be set for training") + data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + packing = False + # load the data by parsing JSON - # TODO: update arg from data_path to training_data_path since we also have validation_data_path - data_files = {"train": data_args.data_path} + data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: data_files["validation"] = data_args.validation_data_path - format_dataset = lambda example: { + format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token } json_dataset = datasets.load_dataset("json", data_files=data_files) formatted_train_dataset = json_dataset["train"].map(format_dataset) - logger.info(f"Training dataset length is {len(formatted_train_dataset)}") + logger.info("Training dataset length is %s", len(formatted_train_dataset)) formatted_validation_dataset = None if data_args.validation_data_path: formatted_validation_dataset = json_dataset["validation"].map(format_dataset) - logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}") - - aim_callback = get_aimstack_callback() - file_logger_callback = FileLoggingCallback(logger) - peft_saving_callback = PeftSavingCallback() - callbacks = [aim_callback, peft_saving_callback, file_logger_callback] - - if train_args.packing: - logger.info("Packing is set to True") - data_collator = None - packing = True - else: - logger.info("Packing is set to False") - if data_args.response_template is None: - logger.error( - "Error, response template is None, needs to be set for training" - ) - exit(-1) - - if data_args.dataset_text_field is None: - logger.error( - "Error, dataset_text_field is None, needs to be set for training" - ) - exit(-1) - - data_collator = DataCollatorForCompletionOnlyLM( - response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, + logger.info( + "Validation dataset length is %s", len(formatted_validation_dataset) ) - packing = False trainer = SFTTrainer( model=model, @@ -258,50 +257,112 @@ def train( data_collator=data_collator, dataset_text_field=data_args.dataset_text_field, args=train_args, - max_seq_length=model_max_length, - callbacks=callbacks, + max_seq_length=max_seq_length, + callbacks=trainer_callbacks, peft_config=peft_config, ) - if run_distributed and peft_config is not None: + # We track additional metrics and experiment metadata after trainer object creation + # this ensure that the process is not repeated multiple times for FSDP runs. + if trainer.is_world_process_zero(): + # Currently tracked only on process zero. + for tracker in trackers: + try: + for k, v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage="additional_metrics") + tracker.set_params(params=exp_metadata, name="experiment_metadata") + except ValueError as e: + logger.error( + "Exception while saving additional metrics and metadata %s", + repr(e), + ) + + if trainer.is_fsdp_enabled and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model ) trainer.train() -def main(**kwargs): +def main(**kwargs): # pylint: disable=unused-argument parser = transformers.HfArgumentParser( dataclass_types=( configs.ModelArguments, configs.DataArguments, configs.TrainingArguments, + configs.TrainerControllerArguments, peft_config.LoraConfig, peft_config.PromptTuningConfig, + FileLoggingTrackerConfig, + AimConfig, ) ) parser.add_argument( "--peft_method", type=str.lower, choices=["pt", "lora", None, "none"], - default="pt", + default="none", + ) + parser.add_argument( + "--exp_metadata", + type=str, + default=None, + help='Pass a json string representing K:V pairs to be associated\ + to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) ( model_args, data_args, training_args, + trainer_controller_args, lora_config, prompt_tuning_config, - peft_method, + file_logger_config, + aim_config, + additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) - if peft_method.peft_method == "lora": + + logger = logging.get_logger("__main__") + + peft_method = additional.peft_method + if peft_method == "lora": tune_config = lora_config - elif peft_method.peft_method == "pt": + elif peft_method == "pt": tune_config = prompt_tuning_config else: tune_config = None - train(model_args, data_args, training_args, tune_config) + + # extra metadata passed via client + metadata = None + if additional.exp_metadata is not None: + try: + metadata = json.loads(additional.exp_metadata) + if metadata is None or not isinstance(metadata, Dict): + logger.warning( + "metadata cannot be converted to simple k:v dict ignoring" + ) + metadata = None + except ValueError as e: + logger.error( + "failed while parsing extra metadata. pass a valid json %s", repr(e) + ) + + combined_tracker_configs = TrackerConfigFactory() + + combined_tracker_configs.file_logger_config = file_logger_config + combined_tracker_configs.aim_config = aim_config + + train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + trainer_controller_args=trainer_controller_args, + tracker_configs=combined_tracker_configs, + additional_callbacks=None, + exp_metadata=metadata, + ) if __name__ == "__main__": diff --git a/tuning/trackers/__init__.py b/tuning/trackers/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tuning/trackers/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py new file mode 100644 index 000000000..342983698 --- /dev/null +++ b/tuning/trackers/aimstack_tracker.py @@ -0,0 +1,83 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +from aim.hugging_face import AimCallback # pylint: disable=import-error +from transformers.utils import logging + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import AimConfig + + +class AimStackTracker(Tracker): + def __init__(self, tracker_config: AimConfig): + """ + Tracker which uses Aimstack to collect and store metrics. + """ + super().__init__(name="aim", tracker_config=tracker_config) + self.logger = logging.get_logger("aimstack_tracker") + + def get_hf_callback(self): + """ + Returns the aim.hugging_face.AimCallback object associated with this tracker. + """ + c = self.config + exp = c.experiment + url = c.aim_url + repo = c.aim_repo + + if url is not None: + aim_callback = AimCallback(repo=url, experiment=exp) + if repo: + aim_callback = AimCallback(repo=repo, experiment=exp) + else: + self.logger.warning( + "Aim tracker requested but repo or server is not specified. " + + "Please specify either aim repo or aim server ip and port for using Aim." + ) + aim_callback = None + + self.hf_callback = aim_callback + return self.hf_callback + + def track(self, metric, name, stage="additional_metrics"): + """ + Track any additional `metric` with `name` under Aimstack tracker. + Expects metric and name to not be None. + stage can be used to pass the metadata associated with metric, + like, training metric or eval metric or additional metric + """ + if metric is None or name is None: + self.logger.warning("Tracked metric value or name should not be None") + return + context = {"subset": stage} + callback = self.hf_callback + run = callback.experiment + if run is not None: + run.track(metric, name=name, context=context) + + def set_params(self, params, name="extra_params"): + """ + Attach any extra params with the run information stored in Aimstack tracker. + Expects params to be a dict of k:v pairs of parameters to store. + name represents the namespace under which parameters will be associated in Aim. + """ + if params is None: + return + callback = self.hf_callback + run = callback.experiment + if run is not None: + for key, value in params.items(): + run.set((name, key), value, strict=False) diff --git a/tuning/trackers/filelogging_tracker.py b/tuning/trackers/filelogging_tracker.py new file mode 100644 index 000000000..66934191f --- /dev/null +++ b/tuning/trackers/filelogging_tracker.py @@ -0,0 +1,88 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from datetime import datetime +import json +import os + +# Third Party +from transformers import TrainerCallback +from transformers.utils import logging + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import FileLoggingTrackerConfig + + +class FileLoggingCallback(TrainerCallback): + """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" + + training_logs_filename = "training_logs.jsonl" + + def __init__(self, logs_filename=None): + self.training_logs_filename = logs_filename + + def on_log(self, args, state, control, logs=None, **kwargs): + """Checks if this log contains keys of interest, e.g., loss, and if so, creates + training_logs.jsonl in the model output dir (if it doesn't already exist), + appends the subdict of the log & dumps the file. + """ + # All processes get the logs from this node; only update from process 0. + if not state.is_world_process_zero: + return + + log_file_path = os.path.join(args.output_dir, self.training_logs_filename) + if logs is not None and "loss" in logs and "epoch" in logs: + self._track_loss("loss", "training_loss", log_file_path, logs, state) + elif logs is not None and "eval_loss" in logs and "epoch" in logs: + self._track_loss("eval_loss", "validation_loss", log_file_path, logs, state) + + def _track_loss(self, loss_key, log_name, log_file, logs, state): + try: + # Take the subdict of the last log line; if any log_keys aren't part of this log + # object, assume this line is something else, e.g., train completion, and skip. + log_obj = { + "name": log_name, + "data": { + "epoch": round(logs["epoch"], 2), + "step": state.global_step, + "value": logs[loss_key], + "timestamp": datetime.isoformat(datetime.now()), + }, + } + except KeyError: + return + + # append the current log to the jsonl file + with open(log_file, "a", encoding="utf-8") as f: + f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") + + +class FileLoggingTracker(Tracker): + def __init__(self, tracker_config: FileLoggingTrackerConfig): + """ + Tracker which encodes callback to record metric, e.g., training loss + to a file in the checkpoint directory. + """ + super().__init__(name="file_logger", tracker_config=tracker_config) + self.logger = logging.get_logger("file_logging_tracker") + + def get_hf_callback(self): + """ + Returns the FileLoggingCallback object associated with this tracker. + """ + file = self.config.training_logs_filename + self.hf_callback = FileLoggingCallback(logs_filename=file) + return self.hf_callback diff --git a/tuning/trackers/tracker.py b/tuning/trackers/tracker.py new file mode 100644 index 000000000..9fb0ae94f --- /dev/null +++ b/tuning/trackers/tracker.py @@ -0,0 +1,40 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Generic Tracker API +class Tracker: + """ + Generic interface for a Tracker Object. + """ + + def __init__(self, name=None, tracker_config=None) -> None: + if tracker_config is not None: + self.config = tracker_config + if name is None: + self._name = "None" + else: + self._name = name + + # we use args here to denote any argument. + def get_hf_callback(self): + return None + + def track(self, metric, name, stage): + pass + + # Object passed here is supposed to be a KV object + # for the parameters to be associated with a run + def set_params(self, params, name): + pass diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py new file mode 100644 index 000000000..3ba127b7f --- /dev/null +++ b/tuning/trackers/tracker_factory.py @@ -0,0 +1,118 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import dataclasses + +# Third Party +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +# Local +from .filelogging_tracker import FileLoggingTracker +from .tracker import Tracker +from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory + +logger = logging.get_logger("tracker_factory") + +# Information about all registered trackers +AVAILABLE_TRACKERS = {} + +AIMSTACK_TRACKER_NAME = "aim" +FILE_LOGGING_TRACKER_NAME = "file_logger" + +# One time package check for list of external trackers. +_is_aim_available = _is_package_available("aim") + + +def _get_tracker_class(T, C): + return {"tracker": T, "config": C} + + +def _register_aim_tracker(): + # pylint: disable=import-outside-toplevel + if _is_aim_available: + # Local + from .aimstack_tracker import AimStackTracker + from tuning.config.tracker_configs import AimConfig + + AimTracker = _get_tracker_class(AimStackTracker, AimConfig) + + AVAILABLE_TRACKERS[AIMSTACK_TRACKER_NAME] = AimTracker + logger.info("Registered aimstack tracker") + else: + logger.info( + "Not registering Aimstack tracker due to unavailablity of package.\n" + "Please install aim if you intend to use it.\n" + "\t pip install aim" + ) + + +def _register_file_logging_tracker(): + FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) + AVAILABLE_TRACKERS[FILE_LOGGING_TRACKER_NAME] = FileTracker + logger.info("Registered file logging tracker") + + +# List of Available Trackers +# file_logger - Logs loss to a file +# aim - Aimstack Tracker +def _register_trackers(): + logger.info("Registering trackers") + if AIMSTACK_TRACKER_NAME not in AVAILABLE_TRACKERS: + _register_aim_tracker() + if FILE_LOGGING_TRACKER_NAME not in AVAILABLE_TRACKERS: + _register_file_logging_tracker() + + +def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory): + if tracker_configs is None: + return + c_name = name + "_config" + d = dataclasses.asdict(tracker_configs) + if c_name in d: + return d[c_name] + return + + +def get_tracker(name: str, tracker_configs: TrackerConfigFactory): + """ + Returns an instance of the tracker object based on the requested `name`. + Expects tracker config to be present as part of the TrackerConfigFactory + object passed as `tracker_configs` argument. + If a valid tracker config is not found this function tries tracker with + default config else returns an empty Tracker() + """ + if not AVAILABLE_TRACKERS: + # a one time step. + _register_trackers() + + if name in AVAILABLE_TRACKERS: + meta = AVAILABLE_TRACKERS[name] + C = meta["config"] + T = meta["tracker"] + + if tracker_configs is not None: + _conf = _get_tracker_config_by_name(name, tracker_configs) + if _conf is not None: + config = C(**_conf) + else: + config = C() + return T(config) + + logger.warning( + "Requested Tracker %s not found. Please check the argument before proceeding.", + name, + ) + return Tracker() diff --git a/tuning/trainercontroller/__init__.py b/tuning/trainercontroller/__init__.py new file mode 100644 index 000000000..151741ed8 --- /dev/null +++ b/tuning/trainercontroller/__init__.py @@ -0,0 +1,19 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Local +from .callback import TrainerControllerCallback diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py new file mode 100644 index 000000000..d30821f19 --- /dev/null +++ b/tuning/trainercontroller/callback.py @@ -0,0 +1,540 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from importlib import resources as impresources +from typing import Dict, List, Union +import inspect +import os +import re + +# Third Party +from simpleeval import EvalWithCompoundTypes, FeatureNotAvailable, NameNotDefined +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.utils import logging +import yaml + +# Local +from tuning.trainercontroller import controllermetrics, operations +from tuning.trainercontroller.control import Control, OperationAction +from tuning.trainercontroller.controllermetrics import ( + handlers as default_metric_handlers, +) +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler +from tuning.trainercontroller.operations import Operation +from tuning.trainercontroller.operations import ( + operation_handlers as default_operation_handlers, +) +from tuning.utils.evaluator import get_evaluator + +logger = logging.get_logger(__name__) + +# Configuration keys +CONTROLLER_METRICS_KEY = "controller-metrics" +OPERATIONS_KEY = "operations" +CONTROLLERS_KEY = "controllers" +ARGS_KEY = "arguments" + +CONTROLLER_NAME_KEY = "name" +CONTROLLER_CLASS_KEY = "class" +CONTROLLER_RULE_KEY = "rule" +CONTROLLER_TRIGGERS_KEY = "triggers" +CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY + + +# pylint: disable=too-many-instance-attributes +class TrainerControllerCallback(TrainerCallback): + """Implements the trainer loop control based + on trainer controller configuration file and metrics""" + + def __init__(self, trainer_controller_config: Union[dict, str]): + """Initializes the callback for trainer controller. + + Args: + trainer_controller_config: dict. Trainer controller configuration + """ + # Checks if the trainer control config is of string type, in which case, it \ + # is a file path for the configuration yaml. On the other hand, if it is a \ + # dictionary, then it the yaml directly passed as such. + if isinstance(trainer_controller_config, str): + if os.path.exists(trainer_controller_config): + with open(trainer_controller_config, "r", encoding="utf-8") as f: + self.trainer_controller_config: dict = yaml.safe_load(f) + if not isinstance(self.trainer_controller_config, dict): + raise TypeError( + "expected the trainer controller config YAML file" + "to contain a dictionary. actual type: %s" + % (type(self.trainer_controller_config)) + ) + else: + raise FileNotFoundError( + f"Trainer controller configuration \ + [{trainer_controller_config}] does NOT exist" + ) + else: + self.trainer_controller_config = trainer_controller_config + + if CONTROLLER_METRICS_KEY not in self.trainer_controller_config: + self.trainer_controller_config[CONTROLLER_METRICS_KEY] = [] + + if OPERATIONS_KEY not in self.trainer_controller_config: + self.trainer_controller_config[OPERATIONS_KEY] = [] + + # Initialize the list of metrics from default `metrics.yaml` in the \ + # controllermetric package. In addition, any metrics mentioned in \ + # the trainer controller config are added to this list. + default_metrics_config_yaml = ( + impresources.files(controllermetrics) / "metrics.yaml" + ) + with default_metrics_config_yaml.open("r") as f: + default_metrics_config = yaml.safe_load(f) + if ( + default_metrics_config is not None + and CONTROLLER_METRICS_KEY in default_metrics_config + and len(default_metrics_config[CONTROLLER_METRICS_KEY]) > 0 + ): + self_controller_metrics = self.trainer_controller_config[ + CONTROLLER_METRICS_KEY + ] + default_controller_metrics: list[dict] = default_metrics_config[ + CONTROLLER_METRICS_KEY + ] + for metric_obj in default_controller_metrics: + metric_name: str = metric_obj[CONTROLLER_NAME_KEY] + found = False + for self_controller_metric in self_controller_metrics: + if self_controller_metric[CONTROLLER_NAME_KEY] == metric_name: + found = True + break + if not found: + self_controller_metrics.append(metric_obj) + + # Initialize the list of operations from default `operations.yaml` \ + # in the operations package. In addition, any operations mentioned \ + # in the trainer controller config are added to this list. + default_operations_config_yaml = ( + impresources.files(operations) / "operations.yaml" + ) + with default_operations_config_yaml.open("r") as f: + default_operations_config = yaml.safe_load(f) + if ( + default_operations_config is not None + and OPERATIONS_KEY in default_operations_config + and len(default_operations_config[OPERATIONS_KEY]) > 0 + ): + self_controller_operations = self.trainer_controller_config[OPERATIONS_KEY] + default_controller_operations: list[dict] = default_operations_config[ + OPERATIONS_KEY + ] + for op_obj in default_controller_operations: + op_name: str = op_obj[CONTROLLER_NAME_KEY] + found = False + for self_controller_operation in self_controller_operations: + if self_controller_operation[CONTROLLER_NAME_KEY] == op_name: + found = True + break + if not found: + self_controller_operations.append(op_obj) + + # Load list of valid events for the trainercontroller callback + # These events are assumed to start with "on_" prefix (on_epoch_end(), on_step_end() etc) + self.valid_events = set() + for callback_method_name, _ in inspect.getmembers( + self, predicate=inspect.ismethod + ): + if re.search(r"^on_", callback_method_name) is not None: + self.valid_events.add(callback_method_name) + logger.debug("List of valid events %s", repr(self.valid_events)) + + # Handlers to trigger on each metric + self.metric_handlers: dict[str, type[MetricHandler]] = {} + self.metrics_on_event: dict[str, list[MetricHandler]] = {} + self.register_metric_handlers(default_metric_handlers) + + # Supported operations + self.operation_handlers: dict[str, type[Operation]] = {} + self.operation_actions = {} + self.register_operation_handlers(default_operation_handlers) + + # controls + self.control_actions_on_event: Dict[str, list[Control]] = {} + + # List of fields produced by the metrics + self.metrics = {} + + def register_metric_handlers(self, handlers: List[MetricHandler]): + """Registers the metric handlers + + Args: + handlers: List[MetricHandler]. List of handlers. + """ + for handler in handlers: + self.metric_handlers[handler.__name__] = handler + + def register_operation_handlers(self, operation_handlers: List[Operation]): + """Registers the operation handlers + + Args: + operation_handlers: List[Operation]. List of operation handlers. + """ + for operation_handler in operation_handlers: + self.operation_handlers[operation_handler.__name__] = operation_handler + + def _compute_metrics(self, event_name: str, **kwargs): + """Invokes the compute() for all the metrics registered for a given event. + + Args: + event_name: str. Event name. + """ + if event_name in self.metrics_on_event: + for m in self.metrics_on_event[event_name]: + self.metrics[m.get_name()] = m.compute(event_name=event_name, **kwargs) + + def _take_control_actions(self, event_name: str, **kwargs): + """Invokes the act() method for all the operations registered for a given event. + + Args: + event_name: str. Event name. + kwargs: List of arguments (key, value)-pairs. + """ + if event_name in self.control_actions_on_event: + evaluator = get_evaluator(metrics=self.metrics) + for control_action in self.control_actions_on_event[event_name]: + rule_succeeded = False + try: + rule_succeeded = evaluator.eval( + expr=control_action.rule_str, + previously_parsed=control_action.rule, + ) + if not isinstance(rule_succeeded, bool): + raise TypeError( + "expected the rule to evaluate to a boolean. actual type: %s" + % (type(rule_succeeded)) + ) + except TypeError as et: + raise TypeError("Rule failed due to incorrect type usage") from et + except ValueError as ev: + raise ValueError( + "Rule failed due to use of disallowed packages" + ) from ev + except NameError as en: + raise NameError( + "Rule failed due to use of disallowed variables" + ) from en + except NameNotDefined as en1: + raise NameError( + "Rule failed because some of the variables are not defined" + ) from en1 + except FeatureNotAvailable as ef: + raise NotImplementedError( + "Rule failed because it uses some unsupported features" + ) from ef + if rule_succeeded: + for operation_action in control_action.operation_actions: + logger.info( + "Taking %s action in %s", + operation_action.action, + control_action.name, + ) + operation_action.instance.act( + action=operation_action.action, + event_name=event_name, + **kwargs, + ) + + def _actions_on_event(self, event_name: str, **kwargs): + """Invokes all functions associated with an event. + + Args: + event_name: str. Event name. + kwargs: List of arguments (key, value)-pairs. + """ + self._compute_metrics(event_name, **kwargs) + self._take_control_actions(event_name, **kwargs) + + def _validate_rule(self, rule): + """Validates the rule to check if there are any import attempts + + Returns: + bool + """ + return re.search(r"__", rule) is None + + def on_init_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """This event gets the training arguments which is finally used by the trainer loop. \ + All metric and operation validation is performed here using these arguments. \ + Following this, validated metrics and operations instances are registered for use. + + Args: + args: TrainingArguments. Training arguments for the trainer loop. + state: TrainerState. Current trainer state. + control: TrainerControl. Trainer control object. + kwargs: List of arguments (key, value)-pairs. + """ + # Training arguments, state and controls are folded into kwargs \ + # to be passed off to handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + + # Check if there any metrics listed in the configuration + if ( + CONTROLLER_METRICS_KEY not in self.trainer_controller_config + or len(self.trainer_controller_config[CONTROLLER_METRICS_KEY]) == 0 + ): + logger.warning("Trainer controller config has no metrics.") + + # Metric handler validation and registration is performed here. + for metric_config in self.trainer_controller_config[CONTROLLER_METRICS_KEY]: + metric_name = metric_config[CONTROLLER_NAME_KEY] + # Get the metric class name from the config section. + metric_handler_name = metric_config[CONTROLLER_CLASS_KEY] + # Get the handler class using the metric class name. + if metric_handler_name not in self.metric_handlers: + raise KeyError(f"Undefined metric handler {metric_handler_name}") + metric_handler_class = self.metric_handlers[metric_handler_name] + # Get the metric handler class arguments specified in the config. + metric_args = metric_config[ARGS_KEY] if ARGS_KEY in metric_config else {} + # Metric handler instance is created here. + metric_handler = metric_handler_class( + name=metric_name, **metric_args, **kwargs + ) + # Add metric instances to the events. + for event_name in metric_handler.get_events(): + if event_name in self.valid_events: + if event_name not in self.metrics_on_event: + self.metrics_on_event[event_name] = [] + self.metrics_on_event[event_name].append(metric_handler) + else: + raise KeyError( + "Event name (%s) is not valid in metric %s" + % (event_name, metric_name) + ) + + # Check if there any operations listed in the configuration + if ( + OPERATIONS_KEY in self.trainer_controller_config + and len(self.trainer_controller_config[OPERATIONS_KEY]) > 0 + ): + # Operation handler validation and registration is performed here. + for operation_config in self.trainer_controller_config[OPERATIONS_KEY]: + operation_name = operation_config[CONTROLLER_NAME_KEY] + # Get the operation class name from the config section. + operation_handler_name = operation_config[CONTROLLER_CLASS_KEY] + # Get the handler class arguments using the operation class name. + operation_args = ( + operation_config[ARGS_KEY] if ARGS_KEY in operation_config else {} + ) + # Operation handler instance is created here. + operation_handler_class = self.operation_handlers[ + operation_handler_name + ] + operation_handler = operation_handler_class( + name=operation_name, **operation_args, **kwargs + ) + # Add operation action instances. + for action_name in operation_handler.get_actions(): + op_key = operation_name + "." + action_name + if op_key in self.operation_actions: + logger.warning( + "Trying to add the operation '%s' when it already exists, ignoring...", + op_key, + ) + continue + self.operation_actions[op_key] = OperationAction( + instance=operation_handler, action=action_name + ) + + # Initialize controllers with respect to events. + if CONTROLLERS_KEY in self.trainer_controller_config: + for controller in self.trainer_controller_config[CONTROLLERS_KEY]: + controller_name: str = controller[CONTROLLER_NAME_KEY] + controller_ops: list[str] = controller[CONTROLLER_OPERATIONS_KEY] + controller_rule: str = controller[CONTROLLER_RULE_KEY] + if not self._validate_rule(controller_rule): + raise ValueError( + "Rule for control %s is invalid" % (controller_name) + ) + for event_name in controller[CONTROLLER_TRIGGERS_KEY]: + if event_name not in self.valid_events: + raise KeyError( + "Controller %s has an invalid event (%s)" + % (controller_name, event_name) + ) + # Generates the byte-code for the rule from the trainer configuration + curr_rule = controller[CONTROLLER_RULE_KEY] + control = Control( + name=controller[CONTROLLER_NAME_KEY], + rule_str=curr_rule, + rule=EvalWithCompoundTypes.parse(expr=curr_rule), + operation_actions=[], + ) + for control_operation_name in controller_ops: + if control_operation_name not in self.operation_actions: + raise KeyError( + "Invalid operation %s for control %s" + % ( + control_operation_name, + controller_name, + ) + ) + control.operation_actions.append( + self.operation_actions[control_operation_name] + ) + if event_name not in self.control_actions_on_event: + self.control_actions_on_event[event_name] = [] + self.control_actions_on_event[event_name].append(control) + self._actions_on_event(event_name="on_init_end", **kwargs) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_step_end", **kwargs) + + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_epoch_begin", **kwargs) + + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_epoch_end", **kwargs) + + def on_prediction_step( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_prediction_step", **kwargs) + + def on_predict( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + metrics, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + kwargs["metrics"] = metrics + self._actions_on_event(event_name="on_predict", **kwargs) + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_log", **kwargs) + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_train_end", **kwargs) + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_train_begin", **kwargs) + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_evaluate", **kwargs) diff --git a/tuning/trainercontroller/control.py b/tuning/trainercontroller/control.py new file mode 100644 index 000000000..4c8b6a6d4 --- /dev/null +++ b/tuning/trainercontroller/control.py @@ -0,0 +1,42 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from dataclasses import dataclass +from typing import List, Optional +import ast + +# Local +from tuning.trainercontroller.operations import Operation + + +@dataclass +class OperationAction: + """Stores the operation handler instance and corresponding action""" + + instance: Operation + action: str + + +@dataclass +class Control: + """Stores the name of control, rule byte-code corresponding actions""" + + name: str + rule_str: str + rule: Optional[ast.AST] = None # stores the abstract syntax tree of the parsed rule + operation_actions: Optional[List[OperationAction]] = None diff --git a/tuning/trainercontroller/controllermetrics/__init__.py b/tuning/trainercontroller/controllermetrics/__init__.py new file mode 100644 index 000000000..1c0ffe59f --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/__init__.py @@ -0,0 +1,42 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Type + +# Local +from .eval_metrics import EvalMetrics +from .loss import Loss +from .trainingstate import TrainingState + +# List of metric handlers +handlers = [] + + +def register(cl: Type): + """Registers the list of metric handlers by adding to the handler list. + + Args: + cl: Class type of the handler + """ + handlers.append(cl) + + +# Register the default metric handlers in this package here +register(TrainingState) +register(EvalMetrics) +register(Loss) diff --git a/tuning/trainercontroller/controllermetrics/eval_metrics.py b/tuning/trainercontroller/controllermetrics/eval_metrics.py new file mode 100644 index 000000000..c3f140f97 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/eval_metrics.py @@ -0,0 +1,75 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any + +# Third Party +from transformers.utils import logging + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + +logger = logging.get_logger(__name__) + + +class EvalMetrics(MetricHandler): + """Implements the controller metric which exposes the evaluation metrics""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + source_events_to_check = {"on_evaluate", "on_predict"} + source_event = kwargs.get("source-event") + if source_event is None: + source_event = "on_evaluate" + elif source_event in source_events_to_check: + super().__init__( + events=[ + source_event, + ], + **kwargs, + ) + else: + raise ValueError( + "Specified source event [%s] is invalid for EvalMetrics" + % (source_event) + ) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, **kwargs) -> Any: + """Exposes the trainer state. + + Args: + kwargs: Remaining event arguments + + Returns: + dict. Trainer state as a dictionary + """ + return kwargs["metrics"] diff --git a/tuning/trainercontroller/controllermetrics/loss.py b/tuning/trainercontroller/controllermetrics/loss.py new file mode 100644 index 000000000..2fd450148 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/loss.py @@ -0,0 +1,64 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any + +# Third Party +from transformers import TrainerState + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + + +class Loss(MetricHandler): + """Implements the controller metric which evaluates loss-per-step""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__(events=["on_log"], **kwargs) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, state: TrainerState = None, **kwargs) -> Any: + """Exposes the latest step loss value in the log. + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + Any. The exposed variables are returned here. + """ + size_of_log_history = len(state.log_history) + for i in range(size_of_log_history - 1, -1, -1): + log = state.log_history[i] + if "loss" not in log: + continue + return float(log["loss"]) diff --git a/tuning/trainercontroller/controllermetrics/metrics.yaml b/tuning/trainercontroller/controllermetrics/metrics.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/tuning/trainercontroller/controllermetrics/metricshandler.py b/tuning/trainercontroller/controllermetrics/metricshandler.py new file mode 100644 index 000000000..1ea746662 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/metricshandler.py @@ -0,0 +1,87 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any, List +import abc + +# Third Party +from transformers import TrainingArguments + + +class MetricHandlerException(Exception): + """Initializes the metric handler exception class""" + + def __init__(self, name): + super().__init__(f"Metric handler {name} failed validation.") + + +class MetricHandler(metaclass=abc.ABCMeta): + """Base class for the controller-metrics""" + + def __init__(self, name: str, events: List[str], args: TrainingArguments, **kwargs): + """Initializes the metric handler base class + + Args: + name: str. Name of the metric handler + event: List[str]. List of events for with the metric computation has to be performed. + args: TrainingArguments. Training arguments. + kwargs: List of arguments (key, value)-pairs + """ + self._name = name + self._events = events + self.training_args = args + self.kwargs = kwargs + if not self.validate(): + raise MetricHandlerException(name) + + def get_name(self): + """Returns the name of the handler. + + Returns: + str + """ + return self._name + + def get_events(self): + """Returns the list of events for the metric. + + Returns: + str + """ + return self._events + + @abc.abstractmethod + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are compatible with + the computation of this metric, and log the errors, and return False when + the metric is incompatible with the configuration + + Returns: + bool + """ + + @abc.abstractmethod + def compute(self, **kwargs) -> Any: + """Computes the controller-metric returns the metric. + + Args: + kwargs: Remaining event arguments. List of arguments (key, value)-pairs. + + Returns: + Any + """ diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py new file mode 100644 index 000000000..59ab3638c --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -0,0 +1,74 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any +import dataclasses + +# Third Party +from transformers import TrainerState + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + + +class TrainingState(MetricHandler): + """Implements the controller metric which exposes the trainer state""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__( + events=[ + "on_init_end", + "on_step_end", + "on_epoch_begin", + "on_epoch_end", + "on_prediction_step", + "on_predict", + "on_log", + "on_train_end", + "on_train_begin", + "on_evaluate", + ], + **kwargs + ) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, state: TrainerState = None, **kwargs) -> Any: + """Exposes the trainer state. + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + dict. Trainer state as a dictionary + """ + return dataclasses.asdict(state) diff --git a/tuning/trainercontroller/operations/__init__.py b/tuning/trainercontroller/operations/__init__.py new file mode 100644 index 000000000..99456d7ec --- /dev/null +++ b/tuning/trainercontroller/operations/__init__.py @@ -0,0 +1,22 @@ +# Standard +from typing import Type + +# Local +from .hfcontrols import HFControls +from .operation import Operation + +# List of operation handlers +operation_handlers = [] + + +def register(cl: Type): + """Registers the list of operation handlers by adding to the handler list. + + Args: + cl: Class type of the handler + """ + operation_handlers.append(cl) + + +# Register the default operation handlers in this package here +register(HFControls) diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py new file mode 100644 index 000000000..2bba9a1d2 --- /dev/null +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -0,0 +1,45 @@ +# Standard +from dataclasses import fields +import inspect +import re + +# Third Party +from transformers import TrainerControl +from transformers.utils import logging + +# Local +from .operation import Operation + +logger = logging.get_logger(__name__) + + +class HFControls(Operation): + """Implements the control actions for the HuggingFace controls in + transformers.TrainerControl class.""" + + def __init__(self, **kwargs): + """Initializes the HuggingFace controls. In this init, the fields with `should_` of the + transformers.TrainerControl data class are extracted, and for each of those fields, the + control_action() method's pointer is set, and injected as a class member function. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + self.kwargs = kwargs + for control_field in fields(TrainerControl): + if re.search(r"^should_.+", control_field.name) is not None: + setattr(self, control_field.name, self.control_action) + super().__init__() + + def control_action(self, control: TrainerControl, **kwargs): + """This method peeks into the stack-frame of the caller to get the action the triggered + a call to it. Using the name of the action, the value of the control is set. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + logger.debug("Arguments passed to control_action: %s", repr(kwargs)) + frame_info = inspect.currentframe().f_back + arg_values = inspect.getargvalues(frame_info) + setattr(control, arg_values.locals["action"], True) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py new file mode 100644 index 000000000..916420e81 --- /dev/null +++ b/tuning/trainercontroller/operations/operation.py @@ -0,0 +1,46 @@ +# Standard +import abc +import inspect +import re + + +class Operation(metaclass=abc.ABCMeta): + """Base class for operations""" + + def __init__(self): + """Initializes the HuggingFace controls. In this init, we follow the convention that + every action should preceed with prefix `should_`. If so, it is treated as a valid + action. + """ + self.valid_actions = {} + for action_name, action_method in inspect.getmembers( + self, predicate=inspect.ismethod + ): + if re.search(r"^should_.+", action_name) is not None: + self.valid_actions[action_name] = action_method + + def validate(self, action: str) -> bool: + """Validates the action by checking if it valid action or not. + + Args: + action: str. String depicting the action. + + Returns: + bool. Indicates True if valid. If not, returns False. + """ + return action in self.valid_actions + + def act(self, action: str, **kwargs): + """Validates the action and invokes it. + + Args: + action: str. String depicting the action. + kwargs: List of arguments (key, value)-pairs. + """ + if not self.validate(action): + raise ValueError(f"Invalid operation {action}") + self.valid_actions[action](**kwargs) + + def get_actions(self) -> list[str]: + """Gets the list of all valid actions.""" + return self.valid_actions.keys() diff --git a/tuning/trainercontroller/operations/operations.yaml b/tuning/trainercontroller/operations/operations.yaml new file mode 100644 index 000000000..bbf6724e9 --- /dev/null +++ b/tuning/trainercontroller/operations/operations.yaml @@ -0,0 +1,3 @@ +operations: + - name: hfcontrols + class: HFControls diff --git a/tuning/utils/__init__.py b/tuning/utils/__init__.py index e69de29bb..38a9531ef 100644 --- a/tuning/utils/__init__.py +++ b/tuning/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 58896c1f9..fc7b7b46f 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from dataclasses import asdict diff --git a/tuning/utils/data_type_utils.py b/tuning/utils/data_type_utils.py index 42b058cde..cefebb100 100644 --- a/tuning/utils/data_type_utils.py +++ b/tuning/utils/data_type_utils.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from typing import Union @@ -22,7 +36,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype: dt = getattr(torch, dtype_str, None) if not isinstance(dt, torch.dtype): logger.error(" ValueError: Unrecognized data type of a torch.Tensor") - exit(-1) + raise ValueError("Unrecognized data type of a torch.Tensor") return dt diff --git a/tuning/utils/evaluator.py b/tuning/utils/evaluator.py new file mode 100644 index 000000000..42095e70c --- /dev/null +++ b/tuning/utils/evaluator.py @@ -0,0 +1,20 @@ +# Standard +from math import sqrt + +# Third Party +from simpleeval import DEFAULT_FUNCTIONS, DEFAULT_NAMES, EvalWithCompoundTypes + + +def get_evaluator(metrics: dict) -> EvalWithCompoundTypes: + """Returns an evaluator that can be used to evaluate simple Python expressions.""" + all_names = { + **metrics, + **DEFAULT_NAMES.copy(), + } + all_funcs = { + "abs": abs, + "len": len, + "sqrt": sqrt, + **DEFAULT_FUNCTIONS.copy(), + } + return EvalWithCompoundTypes(functions=all_funcs, names=all_names) diff --git a/tuning/utils/merge_model_utils.py b/tuning/utils/merge_model_utils.py index a8a41fecb..fc4f357a7 100644 --- a/tuning/utils/merge_model_utils.py +++ b/tuning/utils/merge_model_utils.py @@ -1,6 +1,19 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from typing import Union -import argparse import json import os @@ -27,7 +40,7 @@ def create_merged_model( References: - https://github.com/huggingface/peft/issues/1040 - https://github.com/huggingface/peft/issues/280#issuecomment-1500805831 - - https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter + - https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long Args: checkpoint_model: Union[str, list[str]] @@ -82,7 +95,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str: if not os.path.isfile(adapter_config): raise FileNotFoundError("Unable to locate adapter config to infer base model!") - with open(adapter_config, "r") as cfg: + with open(adapter_config, "r", encoding="utf-8") as cfg: adapter_dict = json.load(cfg) if "base_model_name_or_path" not in adapter_dict: raise KeyError(