Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(library): Propagate upstream Marlin kernel fix #366

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ __global__ void Marlin(
int4* sh_s = sh_b + (stages * b_sh_stage);
// ADDED: shared memory storage for scaled zero points
int4* sh_sz = sh_s + (stages * s_sh_stage);
int4* sh_red = sh_sz + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
Expand Down Expand Up @@ -499,21 +501,21 @@ __global__ void Marlin(
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
}
sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
Expand Down Expand Up @@ -548,7 +550,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
);
Expand All @@ -561,7 +563,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
Expand Down Expand Up @@ -605,7 +607,7 @@ __global__ void Marlin(
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks == -1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]);
((half2*) sh)[idx] = res;
((half2*) sh_red)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
Expand All @@ -626,7 +628,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
Expand Down Expand Up @@ -726,7 +728,6 @@ __global__ void Marlin(
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const int THREADS = 256;
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)

// ADDED: add scaled zero pointer
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
Expand All @@ -737,11 +738,11 @@ const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6
cudaFuncSetAttribute( \
Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM \
max_shared_mem \
); \
Marlin< \
THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
><<<blocks, THREADS, SHARED_MEM, stream>>>( \
><<<blocks, THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\
prob_m, prob_n, prob_k, \
locks \
Expand Down Expand Up @@ -787,6 +788,9 @@ int marlin_cuda(
}
}

int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,14 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features,
)


@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)
# Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332
@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [48, 64])
# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("in_features", [4096, 16384])
@pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("out_features", [2048, 4096])
def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):
dtype = torch.float16
Expand Down
Loading