Skip to content

Commit

Permalink
Merge pull request #354 from vespa-engine/tgm/add-distributed-strategy
Browse files Browse the repository at this point in the history
Include distributed training on ranking framework
  • Loading branch information
thigm85 authored Jun 24, 2022
2 parents c748bcb + 943dc5a commit e16b304
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions vespa/experimental/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def __init__(

self.query_id_name = "query_id"
self.target_name = "label"
self.distribute_strategy = tf.distribute.MirroredStrategy()

def listwise_tf_dataset_from_df(
self, df, feature_names, shuffle_buffer_size, batch_size
Expand Down Expand Up @@ -555,6 +556,7 @@ def tune_model(self, model, train_ds, dev_ds):
objective=kt.Objective("val_ndcg_stateless", direction="max"),
directory=self.folder_dir,
project_name="keras_tuner",
distribution_strategy=self.distribute_strategy,
overwrite=True,
max_trials=self.tuner_max_trials,
executions_per_trial=self.tuner_executions_per_trial,
Expand Down Expand Up @@ -585,13 +587,13 @@ def fit_linear_model(
df_or_file=train_data, feature_names=feature_names
)
dev_ds = self.create_dataset(df_or_file=dev_data, feature_names=feature_names)

linear_hyper_model = LinearHyperModel(
number_documents_per_query=self.number_documents_per_query,
number_features=number_features,
top_n=self.top_n,
learning_rate_range=self.learning_rate_range,
)
with self.distribute_strategy.scope():
linear_hyper_model = LinearHyperModel(
number_documents_per_query=self.number_documents_per_query,
number_features=number_features,
top_n=self.top_n,
learning_rate_range=self.learning_rate_range,
)
if not hyperparameters:
best_hps = self.tune_model(
model=linear_hyper_model, train_ds=train_ds, dev_ds=dev_ds
Expand Down Expand Up @@ -627,17 +629,18 @@ def fit_lasso_linear_model(
df_or_file=train_data, feature_names=feature_names
)
dev_ds = self.create_dataset(df_or_file=dev_data, feature_names=feature_names)
trained_normalization_layer = self.create_and_train_normalization_layer(
train_ds=train_ds
)
lasso_hyper_model = LassoHyperModel(
number_documents_per_query=self.number_documents_per_query,
number_features=number_features,
trained_normalization_layer=trained_normalization_layer,
top_n=self.top_n,
l1_penalty_range=self.l1_penalty_range,
learning_rate_range=self.learning_rate_range,
)
with self.distribute_strategy.scope():
trained_normalization_layer = self.create_and_train_normalization_layer(
train_ds=train_ds
)
lasso_hyper_model = LassoHyperModel(
number_documents_per_query=self.number_documents_per_query,
number_features=number_features,
trained_normalization_layer=trained_normalization_layer,
top_n=self.top_n,
l1_penalty_range=self.l1_penalty_range,
learning_rate_range=self.learning_rate_range,
)
if not hyperparameters:
best_hps = self.tune_model(
model=lasso_hyper_model, train_ds=train_ds, dev_ds=dev_ds
Expand Down

0 comments on commit e16b304

Please sign in to comment.