diff --git a/tests/fixtures/formatter/formatter-tool-sort-order-formatted.yml b/tests/fixtures/formatter/formatter-tool-sort-order-formatted.yml index 87cac82..4f62862 100644 --- a/tests/fixtures/formatter/formatter-tool-sort-order-formatted.yml +++ b/tests/fixtures/formatter/formatter-tool-sort-order-formatted.yml @@ -11,6 +11,7 @@ tools: rank: | helpers.weighted_random_sampling(candidate_destinations) toolshed.g2.bx.psu.edu/repos/iuc/mothur_shhh_seqs/mothur_shhh_seqs/.*: + # This is a comment inherits: wig_to_bigWig cores: 2 mem: 20 diff --git a/tests/fixtures/formatter/formatter-tool-sort-order-input.yml b/tests/fixtures/formatter/formatter-tool-sort-order-input.yml index 039d265..a6507fe 100644 --- a/tests/fixtures/formatter/formatter-tool-sort-order-input.yml +++ b/tests/fixtures/formatter/formatter-tool-sort-order-input.yml @@ -16,6 +16,7 @@ tools: wig_to_bigWig: mem: 10 toolshed.g2.bx.psu.edu/repos/iuc/mothur_shhh_seqs/mothur_shhh_seqs/.*: + # This is a comment cores: 2 mem: 20 inherits: wig_to_bigWig diff --git a/tests/fixtures/linter/linter-invalid-regex.yml b/tests/fixtures/linter/linter-invalid-regex.yml index 78cf912..116f9b6 100644 --- a/tests/fixtures/linter/linter-invalid-regex.yml +++ b/tests/fixtures/linter/linter-invalid-regex.yml @@ -6,7 +6,7 @@ tools: cores: 2 params: native_spec: "--mem {mem} --cores {cores} --gpus {gpus}" - bwa[0-9]++: + bwa[0-9]^++: gpus: 2 destinations: diff --git a/tests/fixtures/linter/linter-warnings.yml b/tests/fixtures/linter/linter-warnings.yml new file mode 100644 index 0000000..d802e59 --- /dev/null +++ b/tests/fixtures/linter/linter-warnings.yml @@ -0,0 +1,29 @@ +global: + default_inherits: default + +tools: + default: + abstract: true + cores: 2 + mem: 4 + params: + native_spec: "--mem {mem} --cores {cores} --gpus {gpus}" + mem-no-cores-1: + mem: 16 + cores-no-mem-1: + cores: 8 + cores-no-mem-2: + # noqa: T102 + cores: 8 + cores-no-mem-3: + # noqa + cores: 8 + +destinations: + local: + runner: local + max_accepted_cores: 4 + max_accepted_mem: 16 + scheduling: + prefer: + - general diff --git a/tests/test_shell.py b/tests/test_shell.py index 21a4168..42e8a40 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -109,6 +109,23 @@ def test_lint_destination_defines_cores_instead_of_accepted_cores(self): "working_dest" not in output, f"Did not expect destination: `working_dest` to be in the output, but found: {output}") + def test_lint_warnings(self): + tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-warnings.yml') + output = self.call_shell_command("tpv", "-vv", "lint", tpv_config) + self.assertTrue( + "T102: The tool named: cores-no-mem-1 sets `cores`" in output, + f"Expected T102 warning for cores-no-mem-1 but output was: {output}") + self.assertFalse( + "T102: The tool named: cores-no-mem-2 sets `cores`" in output, + f"T102 warning for cores-no-mem-2 should be suppressed by noqa but output was: {output}") + self.assertFalse( + "T102: The tool named: cores-no-mem-3 sets `cores`" in output, + f"T102 warning for cores-no-mem-3 should be suppressed by noqa but output was: {output}") + output = self.call_shell_command("tpv", "-vv", "lint", "--ignore=T102", tpv_config) + self.assertFalse( + "T102: The tool named:" in output, + f"T102 warnings should be suppressed by --ignore but output was: {output}") + def test_warn_if_default_inherits_not_marked_abstract(self): tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-default-inherits-marked-abstract.yml') diff --git a/tpv/commands/formatter.py b/tpv/commands/formatter.py index cadcbf4..3271e86 100644 --- a/tpv/commands/formatter.py +++ b/tpv/commands/formatter.py @@ -1,6 +1,8 @@ from __future__ import annotations import logging +from ruamel.yaml.comments import CommentedMap, CommentedSeq + from tpv.core import util log = logging.getLogger(__name__) @@ -64,14 +66,24 @@ def multi_level_dict_sorter(dict_to_sort, sort_order): """ if not sort_order: return dict_to_sort - if isinstance(dict_to_sort, dict): + if isinstance(dict_to_sort, CommentedMap): sorted_keys = sorted(dict_to_sort or [], key=TPVConfigFormatter.generic_key_sorter(sort_order.keys())) - return {key: TPVConfigFormatter.multi_level_dict_sorter(dict_to_sort.get(key), - sort_order.get(key, {}) or sort_order.get('*', {})) - for key in sorted_keys} - elif isinstance(dict_to_sort, list): - return [TPVConfigFormatter.multi_level_dict_sorter(item, sort_order.get('*', [])) - for item in dict_to_sort] + rval = CommentedMap() + for key in sorted_keys: + sorted_value = TPVConfigFormatter.multi_level_dict_sorter( + dict_to_sort.get(key), + sort_order.get(key, {}) or sort_order.get('*', {}) + ) + rval[key] = sorted_value + rval.ca.items.update(dict_to_sort.ca.items) + return rval + elif isinstance(dict_to_sort, CommentedSeq): + rval = CommentedSeq() + for item in dict_to_sort: + sorted_item = TPVConfigFormatter.multi_level_dict_sorter(item, sort_order.get('*', [])) + rval.append(sorted_item) + rval.ca.items.update(dict_to_sort.ca.items) + return rval else: return dict_to_sort diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index 08e232e..70db2f7 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -6,35 +6,58 @@ log = logging.getLogger(__name__) +# Warning codes: +# T101: default inheritance not marked abstract +# T102: entity specifies cores without memory + + class TPVLintError(Exception): pass class TPVConfigLinter(object): - def __init__(self, url_or_path): + def __init__(self, url_or_path, ignore): self.url_or_path = url_or_path + self.ignore = ignore or [] self.warnings = [] self.errors = [] + self.loader = None - def lint(self): + def load_config(self): try: - loader = TPVConfigLoader.from_url_or_path(self.url_or_path) + self.loader = TPVConfigLoader.from_url_or_path(self.url_or_path) except Exception as e: log.error(f"Linting failed due to syntax errors in yaml file: {e}") raise TPVLintError("Linting failed due to syntax errors in yaml file: ") from e - default_inherits = loader.global_settings.get('default_inherits') - for tool_regex, tool in loader.tools.items(): + + def add_warning(self, entity, code, message): + if code not in self.ignore and not self.loader.check_noqa(entity, code): + self.warnings.append((code, message)) + + def lint(self): + if self.loader is None: + self.load_config() + default_inherits = self.loader.global_settings.get('default_inherits') + for tool_regex, tool in self.loader.tools.items(): try: re.compile(tool_regex) except re.error: self.errors.append(f"Failed to compile regex: {tool_regex}") - if default_inherits == tool.id: - self.warnings.append( + if default_inherits == tool.id and not tool.abstract: + self.add_warning( + tool, + "T101", f"The tool named: {default_inherits} is marked globally as the tool to inherit from " "by default. You may want to mark it as abstract if it is not an actual tool and it " "will be excluded from scheduling decisions.") - for destination in loader.destinations.values(): + if tool.cores and not tool.mem: + self.add_warning( + tool, + "T102", + f"The tool named: {tool_regex} sets `cores` but not `mem`. This can lead to " + "unexpected memory usage since memory is typically a multiplier of cores.") + for destination in self.loader.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " "The runner parameter is mandatory.") @@ -46,19 +69,21 @@ def lint(self): f"max_accepted_cores/mem/gpus. This is probably an error. If you're migrating from an older " f"version of TPV, the destination properties for cores/mem/gpus have been superseded by the " f"max_accepted_cores/mem/gpus property. Simply renaming them will give you the same functionality.") - if default_inherits == destination.id: - self.warnings.append( + if default_inherits == destination.id and not destination.abstract: + self.add_warning( + destination, + "T101", f"The destination named: {default_inherits} is marked globally as the destination to inherit from " "by default. You may want to mark it as abstract if it is not meant to be dispatched to, and it " "will be excluded from scheduling decisions.") if self.warnings: - for w in self.warnings: - log.warning(w) + for code, message in self.warnings: + log.warning(f"{code}: {message}") if self.errors: for e in self.errors: log.error(e) raise TPVLintError(f"The following errors occurred during linting: {self.errors}") @staticmethod - def from_url_or_path(url_or_path: str): - return TPVConfigLinter(url_or_path) + def from_url_or_path(url_or_path: str, ignore=None): + return TPVConfigLinter(url_or_path, ignore=ignore) diff --git a/tpv/commands/shell.py b/tpv/commands/shell.py index 9ea0c5b..9765e7b 100644 --- a/tpv/commands/shell.py +++ b/tpv/commands/shell.py @@ -26,7 +26,10 @@ def repr_none(dumper: RoundTripRepresenter, data): def tpv_lint_config_file(args): try: - tpv_linter = TPVConfigLinter.from_url_or_path(args.config) + ignore = [] + if args.ignore is not None: + ignore = [x.strip() for x in args.ignore.split(",")] + tpv_linter = TPVConfigLinter.from_url_or_path(args.config, ignore) tpv_linter.lint() log.info("lint successful.") return 0 @@ -74,6 +77,9 @@ def create_parser(): 'lint', help='loads a TPV configuration file and checks it for syntax errors', description="The linter will check yaml syntax and compile python code blocks") + lint_parser.add_argument( + '--ignore', type=str, + help="Comma-separated list of lint error and warning codes to ignore") lint_parser.add_argument( 'config', type=str, help="Path to the TPV config file to lint. Can be a local path or http url.") @@ -120,14 +126,15 @@ def configure_logging(verbosity_count): # or basicConfig persists for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) + level = max(4 - verbosity_count, 1) * 10 # set global logging level logging.basicConfig( stream=sys.stdout, - level=logging.DEBUG if verbosity_count > 3 else logging.ERROR, + level=level, format='%(levelname)-5s: %(name)s: %(message)s') # Set client log level if verbosity_count: - log.setLevel(max(4 - verbosity_count, 1) * 10) + log.setLevel(level) else: log.setLevel(logging.INFO) diff --git a/tpv/core/loader.py b/tpv/core/loader.py index 99f469c..393e6bd 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -2,6 +2,7 @@ import ast import functools import logging +import re from . import helpers from . import util @@ -10,6 +11,9 @@ log = logging.getLogger(__name__) +NOQA_RE = re.compile(r"#\s*noqa:\s*([A-Z0-9, ]+)?") + + class InvalidParentException(Exception): pass @@ -19,6 +23,7 @@ class TPVConfigLoader(object): def __init__(self, tpv_config: dict): self.compile_code_block = functools.lru_cache(maxsize=None)(self.__compile_code_block) self.global_settings = tpv_config.get('global', {}) + self.noqa = {'tools': {}, 'users': {}, 'roles': {}, 'destinations': {}} entities = self.load_entities(tpv_config) self.tools = entities.get('tools') self.users = entities.get('users') @@ -69,7 +74,28 @@ def recompute_inheritance(self, entities: dict[str, Entity]): for key, entity in entities.items(): entities[key] = self.process_inheritance(entities, entity) - def validate_entities(self, entity_class: type, entity_list: dict) -> dict: + def get_noqa_codes(self, entity_comments: list) -> (bool, set[str] | None): + comments = [] + if entity_comments and len(entity_comments) == 4 and entity_comments[3]: + comments.extend([x.value.strip() for x in entity_comments[3]]) + + for comment in comments: + match = re.match(r"#\s*noqa:?\s*([A-Z0-9, ]+)?", comment) + if match: + codes = match.group(1) + # Return a set of codes or None if `# noqa` with no codes + return (True, set(code.strip() for code in codes.split(',')) if codes else None) + + return (False, None) + + def store_noqa_codes(self, entity_list: dict, entity_id: str, noqa_dict: dict): + if hasattr(entity_list, "ca"): + entity_comments = entity_list.ca.items.get(entity_id) + noqa, noqa_codes = self.get_noqa_codes(entity_comments) + if noqa: + noqa_dict[entity_id] = noqa_codes + + def validate_entities(self, entity_class: type, entity_list: dict, noqa_dict: dict) -> dict: # This code relies on dict ordering guarantees provided since python 3.6 validated = {} for entity_id, entity_dict in entity_list.items(): @@ -81,15 +107,19 @@ def validate_entities(self, entity_class: type, entity_list: dict) -> dict: except Exception: log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}") raise + self.store_noqa_codes(entity_list, entity_id, noqa_dict) self.recompute_inheritance(validated) return validated def load_entities(self, tpv_config: dict) -> dict: validated = { - 'tools': self.validate_entities(Tool, tpv_config.get('tools', {})), - 'users': self.validate_entities(User, tpv_config.get('users', {})), - 'roles': self.validate_entities(Role, tpv_config.get('roles', {})), - 'destinations': self.validate_entities(Destination, tpv_config.get('destinations', {})) + 'tools': self.validate_entities(Tool, tpv_config.get('tools', {}), self.noqa['tools']), + 'users': self.validate_entities(User, tpv_config.get('users', {}), self.noqa['users']), + 'roles': self.validate_entities(Role, tpv_config.get('roles', {}), self.noqa['roles']), + 'destinations': self.validate_entities( + Destination, + tpv_config.get('destinations', {}), + self.noqa['destinations']) } return validated @@ -118,6 +148,21 @@ def merge_loader(self, loader: TPVConfigLoader): self.inherit_existing_entities(self.roles, loader.roles) self.inherit_existing_entities(self.destinations, loader.destinations) + def check_noqa(self, entity: Entity, code: str) -> bool: + if type(entity) is Tool: + noqa = self.noqa['tools'] + elif type(entity) is User: + noqa = self.noqa['users'] + elif type(entity) is Role: + noqa = self.noqa['roles'] + elif type(entity) is Destination: + noqa = self.noqa['destinations'] + else: + raise RuntimeError(f"Unknown entity type: {entity}") + if entity.id in noqa and (noqa[entity.id] is None or code in noqa[entity.id]): + return True + return False + @staticmethod def from_url_or_path(url_or_path: str): tpv_config = util.load_yaml_from_url_or_path(url_or_path) diff --git a/tpv/core/util.py b/tpv/core/util.py index 0bcfcda..a25f6f3 100644 --- a/tpv/core/util.py +++ b/tpv/core/util.py @@ -5,7 +5,7 @@ def load_yaml_from_url_or_path(url_or_path: str): - yaml = ruamel.yaml.YAML(typ='safe') + yaml = ruamel.yaml.YAML(typ="rt") if os.path.isfile(url_or_path): with open(url_or_path, 'r') as f: return yaml.load(f)