diff --git a/.github/workflows/build-wheel.yml b/.github/workflows/build-wheel.yml new file mode 100644 index 00000000000..292d874dd2f --- /dev/null +++ b/.github/workflows/build-wheel.yml @@ -0,0 +1,58 @@ +name: Build PyPi Wheel +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - main + - 'release/[0-9]+.[0-9]+' + push: + branches: + - main + release: + types: [created, published] + schedule: + - cron: '0 0 * * *' + +permissions: + id-token: write + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +# if not dev or release, will create a nightly build +# everything is pushed to internal unless created through a nightly scheduled cron job which creates the build or +# missing release tag workflow/needs to be added in +env: + INTERNAL: ${{ github.event_name != 'schedule' && github.event_name != 'release'}} + RELEASE: ${{ github.event_name =='release' || (startsWith(github.base_ref, 'release/') && github.event_name == 'pull_request')}} + DEV: ${{ github.base_ref == 'main' && github.event_name == 'pull_request'}} + NAME: ${{ github.event.number }} + +jobs: + build_and_push: + runs-on: ubuntu-latest + outputs: + wheel: ${{ steps.push-wheel.outputs.wheel }} + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Login to s3 + uses: aws-actions/configure-aws-credentials@v2 + with: + role-to-assume: ${{ secrets.AWS_WEBIDENTITY_FOR_GITHUB_ACTIONS }} + aws-region: us-east-1 + - name: Build PyPi Wheel + id: build-wheel + uses: neuralmagic/nm-actions/actions/pypi_build@main + with: + dev: $DEV + release: $RELEASE + name: $NAME + - name: Push to s3 bucket + id: push-wheel + uses: neuralmagic/nm-actions/actions/s3_push@main + with: + filename: dist/*.whl + internal: $INTERNAL diff --git a/_scratch/mask.py b/_scratch/mask.py deleted file mode 100644 index 8d2cc6a1b69..00000000000 --- a/_scratch/mask.py +++ /dev/null @@ -1,72 +0,0 @@ -data = { - "instruction":"Improve the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.", - "output":"The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments.", - "prompt":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:", - "text":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments." -} - -prompt = "[fooo]:\nname me some cold blooded animals\n\n[Response]:\n" -bs = "jdlsaj;fdsa dls;ajfd;a dfjdaklfd jsak; fsajfiajeiga sdkga idgdhangia n" - -prompts = prompt + bs + prompt + bs - - -def mask_prompt(target: str): - if not target.startswith("["): - target = "[" + target - if not target.endswith("]"): - target += "]" - - -def generate_mask(string:str, prompt: str, censor: str): - mask = ['1'] * len(string) - is_prompt = True - counter = 0 - for i, char in enumerate(string): - if not is_prompt: - mask[i] = "0" - - if counter > 0: - if not is_prompt and char == prompt[counter]: - counter += 1 - elif is_prompt and char == censor[counter]: - counter += 1 - else: - counter = 0 - - if counter == len(prompt) and not is_prompt: - mask[i-counter+1:i+1] = ["1"] * counter - counter = 0 - is_prompt = True - - if counter == len(censor) and is_prompt: - mask[i-counter+1:i+1] = ["0"] * counter - counter = 0 - is_prompt = False - - if prompt.startswith(char) or censor.startswith(char): - counter = 1 - - return ''.join(mask) - -def filter_string(mask:str, string:str): - filtered_list = [] - current_chunk = '' - for char, m in zip(string, mask): - if m == '1': - current_chunk += char - elif current_chunk: - filtered_list.append(current_chunk) - current_chunk = '' - if current_chunk: - filtered_list.append(current_chunk) - return filtered_list - -# Example usage: -string = "[prompt] hey there\n\n[censor]hello back" -string += string -prompt = "[prompt]" -censor = "[censor]" -mask = generate_mask(string, prompt, censor) -filtered_list = filter_string(mask, string) - diff --git a/setup.py b/setup.py index ed8f609a991..c64705e68dd 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ # default variables to be overwritten by the version.py file is_release = None +is_dev = None version = "unknown" version_major_minor = version @@ -28,7 +29,12 @@ print(f"loaded version {version} from src/sparseml/version.py") version_nm_deps = f"{version_major_minor}.0" -_PACKAGE_NAME = "sparseml" if is_release else "sparseml-nightly" +if is_release: + _PACKAGE_NAME = "sparseml" +elif is_dev: + _PACKAGE_NAME = "sparseml-dev" +else: + _PACKAGE_NAME = "sparseml-nightly" _deps = [ "setuptools<=59.5.0", diff --git a/src/sparseml/transformers/finetune/runner.py b/src/sparseml/transformers/finetune/runner.py index cbcf9cea71f..093414a1647 100644 --- a/src/sparseml/transformers/finetune/runner.py +++ b/src/sparseml/transformers/finetune/runner.py @@ -40,7 +40,11 @@ ) from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.training_args import TrainingArguments -from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model +from sparseml.utils.fsdp.helpers import ( + find_and_move_state_dicts_to_cpu, + is_fsdp_model, + unwrap_and_export_model, +) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -175,6 +179,15 @@ def one_shot(self, stage: Optional[str] = None): output_dir=self._output_dir, tokenizer=self.tokenizer, ) + # only allow the main process move the state + # dicts to cpu + if self.trainer.accelerator.is_main_process: + # assuming quantization is the last step + # we no longer need the original model + # and can safely delete it to save memory + del self.trainer.model + find_and_move_state_dicts_to_cpu(self._output_dir) + else: save_model_and_recipe( model=self.trainer.model, diff --git a/src/sparseml/utils/fsdp/helpers.py b/src/sparseml/utils/fsdp/helpers.py index dccde10c37a..abae74f612c 100644 --- a/src/sparseml/utils/fsdp/helpers.py +++ b/src/sparseml/utils/fsdp/helpers.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import operator +from pathlib import Path from typing import Optional, Union @@ -25,6 +27,7 @@ except ImportError: FullyShardedDataParallel = None +import torch from torch.nn import Module from sparseml.core.model import ModifiableModel @@ -39,8 +42,11 @@ "unwrap_and_export_model", "save_pretrained_fsdp", "get_fsdp_parent", + "find_and_move_state_dicts_to_cpu", ] +_LOGGER = logging.getLogger(__name__) + def is_fsdp_model(model: Module) -> bool: """ @@ -113,6 +119,27 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): ) +def find_and_move_state_dicts_to_cpu(output_dir: str): + """ + Looks for state dicts in the output directory and overwrites them + with cpu state dicts. + + this is needed for quantized models trained with FSDP as the state dict + contains device information, which can cause issues when loading the model + using transformers AutoModel.from_pretrained(...) if the device information + is not removed, assumes the state dicts are named pytorch_model*.bin + """ + + for model_file in Path(output_dir).rglob("pytorch_model*.bin"): + loaded_dict = torch.load(model_file) + for key, value in loaded_dict.items(): + if isinstance(value, torch.Tensor): + loaded_dict[key] = value.cpu() + + torch.save(loaded_dict, model_file) + _LOGGER.info(f"Moved state dict {model_file} to cpu") + + def save_pretrained_fsdp(model, accelerator, output_dir, save_safetensors: bool = True): full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) """ diff --git a/src/sparseml/version.py b/src/sparseml/version.py index 33ac25d7a51..7a0706db10e 100644 --- a/src/sparseml/version.py +++ b/src/sparseml/version.py @@ -21,14 +21,17 @@ version_base = "1.7.0" is_release = False # change to True to set the generated version as a release version +is_dev = False +dev_number = None def _generate_version(): - return ( - version_base - if is_release - else f"{version_base}.{date.today().strftime('%Y%m%d')}" - ) + if is_release: + return version_base + elif is_dev: + return f"{version_base}.dev{dev_number}" + else: + return f"{version_base}.{date.today().strftime('%Y%m%d')}" __all__ = [