diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 41ea2d6f..99568a94 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -241,7 +241,7 @@ def requires_custom_loading(self): return True @property - def requires_agumentation(self): + def requires_augmentation(self): return True def augmentation( diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index b7202add..1900002d 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -185,7 +185,7 @@ def requires_custom_loading(self): return True @property - def requires_agumentation(self): + def requires_augmentation(self): # will skip the augmentation if _no_peft_model == True return not self._no_peft_model diff --git a/plugins/accelerated-peft/tests/test_peft_plugins.py b/plugins/accelerated-peft/tests/test_peft_plugins.py index 38534d5d..d36b3ce8 100644 --- a/plugins/accelerated-peft/tests/test_peft_plugins.py +++ b/plugins/accelerated-peft/tests/test_peft_plugins.py @@ -54,7 +54,7 @@ def test_configure_gptq_plugin(): # check flags and callbacks assert framework.requires_custom_loading - assert framework.requires_agumentation + assert framework.requires_augmentation assert len(framework.get_callbacks_and_ready_for_train()) == 0 # attempt to activate plugin with configuration pointing to wrong path @@ -171,7 +171,7 @@ def test_configure_bnb_plugin(): # check flags and callbacks assert framework.requires_custom_loading - assert framework.requires_agumentation + assert framework.requires_augmentation assert len(framework.get_callbacks_and_ready_for_train()) == 0 # test valid combinatinos @@ -187,7 +187,7 @@ def test_configure_bnb_plugin(): ): # check flags and callbacks assert framework.requires_custom_loading - assert framework.requires_agumentation + assert framework.requires_augmentation assert len(framework.get_callbacks_and_ready_for_train()) == 0 # test no_peft_model is true skips plugin.augmentation @@ -202,7 +202,7 @@ def test_configure_bnb_plugin(): require_packages_check=False, ): # check flags and callbacks - assert (not correct_value) == framework.requires_agumentation + assert (not correct_value) == framework.requires_augmentation # attempt to activate plugin with configuration pointing to wrong path # - raise with message that no plugins can be configured diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py index aa9134a6..391743c6 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py @@ -61,7 +61,7 @@ def __init__( assert self._pad_token_id is not None, "need to get pad token id" @property - def requires_agumentation(self): + def requires_augmentation(self): return True def augmentation( diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 0e4e5ef9..596b5600 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -41,7 +41,7 @@ def __init__(self, configurations: Dict[str, Dict]): ) @property - def requires_agumentation(self): + def requires_augmentation(self): return True def augmentation( diff --git a/plugins/framework/README.md b/plugins/framework/README.md index 5b3cbfd7..4895f322 100644 --- a/plugins/framework/README.md +++ b/plugins/framework/README.md @@ -45,7 +45,7 @@ model, (peft_config,) = framework.augmentation( ) ``` -We also provide `framework.requires_agumentation` to check if augumentation is required by the plugins. +We also provide `framework.requires_augmentation` to check if augumentation is required by the plugins. Finally pass the model to train: diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 3a393815..75b436c9 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -199,10 +199,10 @@ def augmentation( x in model_archs for x in plugin.restricted_model_archs ): raise ValueError( - f"Model architectures in '{model_archs}' are supported for '{plugin_name}'." + f"Model architectures in '{model_archs}' are not supported for '{plugin_name}'." ) - if plugin.requires_agumentation: + if plugin.requires_augmentation: model, modifiable_args = plugin.augmentation( model, train_args, modifiable_args=modifiable_args ) @@ -214,8 +214,8 @@ def requires_custom_loading(self): return len(self.plugins_require_custom_loading) > 0 @property - def requires_agumentation(self): - return any(x.requires_agumentation for _, x in self.active_plugins) + def requires_augmentation(self): + return any(x.requires_augmentation for _, x in self.active_plugins) def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 28fecebf..94ea4ffa 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -171,7 +171,7 @@ def requires_custom_loading(self): return False @property - def requires_agumentation(self): + def requires_augmentation(self): return False def model_loader(self, model_name: str, **kwargs): diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index b1f731d1..6a3bc123 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -159,8 +159,8 @@ def create_plugin_cls( restricted_models: Set = None, require_pkgs: Set = None, requires_custom_loading: bool = False, - requires_agumentation: bool = False, - agumentation: Callable = None, + requires_augmentation: bool = False, + augmentation: Callable = None, model_loader: Callable = None, ): "helper function to create plugin class" @@ -174,11 +174,11 @@ def create_plugin_cls( "restricted_model_archs": restricted_models, "require_packages": require_pkgs, "requires_custom_loading": requires_custom_loading, - "requires_agumentation": requires_agumentation, + "requires_augmentation": requires_augmentation, } - if agumentation is not None: - attributes["augmentation"] = agumentation + if augmentation is not None: + attributes["augmentation"] = augmentation if model_loader is not None: attributes["model_loader"] = model_loader diff --git a/plugins/framework/tests/test_framework.py b/plugins/framework/tests/test_framework.py index 4fd43eb2..b3f4eb9e 100644 --- a/plugins/framework/tests/test_framework.py +++ b/plugins/framework/tests/test_framework.py @@ -68,7 +68,7 @@ def test_model_with_no_config_raises(): # create model and (incomplete) plugin with requires_augmentation = True model_no_config = torch.nn.Module() # empty model - incomplete_plugin = create_plugin_cls(requires_agumentation=True) + incomplete_plugin = create_plugin_cls(requires_augmentation=True) # register and activate 1 incomplete plugin, and: # 1. test correct plugin registration and activation. @@ -104,13 +104,13 @@ def test_single_plugin(): empty_plugin = create_plugin_cls() incomplete_plugin = create_plugin_cls( restricted_models={"CausalLM"}, - requires_agumentation=True, + requires_augmentation=True, ) plugin = create_plugin_cls( restricted_models={"CausalLM"}, - requires_agumentation=True, + requires_augmentation=True, requires_custom_loading=True, - agumentation=dummy_augmentation, + augmentation=dummy_augmentation, model_loader=dummy_custom_loader, ) train_args = None # dummy for now @@ -175,32 +175,32 @@ def test_two_plugins(): model = create_noop_model_with_archs(archs=["CausalLM"]) incomp_plugin1 = create_plugin_cls( - restricted_models={"CausalLM"}, requires_agumentation=True + restricted_models={"CausalLM"}, requires_augmentation=True ) - incomp_plugin2 = create_plugin_cls(requires_agumentation=True) + incomp_plugin2 = create_plugin_cls(requires_augmentation=True) incomp_plugin3 = create_plugin_cls( - class_name="PluginNoop2", requires_agumentation=True + class_name="PluginNoop2", requires_augmentation=True ) plugin1 = create_plugin_cls( restricted_models={"CausalLM"}, - requires_agumentation=True, + requires_augmentation=True, requires_custom_loading=True, - agumentation=dummy_augmentation, + augmentation=dummy_augmentation, model_loader=dummy_custom_loader, ) plugin2 = create_plugin_cls( class_name="PluginNoop2", restricted_models={"CausalLM"}, - requires_agumentation=True, + requires_augmentation=True, requires_custom_loading=True, - agumentation=dummy_augmentation, + augmentation=dummy_augmentation, model_loader=dummy_custom_loader, ) plugin3_no_loader = create_plugin_cls( class_name="PluginNoop2", restricted_models={"CausalLM"}, - requires_agumentation=True, - agumentation=dummy_augmentation, + requires_augmentation=True, + augmentation=dummy_augmentation, ) train_args = None # dummy for now @@ -299,8 +299,8 @@ def _hook( for class_name in ["PluginDEF", "PluginABC"]: plugin = create_plugin_cls( class_name=class_name, - requires_agumentation=True, - agumentation=hook_builder(act_order=plugin_activation_order), + requires_augmentation=True, + augmentation=hook_builder(act_order=plugin_activation_order), ) plugins_to_be_installed.append((class_name, plugin)) @@ -319,8 +319,8 @@ def test_plugin_registration_combination_logic(): plugin = create_plugin_cls( restricted_models={"CausalLM"}, - requires_agumentation=True, - agumentation=dummy_augmentation, + requires_augmentation=True, + augmentation=dummy_augmentation, ) configuration_contents = {"existing1": {"key1": 1}, "existing2": {"key1": 1}} diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index df21fd5c..0d7ce802 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -128,7 +128,7 @@ def __init__(self, configurations: Dict[str, Dict]): ) @property - def requires_agumentation(self): + def requires_augmentation(self): return True def augmentation( diff --git a/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py index 11e91ff6..9d1c0c97 100644 --- a/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py +++ b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py @@ -47,7 +47,7 @@ def test_configure_gptq_foak_plugin(): # check flags and callbacks assert framework.requires_custom_loading is False - assert framework.requires_agumentation + assert framework.requires_augmentation assert len(framework.get_callbacks_and_ready_for_train()) == 0 # attempt to activate plugin with configuration pointing to wrong path