Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understanding the contrastive loss implementation #389

Open
jwutsetro opened this issue May 30, 2024 · 1 comment
Open

Understanding the contrastive loss implementation #389

jwutsetro opened this issue May 30, 2024 · 1 comment

Comments

@jwutsetro
Copy link

Dear,

I am trying to understand your custom contrastive loss class. How I understand it, it correctly computes the positives by shifting the diagonal by batch_size and - batch_size to compute the nominator. But then in the compute of the denominator, the negative mask is defined by the inverse of torch.eye(). As I understand, this means that only the self similarity ( which is always 1) is removed from the denominator but the similarities between a patch and it's augmented version are still included?

I would personally implement it like this:

    def create_negative_mask(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

Will this not result in an unstable training? I am asking because I don't seem to be able to get the contrastive loss to decrease. Attached is my total loss for a batch size of 12,24 and 48. The rotational loss and reconstruction loss are 0.3 and 0.1 for all models respectively, so the total loss is dominated by the contrastive loss not going down.

Screenshot 2024-05-30 at 12 43 48

Kindly,
Joris

@jwutsetro
Copy link
Author

After digging in it a bit more, It seems that the implementations of SIMCLR indeed follow a similar approach. But that still leaves me wondering why we would include the positives in the denominator ? Additionally, if anyone has some suggestions on how to further improve the contrastive loss optimisation, I would be very happy to hear them !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant