From 3fd7129d3859fc14ee7cbad3f8bab321e3438333 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 7 Aug 2024 15:35:52 -0700 Subject: [PATCH] remove self._store to simplify the resolver_op.Context PiperOrigin-RevId: 660563277 --- .../ops/latest_policy_model_op_test.py | 4 ++-- tfx/dsl/input_resolution/ops/test_utils.py | 17 +++++++++++++---- .../ops/training_range_op_test.py | 6 +++--- tfx/dsl/input_resolution/resolver_op.py | 11 ++--------- .../input_resolution/input_graph_resolver.py | 1 - 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py index 45cc8d37b5..0055eccde1 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py @@ -188,7 +188,7 @@ def testLatestPolicyModelOpTest_RaisesSkipSignal(self): {}, policy=_LATEST_EXPORTED, raise_skip_signal=True, - context=resolver_op.Context(store=self.store), + context=resolver_op.Context(self.mlmd_cm), ) # Keys present in input_dict but contains no artifacts. @@ -214,7 +214,7 @@ def testLatestPolicyModelOpTest_DoesNotRaiseSkipSignal(self): {}, policy=_LATEST_EXPORTED, raise_skip_signal=False, - context=resolver_op.Context(store=self.store), + context=resolver_op.Context(self.mlmd_cm), ), policy=_LATEST_EXPORTED, ) diff --git a/tfx/dsl/input_resolution/ops/test_utils.py b/tfx/dsl/input_resolution/ops/test_utils.py index 1ab3ce0908..1d4b0705b5 100644 --- a/tfx/dsl/input_resolution/ops/test_utils.py +++ b/tfx/dsl/input_resolution/ops/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Testing utility for builtin resolver ops.""" + from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union from unittest import mock @@ -25,6 +26,7 @@ from tfx.dsl.components.base import executor_spec from tfx.dsl.input_resolution import resolver_op from tfx.dsl.input_resolution.ops import ops_utils +from tfx.orchestration import metadata from tfx.orchestration import pipeline from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.proto.orchestration import pipeline_pb2 @@ -423,11 +425,18 @@ def strict_run_resolver_op( f'Expected ARTIFACT_MULTIMAP_LIST but arg[{i}] = {arg}' ) op = op_type.create(**kwargs) + + if mlmd_handle_like is not None: + mlmd_handle = mlmd_handle_like + else: + mlmd_handle = metadata.Metadata( + connection_config=metadata_store_pb2.ConnectionConfig(), + ) + mlmd_handle._store = ( # pylint: disable=protected-access + store if store is not None else mock.MagicMock(spec=mlmd.MetadataStore) + ) context = resolver_op.Context( - store=store - if store is not None - else mock.MagicMock(spec=mlmd.MetadataStore), - mlmd_handle_like=mlmd_handle_like, + mlmd_handle_like=mlmd_handle, ) op.set_context(context) result = op.apply(*args) diff --git a/tfx/dsl/input_resolution/ops/training_range_op_test.py b/tfx/dsl/input_resolution/ops/training_range_op_test.py index 3fd4e4433a..dff5bd550d 100644 --- a/tfx/dsl/input_resolution/ops/training_range_op_test.py +++ b/tfx/dsl/input_resolution/ops/training_range_op_test.py @@ -127,7 +127,7 @@ def testTrainingRangeOp_EmptyListReturned(self): actual = test_utils.run_resolver_op( ops.TrainingRange, [], - context=resolver_op.Context(store=self.store), + context=resolver_op.Context(self.mlmd_cm), ) self.assertEmpty(actual) @@ -150,14 +150,14 @@ def testTrainingRangeOp_InvalidArgumentRaised(self): test_utils.run_resolver_op( ops.TrainingRange, [self.model, self.model], - context=resolver_op.Context(store=self.store), + context=resolver_op.Context(self.mlmd_cm), ) # Incorret input artifact type. test_utils.run_resolver_op( ops.TrainingRange, [self.transform_graph], - context=resolver_op.Context(store=self.store), + context=resolver_op.Context(self.mlmd_cm), ) def testTrainingRangeOp_BulkInferrerProducesExamples(self): diff --git a/tfx/dsl/input_resolution/resolver_op.py b/tfx/dsl/input_resolution/resolver_op.py index 964016a5a5..b27f79649e 100644 --- a/tfx/dsl/input_resolution/resolver_op.py +++ b/tfx/dsl/input_resolution/resolver_op.py @@ -25,8 +25,6 @@ from tfx.utils import json_utils from tfx.utils import typing_utils -import ml_metadata as mlmd - # Mark frozen as context instance may be used across multiple operator # invocations. @@ -35,18 +33,13 @@ class Context: def __init__( self, - store=mlmd.MetadataStore, - mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None, + mlmd_handle_like: mlmd_cm.HandleLike, ): - # TODO(b/302730333) We could remove self._store, and only use - # self._mlmd_handle_like. Keeping it for now to preserve backward - # compatibility with other resolve ops. - self._store = store self._mlmd_handle_like = mlmd_handle_like @property def store(self): - return self._store + return mlmd_cm.get_handle(self._mlmd_handle_like).store @property def mlmd_connection_manager(self): diff --git a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py index 667b224a7f..e9a6a15e9c 100644 --- a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py +++ b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py @@ -137,7 +137,6 @@ def _evaluate_op_node( op: resolver_op.ResolverOp = op_type.create(**kwargs) op.set_context( resolver_op.Context( - store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store, mlmd_handle_like=ctx.mlmd_handle_like, ) )