Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add marlin int4 kernel #315

Closed
wants to merge 9 commits into from
Closed

Add marlin int4 kernel #315

wants to merge 9 commits into from

Conversation

dacorvo
Copy link
Collaborator

@dacorvo dacorvo commented Sep 20, 2024

What does this PR do?

This adds a modified Marlin fp16/int4 kernel to the library and creates two new QTensor subclasses to use it:

  • MarlinInt4PackedTensor,
  • MarlinInt4WeightQBitsTensor.

The AWQ kernel is still used by default because from the first tests it seems the modified Marlin kernel either has some accuracy issues or is not properly integrated (perplexity increases).

Note: during the integration, I tried to register the Marlin fp16int4 gemm as a torch.tensor.library.custom_op, but it added an extra latency (up to 50 %), so I used the same legacy declaration (with define/impl).

@dacorvo dacorvo requested a review from SunMarc September 20, 2024 15:35
@dacorvo dacorvo force-pushed the add_marlin_int4_kernel branch 4 times, most recently from ffd984a to 5e5adbe Compare September 25, 2024 16:28
@dacorvo dacorvo marked this pull request as draft September 25, 2024 16:33
@dacorvo dacorvo force-pushed the add_marlin_int4_kernel branch from 5e5adbe to 19ee33e Compare September 25, 2024 19:55
dacorvo and others added 9 commits September 26, 2024 15:16
Original fix in vLLM project:

The reason for the crash was the inline PTX assembly that introduced
the async_copy with streaming behavior. The solution is to use the more
standard PTX for async_copy (without the fractional L2 policy for
"evict_first"). There is no performance difference between standard
async_copy PTX and the previous one.
This is to guarantee Marlin kernels output is similar to the output
obtained using dequantized weights.
@dacorvo dacorvo force-pushed the add_marlin_int4_kernel branch from 19ee33e to 5564b8f Compare September 26, 2024 15:23
@dacorvo
Copy link
Collaborator Author

dacorvo commented Sep 27, 2024

Closing this as the modified kernel is definitely flawed: as soon as it processes more than 32 inputs (i.e. two blocks of 16), there are errors in the outputs starting from the 128th output feature. There is likely some kind of flaw in the weight/scales/zero-point readback as soon as parallelization increases.

@dacorvo dacorvo closed this Sep 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants