diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index fc864a255..ccf7c294f 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -424,7 +424,7 @@ def apply_inverse(self, selector) -> "Schema": def select_by_tag( self, - tags: Union[Union[str, Tags], List[Union[str, Tags]]], + tags: Union[Union[str, Tags], List[Union[str, Tags]], TagSet], pred_fn=None, ) -> "Schema": """Select columns from this Schema that match ANY of the supplied tags. @@ -449,6 +449,7 @@ def select_by_tag( if not isinstance(tags, (list, tuple)): tags = [tags] + tags = TagSet(tags) selected_schemas = {}