You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for open sourcing your code and model weights. As the title says, I am trying to use TK kernels with pre-linearized llama 3.1 8b model and unable to repro the numbers from the paper. I am using hazyresearch/lolcats-llama-3.1-8b-distill and hazyresearch/lolcats-llama-3.1-8b-ft-lora with the below script. I had to make a few changes to the load_model_from_checkpoint function as it expects the checkpoint dir to also contain config files. I also had to change the forward call in linear_window_attention_tk_gen.py as it was not using the TK kernels correctly. Below is the diff of my changes here. I am able to repro the paper number for PIQA dataset when attention_type: lolcats_llama_window_tk but this attention type doesn't use TK kernels so I changed this to attention_type: lolcats_llama_window_tk_gen but then the PIQA acc goes down to 50% which is basically random chance. I am not sure where the problem is.
Hi! Sorry for the slow response! Has the demo script that we provided been working for you for the tk kernel? My suspicion is that the padding is not being handled correctly. See the comment here:
Hi,
Thanks for open sourcing your code and model weights. As the title says, I am trying to use TK kernels with pre-linearized llama 3.1 8b model and unable to repro the numbers from the paper. I am using
hazyresearch/lolcats-llama-3.1-8b-distill
andhazyresearch/lolcats-llama-3.1-8b-ft-lora
with the below script. I had to make a few changes to theload_model_from_checkpoint
function as it expects the checkpoint dir to also contain config files. I also had to change the forward call inlinear_window_attention_tk_gen.py
as it was not using the TK kernels correctly. Below is the diff of my changes here. I am able to repro the paper number for PIQA dataset whenattention_type: lolcats_llama_window_tk
but this attention type doesn't use TK kernels so I changed this toattention_type: lolcats_llama_window_tk_gen
but then the PIQA acc goes down to 50% which is basically random chance. I am not sure where the problem is.The text was updated successfully, but these errors were encountered: