From dcd8b892661dd92069991245943bf769a06fe9da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Mon, 13 Jan 2025 16:17:09 +0100 Subject: [PATCH] doc --- i6_models/parts/losses/nce.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/losses/nce.py b/i6_models/parts/losses/nce.py index 43713acb..24ac0dc4 100644 --- a/i6_models/parts/losses/nce.py +++ b/i6_models/parts/losses/nce.py @@ -29,7 +29,9 @@ def __init__( :param num_samples: num of samples for the estimation, normally a value between 1000-4000. 2000 is a good starting point. - :param model: model on which the NCE loss is to be applied. + :param model: model on which the NCE loss is to be applied. The model requires a member called `output`, which + is expected to be a linear layer with bias and weight members. The member `output` represents the output + layer of the model allowing access to the parameters during loss computation. :param noise_distribution_sampler: for example `i6_model.samplers.LogUniformSampler`. :param log_norm_term: normalisation term for true/sampled logits. :param reduction: reduction method for binary cross entropy. @@ -46,8 +48,12 @@ def __init__( self._bce = nn.BCEWithLogitsLoss(reduction=reduction) def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - # input: [B x T, F] target: [B x T] + """ + :param data: the tensor for the input data, where batch and time are flattened resulting in [B x T, F] shape. + :param target: the tensor for the target data, where batch and time are flattened resulting in [B x T] shape. + :return: + """ with torch.no_grad(): samples = self.noise_distribution_sampler.sample(self.num_samples).cuda()