Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the gaussian_conditional model in Cheng2020 #316

Open
achel-x opened this issue Nov 4, 2024 · 10 comments
Open

Question about the gaussian_conditional model in Cheng2020 #316

achel-x opened this issue Nov 4, 2024 · 10 comments

Comments

@achel-x
Copy link

achel-x commented Nov 4, 2024

When I run the Cheng2020 series, I noticed that the Cheng2020Anchor inherit from JointAutoregressiveHierarchicalPriors.
While the `gaussian_conditional` model in `JointAutoregressiveHierarchicalPriors` is just GaussianConditional, it is not a gaussian mixture model.

I tried to add a sentence in Cheng2020Anchor

self.gaussian_conditional = GaussianMixtureConditional()

But it failed to run.

image

what does the weights mean? and what should i pass to the GaussianMixtureConditional()

@chunbaobao
Copy link

The Discretized Gaussian Mixture Likelihoods follows the equation in the paper:
image
In this equation, $\omega$ refers to the weights in the code:

def _likelihood(
self, inputs: Tensor, scales: Tensor, means: Tensor, weights: Tensor
) -> Tensor:
likelihood = torch.zeros_like(inputs)
M = inputs.size(1)
for k in range(self.K):
likelihood += (
super()._likelihood(
inputs,
scales[:, M * k : M * (k + 1)],
means[:, M * k : M * (k + 1)],
)
* weights[:, M * k : M * (k + 1)]
)
return likelihood

Usually, the parameters of the latent codec distribution, including the weights, are the outputs of some neural networks.
You can slightly modify the network's output to obtain the weights.

def forward(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)
y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}

@achel-x
Copy link
Author

achel-x commented Nov 19, 2024

Thanks for your kindly instruction.

I tried a modification with the issue in here.

#289 (comment)

I have tried to make the same modification of using the GaussianMixtureConditional as showed below

`
class Cheng2020GMM(Cheng2020Anchor):

def __init__(self, N=192, **kwargs):
    super().__init__(N=N, **kwargs)
    
    self.K = 3 # for GMM

    self.entropy_parameters = nn.Sequential(
        nn.Conv2d(N * 12 // 3, N * 10 // 3, 1),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(N * 10 // 3, N * 8 // 3, 1),
        nn.LeakyReLU(inplace=True),
        # nn.Conv2d(N * 8 // 3, N * 6 // 3, 1),
        nn.Conv2d(N * 8 // 3, N * 3 * self.K, 1),
    )

    self.gaussian_conditional = GaussianMixtureConditional(K=self.K)


def forward(self, x):
    y = self.g_a(x)
    z = self.h_a(y)
    z_hat, z_likelihoods = self.entropy_bottleneck(z)
    params = self.h_s(z_hat)

    y_hat = self.gaussian_conditional.quantize(
        y, "noise" if self.training else "dequantize"
    )
    ctx_params = self.context_prediction(y_hat)
    gaussian_params = self.entropy_parameters(
        torch.cat((params, ctx_params), dim=1)
    )
    # print(f"gaussian_params.shape is {gaussian_params.shape}") # [8, 1728, 16, 16]

    # scales_hat, means_hat = gaussian_params.chunk(2, 1)
    scales_hat, means_hat, weight_hat = gaussian_params.chunk(3, 1)
    B, C, H, W = weight_hat.shape   # C is M*K - M*3
    weight_hat = nn.functional.softmax(weight_hat.reshape(B, 3, C//3, H, W), dim=1).reshape(B, C, H, W)

    # _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
    y_hat1, y_likelihoods = self.gaussian_conditional(y, scales_hat, means_hat, weights=weight_hat)

    # x_hat = self.g_s(y_hat)
    x_hat = self.g_s(y_hat1)

    return {
        "x_hat": x_hat,
        "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
    }

`

I compared the Cheng2020GMM and Cheng2020Anchor.
The results are confused

GMM
image

Anchor
image

The GMM is inferior to the anchor and I am unable to undertand it.
If you have some insights here, please help me out at your convenience!

Thanks again for your valuable time.

Best wishes.

@watwwwww
Copy link

Hi! I have the same problem as well. Have you found a good solution yet?

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented Nov 22, 2024

Perhaps try using STE for quantization instead of noise.

Still, it's weird that GMM K=3 performs that much worse than GC. Try setting K=1 and training. Is the performance still worse?

@achel-x
Copy link
Author

achel-x commented Nov 24, 2024

Perhaps try using STE for quantization instead of noise.

Still, it's weird that GMM K=3 performs that much worse than GC. Try setting K=1 and training. Is the performance still worse?

Thanks for your kindly reply. I will try k=1 / STE and train to see the results. But yes, it's weired that there is no STE for quantization in original cheng's GMM paper.

@watwwwww
Copy link

也许可以尝试使用 STE 而不是噪声进行量化。
尽管如此,GMM K=3 的性能比 GC 差得多,这很奇怪。尝试设置 K=1 并进行训练。性能还差吗?

感谢您的友好回复。我将尝试 k=1 / STE 并训练以查看结果。但是,是的,在原来的 Cheng 的 GMM 论文中没有用于量化的 STE 是很奇怪的。

Hi, I read cheng's literature carefully again, and found that the first two output channels of the entropy_parameters in its network structure are both 640,but I see that the number of channels of the second layer convolution output in the entropy_parameters code you rewrote is 512, is it possible that the problem is here?

@achel-x
Copy link
Author

achel-x commented Nov 25, 2024

Thanks for your advice@watwwwww .

the number of channels of the second layer convolution output in the entropy_parameters code you rewrote is 512

The implementaion in paper N is set as 128 for two lower-rate models, and is set as 192 for the two higher-rate models.
Herein, I set the N is 128 to simplify the training.

Actually, I noticed that the entropy_parameter has reduce the number of channels in my implementation, and I have tried to modify it like below

self.layers = nn.Sequential( nn.Conv2d(N * 12 // 3, N * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(N * 10 // 3, N * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(N * 10 // 3, 3 * N * K, 1), )

I reference this implementation. https://github.com/leelitian/cheng2020-GMM/blob/main/model.py

Intuitively, expand the channel representaion seems to gain some improvement.
But, it didn't make effect.

Here are my results.

The lambda is 0.015

GMM (K=3)
image

GMM with revised entropy_parameter (K=3)
image
In addition, expanding the channel representation here, the anchor also needs to be expanded to make a fair comparison.

GMM (K=1)
image

anchor
image

=========================================================
I am trying to add STE quantization in GMM(k=3) and not sure whether it is effective.
Even if it works, it doesn't show the efficiency of GMM.

If you have any thoughts on the question, please feel free to share with me.

@watwwwww
Copy link

Thanks for your advice@watwwwww .

the number of channels of the second layer convolution output in the entropy_parameters code you rewrote is 512

The implementaion in paper N is set as 128 for two lower-rate models, and is set as 192 for the two higher-rate models. Herein, I set the N is 128 to simplify the training.

Actually, I noticed that the entropy_parameter has reduce the number of channels in my implementation, and I have tried to modify it like below

self.layers = nn.Sequential( nn.Conv2d(N * 12 // 3, N * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(N * 10 // 3, N * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(N * 10 // 3, 3 * N * K, 1), )

I reference this implementation. https://github.com/leelitian/cheng2020-GMM/blob/main/model.py

Intuitively, expand the channel representaion seems to gain some improvement. But, it didn't make effect.

Here are my results.

The lambda is 0.015

GMM (K=3) image

GMM with revised entropy_parameter (K=3) image In addition, expanding the channel representation here, the anchor also needs to be expanded to make a fair comparison.

GMM (K=1) image

anchor image

========================================================= I am trying to add STE quantization in GMM(k=3) and not sure whether it is effective. Even if it works, it doesn't show the efficiency of GMM.

If you have any thoughts on the question, please feel free to share with me.

Can I ask you about the dataset you used for training and the specific parameter settings of the training code, because I had a very long training process and no results in the process of validating my idea.

@achel-x
Copy link
Author

achel-x commented Nov 26, 2024

In original paper, the author used 13830 samples from ImageNet. I just extract 14k images from coco dataset for traning, and DIV2k 800 image for test. Maybe the amount of dataset plays a role, but I compare different schemes with the same dataset. It's ok for you to use another dataset.

I am not sure about what settings you asked. The training script is python example/train.py -m cheng2020-anchor -d ./dataset/ --lambda 0.015 --cuda --save -e 200. My Gpu is 3080(10G). The total training time for one model is around 16h.

@achel-x
Copy link
Author

achel-x commented Nov 26, 2024

sry, the 16h is for 300 epochs. To shorter it, I just train 200 epochs. My aim is to validate the effectiveness of GMM, so the baseline is cheng2020-anchor without simplified attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants