Skip to content

Commit

Permalink
Fix flipped sign in BCE loss and add hinge-loss option.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703529468
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 6, 2024
1 parent 30d46d9 commit b616486
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,28 @@ def bce_loss(
y_true = tf.cast(y_true, dtype=logits.dtype)
log_p = tf.math.log_sigmoid(logits)
log_not_p = tf.math.log_sigmoid(-logits)
raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p
# optax sigmoid_binary_cross_entropy:
# -labels * log_p - (1.0 - labels) * log_not_p
raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p
is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype)
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
return tf.reduce_mean(raw_bce * weights)


def hinge_loss(
y_true: tf.Tensor,
logits: tf.Tensor,
is_labeled_mask: tf.Tensor,
weak_neg_weight: float,
) -> tf.Tensor:
"""Weighted SVM hinge loss."""
# Convert multihot to +/- 1 labels.
y_true = 2 * y_true - 1
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
raw_hinge_loss = tf.maximum(0, 1 - y_true * logits)
return tf.reduce_mean(raw_hinge_loss * weights)


def infer(params, embeddings: np.ndarray):
"""Apply the model to embeddings."""
return np.dot(embeddings, params['beta']) + params['beta_bias']
Expand Down Expand Up @@ -105,19 +121,26 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
loss: str = 'bce',
):
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
lin_model = get_linear_model(embedding_dim, num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
lin_model.compile(optimizer=optimizer, loss='binary_crossentropy')
if loss == 'hinge':
loss_fn = hinge_loss
elif loss == 'bce':
loss_fn = bce_loss
else:
raise ValueError(f'Unknown loss: {loss}')

@tf.function
def train_step(y_true, embeddings, is_labeled_mask):
with tf.GradientTape() as tape:
logits = lin_model(embeddings, training=True)
loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, lin_model.trainable_variables)
optimizer.apply_gradients(zip(grads, lin_model.trainable_variables))
Expand Down

0 comments on commit b616486

Please sign in to comment.