Skip to content

Commit

Permalink
Update template models not to use deprecated Keras apis (#7723)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikelite authored Dec 2, 2024
1 parent 34c7147 commit f42957d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 84 deletions.
139 changes: 57 additions & 82 deletions tfx/experimental/templates/taxi/models/keras_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,98 +106,73 @@ def _build_keras_model(hidden_units, learning_rate):
Returns:
A keras Model.
"""
real_valued_columns = [
tf.feature_column.numeric_column(key, shape=())
for key in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
]
categorical_columns = [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=features.VOCAB_SIZE + features.OOV_SIZE,
default_value=0)
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
features.transformed_names(features.BUCKET_FEATURE_KEYS),
features.BUCKET_FEATURE_BUCKET_COUNT)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
features.CATEGORICAL_FEATURE_MAX_VALUES)
]
indicator_column = [
tf.feature_column.indicator_column(categorical_column)
for categorical_column in categorical_columns
]

model = _wide_and_deep_classifier(
# TODO(b/140320729) Replace with premade wide_and_deep keras model
wide_columns=indicator_column,
deep_columns=real_valued_columns,
dnn_hidden_units=hidden_units,
learning_rate=learning_rate)
return model


def _wide_and_deep_classifier(wide_columns, deep_columns, dnn_hidden_units,
learning_rate):
"""Build a simple keras wide and deep model.
Args:
wide_columns: Feature columns wrapped in indicator_column for wide (linear)
part of the model.
deep_columns: Feature columns for deep part of the model.
dnn_hidden_units: [int], the layer sizes of the hidden DNN.
learning_rate: [float], learning rate of the Adam optimizer.
Returns:
A Wide and Deep Keras model
"""
# Keras needs the feature definitions at compile time.
# TODO(b/139081439): Automate generation of input layers from FeatureColumn.
input_layers = {
colname: tf.keras.layers.Input(name=colname, shape=(), dtype=tf.float32)
for colname in features.transformed_names(
features.DENSE_FLOAT_FEATURE_KEYS)
deep_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32)
for colname in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
}
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
wide_vocab_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.VOCAB_FEATURE_KEYS)
})
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
}
wide_bucket_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.BUCKET_FEATURE_KEYS)
})
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') for
colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
})

# TODO(b/161952382): Replace with Keras premade models and
# Keras preprocessing layers.
deep = tf.keras.layers.DenseFeatures(deep_columns)(input_layers)
for numnodes in dnn_hidden_units:
}
wide_categorical_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
}
input_layers = {
**deep_input,
**wide_vocab_input,
**wide_bucket_input,
**wide_categorical_input,
}

deep = tf.keras.layers.concatenate(
[tf.keras.layers.Normalization()(layer) for layer in deep_input.values()]
)
for numnodes in (hidden_units or [100, 70, 50, 25]):
deep = tf.keras.layers.Dense(numnodes)(deep)
wide = tf.keras.layers.DenseFeatures(wide_columns)(input_layers)

output = tf.keras.layers.Dense(
1, activation='sigmoid')(
tf.keras.layers.concatenate([deep, wide]))
output = tf.squeeze(output, -1)
wide_layers = []
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=features.VOCAB_SIZE + features.OOV_SIZE)(
input_layers[key]
)
)
for key, num_tokens in zip(
features.transformed_names(features.BUCKET_FEATURE_KEYS),
features.BUCKET_FEATURE_BUCKET_COUNT,
):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
input_layers[key]
)
)
for key, num_tokens in zip(
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
features.CATEGORICAL_FEATURE_MAX_VALUES,
):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
input_layers[key]
)
)
wide = tf.keras.layers.concatenate(wide_layers)

output = tf.keras.layers.Dense(1, activation='sigmoid')(
tf.keras.layers.concatenate([deep, wide])
)
output = tf.keras.layers.Reshape((1,))(output)

model = tf.keras.Model(input_layers, output)
model.compile(
loss='binary_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
metrics=[tf.keras.metrics.BinaryAccuracy()])
metrics=[tf.keras.metrics.BinaryAccuracy()],
)
model.summary(print_fn=logging.info)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ModelTest(tf.test.TestCase):
def testBuildKerasModel(self):
built_model = model._build_keras_model(
hidden_units=[1, 1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 10)
self.assertEqual(len(built_model.layers), 13)

built_model = model._build_keras_model(hidden_units=[1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 9)
self.assertEqual(len(built_model.layers), 12)

0 comments on commit f42957d

Please sign in to comment.