Skip to content

Commit

Permalink
Move reduce shared memory to the back, query max. shmem size dynamically
Browse files Browse the repository at this point in the history
Fix build error
  • Loading branch information
ahadnagy committed Jan 7, 2025
1 parent 4493482 commit ec2bc69
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// ADDED: shared memory storage for scaled zero points
int4* sh_sz = sh_red + (stages * s_sh_stage);
int4* sh_sz = sh_s + (stages * s_sh_stage);
int4* sh_red = sh_sz + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
Expand Down Expand Up @@ -728,7 +728,6 @@ __global__ void Marlin(
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const int THREADS = 256;
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0

// ADDED: add scaled zero pointer
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
Expand All @@ -739,11 +738,11 @@ const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.
cudaFuncSetAttribute( \
Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM \
max_shared_mem \
); \
Marlin< \
THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
><<<blocks, THREADS, SHARED_MEM, stream>>>( \
><<<blocks, THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\
prob_m, prob_n, prob_k, \
locks \
Expand Down Expand Up @@ -789,6 +788,9 @@ int marlin_cuda(
}
}

int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
Expand Down

0 comments on commit ec2bc69

Please sign in to comment.