Skip to content

General Guide of AMD Triton Performance Optimization

SJW edited this page Sep 29, 2024 · 21 revisions

This document introduces the general steps for Triton kernel optimization. Overall, Triton kernel optimization is similar to CUDA/HIP kernel optimization. It includes the following aspects:

Hardware resource utilization

Each GPU has many Compute Units (CUs), and different CUs do computation in parallel, so how many CUs a kernel can allocate its task on is the first thing to consider. For MI300X, we want our grid to have at least 1024 thread blocks/workgroups (WGs), and the more the merrier. To increase the hardware utilization, generally more parallelism needs to be found in the algorithm (e.g. using larger split-K for GEMMs). Hardware resources can be queried with the command rocminfo (in the folder /opt/rocm/bin). For instance, wan query # of computes, # of SIMD, and wavefront size as:

  • rocminfo | grep "Compute Unit"
  • rocminfo | grep "SIMD"
  • rocminfo | grep "Wavefront Size"

For MI300X, there are 304 CUs, 4 SIMD per CU and wavefront size (warp size) is 64.

Autotunable kernel configurations & environment variables

This is about the amount of memory access and computation assigned to each CU. It is related to the usage of LDS, register and the scheduling of different tasks on a CU.

Software Pipelining

Software pipelining can improve performance on a variety of workloads by overlapping memory access with compute.

Note: The control described in this section only applies to triton at commit a70d5856ff2e and beyond.

You can control pipelining behavior using num_stages:

  1. In tl.range of a for loop for that specific loop:
for k in tl.range(start, end, step, num_stages=2):
  ...
  1. Or using triton.Config. This sets globally on the function and applies to all for loops in the function with tl.load -> tl.dot.
triton.Config(..., num_stages=2),

Specifically when targeting AMD GPUs, triton will employ a stream pipelining approach by streaming data through register buffers. But depending on the workload stream pipelining may also allocate one or multiple buffers in shared memory.

Guidelines

  • When num_stages = 1 no stream pipelining will occur; but a shared memory buffer will be created by a later code generation pass for each tl.load that feeds a tl.dot.
  • When tl.load ops feed into tl.dot op, generally using num_stages = 2 will result in 1 shared memory buffer for each tl.load (e.g. a and b). Each successive increment of num_stages will add an additional shared memory buffer and ping-pong between them.
    • Note: If the tl.load is using an address loaded from another tl.load this will require num_stages = 3 for 1 shared memory buffer. The chaining of loads in this way will require an extra stage for each tl.load in the chain.
  • When a tl.load does not feed a tl.dot or another tl.load it may still be stream pipelined and will ping-pong between num_stages buffers in registers. Thus while the loop is loading the next data tile it may compute on the previously loaded buffer already in registers.
    • Note: You must use the tl.range to specify num_stages when the loads do not feed tl.dot because by default triton.Config only applies to loads feeding into tl.dot ops.

