Skip to content

Commit

Permalink
Automated rollback of commit b0ab1f3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653691605
  • Loading branch information
tfx-copybara committed Aug 6, 2024
1 parent 5e90c67 commit 1f3f787
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 44 deletions.
1 change: 1 addition & 0 deletions build/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ sh_binary(
name = "gen_proto",
srcs = ["gen_proto.sh"],
data = [
"//tfx/dsl/component/experimental:annotations_test_proto_pb2.py",
"//tfx/examples/custom_components/presto_example_gen/proto:presto_config_pb2.py",
"//tfx/extensions/experimental/kfp_compatibility/proto:kfp_component_spec_pb2.py",
"//tfx/extensions/google_cloud_big_query/experimental/elwc_example_gen/proto:elwc_config_pb2.py",
Expand Down
25 changes: 25 additions & 0 deletions tfx/dsl/component/experimental/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
load("//tfx:tfx.bzl", "tfx_py_proto_library")

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

tfx_py_proto_library(
name = "annotations_test_proto_py_pb2",
srcs = ["annotations_test_proto.proto"],
)
39 changes: 28 additions & 11 deletions tfx/dsl/component/experimental/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from tfx.types import artifact
from tfx.utils import deprecation_utils

from google.protobuf import message

try:
import apache_beam as beam # pytype: disable=import-error # pylint: disable=g-import-not-at-top

Expand Down Expand Up @@ -107,31 +109,43 @@ def __repr__(self):
return '%s[%s]' % (self.__class__.__name__, self.type)


class _PrimitiveTypeGenericMeta(type):
class _PrimitiveAndProtoTypeGenericMeta(type):
"""Metaclass for _PrimitiveTypeGeneric, to enable primitive type indexing."""

def __getitem__(
cls: Type['_PrimitiveTypeGeneric'],
params: Type[Union[int, float, str, bool, List[Any], Dict[Any, Any]]],
cls: Type['_PrimitiveAndProtoTypeGeneric'],
params: Type[
Union[
int,
float,
str,
bool,
List[Any],
Dict[Any, Any],
message.Message,
],
],
):
"""Metaclass method allowing indexing class (`_PrimitiveTypeGeneric[T]`)."""
return cls._generic_getitem(params) # pytype: disable=attribute-error


class _PrimitiveTypeGeneric(metaclass=_PrimitiveTypeGenericMeta):
class _PrimitiveAndProtoTypeGeneric(
metaclass=_PrimitiveAndProtoTypeGenericMeta
):
"""A generic that takes a primitive type as its single argument."""

def __init__( # pylint: disable=invalid-name
self,
artifact_type: Type[Union[int, float, str, bool]],
artifact_type: Type[Union[int, float, str, bool, message.Message]],
_init_via_getitem=False,
):
if not _init_via_getitem:
class_name = self.__class__.__name__
raise ValueError(
(
'%s should be instantiated via the syntax `%s[T]`, where T is '
'`int`, `float`, `str`, or `bool`.'
'`int`, `float`, `str`, `bool` or proto type.'
)
% (class_name, class_name)
)
Expand All @@ -143,17 +157,20 @@ def _generic_getitem(cls, params):
# Check that the given parameter is a primitive type.
if (
inspect.isclass(params)
and params in (int, float, str, bool)
and (
params in (int, float, str, bool)
or issubclass(params, message.Message)
)
or json_compat.is_json_compatible(params)
):
return cls(params, _init_via_getitem=True)
else:
class_name = cls.__name__
raise ValueError(
(
'Generic type `%s[T]` expects the single parameter T to be '
'`int`, `float`, `str`, `bool` or JSON-compatible types '
'(Dict[str, T], List[T]) (got %r instead).'
'Generic type `%s[T]` expects the single parameter T to be `int`,'
' `float`, `str`, `bool`, JSON-compatible types (Dict[str, T],'
' List[T]) or a proto type. (got %r instead).'
)
% (class_name, params)
)
Expand Down Expand Up @@ -252,7 +269,7 @@ class AsyncOutputArtifact(Generic[T]):
"""Intermediate artifact object type annotation."""


class Parameter(_PrimitiveTypeGeneric):
class Parameter(_PrimitiveAndProtoTypeGeneric):
"""Component parameter type annotation."""


Expand Down
60 changes: 32 additions & 28 deletions tfx/dsl/component/experimental/annotations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import apache_beam as beam
import tensorflow as tf
from tfx.dsl.component.experimental import annotations
from tfx.dsl.component.experimental import annotations_test_proto_pb2
from tfx.types import artifact
from tfx.types import standard_artifacts
from tfx.types import value_artifact
Expand All @@ -27,18 +28,21 @@ class AnnotationsTest(tf.test.TestCase):

def testArtifactGenericAnnotation(self):
# Error: type hint whose parameter is not an Artifact subclass.
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
_ = annotations._ArtifactGeneric[int] # pytype: disable=unsupported-operands

# Error: type hint with abstract Artifact subclass.
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
_ = annotations._ArtifactGeneric[artifact.Artifact]

# Error: type hint with abstract Artifact subclass.
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
_ = annotations._ArtifactGeneric[value_artifact.ValueArtifact]

# OK.
Expand All @@ -49,56 +53,55 @@ def testArtifactAnnotationUsage(self):
_ = annotations.OutputArtifact[standard_artifacts.Examples]
_ = annotations.AsyncOutputArtifact[standard_artifacts.Model]

def testPrimitiveTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive type
def testPrimitivAndProtoTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive or a proto type
# pytype: disable=unsupported-operands
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric[artifact.Artifact]
_ = annotations._PrimitiveAndProtoTypeGeneric[artifact.Artifact]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric[object]
_ = annotations._PrimitiveAndProtoTypeGeneric[object]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric[123]
_ = annotations._PrimitiveAndProtoTypeGeneric[123]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric['string']
_ = annotations._PrimitiveAndProtoTypeGeneric['string']
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric[Dict[int, int]]
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[int, int]]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveTypeGeneric[bytes]
_ = annotations._PrimitiveAndProtoTypeGeneric[bytes]
# pytype: enable=unsupported-operands
# OK.
_ = annotations._PrimitiveTypeGeneric[int]
_ = annotations._PrimitiveTypeGeneric[float]
_ = annotations._PrimitiveTypeGeneric[str]
_ = annotations._PrimitiveTypeGeneric[bool]
_ = annotations._PrimitiveTypeGeneric[Dict[str, float]]
_ = annotations._PrimitiveTypeGeneric[bool]
_ = annotations._PrimitiveAndProtoTypeGeneric[int]
_ = annotations._PrimitiveAndProtoTypeGeneric[float]
_ = annotations._PrimitiveAndProtoTypeGeneric[str]
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[str, float]]
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
_ = annotations._PrimitiveAndProtoTypeGeneric[
annotations_test_proto_pb2.TestMessage
]

def testPipelineTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive type
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[artifact.Artifact]
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[object]
# pytype: disable=unsupported-operands
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[123]
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric['string']
# pytype: enable=unsupported-operands

Expand All @@ -110,6 +113,7 @@ def testParameterUsage(self):
_ = annotations.Parameter[float]
_ = annotations.Parameter[str]
_ = annotations.Parameter[bool]
_ = annotations.Parameter[annotations_test_proto_pb2.TestMessage]


if __name__ == '__main__':
Expand Down
21 changes: 21 additions & 0 deletions tfx/dsl/component/experimental/annotations_test_proto.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 Google LLC. All Rights Reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";

package tfx.dsl.component.experimental;

message TestMessage {
int32 number = 1;
string name = 2;
}
16 changes: 12 additions & 4 deletions tfx/dsl/component/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tfx.types import artifact
from tfx.types import component_spec
from tfx.types import system_executions
from google.protobuf import message


class ArgFormats(enum.Enum):
Expand Down Expand Up @@ -224,10 +225,17 @@ def _create_component_spec_class(
json_compatible_outputs[key],
)
if parameters:
for key, primitive_type in parameters.items():
spec_parameters[key] = component_spec.ExecutionParameter(
type=primitive_type, optional=(key in arg_defaults)
)
for key, param_type in parameters.items():
if inspect.isclass(param_type) and issubclass(
param_type, message.Message
):
spec_parameters[key] = component_spec.ExecutionParameter(
type=param_type, optional=(key in arg_defaults), use_proto=True
)
else:
spec_parameters[key] = component_spec.ExecutionParameter(
type=param_type, optional=(key in arg_defaults)
)
component_spec_class = type(
'%s_Spec' % func.__name__,
(tfx_types.ComponentSpec,),
Expand Down
14 changes: 13 additions & 1 deletion tfx/dsl/component/experimental/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List
import tensorflow as tf
from tfx.dsl.component.experimental import annotations
from tfx.dsl.component.experimental import annotations_test_proto_pb2
from tfx.dsl.component.experimental import decorators
from tfx.dsl.component.experimental import function_parser
from tfx.dsl.component.experimental import utils
Expand Down Expand Up @@ -106,6 +107,9 @@ def func_with_primitive_parameter(
float_param: annotations.Parameter[float],
str_param: annotations.Parameter[str],
bool_param: annotations.Parameter[bool],
proto_param: annotations.Parameter[
annotations_test_proto_pb2.TestMessage
],
dict_int_param: annotations.Parameter[Dict[str, int]],
list_bool_param: annotations.Parameter[List[bool]],
dict_list_bool_param: annotations.Parameter[Dict[str, List[bool]]],
Expand All @@ -124,6 +128,7 @@ def func_with_primitive_parameter(
'float_param': float,
'str_param': str,
'bool_param': bool,
'proto_param': annotations_test_proto_pb2.TestMessage,
'dict_int_param': Dict[str, int],
'list_bool_param': List[bool],
'dict_list_bool_param': Dict[str, List[bool]],
Expand Down Expand Up @@ -193,6 +198,9 @@ def func(
standard_artifacts.Examples
],
int_param: annotations.Parameter[int],
proto_param: annotations.Parameter[
annotations_test_proto_pb2.TestMessage
],
json_compat_param: annotations.Parameter[Dict[str, int]],
str_param: annotations.Parameter[str] = 'foo',
) -> annotations.OutputDict(
Expand Down Expand Up @@ -257,11 +265,15 @@ def func(
spec_outputs['map_str_float_output'].type, standard_artifacts.JsonValue
)
spec_parameter = actual_spec_class.PARAMETERS
self.assertLen(spec_parameter, 3)
self.assertLen(spec_parameter, 4)
self.assertEqual(spec_parameter['int_param'].type, int)
self.assertEqual(spec_parameter['int_param'].optional, False)
self.assertEqual(spec_parameter['str_param'].type, str)
self.assertEqual(spec_parameter['str_param'].optional, True)
self.assertEqual(
spec_parameter['proto_param'].type,
annotations_test_proto_pb2.TestMessage,
)
self.assertEqual(spec_parameter['json_compat_param'].type, Dict[str, int])
self.assertEqual(spec_parameter['json_compat_param'].optional, False)
self.assertEqual(actual_spec_class.TYPE_ANNOTATION, type_annotation)
Expand Down

0 comments on commit 1f3f787

Please sign in to comment.