diff --git a/tfx/experimental/templates/taxi/models/keras_model/model.py b/tfx/experimental/templates/taxi/models/keras_model/model.py index 24232320f5..9cad95aed8 100644 --- a/tfx/experimental/templates/taxi/models/keras_model/model.py +++ b/tfx/experimental/templates/taxi/models/keras_model/model.py @@ -106,98 +106,73 @@ def _build_keras_model(hidden_units, learning_rate): Returns: A keras Model. """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=features.VOCAB_SIZE + features.OOV_SIZE, - default_value=0) - for key in features.transformed_names(features.VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - features.transformed_names(features.BUCKET_FEATURE_KEYS), - features.BUCKET_FEATURE_BUCKET_COUNT) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - features.transformed_names(features.CATEGORICAL_FEATURE_KEYS), - features.CATEGORICAL_FEATURE_MAX_VALUES) - ] - indicator_column = [ - tf.feature_column.indicator_column(categorical_column) - for categorical_column in categorical_columns - ] - - model = _wide_and_deep_classifier( - # TODO(b/140320729) Replace with premade wide_and_deep keras model - wide_columns=indicator_column, - deep_columns=real_valued_columns, - dnn_hidden_units=hidden_units, - learning_rate=learning_rate) - return model - - -def _wide_and_deep_classifier(wide_columns, deep_columns, dnn_hidden_units, - learning_rate): - """Build a simple keras wide and deep model. - - Args: - wide_columns: Feature columns wrapped in indicator_column for wide (linear) - part of the model. - deep_columns: Feature columns for deep part of the model. - dnn_hidden_units: [int], the layer sizes of the hidden DNN. - learning_rate: [float], learning rate of the Adam optimizer. - - Returns: - A Wide and Deep Keras model - """ - # Keras needs the feature definitions at compile time. - # TODO(b/139081439): Automate generation of input layers from FeatureColumn. - input_layers = { - colname: tf.keras.layers.Input(name=colname, shape=(), dtype=tf.float32) - for colname in features.transformed_names( - features.DENSE_FLOAT_FEATURE_KEYS) + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32) + for colname in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS) } - input_layers.update({ - colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') for colname in features.transformed_names(features.VOCAB_FEATURE_KEYS) - }) - input_layers.update({ - colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') for colname in features.transformed_names(features.BUCKET_FEATURE_KEYS) - }) - input_layers.update({ - colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') for - colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS) - }) - - # TODO(b/161952382): Replace with Keras premade models and - # Keras preprocessing layers. - deep = tf.keras.layers.DenseFeatures(deep_columns)(input_layers) - for numnodes in dnn_hidden_units: + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } + + deep = tf.keras.layers.concatenate( + [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()] + ) + for numnodes in (hidden_units or [100, 70, 50, 25]): deep = tf.keras.layers.Dense(numnodes)(deep) - wide = tf.keras.layers.DenseFeatures(wide_columns)(input_layers) - output = tf.keras.layers.Dense( - 1, activation='sigmoid')( - tf.keras.layers.concatenate([deep, wide])) - output = tf.squeeze(output, -1) + wide_layers = [] + for key in features.transformed_names(features.VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=features.VOCAB_SIZE + features.OOV_SIZE)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + features.transformed_names(features.BUCKET_FEATURE_KEYS), + features.BUCKET_FEATURE_BUCKET_COUNT, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + features.transformed_names(features.CATEGORICAL_FEATURE_KEYS), + features.CATEGORICAL_FEATURE_MAX_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.keras.layers.Reshape((1,))(output) model = tf.keras.Model(input_layers, output) model.compile( loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), - metrics=[tf.keras.metrics.BinaryAccuracy()]) + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) model.summary(print_fn=logging.info) return model diff --git a/tfx/experimental/templates/taxi/models/keras_model/model_test.py b/tfx/experimental/templates/taxi/models/keras_model/model_test.py index 7dd6110a6b..a12a6e3c32 100644 --- a/tfx/experimental/templates/taxi/models/keras_model/model_test.py +++ b/tfx/experimental/templates/taxi/models/keras_model/model_test.py @@ -22,7 +22,7 @@ class ModelTest(tf.test.TestCase): def testBuildKerasModel(self): built_model = model._build_keras_model( hidden_units=[1, 1], learning_rate=0.1) # pylint: disable=protected-access - self.assertEqual(len(built_model.layers), 10) + self.assertEqual(len(built_model.layers), 13) built_model = model._build_keras_model(hidden_units=[1], learning_rate=0.1) # pylint: disable=protected-access - self.assertEqual(len(built_model.layers), 9) + self.assertEqual(len(built_model.layers), 12)