From a4f29a0f9b9e2f1ad112982af9c19f881c629d1b Mon Sep 17 00:00:00 2001 From: Madhur Karampudi <142544288+vkarampudi@users.noreply.github.com> Date: Sun, 8 Dec 2024 21:24:36 -0800 Subject: [PATCH] Removing tf-ranking as a dependency untill it supports tf 2.16 (#7725) --- tfx/examples/ranking/features.py | 35 +- .../ranking/ranking_pipeline_e2e_test.py | 47 +-- tfx/examples/ranking/ranking_utils.py | 314 ++++-------------- .../struct2tensor_parsing_utils_test.py | 164 +++++---- 4 files changed, 195 insertions(+), 365 deletions(-) diff --git a/tfx/examples/ranking/features.py b/tfx/examples/ranking/features.py index e338240750..4863da52d6 100644 --- a/tfx/examples/ranking/features.py +++ b/tfx/examples/ranking/features.py @@ -17,36 +17,37 @@ These names will be shared between the transform and the model. """ -import tensorflow as tf -from tfx.examples.ranking import struct2tensor_parsing_utils +# import tensorflow as tf +# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added. +# from tfx.examples.ranking import struct2tensor_parsing_utils # Labels are expected to be dense. In case of a batch of ELWCs have different # number of documents, the shape of the label is [N, D], where N is the batch # size, D is the maximum number of documents in the batch. If an ELWC in the # batch has D_0 < D documents, then the value of label at D0 <= d < D must be # negative to indicate that the document is invalid. -LABEL_PADDING_VALUE = -1 +#LABEL_PADDING_VALUE = -1 # Names of features in the ELWC. -QUERY_TOKENS = 'query_tokens' -DOCUMENT_TOKENS = 'document_tokens' -LABEL = 'relevance' +#QUERY_TOKENS = 'query_tokens' +#DOCUMENT_TOKENS = 'document_tokens' +#LABEL = 'relevance' # This "feature" does not exist in the data but will be created on the fly. -LIST_SIZE_FEATURE_NAME = 'example_list_size' +# LIST_SIZE_FEATURE_NAME = 'example_list_size' -def get_features(): - """Defines the context features and example features spec for parsing.""" +#def get_features(): +# """Defines the context features and example features spec for parsing.""" - context_features = [ - struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string) - ] + # context_features = [ + # struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string) + # ] - example_features = [ - struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string) - ] +# example_features = [ +# struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string) +# ] - label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64) +# label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64) - return context_features, example_features, label +# return context_features, example_features, label diff --git a/tfx/examples/ranking/ranking_pipeline_e2e_test.py b/tfx/examples/ranking/ranking_pipeline_e2e_test.py index 7d71530f4b..9e953cc688 100644 --- a/tfx/examples/ranking/ranking_pipeline_e2e_test.py +++ b/tfx/examples/ranking/ranking_pipeline_e2e_test.py @@ -16,9 +16,12 @@ import unittest import tensorflow as tf -from tfx.examples.ranking import ranking_pipeline -from tfx.orchestration import metadata -from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner +# from tfx.orchestration import metadata +# from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner + +# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added. +# from tfx.examples.ranking import ranking_pipeline + try: import struct2tensor # pylint: disable=g-import-not-at-top @@ -62,23 +65,23 @@ def assertExecutedOnce(self, component) -> None: execution = tf.io.gfile.listdir(os.path.join(component_path, output)) self.assertEqual(1, len(execution)) - def testPipeline(self): - BeamDagRunner().run( - ranking_pipeline._create_pipeline( - pipeline_name=self._pipeline_name, - pipeline_root=self._tfx_root, - data_root=self._data_root, - module_file=self._module_file, - serving_model_dir=self._serving_model_dir, - metadata_path=self._metadata_path, - beam_pipeline_args=['--direct_num_workers=1'])) - self.assertTrue(tf.io.gfile.exists(self._serving_model_dir)) - self.assertTrue(tf.io.gfile.exists(self._metadata_path)) + #def testPipeline(self): + # BeamDagRunner().run( + # ranking_pipeline._create_pipeline( + # pipeline_name=self._pipeline_name, + # pipeline_root=self._tfx_root, + # data_root=self._data_root, + # module_file=self._module_file, + # serving_model_dir=self._serving_model_dir, + # metadata_path=self._metadata_path, + # beam_pipeline_args=['--direct_num_workers=1'])) + # self.assertTrue(tf.io.gfile.exists(self._serving_model_dir)) + # self.assertTrue(tf.io.gfile.exists(self._metadata_path)) - metadata_config = metadata.sqlite_metadata_connection_config( - self._metadata_path) - with metadata.Metadata(metadata_config) as m: - artifact_count = len(m.store.get_artifacts()) - execution_count = len(m.store.get_executions()) - self.assertGreaterEqual(artifact_count, execution_count) - self.assertEqual(9, execution_count) + # metadata_config = metadata.sqlite_metadata_connection_config( + # self._metadata_path) + # with metadata.Metadata(metadata_config) as m: + # artifact_count = len(m.store.get_artifacts()) + # execution_count = len(m.store.get_executions()) + # self.assertGreaterEqual(artifact_count, execution_count) + # self.assertEqual(9, execution_count) diff --git a/tfx/examples/ranking/ranking_utils.py b/tfx/examples/ranking/ranking_utils.py index 7312bed837..9e953cc688 100644 --- a/tfx/examples/ranking/ranking_utils.py +++ b/tfx/examples/ranking/ranking_utils.py @@ -11,247 +11,77 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Module file.""" +"""Tests for tfx.examples.ranking.ranking_pipeline.""" +import os +import unittest import tensorflow as tf -import tensorflow_ranking as tfr -import tensorflow_transform as tft -from tfx.examples.ranking import features -from tfx.examples.ranking import struct2tensor_parsing_utils -from tfx_bsl.public import tfxio - - -def make_decoder(): - """Creates a data decoder that that decodes ELWC records to tensors. - - A DataView (see "TfGraphDataViewProvider" component in the pipeline) - will refer to this decoder. And any components that consumes the data - with the DataView applied will use this decoder. - - Returns: - A ELWC decoder. - """ - context_features, example_features, label_feature = features.get_features() - - return struct2tensor_parsing_utils.ELWCDecoder( - name='ELWCDecoder', - context_features=context_features, - example_features=example_features, - size_feature_name=features.LIST_SIZE_FEATURE_NAME, - label_feature=label_feature) - - -def preprocessing_fn(inputs): - """Transform preprocessing_fn.""" - - # generate a shared vocabulary. - _ = tft.vocabulary( - tf.concat([ - inputs[features.QUERY_TOKENS].flat_values, - inputs[features.DOCUMENT_TOKENS].flat_values - ], - axis=0), - vocab_filename='shared_vocab') - return inputs - - -def run_fn(trainer_fn_args): - """TFX trainer entry point.""" - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - hparams = dict( - batch_size=32, - embedding_dimension=20, - learning_rate=0.05, - dropout_rate=0.8, - hidden_layer_dims=[64, 32, 16], - loss='approx_ndcg_loss', - use_batch_norm=True, - batch_norm_moment=0.99 - ) - - train_dataset = _input_fn(trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - hparams['batch_size']) - eval_dataset = _input_fn(trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - hparams['batch_size']) - - model = _create_ranking_model(tf_transform_output, hparams) - model.summary() - log_dir = trainer_fn_args.model_run_dir - # Write logs to path - tensorboard_callback = tf.keras.callbacks.TensorBoard( - log_dir=log_dir, update_freq='epoch') - model.fit( - train_dataset, - steps_per_epoch=trainer_fn_args.train_steps, - validation_data=eval_dataset, - validation_steps=trainer_fn_args.eval_steps, - callbacks=[tensorboard_callback]) - - # TODO(zhuo): Add support for Regress signature. - @tf.function(input_signature=[tf.TensorSpec([None], tf.string)], - autograph=False) - def predict_serving_fn(serialized_elwc_records): - decode_fn = trainer_fn_args.data_accessor.data_view_decode_fn - decoded = decode_fn(serialized_elwc_records) - decoded.pop(features.LABEL) - return {tf.saved_model.PREDICT_OUTPUTS: model(decoded)} - - model.save( - trainer_fn_args.serving_model_dir, - save_format='tf', - signatures={ - 'serving_default': - predict_serving_fn.get_concrete_function(), - }) - - -def _input_fn(file_patterns, - data_accessor, - batch_size) -> tf.data.Dataset: - """Returns a dataset of decoded tensors.""" - - def prepare_label(parsed_ragged_tensors): - label = parsed_ragged_tensors.pop(features.LABEL) - # Convert labels to a dense tensor. - label = label.to_tensor(default_value=features.LABEL_PADDING_VALUE) - return parsed_ragged_tensors, label - - # NOTE: this dataset already contains RaggedTensors from the Decoder. - dataset = data_accessor.tf_dataset_factory( - file_patterns, - tfxio.TensorFlowDatasetOptions(batch_size=batch_size), - schema=None) - return dataset.map(prepare_label).repeat() - - -def _preprocess_keras_inputs(context_keras_inputs, example_keras_inputs, - tf_transform_output, hparams): - """Preprocesses the inputs, including vocab lookup and embedding.""" - lookup_layer = tf.keras.layers.experimental.preprocessing.StringLookup( - max_tokens=( - tf_transform_output.vocabulary_size_by_name('shared_vocab') + 1), - vocabulary=tf_transform_output.vocabulary_file_by_name('shared_vocab'), - num_oov_indices=1, - oov_token='[UNK#]', - mask_token=None) - embedding_layer = tf.keras.layers.Embedding( - input_dim=( - tf_transform_output.vocabulary_size_by_name('shared_vocab') + 1), - output_dim=hparams['embedding_dimension'], - embeddings_initializer=None, - embeddings_constraint=None) - def embedding(input_tensor): - # TODO(b/158673891): Support weighted features. - embedded_tensor = embedding_layer(lookup_layer(input_tensor)) - mean_embedding = tf.reduce_mean(embedded_tensor, axis=-2) - # mean_embedding could be a dense tensor (context feature) or a ragged - # tensor (example feature). if it's ragged, we densify it first. - if isinstance(mean_embedding.type_spec, tf.RaggedTensorSpec): - return struct2tensor_parsing_utils.make_ragged_densify_layer()( - mean_embedding) - return mean_embedding - preprocessed_context_features, preprocessed_example_features = {}, {} - context_features, example_features, _ = features.get_features() - for feature in context_features: - preprocessed_context_features[feature.name] = embedding( - context_keras_inputs[feature.name]) - for feature in example_features: - preprocessed_example_features[feature.name] = embedding( - example_keras_inputs[feature.name]) - list_size = struct2tensor_parsing_utils.make_ragged_densify_layer()( - context_keras_inputs[features.LIST_SIZE_FEATURE_NAME]) - list_size = tf.reshape(list_size, [-1]) - mask = tf.sequence_mask(list_size) - - return preprocessed_context_features, preprocessed_example_features, mask - - -def _create_ranking_model(tf_transform_output, hparams) -> tf.keras.Model: - """Creates a Keras ranking model.""" - context_feature_specs, example_feature_specs, _ = features.get_features() - context_keras_inputs, example_keras_inputs = ( - struct2tensor_parsing_utils.create_keras_inputs( - context_feature_specs, example_feature_specs, - features.LIST_SIZE_FEATURE_NAME)) - context_features, example_features, mask = _preprocess_keras_inputs( - context_keras_inputs, example_keras_inputs, tf_transform_output, hparams) - - # Since argspec inspection is expensive, for keras layer, - # layer_obj._call_spec.arg_names is a property that uses cached argspec for - # call. We use this to determine whether the layer expects `inputs` as first - # argument. - # TODO(b/185176464): update tfr dependency to remove this branch. - flatten_list = tfr.keras.layers.FlattenList() - - # TODO(kathywu): remove the except branch once changes to the call function - # args in the Keras Layer have been released. - try: - first_arg_name = flatten_list._call_spec.arg_names[0] # pylint: disable=protected-access - except AttributeError: - first_arg_name = flatten_list._call_fn_args[0] # pylint: disable=protected-access - if first_arg_name == 'inputs': - (flattened_context_features, flattened_example_features) = flatten_list( - inputs=(context_features, example_features, mask)) - else: - (flattened_context_features, - flattened_example_features) = flatten_list(context_features, - example_features, mask) - - # Concatenate flattened context and example features along `list_size` dim. - context_input = [ - tf.keras.layers.Flatten()(flattened_context_features[name]) - for name in sorted(flattened_context_features) - ] - example_input = [ - tf.keras.layers.Flatten()(flattened_example_features[name]) - for name in sorted(flattened_example_features) - ] - input_layer = tf.concat(context_input + example_input, 1) - dnn = tf.keras.Sequential() - if hparams['use_batch_norm']: - dnn.add( - tf.keras.layers.BatchNormalization( - momentum=hparams['batch_norm_moment'])) - for layer_size in hparams['hidden_layer_dims']: - dnn.add(tf.keras.layers.Dense(units=layer_size)) - if hparams['use_batch_norm']: - dnn.add(tf.keras.layers.BatchNormalization( - momentum=hparams['batch_norm_moment'])) - dnn.add(tf.keras.layers.Activation(activation=tf.nn.relu)) - dnn.add(tf.keras.layers.Dropout(rate=hparams['dropout_rate'])) - - dnn.add(tf.keras.layers.Dense(units=1)) - - # Since argspec inspection is expensive, for keras layer, - # layer_obj._call_spec.arg_names is a property that uses cached argspec for - # call. We use this to determine whether the layer expects `inputs` as first - # argument. - restore_list = tfr.keras.layers.RestoreList() - - # TODO(kathywu): remove the except branch once changes to the call function - # args in the Keras Layer have been released. - try: - first_arg_name = flatten_list._call_spec.arg_names[0] # pylint: disable=protected-access - except AttributeError: - first_arg_name = flatten_list._call_fn_args[0] # pylint: disable=protected-access - if first_arg_name == 'inputs': - logits = restore_list(inputs=(dnn(input_layer), mask)) - else: - logits = restore_list(dnn(input_layer), mask) - - model = tf.keras.Model( - inputs={ - **context_keras_inputs, - **example_keras_inputs - }, - outputs=logits, - name='dnn_ranking_model') - model.compile( - optimizer=tf.keras.optimizers.Adagrad( - learning_rate=hparams['learning_rate']), - loss=tfr.keras.losses.get(hparams['loss']), - metrics=tfr.keras.metrics.default_keras_metrics()) - return model +# from tfx.orchestration import metadata +# from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner + +# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added. +# from tfx.examples.ranking import ranking_pipeline + + +try: + import struct2tensor # pylint: disable=g-import-not-at-top +except ImportError: + struct2tensor = None + +import pytest + + +@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. " +"If all tests pass, please remove this mark.") +@pytest.mark.e2e +@unittest.skipIf(struct2tensor is None, + 'Cannot import required modules. This can happen when' + ' struct2tensor is not available.') +class RankingPipelineTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self._test_dir = os.path.join( + os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), + self._testMethodName) + + self._pipeline_name = 'tf_ranking_test' + self._data_root = os.path.join(os.path.dirname(__file__), + 'testdata', 'input') + self._tfx_root = os.path.join(self._test_dir, 'tfx') + self._module_file = os.path.join(os.path.dirname(__file__), + 'ranking_utils.py') + self._serving_model_dir = os.path.join(self._test_dir, 'serving_model') + self._metadata_path = os.path.join(self._tfx_root, 'metadata', + self._pipeline_name, 'metadata.db') + print('TFX ROOT: ', self._tfx_root) + + def assertExecutedOnce(self, component) -> None: + """Check the component is executed exactly once.""" + component_path = os.path.join(self._pipeline_root, component) + self.assertTrue(tf.io.gfile.exists(component_path)) + outputs = tf.io.gfile.listdir(component_path) + for output in outputs: + execution = tf.io.gfile.listdir(os.path.join(component_path, output)) + self.assertEqual(1, len(execution)) + + #def testPipeline(self): + # BeamDagRunner().run( + # ranking_pipeline._create_pipeline( + # pipeline_name=self._pipeline_name, + # pipeline_root=self._tfx_root, + # data_root=self._data_root, + # module_file=self._module_file, + # serving_model_dir=self._serving_model_dir, + # metadata_path=self._metadata_path, + # beam_pipeline_args=['--direct_num_workers=1'])) + # self.assertTrue(tf.io.gfile.exists(self._serving_model_dir)) + # self.assertTrue(tf.io.gfile.exists(self._metadata_path)) + + # metadata_config = metadata.sqlite_metadata_connection_config( + # self._metadata_path) + # with metadata.Metadata(metadata_config) as m: + # artifact_count = len(m.store.get_artifacts()) + # execution_count = len(m.store.get_executions()) + # self.assertGreaterEqual(artifact_count, execution_count) + # self.assertEqual(9, execution_count) diff --git a/tfx/examples/ranking/struct2tensor_parsing_utils_test.py b/tfx/examples/ranking/struct2tensor_parsing_utils_test.py index f523ef1de7..2d2406012a 100644 --- a/tfx/examples/ranking/struct2tensor_parsing_utils_test.py +++ b/tfx/examples/ranking/struct2tensor_parsing_utils_test.py @@ -1,3 +1,4 @@ + # Copyright 2021 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,18 +16,18 @@ -import itertools -import unittest +# import unittest import tensorflow as tf from google.protobuf import text_format from tensorflow_serving.apis import input_pb2 -try: - from tfx.examples.ranking import struct2tensor_parsing_utils # pylint: disable=g-import-not-at-top -except ImportError: - struct2tensor_parsing_utils = None +#try: + # This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added. + # from tfx.examples.ranking import struct2tensor_parsing_utils # pylint: disable=g-import-not-at-top +#except ImportError: + # struct2tensor_parsing_utils = None _ELWCS = [ @@ -171,82 +172,77 @@ ] -@unittest.skipIf(struct2tensor_parsing_utils is None, - 'Cannot import required modules. This can happen when' - ' struct2tensor is not available.') +# @unittest.skipIf(struct2tensor_parsing_utils is None, +# 'Cannot import required modules. This can happen when' +# ' struct2tensor is not available.') class ELWCDecoderTest(tf.test.TestCase): - - def testAllDTypes(self): - context_features = [ - struct2tensor_parsing_utils.Feature('ctx.int', tf.int64), - struct2tensor_parsing_utils.Feature('ctx.float', tf.float32), - struct2tensor_parsing_utils.Feature('ctx.bytes', tf.string), - ] - example_features = [ - struct2tensor_parsing_utils.Feature('example_int', tf.int64), - struct2tensor_parsing_utils.Feature('example_float', tf.float32), - struct2tensor_parsing_utils.Feature('example_bytes', tf.string), - ] - decoder = struct2tensor_parsing_utils.ELWCDecoder( - 'test_decoder', context_features, example_features, - size_feature_name=None, label_feature=None) - - result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) - self.assertLen(result, len(context_features) + len(example_features)) - for f in itertools.chain(context_features, example_features): - self.assertIn(f.name, result) - self.assertIsInstance(result[f.name], tf.RaggedTensor) - - expected = { - 'ctx.int': [[1, 2], [3]], - 'ctx.float': [[1.0, 2.0], [3.0]], - 'ctx.bytes': [[], [b'c']], - 'example_int': [[[11], [22]], [[33]]], - 'example_float': [[[11.0, 12.0], []], [[14.0, 15.0]]], - 'example_bytes': [[[b'u', b'v'], [b'w']], [[b'x', b'y', b'z']]], - } - self.assertEqual({k: v.to_list() for k, v in result.items()}, expected) - - def testDefaultFilling(self): - context_features = [ - struct2tensor_parsing_utils.Feature('ctx.bytes', tf.string, - default_value=b'g', length=1), - ] - example_features = [ - struct2tensor_parsing_utils.Feature('example_float', tf.float32, - default_value=-1.0, length=2), - ] - decoder = struct2tensor_parsing_utils.ELWCDecoder( - 'test_decoder', context_features, example_features, - size_feature_name=None, label_feature=None) - - result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) - self.assertLen(result, len(context_features) + len(example_features)) - for f in itertools.chain(context_features, example_features): - self.assertIn(f.name, result) - self.assertIsInstance(result[f.name], tf.RaggedTensor) - - expected = { - 'ctx.bytes': [[b'g'], [b'c']], - 'example_float': [[[11.0, 12.0], [-1.0, -1.0]], [[14.0, 15.0]]], - } - self.assertEqual({k: v.to_list() for k, v in result.items()}, expected) - - def testLabelFeature(self): - decoder = struct2tensor_parsing_utils.ELWCDecoder( - 'test_decoder', [], [], - size_feature_name=None, - label_feature=struct2tensor_parsing_utils.Feature( - 'example_int', tf.int64)) - result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) - - self.assertLen(result, 1) - self.assertEqual(result['example_int'].to_list(), [[11.0, 22.0], [33.0]]) - - def testSizeFeature(self): - decoder = struct2tensor_parsing_utils.ELWCDecoder( - 'test_decoder', [], [], - size_feature_name='example_list_size') - result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) - self.assertLen(result, 1) - self.assertEqual(result['example_list_size'].to_list(), [[2], [1]]) + pass # Added to prevent syntax error due to an empty class definition + #def testAllDTypes(self): + # context_features = [ + # struct2tensor_parsing_utils.Feature('ctx.int', tf.int64), + # struct2tensor_parsing_utils.Feature('ctx.float', tf.float32), + # struct2tensor_parsing_utils.Feature('ctx.bytes', tf.string), + # ] + # example_features = [ + # struct2tensor_parsing_utils.Feature('example_int', tf.int64), + # struct2tensor_parsing_utils.Feature('example_float', tf.float32), + # struct2tensor_parsing_utils.Feature('example_bytes', tf.string), + # ] + # decoder = struct2tensor_parsing_utils.ELWCDecoder( + # 'test_decoder', context_features, example_features, + # size_feature_name=None, label_feature=None) + + # result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) + # self.assertLen(result, len(context_features) + len(example_features)) + # for f in itertools.chain(context_features, example_features): + # self.assertIn(f.name, result) + # self.assertIsInstance(result[f.name], tf.RaggedTensor) + + # expected = { + # 'ctx.int': [[1, 2], [3]], + # 'ctx.float': [[1.0, 2.0], [3.0]], + # 'ctx.bytes': [[], [b'c']], + # 'example_int': [[[11], [22]], [[33]]], + # 'example_float': [[[11.0, 12.0], []], [[14.0, 15.0]]], + # 'example_bytes': [[[b'u', b'v'], [b'w']], [[b'x', b'y', b'z']]], + # } + # self.assertEqual({k: v.to_list() for k, v in result.items()}, expected) + # def testDefaultFilling(self): + # context_features = [ + # struct2tensor_parsing_utils.Feature('ctx.bytes', tf.string, + # default_value=b'g', length=1), + # ] + # example_features = [ + # struct2tensor_parsing_utils.Feature('example_float', tf.float32, + # default_value=-1.0, length=2), + # ] + # decoder = struct2tensor_parsing_utils.ELWCDecoder( + # 'test_decoder', context_features, example_features, + # size_feature_name=None, label_feature=None) + # result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) + # self.assertLen(result, len(context_features) + len(example_features)) + # for f in itertools.chain(context_features, example_features): + # self.assertIn(f.name, result) + # self.assertIsInstance(result[f.name], tf.RaggedTensor) + # expected = { + # 'ctx.bytes': [[b'g'], [b'c']], + # 'example_float': [[[11.0, 12.0], [-1.0, -1.0]], [[14.0, 15.0]]], + # } + # self.assertEqual({k: v.to_list() for k, v in result.items()}, expected) + # def testLabelFeature(self): + # decoder = struct2tensor_parsing_utils.ELWCDecoder( + # 'test_decoder', [], [], + # size_feature_name=None, + # label_feature=struct2tensor_parsing_utils.Feature( + # 'example_int', tf.int64)) + # result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) + + # self.assertLen(result, 1) + # self.assertEqual(result['example_int'].to_list(), [[11.0, 22.0], [33.0]]) + # def testSizeFeature(self): + # decoder = struct2tensor_parsing_utils.ELWCDecoder( + # 'test_decoder', [], [], + # size_feature_name='example_list_size') + # result = decoder.decode_record(tf.convert_to_tensor(_ELWCS)) + # self.assertLen(result, 1) + # self.assertEqual(result['example_list_size'].to_list(), [[2], [1]])