More Tuning Switches

  • waves_per_eu=n. (Check section for how to compute occupancy) This is new on AMD GPUs. It hints the compiler to reduce vector general purpose register(VGPR) ( \to a level such that occupancy = n could be achieved. This only helps if both of the following satisfy:
  1. The occupancy of the kernel is limited by VGPR usage.
  2. The current VGPR usage is only a few above a boundary in Table.1 in AMD lab notes.

For example, according to Table.1 in the AMD lab notes, the available VGPR is 512 per Execution Unit (EU) and VGPR is allocated at the unit of 16. If the current VGPR usage is 170, the actual requested VGPR will be 176, so the occupancy is only 2 waves/CU since 176 x 3 > 512. Then if we set waves_per_eu to 3, the LLVM backend will try to bring VGPR usage down so that we might fit 3 waves/EU.

  • Tile sizes: BLOCK_M, BLOCK_N, BLOCK_K. This needs to be tuned. We want tile sizes large enough to maximize the efficiency of memory-to-computation ratio, but small enough to parallelize the greatest number of WGs at the grid level.
  • matrix_instr_nonkdim: this is an experimental feature for FA-like kernels. It can choose the size of MFMA instruction used. For GEMM kernels on MI300X, we found that mfma_16x16 has better performance than mfma_32x32 even for large tile/GEMM sizes.
  1. Matrix_instr_nonkdim = 16: mfma_16x16 is used
  2. Matrix_instr_nonkdim = 32: mfma_32x32 is used
  • There is one environment variable that should be turned on in most cases:

OPTIMIZE_EPILOGUE=1 will enable the optimize_epilogue pass, which removes the convert_layout in the epilogue. By default, the results of MFMA instruction are converted to blocked layout, which leads to global_store with maximum vector length, i.e. global_store_dwordx4. This is done implicitly with LDS as the intermediate buffer to achieve data exchange between threads. And padding is used in LDS to avoid bank conflicts. This usually leads to extra LDS usage which may reduce occupancy. Setting OPTIMIZE_EPILOGUE=1 will have the effect of storing the result in the MFMA layout. This reduces the efficiency of global stores but has insignificant influence on kernel execution time.

Note that this variable is not turned on by default since it only works with tt.store but not tt.atomic_add, which is used in split-k and stream-k GEMM kernels. We are working on enabling it with tt.atomic_add and turning it on by default.

Memory access efficiency

GPU has global memory, local data share (LDS, shared memory), and register. We know that global memory has high access latency and size is big. LDS access has much lower latency, but size is small. Register access is the fastest yet smallest among the three. Overall, we want data in global memory to be loaded/stored as few times as possible. If different threads in a block need to access the same data, these data should be first transferred from global memory to LDS, then accessed by different threads in a workgroup.

IR analysis

In Triton, we have different layouts, including blocked layout, shared layout, and sliced layout, and MFMA layout. From the Triton GPU IR, we can know in which memory each computation is performed. Here is a snippet of IR from the Flash Attention (FA) decode int4 KV program. It is to dequantize the int4 KV from int4 data type to fp16.

%190 = tt.load %189 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x64xi32, #blocked6> loc(#loc159)
%266 = arith.andi %190, %cst_28 : tensor<1x64xi32, #blocked6> loc(#loc250)
%267 = arith.trunci %266 : tensor<1x64xi32, #blocked6> to tensor<1x64xi16, #blocked6> loc(#loc251)
%268 = tt.bitcast %267 : tensor<1x64xi16, #blocked6> -> tensor<1x64xf16, #blocked6> loc(#loc252)
%269 = triton_gpu.convert_layout %268 : (tensor<1x64xf16, #blocked6>) -> tensor<1x64xf16, #shared1> loc(#loc252)
%270 = tt.trans %269 : (tensor<1x64xf16, #shared1>) -> tensor<64x1xf16, #shared2> loc(#loc194)
%276 = triton_gpu.convert_layout %270 : (tensor<64x1xf16, #shared2>) -> tensor<64x1xf16, #blocked5> loc(#loc254)
%293 = arith.mulf %276, %cst_30 : tensor<64x1xf16, #blocked5> loc(#loc254)
%295 = arith.mulf %292, %294 : tensor<64x32xf16, #blocked5> loc(#loc264)
%297 = arith.addf %295, %296 : tensor<64x32xf16, #blocked5> loc(#loc255)
%298 = triton_gpu.convert_layout %297 : (tensor<64x32xf16, #blocked5>) -> tensor<64x32xf16, #shared1> loc(#loc255)
%299 = tt.trans %298 : (tensor<64x32xf16, #shared1>) -> tensor<32x64xf16, #shared2> loc(#loc196)
%300 = triton_gpu.convert_layout %299 : (tensor<32x64xf16, #shared2>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> loc(#loc197)

From the IR here, we can see i32 data is loaded from global memory to registers. With a few element-wise operations in registers, then it is stored in shared memory for the transpose operation, which needs data movement across different threads. With transpose done, it is loaded from LDS to register again, with a few more element-wise operations, they are stored to LDS again. Last step is loaded from LDS to registers and converted to the dot operand layout. We can see from the IR that it uses the LDS twice, one is for the transpose, the other is to convert blocked layout to dot operand layout. However, these two do not require multiple LDS accesses. The conversion from shared1 to shared2 is a logical conversion and does not come with additional LDS accesses. Instead, it lowers to LDS address calculation for the subsequent LDS read accesses that will read the block transposed instead of normal. But this address calculation has to exist even for the case when the read is normal (just with different addresses).

A note on transpose operations: Some times transpose operations can be absorbed into existing layout conversion operations. However, in general, if possible, one should avoid doing transposes in the user level code. For example, instead of having the input provided normal and then doing a tl.trans call in the Triton kernel, it is preferably to provide the input transposed using the .T operation in PyTorch.

Assembly analysis

  • In the ISA, make sure global_load_dwordx4 is used, especially when the load happens in the loop.
  • In most cases, the LDS load and store should use _b128 as well to minimize the number of LDS access instructions. Note that upstream (Phantom dubs this as backend) may not have _b128 LDS read/write at this moment, so it uses _b64. No matter if you use fork or upstream, the LDS access should have _b64 vector width, for most cases.
  • The AMD ISA has s_waitcnt instruction to synchronize the dependency of memory access and computations. The s_waitcnt instructions can have two signals typically in the Triton context
  1. lgkmcnt(n): lgkm stands for LDS, GDS, Constant and Message. For our context, it is often related to LDS access. The number n here means the number of such accesses can be left out to continue. For example, 0 means all lgkm access must finish before continuing, and 1 means only 1 lgkm access can be still running asynchronously before proceeding.
  2. vmcnt(n): vm means vector memory. This happens when vector memory is accessed, e.g., global load from global memory to vector memory. The variable n here means the same thing as the above.

The general guideline is:

  1. Vectorize memory access as much as possible.
  2. Ensure synchronization is done efficiently.
  3. Overlap of instructions to hide latency, but it requires thoughtful analysis of the algorithms.
  4. If you spot anything inefficient, you can trace it back to LLVM IR, TTGIR and even TTIR to see where the problem comes from. If you spot it during compiler optimization, activate the MLIR dump and check which optimization pass caused the problem.

Tools (rocprof, omniperf)

Understand/Compute the occupancy of the kernel

  • Get the VGPR count, search for .vgpr_count in the ISA, say it is N
  • Get the allocated LDS following the steps (say you got L for the kernel)
  1. export MLIR_ENABLE_DUMP=1 2.rm -rf ~/.triton/cache
  2. python kernel.py | | grep "triton_gpu.shared = " | tail -n 1
  3. Look for something like triton_gpu.shared = 65536, it means 65536 bytes LDS is allocated for the kernel.
  • Get number of waves per workgroup following the steps (say you got nW)
  1. export MLIR_ENABLE_DUMP=1
  2. rm -rf ~/.triton/cache
  3. python kernel.py | | grep "triton_gpu.num-warps " | tail -n 1
  4. Look for something like “triton_gpu.num-warps" = 8 it means 8 waves per workgroup
  • Compute occupancy limited by VGPR based on N according to Table.1 in AMD lab notes. Say you got waves per EU as occ_vgpr.
  • Compute occupancy limited by LDS based on L by: occ_lds = floor(65536 / L).

Then the occupancy is occ = min(floor(occ_vgpr * 4 / nW), occ_lds) * nW / 4

  1. occ_vgpr * 4 gives the total number of waves on all 4 EUs (SIMDs) per CU
  2. floor(occ_vgpr * 4 / nW) gives the occupancy of workgroups per CU regrading VGPR usage

Then the true occ is the minimum of the two. The above logic is available in occ.sh.

PyTorch inductor Triton tuning knobs

  • To enable gemm/conv lowerings to triton, requires use of inductor’s max_autotune mode. This will benchmark a static list of triton configs (conv configs for max autotune + matmul configs for max autotune) and use the fastest for each shape. (Note: if regular MIOpen/rocBlas is faster for a specific operation, triton will not be used)

    1. torch._inductor.config.max_autotune = True or TORCHINDUCTOR_MAX_AUTOTUNE=1
    2. Or for more fine-grained control
  1. torch._inductor.config.max_autotune.pointwise = True - to enable tuning for pointwise/reduction ops
  2. torch._inductor.config.max_autotune_gemm = True - to enable tuning/lowering of mm/convs
  3. torch._inductor.max_autotune_gemm_backends/TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS - Selects the candidate backends for mm autotuning Defaults to “TRITON,ATEN”, NV also includes CUTLASS tuning option. Limiting this to “TRITON” may improve performance by enabling more fused mm kernels instead of going to rocBlas
  • For further mm tuning coordinate_descent tuning may improve performance, which attempts
    1. torch._inductor.config.coordinate_descent_tuning=True/TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
  • Inference can see large improvements on AMD by utilising torch._inductor.config.freezing=True/TORCHINDUCTOR_FREEZING=1, which inlines weights as constants and enables constant folding optimisations.
  • Enabling inductor’s cpp_wrapper may improve overhead, this will generate a c++ code which launches Triton binaries directly with hipModuleLaunchKernel and relies on hipification. (Note: We are still failing a few tests regarding this feature) - torch._inductor.config.cpp_wrapper=True/TORCHINDUCTOR_CPP_WRAPPER=1
  • For NHWC convolutions workloads torch._inductor.config.layout_optimization=True/TORCHINDUCTOR_LAYOUT_OPTIMIZATION=` can help be enforcing channels_last format throughout the graph avoiding any additional transposes added by inductor. (Note: PYTORCH_MIOPEN_SUGGEST_NHWC=1 recommended if using this)
  • If need to extract the triton kernel TORCH_COMPILE_DEBUG creates a torch_compile_debug/ directory at current path, in output_code.py the code-strings for the triton kernels are defined. Manual work is then required to strip out the kernel and create kernel compilation/launch via triton.
  • For advanced matmul/conv config tuning the inductor-gemm-tuner can help, this implements the triton conv/mm implementations used upstream and allows specification of inputs and config tuning search space, if new tunings are found can be added to autotune list. More work needs to be done on parsing the results of this tuning
    1. Example used for resnet152: HIP_FORCE_DEV_KERNARG=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_FIND_MODE=1 TORCHINDUCTOR_COMPILE_THREADS=1 python bench.py --fp16 --kernel conv --input_file=input/models/resnet152/conv_perf_drop.json --config_file=configs/models/resnet152/tuned.json

Debugging Memory Access Faults

Identifying the faulting kernel is often enough to triage a memory access fault. To that end, the rocm debug agent can trap a memory access fault and provide a dump of all active wavefronts that caused the error as well as the name of the kernel. The README provides full instructions, but to summarize:

  1. Compiling with -ggdb -O0 is recommended but not required.
  2. HSA_TOOLS_LIB=/opt/rocm/lib/librocm-debug-agent.so.2 HSA_ENABLE_DEBUG=1 ./my_program

When the debug agent traps the fault, it will produce extremely verbose output of all wavefront registers and memory content. Importantly, it also prints something like:

Disassembly for function vector_add_assert_trap(int*, int*, int*):
code object: file:////rocm-debug-agent/build/test/rocm-debug-agent-test#offset=14309&size=31336
loaded at: [0x7fd4f100c000-0x7fd4f100e070]

The kernel name and the code object file should be listed. In the example above, the kernel name is vector_add_assert_trap, but this might also look like:

Disassembly for function memory:///path/to/codeobject#offset=1234&size=567:

In this case, it is an in-memory kernel that was generated at runtime. Using the env var ROCM_DEBUG_AGENT_OPTIONS="--all --save-code-objects" will have the debug agent save all code objects to the current directory (use --save-code-objects=[DIR] to place them in another location). The code objects will be renamed from the URI format with special characters replaced by ‘_’. Use llvm-objdump to disassemble the indicated in-memory code object that has now been saved to disk. The name of the kernel is often found inside the disassembled code object.

llvm-objdump --disassemble-all path/to/code-object.co

It is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch, where possible. This will give the debug agent the best chance at finding the memory fault where it originates, otherwise it could be masked by writing past the end of a cached block within a larger allocation.

PYTORCH_NO_HIP_MEMORY_CACHING=1
HSA_DISABLE_FRAGMENT_ALLOCATOR=1

Miscellaneous

a. Performance critical HIP provides an environment variable export HIP_FORCE_DEV_KERNARG=1 that can put arguments of HIP kernels directly to device memory to reduce the latency of accessing kernel arguments. It can reduce 2 to 3 us for some kernels.

b. Set clock for deterministic. Use the command rocm-smi --setperfdeterminism 1900 to see the max clock speedup to 1900MHz instead of the default 2100MHz. This can reduce the chance of clock speed decrease due to chip high temperature by setting a lower cap. This setting can be restored to default with rocm-smi -r.

c. Set numa autobalance. Run the command cat /proc/sys/kernel/numa_balancing to check the current settings. Output 0 indicates this setting is available. If not output or output is 1, we can run the command sudo sh -c \'echo 0 > /proc/sys/kernel/numa_balancing to set this.

Usage: ./env_check.sh [set/reset/check] (use ./env_check.sh -h for help info)

Script contents: (you can download it env_check.sh)

#!/bin/bash

function print_usage {
	echo "    Usage: env_check.sh set/reset/check"
	echo "                      set: configure the settings in this script"
	echo "                      reset: reset to default settings"
	echo "                      check: check the current settings"
}

function set_env {
	export HIP_FORCE_DEV_KERNARG=1
	rocm-smi --setperfdeterminism 1900
	sudo sh -c echo 0 > /proc/sys/kernel/numa_balancing
        sudo cpupower frequency-set -r -g performance
        sudo cpupower idle-set -d 1
}

function reset_env {
	unset HIP_FORCE_DEV_KERNARG
	rocm-smi -r
	sudo sh -c echo 1 > /proc/sys/kernel/numa_balancing
}

function check_env {
	echo ""
	echo "---------------------------------------------------------------"
	echo ""

	# check the flag to force kernel to be on device memory
	echo "1. Check forcing kernel args on device memory"
	dev_kernarg=$(env | grep HIP_FORCE_DEV_KERNARG)
	if [ -z $dev_kernarg ]
	then
		echo "  no setting for forcing kernel args on device memory"
		echo "  run the command \"export HIP_FORCE_DEV_KERNARG=1\" to force it"
	else
		echo "  env var \"HIP_FORCE_DEV_KERNARG\" for forcing kernel args on device"
		echo "  memory is set, we have HIP_FORCE_DEV_KERNARG=" $HIP_FORCE_DEV_KERNARG
		if [ "$HIP_FORCE_DEV_KERNARG" -eq 0 ]
		then
			echo "  env var HIP_FORCE_DEV_KERNARG is 0, set it to 1 by:"
			echo "  command \"export HIP_FORCE_DEV_KERNARG=1\""
		fi
	fi

	echo ""
	echo ""
	echo "2. Set perfdeterminism, highest frequency"
	echo "  run the command \"rocm-smi -a | grep sclk\" to check highest frequency."
	echo "  you can run the command \"rocm-smi --setperfdeterminism # (e.g. 1900)\" to"
	echo "  set clock frequency limit to get minimal performance, which is more reproducible"
	echo "  you can restore the setting by running \"rocm-smi --resetperfdeterminism\""
	
	echo ""
	echo ""
	echo "3. Check numa autobalance"
	autobal=$(cat /proc/sys/kernel/numa_balancing)
	if [ $autobal -ne 0 ]
	then
		echo "  run the command \"sudo sh -c \'echo 0 > /proc/sys/kernel/numa_balancing\'\""
		echo "  to set numa autobalance". 
		echo "  you can disable it with \"sudo sh -c \'echo 1 > /proc/sys/kernel/numa_balancing\'\""
	else
		echo "  numa autobalance is checked with:"
		echo "  (cat /proc/sys/kernel/numa_balancing)=0"
	fi

	echo ""
	echo "---------------------------------------------------------------"
	echo ""
}


if [ $# -eq 0 ]
then
	echo "   \"env_set.sh -h\" for help info"
	print_usage
	exit 1
fi

input=$1
if [ $1 == "set" ]
then
	set_env
elif [ $1 == "reset" ]
then
	reset_env
elif [ $1 == "check" ]
then
	check_env
else
	print_usage
fi

Note (names are in alphabetical order): Jason Furmanek, Vinayak Gokhale, Jack Taylor, Peng Sun, Simon Waters, Shucai Xiao, Lei Zhang, and Lixun Zhang contribute to the contents together.