Skip to content

Commit

Permalink
Add support for make_dict_op in `placeholder_utils.get_all_types_in…
Browse files Browse the repository at this point in the history
…_placeholder_expression`.

PiperOrigin-RevId: 659641341
  • Loading branch information
tfx-copybara committed Aug 5, 2024
1 parent dca6148 commit 1c15280
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tfx/dsl/compiler/placeholder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,9 @@ def get_all_types_in_placeholder_expression(
expressions = operator_pb.expressions
elif operator_name == "make_proto_op":
expressions = operator_pb.fields.values()
elif operator_name == "make_dict_op":
expressions = [entry.key for entry in operator_pb.entries]
expressions += [entry.value for entry in operator_pb.entries]
else:
raise ValueError(
f"Unrecognized placeholder operator {operator_name} in expression: "
Expand Down
32 changes: 32 additions & 0 deletions tfx/dsl/compiler/placeholder_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,38 @@ def testGetTypesOfMakeProtoOperator(self):
)
self.assertSetEqual(actual_types, set(ph_types))

def testGetTypesOfMakeDictOperator(self):
ph_types = placeholder_pb2.Placeholder.Type.values()
expressions = " ".join(f"""
entries {{
key: {{
value: {{
string_value: "field_{_ph_type_to_str(ph_type)}"
}}
}}
value: {{
placeholder: {{
type: {ph_type}
key: 'baz'
}}
}}
}}
""" for ph_type in ph_types)
placeholder_expression = text_format.Parse(
f"""
operator {{
make_dict_op {{
{expressions}
}}
}}
""",
placeholder_pb2.PlaceholderExpression(),
)
actual_types = placeholder_utils.get_all_types_in_placeholder_expression(
placeholder_expression
)
self.assertSetEqual(actual_types, set(ph_types))

def testGetsOperatorsFromProtoReflection(self):
self.assertSetEqual(
placeholder_utils.get_unary_operator_names(),
Expand Down

0 comments on commit 1c15280

Please sign in to comment.