From b616486ad1518f62d86edf48c5bc1f86f93c0320 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 6 Dec 2024 09:49:06 -0800 Subject: [PATCH] Fix flipped sign in BCE loss and add hinge-loss option. PiperOrigin-RevId: 703529468 --- hoplite/agile/classifier.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/hoplite/agile/classifier.py b/hoplite/agile/classifier.py index f1f4edc..7cb9367 100644 --- a/hoplite/agile/classifier.py +++ b/hoplite/agile/classifier.py @@ -43,12 +43,28 @@ def bce_loss( y_true = tf.cast(y_true, dtype=logits.dtype) log_p = tf.math.log_sigmoid(logits) log_not_p = tf.math.log_sigmoid(-logits) - raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p + # optax sigmoid_binary_cross_entropy: + # -labels * log_p - (1.0 - labels) * log_not_p + raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype) weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask return tf.reduce_mean(raw_bce * weights) +def hinge_loss( + y_true: tf.Tensor, + logits: tf.Tensor, + is_labeled_mask: tf.Tensor, + weak_neg_weight: float, +) -> tf.Tensor: + """Weighted SVM hinge loss.""" + # Convert multihot to +/- 1 labels. + y_true = 2 * y_true - 1 + weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask + raw_hinge_loss = tf.maximum(0, 1 - y_true * logits) + return tf.reduce_mean(raw_hinge_loss * weights) + + def infer(params, embeddings: np.ndarray): """Apply the model to embeddings.""" return np.dot(embeddings, params['beta']) + params['beta_bias'] @@ -105,6 +121,7 @@ def train_linear_classifier( learning_rate: float, weak_neg_weight: float, num_train_steps: int, + loss: str = 'bce', ): """Train a linear classifier.""" embedding_dim = data_manager.db.embedding_dimension() @@ -112,12 +129,18 @@ def train_linear_classifier( lin_model = get_linear_model(embedding_dim, num_classes) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) lin_model.compile(optimizer=optimizer, loss='binary_crossentropy') + if loss == 'hinge': + loss_fn = hinge_loss + elif loss == 'bce': + loss_fn = bce_loss + else: + raise ValueError(f'Unknown loss: {loss}') @tf.function def train_step(y_true, embeddings, is_labeled_mask): with tf.GradientTape() as tape: logits = lin_model(embeddings, training=True) - loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight) + loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight) loss = tf.reduce_mean(loss) grads = tape.gradient(loss, lin_model.trainable_variables) optimizer.apply_gradients(zip(grads, lin_model.trainable_variables))