Skip to content

Commit

Permalink
fix: function name 'requires_agumentation' to 'requires_augmentation' (
Browse files Browse the repository at this point in the history
…#118)

* fix: change function name 'requires_agumentation' to 'requires_augmentation'

Signed-off-by: Will Johnson <[email protected]>

* fix: replacement error

Signed-off-by: Will Johnson <[email protected]>

* fix: error

Signed-off-by: Will Johnson <[email protected]>

* fix: 'requires_agumentation' -> 'requires_augmentation'

Signed-off-by: Will Johnson <[email protected]>

---------

Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj authored Jan 9, 2025
1 parent e0bca3e commit 03035e6
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def requires_custom_loading(self):
return True

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def requires_custom_loading(self):
return True

@property
def requires_agumentation(self):
def requires_augmentation(self):
# will skip the augmentation if _no_peft_model == True
return not self._no_peft_model

Expand Down
8 changes: 4 additions & 4 deletions plugins/accelerated-peft/tests/test_peft_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_configure_gptq_plugin():

# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# attempt to activate plugin with configuration pointing to wrong path
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_configure_bnb_plugin():

# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# test valid combinatinos
Expand All @@ -187,7 +187,7 @@ def test_configure_bnb_plugin():
):
# check flags and callbacks
assert framework.requires_custom_loading
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# test no_peft_model is true skips plugin.augmentation
Expand All @@ -202,7 +202,7 @@ def test_configure_bnb_plugin():
require_packages_check=False,
):
# check flags and callbacks
assert (not correct_value) == framework.requires_agumentation
assert (not correct_value) == framework.requires_augmentation

# attempt to activate plugin with configuration pointing to wrong path
# - raise with message that no plugins can be configured
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
assert self._pad_token_id is not None, "need to get pad token id"

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, configurations: Dict[str, Dict]):
)

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ model, (peft_config,) = framework.augmentation(
)
```

We also provide `framework.requires_agumentation` to check if augumentation is required by the plugins.
We also provide `framework.requires_augmentation` to check if augumentation is required by the plugins.

Finally pass the model to train:

Expand Down
8 changes: 4 additions & 4 deletions plugins/framework/src/fms_acceleration/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def augmentation(
x in model_archs for x in plugin.restricted_model_archs
):
raise ValueError(
f"Model architectures in '{model_archs}' are supported for '{plugin_name}'."
f"Model architectures in '{model_archs}' are not supported for '{plugin_name}'."
)

if plugin.requires_agumentation:
if plugin.requires_augmentation:
model, modifiable_args = plugin.augmentation(
model, train_args, modifiable_args=modifiable_args
)
Expand All @@ -214,8 +214,8 @@ def requires_custom_loading(self):
return len(self.plugins_require_custom_loading) > 0

@property
def requires_agumentation(self):
return any(x.requires_agumentation for _, x in self.active_plugins)
def requires_augmentation(self):
return any(x.requires_augmentation for _, x in self.active_plugins)

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def requires_custom_loading(self):
return False

@property
def requires_agumentation(self):
def requires_augmentation(self):
return False

def model_loader(self, model_name: str, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def create_plugin_cls(
restricted_models: Set = None,
require_pkgs: Set = None,
requires_custom_loading: bool = False,
requires_agumentation: bool = False,
agumentation: Callable = None,
requires_augmentation: bool = False,
augmentation: Callable = None,
model_loader: Callable = None,
):
"helper function to create plugin class"
Expand All @@ -174,11 +174,11 @@ def create_plugin_cls(
"restricted_model_archs": restricted_models,
"require_packages": require_pkgs,
"requires_custom_loading": requires_custom_loading,
"requires_agumentation": requires_agumentation,
"requires_augmentation": requires_augmentation,
}

if agumentation is not None:
attributes["augmentation"] = agumentation
if augmentation is not None:
attributes["augmentation"] = augmentation

if model_loader is not None:
attributes["model_loader"] = model_loader
Expand Down
34 changes: 17 additions & 17 deletions plugins/framework/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_model_with_no_config_raises():

# create model and (incomplete) plugin with requires_augmentation = True
model_no_config = torch.nn.Module() # empty model
incomplete_plugin = create_plugin_cls(requires_agumentation=True)
incomplete_plugin = create_plugin_cls(requires_augmentation=True)

# register and activate 1 incomplete plugin, and:
# 1. test correct plugin registration and activation.
Expand Down Expand Up @@ -104,13 +104,13 @@ def test_single_plugin():
empty_plugin = create_plugin_cls()
incomplete_plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
)
plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
train_args = None # dummy for now
Expand Down Expand Up @@ -175,32 +175,32 @@ def test_two_plugins():

model = create_noop_model_with_archs(archs=["CausalLM"])
incomp_plugin1 = create_plugin_cls(
restricted_models={"CausalLM"}, requires_agumentation=True
restricted_models={"CausalLM"}, requires_augmentation=True
)
incomp_plugin2 = create_plugin_cls(requires_agumentation=True)
incomp_plugin2 = create_plugin_cls(requires_augmentation=True)
incomp_plugin3 = create_plugin_cls(
class_name="PluginNoop2", requires_agumentation=True
class_name="PluginNoop2", requires_augmentation=True
)
plugin1 = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
plugin2 = create_plugin_cls(
class_name="PluginNoop2",
restricted_models={"CausalLM"},
requires_agumentation=True,
requires_augmentation=True,
requires_custom_loading=True,
agumentation=dummy_augmentation,
augmentation=dummy_augmentation,
model_loader=dummy_custom_loader,
)
plugin3_no_loader = create_plugin_cls(
class_name="PluginNoop2",
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
requires_augmentation=True,
augmentation=dummy_augmentation,
)
train_args = None # dummy for now

Expand Down Expand Up @@ -299,8 +299,8 @@ def _hook(
for class_name in ["PluginDEF", "PluginABC"]:
plugin = create_plugin_cls(
class_name=class_name,
requires_agumentation=True,
agumentation=hook_builder(act_order=plugin_activation_order),
requires_augmentation=True,
augmentation=hook_builder(act_order=plugin_activation_order),
)
plugins_to_be_installed.append((class_name, plugin))

Expand All @@ -319,8 +319,8 @@ def test_plugin_registration_combination_logic():

plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
requires_augmentation=True,
augmentation=dummy_augmentation,
)

configuration_contents = {"existing1": {"key1": 1}, "existing2": {"key1": 1}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, configurations: Dict[str, Dict]):
)

@property
def requires_agumentation(self):
def requires_augmentation(self):
return True

def augmentation(
Expand Down
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/tests/test_foak_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_configure_gptq_foak_plugin():

# check flags and callbacks
assert framework.requires_custom_loading is False
assert framework.requires_agumentation
assert framework.requires_augmentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# attempt to activate plugin with configuration pointing to wrong path
Expand Down

0 comments on commit 03035e6

Please sign in to comment.