Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Jan 13, 2025
1 parent 315074e commit dcd8b89
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions i6_models/parts/losses/nce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(
:param num_samples: num of samples for the estimation, normally a value between 1000-4000.
2000 is a good starting point.
:param model: model on which the NCE loss is to be applied.
:param model: model on which the NCE loss is to be applied. The model requires a member called `output`, which
is expected to be a linear layer with bias and weight members. The member `output` represents the output
layer of the model allowing access to the parameters during loss computation.
:param noise_distribution_sampler: for example `i6_model.samplers.LogUniformSampler`.
:param log_norm_term: normalisation term for true/sampled logits.
:param reduction: reduction method for binary cross entropy.
Expand All @@ -46,8 +48,12 @@ def __init__(
self._bce = nn.BCEWithLogitsLoss(reduction=reduction)

def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# input: [B x T, F] target: [B x T]
"""
:param data: the tensor for the input data, where batch and time are flattened resulting in [B x T, F] shape.
:param target: the tensor for the target data, where batch and time are flattened resulting in [B x T] shape.
:return:
"""
with torch.no_grad():
samples = self.noise_distribution_sampler.sample(self.num_samples).cuda()

Expand Down

0 comments on commit dcd8b89

Please sign in to comment.