diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 99fb1c6e..3c4bdc0c 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -25,9 +25,9 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - framework_plugin: + plugin_name: - "framework" - - "accelerated-peft" + # - "accelerated-peft" # enable later steps: - uses: actions/checkout@v4 @@ -39,7 +39,11 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install tox - - name: Change to plugin directory - run: cd plugins/${{ matrix.framework_plugin }} - - name: Run linter and formatter - run: tox -e lint + - name: Run linter + run: | + cd plugins/${{ matrix.plugin_name }} + tox -e lint + - name: Run formatter + run: | + cd plugins/${{ matrix.plugin_name }} + tox -e fmt diff --git a/plugins/accelerated-peft/tox.ini b/plugins/accelerated-peft/tox.ini index 6460cdbc..b79d0691 100644 --- a/plugins/accelerated-peft/tox.ini +++ b/plugins/accelerated-peft/tox.ini @@ -18,6 +18,13 @@ commands = [testenv:lint] description = run linters +deps = + pylint>=2.16.2,<=3.1.0 +commands = pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format skip_install = true deps = black>=22.12 @@ -26,6 +33,7 @@ commands = black {posargs:.} isort {posargs:.} + # [testenv:build] # description = build wheel # deps = diff --git a/plugins/framework/.pylintrc b/plugins/framework/.pylintrc new file mode 100644 index 00000000..45da4212 --- /dev/null +++ b/plugins/framework/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/plugins/framework/src/fms_acceleration/__init__.py b/plugins/framework/src/fms_acceleration/__init__.py index c396c568..e39cd055 100644 --- a/plugins/framework/src/fms_acceleration/__init__.py +++ b/plugins/framework/src/fms_acceleration/__init__.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard +# use importlib to load the packages, if they are installed +import importlib + # Local -from .framework import AccelerationFramework from .constants import PLUGIN_PREFIX, PLUGINS +from .framework import AccelerationFramework from .framework_plugin import ( AccelerationPlugin, AccelerationPluginConfigError, get_relevant_configuration_sections, ) -# Standard -# use importlib to load the packages, if they are installed -import importlib - for postfix in PLUGINS: plugin_name = f"{PLUGIN_PREFIX}{postfix}" if importlib.util.find_spec(plugin_name): diff --git a/plugins/framework/src/fms_acceleration/cli.py b/plugins/framework/src/fms_acceleration/cli.py index a29fecef..04ad263a 100644 --- a/plugins/framework/src/fms_acceleration/cli.py +++ b/plugins/framework/src/fms_acceleration/cli.py @@ -13,22 +13,24 @@ # limitations under the License. -import argparse +# Standard +from typing import List, Union import os -import sys import subprocess -from typing import List, Union -import yaml - -from .constants import PLUGIN_PREFIX, PLUGINS +import sys +# Third Party from pip._internal.cli.main import main as pipmain - from transformers.utils.import_utils import _is_package_available +import yaml + +# Local +from .constants import PLUGIN_PREFIX, PLUGINS GITHUB_URL = "github.com/foundation-model-stack/fms-acceleration.git" -REPO_CACHE_DIR = '.fms/repository' +REPO_CACHE_DIR = ".fms/repository" + # TODO: make a version that fetches the def install_plugin( @@ -36,26 +38,29 @@ def install_plugin( ): "function to install plugin. Inputs should contain a pkg_name." - pkg_name = [x for x in args if not x.startswith('-')] - assert len(pkg_name) == 1,\ - "Please specify exactly one plugin to install" + pkg_name = [x for x in args if not x.startswith("-")] + assert len(pkg_name) == 1, "Please specify exactly one plugin to install" pkg_name = pkg_name[0] # take the flags - args = [x for x in args if x.startswith('-')] + args = [x for x in args if x.startswith("-")] if os.path.exists(pkg_name): - pipmain(['install', *args, pkg_name]) - return + pipmain(["install", *args, pkg_name]) + return if pkg_name.startswith(PLUGIN_PREFIX): pkg_name = pkg_name.replace(PLUGIN_PREFIX, "") # otherwise should be an internet install - pipmain([ - 'install', *args, - f'git+https://{GITHUB_URL}#subdirectory=plugins/accelerated-{pkg_name}' - ]) + pipmain( + [ + "install", + *args, + f"git+https://{GITHUB_URL}#subdirectory=plugins/accelerated-{pkg_name}", + ] + ) + def list_plugins(): print( @@ -75,13 +80,14 @@ def list_plugins(): print(f"{i+1}. {full_name} [{name}] {postfix}") + def get_benchmark_artifacts(dest_dir: str): if not os.path.exists(dest_dir): os.makedirs(dest_dir) - - if not os.path.exists(os.path.join(dest_dir, '.git')): - command = f"""cd {dest_dir} && git init && git remote add -f origin https://{GITHUB_URL} && \ + if not os.path.exists(os.path.join(dest_dir, ".git")): + command = f"""cd {dest_dir} && git init && \ + git remote add -f origin https://{GITHUB_URL} && \ git config --global init.defaultBranch main && \ git config core.sparsecheckout true && \ echo scripts/benchmarks >> .git/info/sparse-checkout && \ @@ -91,63 +97,73 @@ def get_benchmark_artifacts(dest_dir: str): command = f"cd {dest_dir} && git fetch origin && " command += "git pull origin main " - out = subprocess.run(command, shell=True, capture_output=True) + out = subprocess.run(command, shell=True, capture_output=True, check=False) if out.returncode != 0: - raise RuntimeError(f"could not get benchmark artifacts with error code {out.returncode}") - return out + raise RuntimeError( + f"could not get benchmark artifacts with error code {out.returncode}" + ) + return out + def list_sample_configs( - configs_dir: str, - contents_file: str = 'sample-configurations/CONTENTS.yaml', + configs_dir: str, + contents_file: str = "sample-configurations/CONTENTS.yaml", get_artifacts: bool = True, ): if get_artifacts: get_benchmark_artifacts(REPO_CACHE_DIR) - with open(os.path.join(configs_dir, contents_file)) as f: - for i, entry in enumerate(yaml.safe_load(f)['framework_configs']): - shortname = entry['shortname'] - plugins = entry['plugins'] - filename = entry['filename'] - print (f"{i+1}. {shortname} ({filename}) - plugins: {plugins}") + with open(os.path.join(configs_dir, contents_file), encoding="utf-8") as f: + for i, entry in enumerate(yaml.safe_load(f)["framework_configs"]): + shortname = entry["shortname"] + plugins = entry["plugins"] + filename = entry["filename"] + print(f"{i+1}. {shortname} ({filename}) - plugins: {plugins}") + def list_arguments( - scenario_dir: str, + scenario_dir: str, config_shortnames: Union[str, List[str]], - scenario_file: str = 'scripts/benchmarks/scenarios.yaml', - ignored_fields = ['model_name_or_path'], + scenario_file: str = "scripts/benchmarks/scenarios.yaml", + ignored_fields: List = None, get_artifacts: bool = True, ): + if ignored_fields is None: + ignored_fields = ["model_name_or_path"] + if get_artifacts: get_benchmark_artifacts(REPO_CACHE_DIR) if isinstance(config_shortnames, str): config_shortnames = [config_shortnames] - with open(os.path.join(scenario_dir, scenario_file)) as f: - scenarios = yaml.safe_load(f)['scenarios'] + with open(os.path.join(scenario_dir, scenario_file), encoding="utf-8") as f: + scenarios = yaml.safe_load(f)["scenarios"] found = 0 - print (f"Searching for configuration shortnames: {config_shortnames}") + print(f"Searching for configuration shortnames: {config_shortnames}") for scn in scenarios: - if 'framework_config' not in scn: + if "framework_config" not in scn: continue - hit_sn = [x for x in config_shortnames if x in scn['framework_config']] + hit_sn = [x for x in config_shortnames if x in scn["framework_config"]] if len(hit_sn) > 0: found += 1 - name = scn['name'] - arguments = scn['arguments'] + name = scn["name"] + arguments = scn["arguments"] hit_sn = ", ".join(hit_sn) - print (f"{found}. scenario: {name}\n configs: {hit_sn}\n arguments:") + print(f"{found}. scenario: {name}\n configs: {hit_sn}\n arguments:") lines = [] for key, val in arguments.items(): if key not in ignored_fields: lines.append(f" --{key} {val}") - - print (" \\\n".join(lines)) - print ("\n") + + print(" \\\n".join(lines)) + print("\n") if not found: - print(f"ERROR: Could not list arguments for configuration shortname '{config_shortnames}'") + print( + f"ERROR: Could not list arguments for configuration shortname '{config_shortnames}'" + ) + def cli(): # not using argparse since its so simple @@ -157,31 +173,30 @@ def cli(): ) argv = sys.argv if len(argv) == 1: - print (message) + print(message) return - else: + if len(argv) > 1: command = argv[1] if len(argv) > 2: variadic = sys.argv[2:] else: variadic = [] - if command == 'install': + if command == "install": assert len(variadic) >= 1, "Please provide the acceleration plugin name" install_plugin(*variadic) - elif command == 'plugins': + elif command == "plugins": assert len(variadic) == 0, "list does not require arguments" list_plugins() - elif command == 'configs': + elif command == "configs": assert len(variadic) == 0, "list-config does not require arguments" list_sample_configs(REPO_CACHE_DIR) - elif command == 'arguments': + elif command == "arguments": assert len(variadic) >= 1, "Please provide the config shortname" list_arguments(REPO_CACHE_DIR, *variadic) else: - raise NotImplementedError( - f"Unknown fms_acceleration.cli command '{command}'" - ) + raise NotImplementedError(f"Unknown fms_acceleration.cli command '{command}'") + -if __name__ == '__main__': - cli() \ No newline at end of file +if __name__ == "__main__": + cli() diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 529b6bd5..6d545ac7 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Callable, List, Optional, Set, Tuple # Third Party from accelerate import Accelerator @@ -23,19 +23,20 @@ import torch import yaml -# want to use the transformers logger, but a bit of pain -logger = logging.get_logger(__name__) # pylint: disable=invalid-name -logger.setLevel(logging._get_default_logging_level()) -logger.addHandler(logging._default_handler) - -# First Party +# Local +from .constants import KEY_PLUGINS from .framework_plugin import ( PLUGIN_REGISTRATIONS, AccelerationPlugin, PluginRegistration, get_relevant_configuration_sections, ) -from .constants import KEY_PLUGINS + +# want to use the transformers logger, but a bit of pain +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger.setLevel(logging._get_default_logging_level()) +logger.addHandler(logging._default_handler) + def check_plugin_packages(plugin: AccelerationPlugin): if plugin.require_packages is None: @@ -47,13 +48,14 @@ def check_plugin_packages(plugin: AccelerationPlugin): missing_packages.append(package_name) return len(missing_packages) == 0, missing_packages + def log_initialization_message( active_class_names: Set[str], registered_plugins: List[PluginRegistration], # list of regs - logger: Callable = None, + logging_func: Callable = None, ): - if logger is None: - logger = print + if logging_func is None: + logging_func = print def _registration_display(reg: PluginRegistration): return ( @@ -62,36 +64,33 @@ def _registration_display(reg: PluginRegistration): f"Version: {reg.package_version}." ) - logger("***** FMS AccelerationFramework *****") + logging_func("***** FMS AccelerationFramework *****") for reg in registered_plugins: if reg.plugin.__name__ in active_class_names: - logger(_registration_display(reg)) + logging_func(_registration_display(reg)) class AccelerationFramework: - - active_plugins: List[Tuple[str, AccelerationPlugin]] = list() - plugins_require_custom_loading: List = list() + active_plugins: List[Tuple[str, AccelerationPlugin]] = [] + plugins_require_custom_loading: List = [] def __init__( self, configuration_file: Optional[str], require_packages_check: bool = True ): - - with open(configuration_file, "r") as f: + with open(configuration_file, "r", encoding="utf-8") as f: contents = yaml.safe_load(f) if KEY_PLUGINS not in contents or contents[KEY_PLUGINS] is None: raise ValueError(f"Configuration file must contain a '{KEY_PLUGINS}' body") # pepare the plugin configurations - plugin_configs = {k: v for k, v in contents[KEY_PLUGINS].items()} + plugin_configs = dict(contents[KEY_PLUGINS].items()) # relevant sections are returned following plugin precedence, i.e., # they follow the registration order. for selected_configs, cls in get_relevant_configuration_sections( plugin_configs ): - # then the model is to be installed # get the plugin plugin_name = str(cls.__name__) @@ -108,7 +107,7 @@ def __init__( # check if already activated, if so, will not reactivate again # maintain uniqueness of activated plugins - if any([x == plugin_name for x, _ in self.active_plugins]): + if any(x == plugin_name for x, _ in self.active_plugins): continue # activate plugin @@ -123,15 +122,16 @@ def __init__( "framework configuration file." ) - assert ( - len(self.plugins_require_custom_loading) <= 1 - ), f"Can load at most 1 plugin with custom model loading, but tried to '{self.plugins_require_custom_loading}'." + assert len(self.plugins_require_custom_loading) <= 1, ( + "Can load at most 1 plugin with custom model loading, " + f"but tried to '{self.plugins_require_custom_loading}'." + ) def model_loader(self, model_name: str, **kwargs): - if len(self.plugins_require_custom_loading) == 0: raise NotImplementedError( - f"Attempted model loading, but none of activated plugins '{list(self.active_plugins)}' " + "Attempted model loading, but none " + f"of activated plugins '{list(self.active_plugins)}' " "require custom loading." ) @@ -152,10 +152,9 @@ def augmentation( # NOTE: this assumes that augmentation order does not matter for plugin_name, plugin in self.active_plugins: - # check the model arcs at augmentation if plugin.restricted_model_archs and not any( - [x in model_archs for x in plugin.restricted_model_archs] + x in model_archs for x in plugin.restricted_model_archs ): raise ValueError( f"Model architectures in '{model_archs}' are supported for '{plugin_name}'." @@ -174,16 +173,16 @@ def requires_custom_loading(self): @property def requires_agumentation(self): - return any([x.requires_agumentation for _, x in self.active_plugins]) + return any(x.requires_agumentation for _, x in self.active_plugins) def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None ): # show the initialized message log_initialization_message( - set([x for x, _ in self.active_plugins]), + {x for x, _ in self.active_plugins}, PLUGIN_REGISTRATIONS, - logger=logger.info, + logging_func=logger.info, ) cbks = [] diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 1d17a863..fc6da973 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -36,7 +36,7 @@ class PluginRegistration: package_version: str = None -PLUGIN_REGISTRATIONS: List[PluginRegistration] = list() +PLUGIN_REGISTRATIONS: List[PluginRegistration] = [] def _trace_key_path(configuration: Dict, key: str): @@ -85,7 +85,6 @@ def get_relevant_configuration_sections(configuration: Dict) -> Dict: class AccelerationPlugin: - # will be triggered if the configuration_paths are found in the # acceleration framework configuration file (under KEY_PLUGINS) @staticmethod @@ -94,13 +93,18 @@ def register_plugin( configuration_and_paths: List[str], **kwargs, ): - global PLUGIN_REGISTRATIONS + + # pylint: disable=trailing-whitespace + # removed because of src/fms_acceleration/framework_plugin.py:96:8: + # W0602: Using global for 'PLUGIN_REGISTRATIONS' but no assignment + # is done (global-variable-not-assigned) + # global PLUGIN_REGISTRATIONS # get the package metadata pkg_name = sys.modules[plugin.__module__].__package__ try: package_version = importlib.metadata.version(pkg_name) - except importlib.metadata.PackageNotFoundError: + except importlib.metadata.PackageNotFoundError: package_version = None PLUGIN_REGISTRATIONS.append( @@ -116,7 +120,6 @@ def register_plugin( require_packages: Optional[Set] = None def __init__(self, configurations: Dict[str, Dict]): - # will pass in a list of dictionaries keyed by "configuration_keys" # to be used for initialization self.configurations = configurations @@ -153,13 +156,15 @@ def _check_config_and_maybe_check_values(self, key: str, values: List[Any] = Non # if the tree is a dict if len(t.keys()) > 1: raise AccelerationPluginConfigError( - f"{self.__class__.__name__}: '{key}' found but amongst multiple '{t.keys()}' exist. Ambiguous check in expected set '{values}'." + f"{self.__class__.__name__}: '{key}' found but amongst multiple " + "'{t.keys()}' exist. Ambiguous check in expected set '{values}'." ) t = list(t.keys())[0] # otherwise take the first value if t not in values: raise AccelerationPluginConfigError( - f"{self.__class__.__name__}: Value at '{key}' was '{t}'. Not found in expected set '{values}'." + f"{self.__class__.__name__}: Value at '{key}' was '{t}'. " + "Not found in expected set '{values}'." ) else: # if nothing to check against, we still want to ensure its a valid diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index aa796707..3cc4004f 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -51,7 +51,7 @@ def update_configuration_contents( def read_configuration(path: str) -> Dict: "helper function to read yaml config into json" - with open(path) as f: + with open(path, encoding="utf-8") as f: return yaml.safe_load(f) @@ -69,13 +69,16 @@ def build_framework_and_maybe_instantiate( plugins_to_be_registered: List[ Tuple[List[str], Type[AccelerationPlugin]] # and_paths, plugin_class ], - configuration_contents: Dict = {}, + configuration_contents: Dict = None, instantiate: bool = True, reset_registrations: bool = True, require_packages_check: bool = True, ): "helper function to register plugins and instantiate an acceleration framework for testing" + if configuration_contents is None: + configuration_contents = {} + # empty out if reset_registrations: old_registrations = [] @@ -93,7 +96,9 @@ def build_framework_and_maybe_instantiate( ) if instantiate: - yield configure_framework_from_json(configuration_contents, require_packages_check) + yield configure_framework_from_json( + configuration_contents, require_packages_check + ) else: yield @@ -104,9 +109,11 @@ def build_framework_and_maybe_instantiate( AccelerationFramework.active_plugins = old_active_plugins AccelerationFramework.plugins_require_custom_loading = old_custom_loading_plugins -# alias because default instantiate=True + +# alias because default instantiate=True build_framework_and_instantiate = build_framework_and_maybe_instantiate + def instantiate_framework( configuration_contents: Dict, require_packages_check: bool = True, @@ -122,8 +129,12 @@ def instantiate_framework( ) -def create_noop_model_with_archs(class_name: str = "ModelNoop", archs: List[str] = []): +def create_noop_model_with_archs( + class_name: str = "ModelNoop", archs: List[str] = None +): "helper function to create a dummy model with mocked architectures" + if archs is None: + archs = [] config = type("Config", (object,), {"architectures": archs}) return type(class_name, (torch.nn.Module,), {"config": config}) @@ -131,8 +142,8 @@ def create_noop_model_with_archs(class_name: str = "ModelNoop", archs: List[str] def create_plugin_cls( class_name: str = "PluginNoop", - restricted_models: Set = {}, - require_pkgs: Set = {}, + restricted_models: Set = None, + require_pkgs: Set = None, requires_custom_loading: bool = False, requires_agumentation: bool = False, agumentation: Callable = None, @@ -140,6 +151,11 @@ def create_plugin_cls( ): "helper function to create plugin class" + if restricted_models is None: + restricted_models = set() + if require_pkgs is None: + require_pkgs = set() + attributes = { "restricted_model_archs": restricted_models, "require_packages": require_pkgs, diff --git a/plugins/framework/tests/test_framework.py b/plugins/framework/tests/test_framework.py index eff1600a..de7abd41 100644 --- a/plugins/framework/tests/test_framework.py +++ b/plugins/framework/tests/test_framework.py @@ -15,103 +15,19 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ -# Standard -from contextlib import contextmanager -from tempfile import NamedTemporaryFile -from typing import Callable, Dict, List, Set, Tuple, Type - # Third Party -import pytest +import pytest # pylint: disable=(import-error import torch -import yaml # First Party -from fms_acceleration.framework import KEY_PLUGINS, AccelerationFramework -from fms_acceleration.framework_plugin import PLUGIN_REGISTRATIONS, AccelerationPlugin - -# ----------------------------- HELPER ------------------------------------- - - -@contextmanager -def build_framework_and_instantiate( - plugins_to_be_registered: List[ - Tuple[List[str], Type[AccelerationPlugin]] # and_paths, plugin_class - ], - configuration_contents: Dict, -): - "helper function to instantiate an acceleration framework for testing" - - # empty out - old_registrations = [] - old_registrations.extend(PLUGIN_REGISTRATIONS) - PLUGIN_REGISTRATIONS.clear() - old_active_plugins = AccelerationFramework.active_plugins - old_custom_loading_plugins = AccelerationFramework.plugins_require_custom_loading - AccelerationFramework.active_plugins = [] - AccelerationFramework.plugins_require_custom_loading = [] - - for path, plugin in plugins_to_be_registered: - AccelerationPlugin.register_plugin( - plugin, - configuration_and_paths=path, - ) - - with NamedTemporaryFile("w") as f: - yaml.dump({KEY_PLUGINS: configuration_contents}, f) - yield AccelerationFramework(f.name) - - # put back - PLUGIN_REGISTRATIONS.clear() - PLUGIN_REGISTRATIONS.extend(old_registrations) - AccelerationFramework.active_plugins = old_active_plugins - AccelerationFramework.plugins_require_custom_loading = old_custom_loading_plugins - - -def create_noop_model_with_archs(class_name: str = "ModelNoop", archs: List[str] = []): - "helper function to create a dummy model with mocked architectures" - - config = type("Config", (object,), {"architectures": archs}) - return type(class_name, (torch.nn.Module,), {"config": config}) - - -def create_plugin_cls( - class_name: str = "PluginNoop", - restricted_models: Set = {}, - require_pkgs: Set = {}, - requires_custom_loading: bool = False, - requires_agumentation: bool = False, - agumentation: Callable = None, - model_loader: Callable = None, -): - "helper function to create plugin class" - - attributes = { - "restricted_model_archs": restricted_models, - "require_packages": require_pkgs, - "requires_custom_loading": requires_custom_loading, - "requires_agumentation": requires_agumentation, - } - - if agumentation is not None: - attributes["augmentation"] = agumentation - - if model_loader is not None: - attributes["model_loader"] = model_loader - - return type(class_name, (AccelerationPlugin,), attributes) - - -def dummy_augmentation(self, model, train_args, modifiable_args): - "dummy augmentation implementation" - return model, modifiable_args - - -def dummy_custom_loader(self, model_name, **kwargs): - "dummy custom loader returning dummy model" - return create_noop_model_with_archs(archs=["DummyModel"]) - - -# ----------------------------- TESTS ------------------------------------- +from fms_acceleration.framework_plugin import PLUGIN_REGISTRATIONS +from fms_acceleration.utils.test_utils import ( + build_framework_and_instantiate, + create_noop_model_with_archs, + create_plugin_cls, + dummy_augmentation, + dummy_custom_loader, +) def test_config_with_empty_body_raises(): @@ -208,7 +124,6 @@ def test_single_plugin(): plugins_to_be_registered=[(["dummy"], incomplete_plugin)], configuration_contents={"dummy": {"key1": 1}}, ) as framework: - # check 1. assert len(PLUGIN_REGISTRATIONS) == 1 assert len(framework.active_plugins) == 1 @@ -300,7 +215,6 @@ def test_two_plugins(): ], configuration_contents={"dummy": {"key1": 1}, "dummy2": {"key1": 1}}, ) as framework: - # check 1. assert len(PLUGIN_REGISTRATIONS) == 2 @@ -357,13 +271,15 @@ def test_plugin_registration_order(): "test that plugin registration order determines their activation order" # build a set of hooks that register the activation order - def hook_builder(act_order=[]): + def hook_builder(act_order=None): def _hook( self, model, train_args, modifiable_args, ): + if act_order is None: + act_order = [] act_order.append(str(self.__class__)) return model, modifiable_args @@ -391,7 +307,6 @@ def _hook( plugins_to_be_registered=[([k], v) for k, v in plugins_to_be_installed], configuration_contents={k: {"key1": 1} for k, _ in plugins_to_be_installed}, ) as framework: - # trigger augmentation of active plugins and check order of activation framework.augmentation(model, None, None) for c, (n, _) in zip(plugin_activation_order, plugins_to_be_installed): diff --git a/plugins/framework/tox.ini b/plugins/framework/tox.ini index a5db281a..52513f9a 100644 --- a/plugins/framework/tox.ini +++ b/plugins/framework/tox.ini @@ -8,6 +8,13 @@ commands = pytest {posargs:tests} [testenv:lint] description = run linters +deps = + pylint>=2.16.2,<=3.1.0 +commands = pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format skip_install = true deps = black>=22.12