Skip to content

Commit

Permalink
rm math module
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Dec 19, 2024
1 parent acccbbe commit afd5481
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions i6_models/losses/nce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch import nn
from torch.nn import functional as F
from typing import Optional
import math


class NoiseContrastiveEstimationLossV1(nn.Module):
Expand Down Expand Up @@ -53,10 +52,9 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
samples = self.noise_distribution_sampler.sample(self.num_samples).cuda()

# log-probabilities for the noise distribution k * q(w|h)
sampled_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob(
samples
) # [num_samples]
true_sample_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob(target) # [B x T]
ws = torch.log(torch.Tensor([self.num_samples]))
sampled_prob = ws + self.noise_distribution_sampler.log_prob(samples) # [num_samples]
true_sample_prob = ws + self.noise_distribution_sampler.log_prob(target) # [B x T]

all_classes = torch.cat((target, samples), 0) # [B x T + num_sampled]

Expand Down

0 comments on commit afd5481

Please sign in to comment.