Skip to content

Commit

Permalink
Move reduce shared memory to the back, query max. shmem size dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
ahadnagy committed Jan 7, 2025
1 parent 4493482 commit ae10ca1
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// ADDED: shared memory storage for scaled zero points
int4* sh_sz = sh_red + (stages * s_sh_stage);
int4* sh_red = sh_sz + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
Expand Down Expand Up @@ -728,7 +728,6 @@ __global__ void Marlin(
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const int THREADS = 256;
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0

// ADDED: add scaled zero pointer
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
Expand All @@ -739,11 +738,11 @@ const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.
cudaFuncSetAttribute( \
Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM \
max_shared_mem \
); \
Marlin< \
THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
><<<blocks, THREADS, SHARED_MEM, stream>>>( \
><<<blocks, THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\
prob_m, prob_n, prob_k, \
locks \
Expand Down Expand Up @@ -789,6 +788,10 @@ int marlin_cuda(
}
}

int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);

int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
Expand Down

0 comments on commit ae10ca1

Please sign in to comment.