Skip to content

Commit

Permalink
remove self._store to simplify the resolver_op.Context
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660563277
  • Loading branch information
tfx-copybara committed Aug 8, 2024
1 parent 236ac38 commit 3fd7129
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def testLatestPolicyModelOpTest_RaisesSkipSignal(self):
{},
policy=_LATEST_EXPORTED,
raise_skip_signal=True,
context=resolver_op.Context(store=self.store),
context=resolver_op.Context(self.mlmd_cm),
)

# Keys present in input_dict but contains no artifacts.
Expand All @@ -214,7 +214,7 @@ def testLatestPolicyModelOpTest_DoesNotRaiseSkipSignal(self):
{},
policy=_LATEST_EXPORTED,
raise_skip_signal=False,
context=resolver_op.Context(store=self.store),
context=resolver_op.Context(self.mlmd_cm),
),
policy=_LATEST_EXPORTED,
)
Expand Down
17 changes: 13 additions & 4 deletions tfx/dsl/input_resolution/ops/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utility for builtin resolver ops."""

from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
from unittest import mock

Expand All @@ -25,6 +26,7 @@
from tfx.dsl.components.base import executor_spec
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops_utils
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.proto.orchestration import pipeline_pb2
Expand Down Expand Up @@ -423,11 +425,18 @@ def strict_run_resolver_op(
f'Expected ARTIFACT_MULTIMAP_LIST but arg[{i}] = {arg}'
)
op = op_type.create(**kwargs)

if mlmd_handle_like is not None:
mlmd_handle = mlmd_handle_like
else:
mlmd_handle = metadata.Metadata(
connection_config=metadata_store_pb2.ConnectionConfig(),
)
mlmd_handle._store = ( # pylint: disable=protected-access
store if store is not None else mock.MagicMock(spec=mlmd.MetadataStore)
)
context = resolver_op.Context(
store=store
if store is not None
else mock.MagicMock(spec=mlmd.MetadataStore),
mlmd_handle_like=mlmd_handle_like,
mlmd_handle_like=mlmd_handle,
)
op.set_context(context)
result = op.apply(*args)
Expand Down
6 changes: 3 additions & 3 deletions tfx/dsl/input_resolution/ops/training_range_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def testTrainingRangeOp_EmptyListReturned(self):
actual = test_utils.run_resolver_op(
ops.TrainingRange,
[],
context=resolver_op.Context(store=self.store),
context=resolver_op.Context(self.mlmd_cm),
)
self.assertEmpty(actual)

Expand All @@ -150,14 +150,14 @@ def testTrainingRangeOp_InvalidArgumentRaised(self):
test_utils.run_resolver_op(
ops.TrainingRange,
[self.model, self.model],
context=resolver_op.Context(store=self.store),
context=resolver_op.Context(self.mlmd_cm),
)

# Incorret input artifact type.
test_utils.run_resolver_op(
ops.TrainingRange,
[self.transform_graph],
context=resolver_op.Context(store=self.store),
context=resolver_op.Context(self.mlmd_cm),
)

def testTrainingRangeOp_BulkInferrerProducesExamples(self):
Expand Down
11 changes: 2 additions & 9 deletions tfx/dsl/input_resolution/resolver_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from tfx.utils import json_utils
from tfx.utils import typing_utils

import ml_metadata as mlmd


# Mark frozen as context instance may be used across multiple operator
# invocations.
Expand All @@ -35,18 +33,13 @@ class Context:

def __init__(
self,
store=mlmd.MetadataStore,
mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None,
mlmd_handle_like: mlmd_cm.HandleLike,
):
# TODO(b/302730333) We could remove self._store, and only use
# self._mlmd_handle_like. Keeping it for now to preserve backward
# compatibility with other resolve ops.
self._store = store
self._mlmd_handle_like = mlmd_handle_like

@property
def store(self):
return self._store
return mlmd_cm.get_handle(self._mlmd_handle_like).store

@property
def mlmd_connection_manager(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _evaluate_op_node(
op: resolver_op.ResolverOp = op_type.create(**kwargs)
op.set_context(
resolver_op.Context(
store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store,
mlmd_handle_like=ctx.mlmd_handle_like,
)
)
Expand Down

0 comments on commit 3fd7129

Please sign in to comment.