Skip to content

Commit

Permalink
This is an internal cleanup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 420896786
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Jan 11, 2022
1 parent fcb15bd commit 65bb403
Showing 1 changed file with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import datetime
import numbers
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union
import apache_beam as beam
import numpy as np

Expand All @@ -41,6 +41,9 @@
from tfx_bsl.tfxio import tensor_adapter
from tensorflow_metadata.proto.v0 import schema_pb2

SliceKeyTypeVar = TypeVar('SliceKeyTypeVar', slicer.SliceKeyType,
slicer.CrossSliceKeyType)

_COMBINER_INPUTS_KEY = '_combiner_inputs'
_DEFAULT_COMBINER_INPUT_KEY = '_default_combiner_input'
_DEFAULT_NUM_JACKKNIFE_BUCKETS = 20
Expand Down Expand Up @@ -381,19 +384,20 @@ def _is_private_metrics(metric_key: metric_types.MetricKey):


def _remove_private_metrics(
slice_key: slicer.SliceKeyOrCrossSliceKeyType,
metrics: metric_types.MetricsDict
) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]:
slice_key: SliceKeyTypeVar, metrics: metric_types.MetricsDict
) -> Tuple[SliceKeyTypeVar, metric_types.MetricsDict]:
return (slice_key,
{k: v for (k, v) in metrics.items() if not _is_private_metrics(k)})


@beam.ptransform_fn
def _AddCrossSliceMetrics( # pylint: disable=invalid-name
sliced_combiner_outputs: beam.pvalue.PCollection,
sliced_combiner_outputs: beam.pvalue.PCollection[Tuple[
slicer.SliceKeyType, metric_types.MetricsDict]],
cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]],
cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]:
) -> beam.pvalue.PCollection[Tuple[slicer.SliceKeyOrCrossSliceKeyType,
metric_types.MetricsDict]]:
"""Generates CrossSlice metrics from SingleSlices."""

def is_slice_applicable(
Expand Down Expand Up @@ -495,8 +499,8 @@ def _AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=invalid-name
derived_computations: List[metric_types.DerivedMetricComputation],
cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]] = None,
baseline_model_name: Optional[str] = None
) -> beam.PCollection[Tuple[slicer.SliceKeyType, metric_types.MetricsDict]]:
baseline_model_name: Optional[str] = None) -> beam.PCollection[Tuple[
slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]]:
"""A PTransform for adding cross slice and derived metrics.
This PTransform uses the input PCollection of sliced metrics to compute
Expand Down Expand Up @@ -564,11 +568,11 @@ def add_diff_metrics(


def _filter_by_key_type(
sliced_metrics_plots_attributions: Tuple[slicer.SliceKeyType,
sliced_metrics_plots_attributions: Tuple[SliceKeyTypeVar,
Dict[metric_types.MetricKey, Any]],
key_type: Type[Union[metric_types.MetricKey, metric_types.PlotKey,
metric_types.AttributionsKey]]
) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]:
) -> Tuple[SliceKeyTypeVar, Dict[metric_types.MetricKey, Any]]:
"""Filters metrics and plots by key type."""
slice_value, metrics_plots_attributions = sliced_metrics_plots_attributions
output = {}
Expand Down

0 comments on commit 65bb403

Please sign in to comment.