Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combined loss implementation #20

Open
AmenRa opened this issue Nov 3, 2021 · 7 comments
Open

Combined loss implementation #20

AmenRa opened this issue Nov 3, 2021 · 7 comments

Comments

@AmenRa
Copy link

AmenRa commented Nov 3, 2021

Hi, I am trying to understand how you combined the hard negative loss Ls with the in-batch random negative loss Lr, as in the paper the in-batch random negative loss is scaled by an alpha hyperparameter but there is no mention of the value of alpha you used in the experiments.

Following star/train.py I found the RobertaDot_InBatch model, whose forward function calls the inbatch_train method.

A the end of the inbatch_train method (line 182), I found

return ((first_loss + second_loss) / (first_num + second_num),)

which is different from the combined loss proposed in the paper (Eq. 13).

Am I missing something?

Also, for each query in the batch, did you consider all the possible in-batch random negatives or just one?

Thanks in advance!

@mohammed-elkomy
Copy link

mohammed-elkomy commented Nov 5, 2021

Hi @AmenRa, I'm also interested in this repo, if you let me join the discussion,

regarding your first question, I agree with you, it seems alpha has been implicitly set to one (not sure why).

for the second question, I believe the overall pair-wise loss is computed for every batch, here in line 141 the positive scores are repeated to match the in-batch negatives (to generate all possible pairs) and then we compute ranknet loss line 171

Hope this helps.

@jingtaozhan
Copy link
Owner

@mohammed-elkomy Thanks for joining the discussion :). Your understandings are exactly correct.

For the first question, we found this simple strategy, setting alpha to one, performed impressively well. We did not do a thorough study about the influence of different alpha values. So maybe it is better to set it to some other value.

For the second question, we consider all in-batch samples as approximately random negatives. The positive scores are repeated to compute the pairwise loss.

Thank @AmenRa @mohammed-elkomy for the interest in our work!

Best,
Jingtao

@AmenRa
Copy link
Author

AmenRa commented Nov 8, 2021

@mohammed-elkomy and @jingtaozhan, thank you both for replying!

I have further questions.

What are first_num and second_num?
Are they the number of hard negative losses (one loss for each query-hard negative pair) and the number of random negative losses, respectively?

Assuming I'm correct:

  1. Isn't first_num equal to the batch size?
  2. Usually, we should have much more random negatives than hard negatives, so second_num should be much larger than first_num. If so, I don't understand why in line 182 first_loss and lecond_loss are summed and then divided by first_num + second_num. Shouldn't be the first_loss and lecond_loss means to be summed together?

Thanks again.

@mohammed-elkomy
Copy link

mohammed-elkomy commented Nov 11, 2021

Hi @AmenRa,

I'm not sure what you mean by batch size, but I think you mean the number of queries or positive documents per training step [dataset 216 dataset 217, for each training step we sample

  1. n queries
  2. n positive documents
  3. n * k negative documents (where k is the number of hard negatives per query self.hard_num)

first_num : represents all possible random-inbatch-negative query-document pairs in a training step dataset 241 (this number depends on the actual pairs in each training step and can't be )

on the other hand second_num : represents pairs of queries hard negative documents (hard negatives are sampled based on args.topk check the prepare_hardneg.py), and also some random-inbatch-negative query-document pairs (since a hard negative for query_i in the batch is mostly a random negative for another query query_j) (I know my words are confusing 😅)

and since first_num : random negative and second_num: hard negatives + random negative you may add the two losses line 182 because second_num are just negatives but sampled based on model performance (called hard negatives)

Regards, (hope I'm correct 🤣🤣)

@jingtaozhan
Copy link
Owner

Very sorry for the late reply. I had two busy weeks and forgot to reply @AmenRa @mohammed-elkomy

Thank @mohammed-elkomy for detailed explanation. Your words are very clear and exactly correct. This is exactly how the loss is computed :).

@AmenRa
Copy link
Author

AmenRa commented Nov 20, 2021

Thanks @mohammed-elkomy for the explanation.

I probably swapped the random-negatives related code with the hard-negative one while reasoning about the implementation.

Following your explanation, I assume first_loss is the sum of the random negative losses and second_loss is the sum of the hard negative losses.

From my understanding first_num != second_num, does this mean that in line 182 we take a global average of the losses and not a sum of the averages of the two part of the loss (first_loss and second_loss)?

Hope it's clear what I mean :D

@jingtaozhan
Copy link
Owner

second_loss is the sum of the hard negative losses.

Not exactly correct. Part of the second loss is indeed the hard negative loss, while the other part is approximately random negative loss because qi's hard negative is qj(j!=i)‘s random negative. Note, line 165 uses matmul and line 167 repeats the positive scores. So there are also random negative pairs.

From my understanding first_num != second_num, does this mean that in line 182 we take a global average of the losses and not a sum of the averages of the two part of the loss (first_loss and second_loss)?

Yes, it is a global average of all random and hard negative pairs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants