diff --git a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu index b18b0469..c8dc3fec 100644 --- a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu +++ b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu @@ -374,9 +374,9 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_red = sh_s + (stages * s_sh_stage); // ADDED: shared memory storage for scaled zero points - int4* sh_sz = sh_red + (stages * s_sh_stage); + int4* sh_sz = sh_s + (stages * s_sh_stage); + int4* sh_red = sh_sz + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -728,7 +728,6 @@ __global__ void Marlin( // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0 // ADDED: add scaled zero pointer #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ @@ -739,11 +738,11 @@ const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8. cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM \ + max_shared_mem \ ); \ Marlin< \ THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ - ><<>>( \ + ><<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\ prob_m, prob_n, prob_k, \ locks \ @@ -789,6 +788,9 @@ int marlin_cuda( } } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;