- Published on
Registers, Best Practices
- Authors

- Name
- Sriram Govindan
- @s_gowindone
As GPU Kernel developers the importance of registers are not lost on us. Registers provide the fastest access time, they are the fundamental unit behind several instructions (MMA, warp shuffle, warp broadcast, ...), and they are quite numerous (hundreds available to use per thread).
Yet, with all their versatility, registers can also cause massive performance drops. This occurs through improper access, which the compiler will try to optimize around by spilling to slower memory regions.
These problems happen way too often, enough so that In this article we are going to go over the basics on registers and some of the several ways you can accidentally cause spills.
In the next article we will dive deeper into the lower level representations of registers in SASS, and the optimizations that can be done there.

- The Different Kinds of Registers
- Allocation & Access
- Register Spillage
- Diagnosing Register spills
- Concluding remarks
- Resources
The Different Kinds of Registers
This article will require reading a decent amount of IR. In particular we will be looking at PTX generated by NVCC with cuda 13, and CDNA4 IR which is generated through the mojo stack.
To make things easier lets quickly go over the different types of registers we will be encountering in both instructions sets.
AMD Registers
There are 2 main types of registers we will be focusing on. These are scalar general purpose registers (SGPR) and vector general purpose registers (VGPR), each of which are 32 bit.
Despite the nomenclature, both registers can be indexed in a vectorized manner. When looking at the IR you will often see vgpr and sgpr followed by this indexing scheme [start:end], this simply means that the instructions applies to all registers in the inclusive range of start to end.
For example this instruction is loading 4 32 bit values into 4 registers 0, 1, 2, 3, from a global memory address held in registers 4 and 5.
s_load_dwordx4 s[0:3], s[4:5], 0x8
The real difference between Scalar and Vector registers is their accessibility. Scalar registers are fewer in number with anywhere between 16 and 102 being available (as of CDNA4), each being shared among a wave (64 threads). VGPRs are much more numerous with upto 512 registers being available per thread.
Nvidia Registers
Nvidia's PTX tends to be higher level than CDNA IR, consequently masking away the ownership level of each register. However, when looking at the lower level representation (SASS) we can see each register is backed by either a Uniform General Purpose Register or a regular General Purpose Register. The latter of which serving an equivalent role to the sgpr
For the purpose of this article we will focus on what's present in the PTX. In PTX Registers are distinguished by the % symbol, a letter(s), and a number. The letters represent the type of the register:
- f: float
- rd: 64-bit register
- r: 32-bit register
- p: predicate register
- ...
The number gives a unique ID for easy referencing.
Allocation & Access
With a general idea of the registers available lets start understanding how the compiler uses them.
Registers are the backing of most variables in your kernel code, the best way to confirm this, is to inspect the IR generated.
For example lets take a look at this:
C++ Code
struct Foo {
float bar;
float baz;
float buzz;
__host__ __device__ Foo(
const float bar,
const float baz,
const float buzz): bar(bar), baz(baz), buzz(buzz) {}
__host__ __device__ void scale() {
bar *= 108.8f;
baz *= 108.8f;
buzz *= 108.f;
}
};
__global__ void test_kernel(
float * result,
const float bar,
const float baz,
const float buzz) {
const float new_bar{bar + 10};
const float new_baz{baz + 10};
const float new_buzz{buzz + 10};
float arr[3];
Foo foo(new_bar, new_baz, new_buzz);
foo.scale();
arr[0] = foo.bar * 0.5f;
arr[1] = foo.baz * 0.8f;
arr[2] = foo.buzz * 0.9f;
result[0] = arr[0];
result[1] = arr[1];
result[2] = arr[2];
}
Equivalent MOJO code
struct Foo:
var bar: Float32
var baz: Float32
var buzz: Float32
fn __init__(out self, bar: Float32, baz: Float32, buzz: Float32):
self.bar = bar
self.baz = baz
self.buzz = buzz
fn scale(mut self):
self.bar *= 108.8
self.baz *= 108.8
self.buzz *= 108.0
fn kernel(
result: Span[Float32, MutAnyOrigin],
bar: Float32,
baz: Float32,
buzz: Float32,
):
var new_bar = bar + 10
var new_baz = baz + 10
var new_buzz = buzz + 10
var arr = StaticTuple[Float32, 3](fill=1.0)
var foo = Foo(
new_bar,
new_baz,
new_buzz,
)
foo.scale()
arr[0] = foo.bar + 0.5
arr[1] = foo.baz + 0.8
arr[2] = foo.buzz + 0.9
result[0] = arr[0]
result[1] = arr[1]
result[2] = arr[2]
PTX
.visible .entry test_kernel(float*, float, float, float)(
.param .u64 test_kernel(float*, float, float, float)_param_0,
.param .f32 test_kernel(float*, float, float, float)_param_1,
.param .f32 test_kernel(float*, float, float, float)_param_2,
.param .f32 test_kernel(float*, float, float, float)_param_3
)
{
ld.param.u64 %rd1, [test_kernel(float*, float, float, float)_param_0];
ld.param.f32 %f1, [test_kernel(float*, float, float, float)_param_1];
ld.param.f32 %f2, [test_kernel(float*, float, float, float)_param_2];
ld.param.f32 %f3, [test_kernel(float*, float, float, float)_param_3];
cvta.to.global.u64 %rd2, %rd1;
add.f32 %f4, %f1, 0f41200000;
add.f32 %f5, %f2, 0f41200000;
add.f32 %f6, %f3, 0f41200000;
mul.f32 %f7, %f4, 0f42D9999A;
mul.f32 %f8, %f5, 0f42D9999A;
mul.f32 %f9, %f6, 0f42D80000;
mul.f32 %f10, %f7, 0f3F000000;
mul.f32 %f11, %f8, 0f3F4CCCCD;
mul.f32 %f12, %f9, 0f3F666666;
st.global.f32 [%rd2], %f10;
st.global.f32 [%rd2+4], %f11;
st.global.f32 [%rd2+8], %f12;
ret;
}
CDNA4 IR
s_load_dwordx4 s[4:7], s[0:1], 0x10
s_load_dwordx2 s[2:3], s[0:1], 0x0
v_mov_b32_e32 v0, 0x41200000
v_mov_b32_e32 v2, 0x3f666666
s_mov_b32 s0, 0x42d9999a
s_waitcnt lgkmcnt(0)
v_pk_add_f32 v[4:5], s[4:5], v[0:1] op_sel_hi:[1,0]
v_add_f32_e32 v0, s6, v0
v_fmac_f32_e32 v2, 0x42d80000, v0
v_mov_b32_e32 v0, 0.5
v_mov_b32_e32 v1, 0x3f4ccccd
v_mov_b32_e32 v3, 0
v_pk_fma_f32 v[0:1], v[4:5], s[0:1], v[0:1] op_sel_hi:[1,0,1]
global_store_dwordx3 v3, v[0:2], s[2:3]
s_endpgm
IR Analysis
Looking back at the original code we can see that we have 3 different sets of variables
- scalar variables:
new_bar,new_bazandnew_buzz - arrays:
arr - structs:
foo
If we look at the IR we can see that all of these are held in registers:
On Nvidia the scalar variables use registers f4 to f6
add.f32 %f4, %f1, 0f41200000;
add.f32 %f5, %f2, 0f41200000;
add.f32 %f6, %f3, 0f41200000;
The struct uses registers f7 to f9
mul.f32 %f7, %f4, 0f42D9999A;
mul.f32 %f8, %f5, 0f42D9999A;
mul.f32 %f9, %f6, 0f42D80000;
Finally the array uses registers f10 to f12
mul.f32 %f10, %f7, 0f3F000000;
mul.f32 %f11, %f8, 0f3F4CCCCD;
mul.f32 %f12, %f9, 0f3F666666;
On AMD these data structures get fused together, but we can still see that they never leave registers.
In the first two lines we load our kernel arguments into registers. Scalar registers 4 - 7 hold the arguments bar, baz and buzz (we load 4 variables since a 3 word load is not available) and scalar registers 2 and 3 store the address of the result.
s_load_dwordx4 s[4:7], s[0:1], 0x10
s_load_dwordx2 s[2:3], s[0:1], 0x0
We move 10.0 into vector register 0, 0.9 into vector register 2 and 108.8 into scalar register 0.
v_mov_b32_e32 v0, 0x41200000
v_mov_b32_e32 v2, 0x3f666666
s_mov_b32 s0, 0x42d9999a
Next we add 10.0 to bar, baz and buzz storing the values in vector registers 4, 5, and 0 respectively.
v_pk_add_f32 v[4:5], s[4:5], v[0:1] op_sel_hi:[1,0]
v_add_f32_e32 v0, s6, v0
buzz is then multiplied by 108 and accumulated into vector register 2 which already holds 0.9. Then 0.5, 0.8 and 0 are then moved into vector registers 0, 1, and 3.
v_fmac_f32_e32 v2, 0x42d80000, v0
v_mov_b32_e32 v0, 0.5
v_mov_b32_e32 v1, 0x3f4ccccd
v_mov_b32_e32 v3, 0
We multiply both bar and baz by 108.8, We then add 0.5 to bar and 0.8 to baz. Finally we store the result back into global memory.
v_pk_fma_f32 v[0:1], v[4:5], s[0:1], v[0:1] op_sel_hi:[1,0,1]
global_store_dwordx3 v3, v[0:2], s[2:3]
This is the best case scenario. We want the compiler to place all these variable in registers, yet it's also overly simplistic. So let's discuss some common coding practices that often lead to spillage.
Register Spillage
Spilling is done by compilers to ensure that a instruction can be completed without exceeding the hardware limitations of the available registers. If your lucky the compiler can choose to spill somewhere friendly like shared or constant memory, but the most likely location tends to be local memory.
Before we discuss what causes registers to spill lets first understand local memory.
Local Memory Explained
local memory is a section of global memory that is private to each thread. On Nvidia, you can tell local memory is being used because the local qualifier will be present.
For example: st.local.v4.f32, ld.local.v4.f32.
If you notice ld and st instructions in areas where there shouldn't be, there is also a good chance data is being spilled to local memory.
On AMD local memory is referred to as scratch memory. You can tell scratch memory is being accessed because the word scratch will be prepended to the instruction.
In some cases buffer instructions such as buffer load and store may also be writing to scratch memory. If you see a random buffer instruction, be wary of the storage destination.
Since local memory is a part of global memory access times take much longer. This can cause severe performance hits when registers unexpectedly spill. In my case I've even seen performance drop in half because of it.
Using too many Registers
Now that we understand where the data spills too, what causes it to spill? Data can spill for many reasons, but one of the big ones is simply using too many registers.
Data can spill to local memory if you use more registers than are available. It's important to understand how many registers your hardware provides, and to work within those limits. This becomes especially important when working with various hardware units (ldmatrix, stmatrix, MMA) that rely heavily on registers.
Here are available registers per SM across popular Nvidia GPUs
| GPU | Architecture | Compute Capability | Registers per SM | Max Registers per Thread |
|---|---|---|---|---|
| A100 | Ampere | 8.0 | 64K (65,536) 32-bit | 255 |
| H100 | Hopper | 9.0 | 64K (65,536) 32-bit | 255 |
| B200 | Blackwell | 10.0 | 64K (65,536) 32-bit | 255 |
Here are are available registers per CU across popular AMD GPUs
| GPU | Architecture | Max Registers per Thread |
|---|---|---|
| MI300X | CDNA3 | 512 VGPR/AGPR |
| MI355X | CDNA4 | 512 VGPR/AGPR |
And here is an example where we clearly go over that limit:
C++ Code
__global__ void test_kernel(float * data, float * result, int iters) {
float arr_one[500];
float arr_two[500];
for (int i{0}; i < 500; ++i)
arr_one[i] = data[i];
for (int i{0}; i < 500; ++i)
arr_two[i] = data[i + 500];
for (int iter{0}; iter < iters; iter++){
for (int i{0}; i < 500; ++i)
result[i] += arr_one[i] * arr_two[i];
}
}
Equivalent MOJO code
@always_inline
fn kernel(
data: Span[float32, MutAnyOrigin],
result: Span[float32, MutAnyOrigin],
iters: Int
):
var arr_one = StaticTuple[float32, 1000]()
var arr_two = StaticTuple[float32, 1000]()
for i in range(1000):
arr_one[i] = data[i]
for i in range(1000):
arr_two[i] = data[i + 1000]
for i in range(iters):
for j in range(1000):
result[i] += arr_one[j] * arr_two[j]
Nvidia IR
visible .func (.param .b32 func_retval0) __cudaCDP2Malloc(
.param .b64 __cudaCDP2Malloc_param_0,
.param .b64 __cudaCDP2Malloc_param_1
)
{
mov.u32 %r1, 999;
st.param.b32 [func_retval0+0], %r1;
ret;
}
.visible .entry test_kernel(float*, float*, int)(
.param .u64 test_kernel(float*, float*, int)_param_0,
.param .u64 test_kernel(float*, float*, int)_param_1,
.param .u32 test_kernel(float*, float*, int)_param_2
)
{
mov.u64 %SPL, __local_depot1;
ld.param.u64 %rd12, [test_kernel(float*, float*, int)_param_0];
ld.param.u64 %rd11, [test_kernel(float*, float*, int)_param_1];
ld.param.u32 %r9, [test_kernel(float*, float*, int)_param_2];
cvta.to.global.u64 %rd1, %rd12;
add.u64 %rd2, %SPL, 0;
mov.u32 %r14, 0;
$L__BB1_1:
mul.wide.s32 %rd14, %r14, 4;
add.s64 %rd15, %rd1, %rd14;
add.s64 %rd16, %rd2, %rd14;
ld.global.f32 %f1, [%rd15+12];
ld.global.f32 %f2, [%rd15+8];
ld.global.f32 %f3, [%rd15+4];
ld.global.f32 %f4, [%rd15];
st.local.v4.f32 [%rd16], {%f4, %f3, %f2, %f1};
add.s32 %r14, %r14, 4;
setp.ne.s32 %p1, %r14, 500;
@%p1 bra $L__BB1_1;
cvta.to.global.u64 %rd3, %rd11;
add.u64 %rd4, %SPL, 2000;
mov.u32 %r15, 0;
$L__BB1_3:
mul.wide.s32 %rd18, %r15, 4;
add.s64 %rd19, %rd1, %rd18;
add.s64 %rd20, %rd4, %rd18;
ld.global.f32 %f5, [%rd19+2012];
ld.global.f32 %f6, [%rd19+2008];
ld.global.f32 %f7, [%rd19+2004];
ld.global.f32 %f8, [%rd19+2000];
st.local.v4.f32 [%rd20], {%f8, %f7, %f6, %f5};
add.s32 %r15, %r15, 4;
setp.ne.s32 %p2, %r15, 500;
@%p2 bra $L__BB1_3;
setp.lt.s32 %p3, %r9, 1;
@%p3 bra $L__BB1_9;
mov.u32 %r12, 0;
mov.u32 %r16, %r12;
$L__BB1_6:
mov.u64 %rd21, %rd2;
mov.u64 %rd22, %rd4;
mov.u64 %rd23, %rd3;
mov.u32 %r17, %r12;
$L__BB1_7:
ld.local.v4.f32 {%f9, %f10, %f11, %f12}, [%rd21];
ld.local.v4.f32 {%f17, %f18, %f19, %f20}, [%rd22];
ld.global.f32 %f25, [%rd23];
fma.rn.f32 %f26, %f9, %f17, %f25;
st.global.f32 [%rd23], %f26;
ld.global.f32 %f27, [%rd23+4];
fma.rn.f32 %f28, %f10, %f18, %f27;
st.global.f32 [%rd23+4], %f28;
ld.global.f32 %f29, [%rd23+8];
fma.rn.f32 %f30, %f11, %f19, %f29;
st.global.f32 [%rd23+8], %f30;
ld.global.f32 %f31, [%rd23+12];
fma.rn.f32 %f32, %f12, %f20, %f31;
st.global.f32 [%rd23+12], %f32;
add.s64 %rd23, %rd23, 16;
add.s64 %rd22, %rd22, 16;
add.s64 %rd21, %rd21, 16;
add.s32 %r17, %r17, 4;
setp.ne.s32 %p4, %r17, 500;
@%p4 bra $L__BB1_7;
add.s32 %r16, %r16, 1;
setp.lt.s32 %p5, %r16, %r9;
@%p5 bra $L__BB1_6;
$L__BB1_9:
ret;
}
CDNA4 IR
s_load_dwordx2 s[2:3], s[0:1], 0x0
s_load_dwordx4 s[4:7], s[0:1], 0x10
s_movk_i32 s12, 0xfa4
s_mov_b64 s[8:9], 0x3e9
s_waitcnt lgkmcnt(0)
s_mov_b64 s[10:11], s[2:3]
.LBB0_1:
s_load_dword s13, s[10:11], 0x0
s_waitcnt lgkmcnt(0)
v_mov_b32_e32 v0, s13
scratch_store_dword off, v0, s12
s_add_i32 s12, s12, 4
s_add_u32 s10, s10, 4
s_addc_u32 s11, s11, 0
s_add_u32 s8, s8, -1
s_addc_u32 s9, s9, -1
v_cmp_gt_u64_e64 s[14:15], s[8:9], 1
s_and_b64 vcc, exec, s[14:15]
s_cbranch_vccnz .LBB0_1
s_add_u32 s2, s2, 0xfa0
s_addc_u32 s3, s3, 0
s_mov_b32 s10, 4
s_mov_b64 s[8:9], 0x3e9
.LBB0_3:
s_load_dword s11, s[2:3], 0x0
s_waitcnt lgkmcnt(0)
v_mov_b32_e32 v0, s11
scratch_store_dword off, v0, s10
s_add_i32 s10, s10, 4
s_add_u32 s8, s8, -1
s_addc_u32 s9, s9, -1
s_add_u32 s2, s2, 4
v_cmp_gt_u64_e64 s[12:13], s[8:9], 1
s_addc_u32 s3, s3, 0
s_and_b64 vcc, exec, s[12:13]
s_cbranch_vccnz .LBB0_3
s_load_dwordx2 s[0:1], s[0:1], 0x20
s_waitcnt lgkmcnt(0)
v_cmp_lt_i64_e64 s[2:3], s[0:1], 1
s_and_b64 vcc, exec, s[2:3]
s_cbranch_vccnz .LBB0_9
v_cmp_gt_i64_e64 s[2:3], s[0:1], 0
s_and_b64 s[2:3], s[2:3], exec
s_cselect_b32 s3, s1, 0
s_cselect_b32 s2, s0, 0
v_mov_b32_e32 v0, 0
s_mov_b64 s[8:9], s[2:3]
.LBB0_6:
v_mov_b64_e32 v[2:3], s[8:9]
s_sub_u32 s10, s2, s8
v_cmp_lt_i64_e32 vcc, s[0:1], v[2:3]
s_subb_u32 s11, s3, s9
s_and_b64 s[12:13], vcc, exec
s_cselect_b32 s13, s7, 0
s_cselect_b32 s12, s6, 0
s_lshl_b64 s[10:11], s[10:11], 2
s_add_u32 s14, s4, s10
s_addc_u32 s15, s5, s11
s_lshl_b64 s[10:11], s[12:13], 2
s_add_u32 s10, s14, s10
s_addc_u32 s11, s15, s11
global_load_dword v1, v0, s[10:11]
s_mov_b32 s14, 4
s_movk_i32 s15, 0xfa4
s_mov_b64 s[12:13], 0x3e9
.LBB0_7:
scratch_load_dword v2, off, s15
scratch_load_dword v3, off, s14
s_add_i32 s14, s14, 4
s_add_i32 s15, s15, 4
s_add_u32 s12, s12, -1
s_addc_u32 s13, s13, -1
v_cmp_lt_u64_e64 s[16:17], s[12:13], 2
s_and_b64 vcc, exec, s[16:17]
s_waitcnt vmcnt(0)
v_fmac_f32_e32 v1, v2, v3
s_cbranch_vccz .LBB0_7
s_add_u32 s12, s8, -1
s_addc_u32 s13, s9, -1
v_cmp_lt_u64_e64 s[8:9], s[8:9], 2
s_and_b64 vcc, exec, s[8:9]
s_mov_b64 s[8:9], s[12:13]
global_store_dword v0, v1, s[10:11]
s_cbranch_vccz .LBB0_6
.LBB0_9:
s_endpgm
In the PTX lines such as
st.local.v4.f32 [%rd16], {%f4, %f3, %f2, %f1};
st.local.v4.f32 [%rd20], {%f8, %f7, %f6, %f5};
and
ld.local.v4.f32 {%f9, %f10, %f11, %f12}, [%rd21];
ld.local.v4.f32 {%f17, %f18, %f19, %f20}, [%rd22];
make it evident that local memory is being written too and read from. In particular the backings of arr_one and arr_two seem to be entirely in local memory.
Same goes for AMD through the use of scratch instructions:
scratch_store_dword off, v0, s12
scratch_store_dword off, v0, s10
scratch_load_dword v2, off, s15
scratch_load_dword v3, off, s14
While this is an obvious case of going over limits, it's still possible to exceed register limits even if your initial allocation is small. This mainly occurs through unexpected copies usually induced by function calls.
Solutions
Tiling
The easiest way to ensure you are not using more registers than are available is to use an iterative approach. Load your data into registers in segments, perform your calculations, and replace the data in the registers with the next set.
Libraries such as cutlass heavily rely on this strategy through the use of Tiled Layouts and slicing.
For example in this code snippet adapted from here. For every iteration in the for loop, data is being loaded into register groups tXrA (registers group A for thread x) and tXrB (registers group B for thread x). These registers will then get overwritten for every iteration in the while loop, allowing the program to reuse registers and be under the limit.
CUTE_NO_UNROLL
while (k_tile_count > -(K_PIPE_MAX-1))
{
tXsA_p = tXsA(_,_,_,smem_pipe_read);
tXsB_p = tXsB(_,_,_,smem_pipe_read);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
copy(s2r_atom_a, tXsA_p(_,_,k_block_next), tXrA(_,_,k_block_next));
copy(s2r_atom_b, tXsB_p(_,_,k_block_next), tXrB(_,_,k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
// Thread-level register gemm for k_block
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
smem_pipe_read += 1
}
Register Reallocation
On Hopper and above, instructions such as setmaxnreg allow developers to allocate more registers to a warpgroup (4 warps, 128 threads) by taking away registers from all other threads in the block.
Runtime indexing
The next way data can spill, is through runtime indexing.
When working with arrays, you must ensure that the indices used to access the array are defined at compile time. This is because GPU arrays (when accessed correctly) are just a grouping of registers. Registers are not addressable, if you index a C style array with a runtime value j the compiler knows to grab the data at the current address + j. If you do the same with a register backed GPU array, the compiler would have to emit instructions that determine which register correlates to index j. Following that strategy would result in instruction bloat and performance degradation. So for performance and simplicity the compiler will generally choose to have the array be backed by local memory instead.
On the other hand if the indices are known at compile time, the compiler can instead systematically replace each index access with the corresponding register.

In this example we can see both runtime and compile time indexing in action
C++ Code
__global__ void test_kernel(
const float * data,
float * result,
int K,
int BK
) {
int iters{K / BK};
constexpr int len{20};
float arr[len];
for (int i{0}; i < iters; ++i){
for (int j{0}; j < BK; ++j){
arr[iters * BK + j] = data[iters * BK + j];
}
}
#pragma unroll
for (int i{0}; i < len; ++i){
result[i] = arr[i] * i;
}
}
Equivalent Mojo Code
@always_inline
fn kernel(
data: Span[Float32, MutAnyOrigin],
result: Span[Float32, MutAnyOrigin],
k: UInt32,
BK: UInt32
):
var iters = k / BK
comptime len = 20
var arr = StaticTuple[Float32, len]()
for i in range(iters):
for j in range(BK):
arr[i * BK + j] = data[UInt(i * BK + j)]
@parameter
for i in range(len):
result[UInt(i)] = arr[i] * i
PTX
.visible .entry test_kernel(float const*, float*, int, int)(
.param .u64 test_kernel(float const*, float*, int, int)_param_0,
.param .u64 test_kernel(float const*, float*, int, int)_param_1,
.param .u32 test_kernel(float const*, float*, int, int)_param_2,
.param .u32 test_kernel(float const*, float*, int, int)_param_3
)
{
mov.u64 %SPL, __local_depot1;
ld.param.u64 %rd11, [test_kernel(float const*, float*, int, int)_param_0];
ld.param.u64 %rd10, [test_kernel(float const*, float*, int, int)_param_1];
ld.param.u32 %r14, [test_kernel(float const*, float*, int, int)_param_2];
ld.param.u32 %r13, [test_kernel(float const*, float*, int, int)_param_3];
cvta.to.global.u64 %rd1, %rd11;
add.u64 %rd2, %SPL, 0;
div.s32 %r1, %r14, %r13;
setp.lt.s32 %p1, %r1, 1;
mov.f32 %f146, 0f7FC00000;
mov.f32 %f147, 0f7FC00000;
mov.f32 %f148, 0f7FC00000;
mov.f32 %f149, 0f7FC00000;
mov.f32 %f150, 0f7FC00000;
mov.f32 %f151, 0f7FC00000;
mov.f32 %f152, 0f7FC00000;
mov.f32 %f153, 0f7FC00000;
mov.f32 %f154, 0f7FC00000;
mov.f32 %f155, 0f7FC00000;
mov.f32 %f156, 0f7FC00000;
mov.f32 %f157, 0f7FC00000;
mov.f32 %f158, 0f7FC00000;
mov.f32 %f159, 0f7FC00000;
mov.f32 %f160, 0f7FC00000;
mov.f32 %f161, 0f7FC00000;
mov.f32 %f162, 0f7FC00000;
mov.f32 %f163, 0f7FC00000;
mov.f32 %f165, 0f7FC00000;
@%p1 bra $L__BB1_12;
setp.lt.s32 %p2, %r13, 1;
mul.lo.s32 %r2, %r1, %r13;
@%p2 bra $L__BB1_12;
add.s32 %r3, %r13, -1;
and.b32 %r4, %r13, 3;
sub.s32 %r5, %r13, %r4;
mul.wide.s32 %rd3, %r2, 4;
mov.u32 %r15, 0;
mov.u32 %r19, %r15;
$L__BB1_3:
setp.lt.u32 %p3, %r3, 3;
mov.u32 %r22, %r15;
@%p3 bra $L__BB1_6;
mov.u32 %r22, 0;
mov.u64 %rd17, %rd1;
mov.u64 %rd18, %rd2;
mov.u32 %r21, %r5;
$L__BB1_5:
add.s64 %rd13, %rd17, %rd3;
ld.global.f32 %f81, [%rd13];
add.s64 %rd14, %rd18, %rd3;
st.local.f32 [%rd14], %f81;
ld.global.f32 %f82, [%rd13+4];
st.local.f32 [%rd14+4], %f82;
ld.global.f32 %f83, [%rd13+8];
st.local.f32 [%rd14+8], %f83;
ld.global.f32 %f84, [%rd13+12];
st.local.f32 [%rd14+12], %f84;
add.s32 %r22, %r22, 4;
add.s64 %rd18, %rd18, 16;
add.s64 %rd17, %rd17, 16;
add.s32 %r21, %r21, -4;
setp.ne.s32 %p4, %r21, 0;
@%p4 bra $L__BB1_5;
$L__BB1_6:
setp.eq.s32 %p5, %r4, 0;
@%p5 bra $L__BB1_10;
setp.eq.s32 %p6, %r4, 1;
add.s32 %r18, %r22, %r2;
mul.wide.s32 %rd15, %r18, 4;
add.s64 %rd8, %rd1, %rd15;
ld.global.f32 %f85, [%rd8];
add.s64 %rd9, %rd2, %rd15;
st.local.f32 [%rd9], %f85;
@%p6 bra $L__BB1_10;
setp.eq.s32 %p7, %r4, 2;
ld.global.f32 %f86, [%rd8+4];
st.local.f32 [%rd9+4], %f86;
@%p7 bra $L__BB1_10;
ld.global.f32 %f87, [%rd8+8];
st.local.f32 [%rd9+8], %f87;
$L__BB1_10:
add.s32 %r19, %r19, 1;
setp.lt.s32 %p8, %r19, %r1;
@%p8 bra $L__BB1_3;
ld.local.v4.f32 {%f88, %f164, %f90, %f91}, [%rd2];
ld.local.v4.f32 {%f95, %f96, %f97, %f98}, [%rd2+16];
ld.local.v4.f32 {%f103, %f104, %f105, %f106}, [%rd2+32];
ld.local.v4.f32 {%f111, %f112, %f113, %f114}, [%rd2+48];
ld.local.v4.f32 {%f119, %f120, %f121, %f122}, [%rd2+64];
mul.f32 %f165, %f88, 0f00000000;
add.f32 %f163, %f90, %f90;
mul.f32 %f162, %f91, 0f40400000;
mul.f32 %f161, %f95, 0f40800000;
mul.f32 %f160, %f96, 0f40A00000;
mul.f32 %f159, %f97, 0f40C00000;
mul.f32 %f158, %f98, 0f40E00000;
mul.f32 %f157, %f103, 0f41000000;
mul.f32 %f156, %f104, 0f41100000;
mul.f32 %f155, %f105, 0f41200000;
mul.f32 %f154, %f106, 0f41300000;
mul.f32 %f153, %f111, 0f41400000;
mul.f32 %f152, %f112, 0f41500000;
mul.f32 %f151, %f113, 0f41600000;
mul.f32 %f150, %f114, 0f41700000;
mul.f32 %f149, %f119, 0f41800000;
mul.f32 %f148, %f120, 0f41880000;
mul.f32 %f147, %f121, 0f41900000;
mul.f32 %f146, %f122, 0f41980000;
$L__BB1_12:
cvta.to.global.u64 %rd16, %rd10;
st.global.f32 [%rd16], %f165;
st.global.f32 [%rd16+4], %f164;
st.global.f32 [%rd16+8], %f163;
st.global.f32 [%rd16+12], %f162;
st.global.f32 [%rd16+16], %f161;
st.global.f32 [%rd16+20], %f160;
st.global.f32 [%rd16+24], %f159;
st.global.f32 [%rd16+28], %f158;
st.global.f32 [%rd16+32], %f157;
st.global.f32 [%rd16+36], %f156;
st.global.f32 [%rd16+40], %f155;
st.global.f32 [%rd16+44], %f154;
st.global.f32 [%rd16+48], %f153;
st.global.f32 [%rd16+52], %f152;
st.global.f32 [%rd16+56], %f151;
st.global.f32 [%rd16+60], %f150;
st.global.f32 [%rd16+64], %f149;
st.global.f32 [%rd16+68], %f148;
st.global.f32 [%rd16+72], %f147;
st.global.f32 [%rd16+76], %f146;
ret;
}
In the first loop all the indexes are runtime based. This results in all our stores getting spilled to local memory.
st.local.f32 [%rd14], %f81;
ld.global.f32 %f82, [%rd13+4];
st.local.f32 [%rd14+4], %f82;
ld.global.f32 %f83, [%rd13+8];
st.local.f32 [%rd14+8], %f83;
ld.global.f32 %f84, [%rd13+12];
st.local.f32 [%rd14+12], %f84;
In the second loop all our indices, are compile time based. To make up for it, the compiler decides to move all the data back into it's designated registers during the multiplication step.
// load from local memory
ld.local.v4.f32 {%f88, %f164, %f90, %f91}, [%rd2];
ld.local.v4.f32 {%f95, %f96, %f97, %f98}, [%rd2+16];
ld.local.v4.f32 {%f103, %f104, %f105, %f106}, [%rd2+32];
ld.local.v4.f32 {%f111, %f112, %f113, %f114}, [%rd2+48];
ld.local.v4.f32 {%f119, %f120, %f121, %f122}, [%rd2+64];
// multiply and store back to designated register
mul.f32 %f162, %f91, 0f40400000;
mul.f32 %f161, %f95, 0f40800000;
mul.f32 %f160, %f96, 0f40A00000;
...
From there, the data get's transported back to global memory
st.global.f32 [%rd16+12], %f162;
...
Runtime Register Indexing
On the AMD side of things some cool optimizations take place that help us avoid the register spillage.
CDNA4 IR
s_load_dwordx2 s[4:5], s[0:1], 0x20
s_load_dwordx2 s[2:3], s[0:1], 0x10
s_waitcnt lgkmcnt(0)
v_cvt_f32_u32_e32 v0, s5
s_cmp_gt_u32 s5, s4
v_rcp_iflag_f32_e32 v0, v0
s_nop 0
v_mul_f32_e32 v0, 0x4f7ffffe, v0
v_cvt_u32_f32_e32 v0, v0
s_nop 0
v_readfirstlane_b32 s6, v0
s_cbranch_scc1 .LBB0_6
s_sub_i32 s7, 0, s5
s_mul_i32 s7, s7, s6
s_mul_hi_u32 s7, s6, s7
s_add_i32 s6, s6, s7
s_mul_hi_u32 s6, s4, s6
s_mul_i32 s7, s6, s5
s_sub_i32 s4, s4, s7
s_add_i32 s7, s6, 1
s_sub_i32 s8, s4, s5
s_load_dwordx2 s[0:1], s[0:1], 0x0
s_cmp_ge_u32 s4, s5
s_mov_b32 s37, 0
s_cselect_b32 s6, s7, s6
s_cselect_b32 s4, s8, s4
s_add_i32 s7, s6, 1
s_mov_b32 s36, s37
s_cmp_ge_u32 s4, s5
s_mov_b32 s38, s37
s_mov_b32 s39, s37
s_mov_b32 s40, s37
s_mov_b32 s41, s37
s_mov_b32 s42, s37
s_mov_b32 s43, s37
s_mov_b32 s44, s37
s_mov_b32 s45, s37
s_mov_b32 s46, s37
s_mov_b32 s47, s37
s_mov_b32 s48, s37
s_mov_b32 s49, s37
s_mov_b32 s50, s37
s_mov_b32 s51, s37
s_mov_b32 s52, s37
s_mov_b32 s53, s37
s_mov_b32 s54, s37
s_mov_b32 s55, s37
v_mov_b64_e32 v[0:1], s[36:37]
s_cselect_b32 s4, s7, s6
v_mov_b64_e32 v[2:3], s[38:39]
v_mov_b64_e32 v[4:5], s[40:41]
v_mov_b64_e32 v[6:7], s[42:43]
v_mov_b64_e32 v[8:9], s[44:45]
v_mov_b64_e32 v[10:11], s[46:47]
v_mov_b64_e32 v[12:13], s[48:49]
v_mov_b64_e32 v[14:15], s[50:51]
v_mov_b64_e32 v[16:17], s[52:53]
v_mov_b64_e32 v[18:19], s[54:55]
v_mov_b64_e32 v[20:21], s[56:57]
v_mov_b64_e32 v[22:23], s[58:59]
v_mov_b64_e32 v[24:25], s[60:61]
v_mov_b64_e32 v[26:27], s[62:63]
v_mov_b64_e32 v[28:29], s[64:65]
v_mov_b64_e32 v[30:31], s[66:67]
s_mov_b32 s6, 0
.LBB0_2:
s_mov_b32 s36, s6
s_mov_b32 s7, s5
.LBB0_3:
s_add_i32 s7, s7, -1
s_lshl_b64 s[8:9], s[36:37], 2
s_waitcnt lgkmcnt(0)
s_add_u32 s8, s0, s8
s_addc_u32 s9, s1, s9
s_load_dword s8, s[8:9], 0x0
s_waitcnt lgkmcnt(0)
v_mov_b32_e32 v32, s8
s_set_gpr_idx_on s36, gpr_idx(DST)
v_mov_b32_e32 v0, v32
s_set_gpr_idx_off
s_add_i32 s36, s36, 1
s_cmp_eq_u32 s7, 0
s_cbranch_scc0 .LBB0_3
s_add_i32 s4, s4, -1
s_add_i32 s6, s6, s5
s_cmp_eq_u32 s4, 0
s_cbranch_scc0 .LBB0_2
v_mul_f32_e32 v0, 0, v0
s_branch .LBB0_7
.LBB0_6:
v_mov_b32_e32 v0, 0x7fc00000
v_mov_b32_e32 v1, 0
v_mov_b32_e32 v2, 0
v_mov_b32_e32 v3, 0
v_mov_b32_e32 v4, 0
v_mov_b32_e32 v5, 0
v_mov_b32_e32 v6, 0
v_mov_b32_e32 v7, 0
v_mov_b32_e32 v8, 0
v_mov_b32_e32 v9, 0
v_mov_b32_e32 v10, 0
v_mov_b32_e32 v11, 0
v_mov_b32_e32 v12, 0
v_mov_b32_e32 v13, 0
v_mov_b32_e32 v14, 0
v_mov_b32_e32 v15, 0
v_mov_b32_e32 v16, 0
v_mov_b32_e32 v17, 0
v_mov_b32_e32 v18, 0
v_mov_b32_e32 v19, 0
.LBB0_7:
s_mov_b32 s0, 2.0
s_mov_b32 s1, 0x40400000
v_pk_mul_f32 v[2:3], v[2:3], s[0:1]
s_mov_b32 s0, 4.0
v_mov_b32_e32 v20, 0
s_mov_b32 s1, 0x40a00000
global_store_dwordx4 v20, v[0:3], s[2:3]
s_nop 1
v_pk_mul_f32 v[0:1], v[4:5], s[0:1]
s_mov_b32 s0, 0x40c00000
s_mov_b32 s1, 0x40e00000
v_pk_mul_f32 v[2:3], v[6:7], s[0:1]
s_mov_b32 s0, 0x41000000
s_mov_b32 s1, 0x41100000
global_store_dwordx4 v20, v[0:3], s[2:3] offset:16
s_nop 1
v_pk_mul_f32 v[0:1], v[8:9], s[0:1]
s_mov_b32 s0, 0x41200000
s_mov_b32 s1, 0x41300000
v_pk_mul_f32 v[2:3], v[10:11], s[0:1]
s_mov_b32 s0, 0x41400000
s_mov_b32 s1, 0x41500000
global_store_dwordx4 v20, v[0:3], s[2:3] offset:32
s_nop 1
v_pk_mul_f32 v[0:1], v[12:13], s[0:1]
s_mov_b32 s0, 0x41600000
s_mov_b32 s1, 0x41700000
v_pk_mul_f32 v[2:3], v[14:15], s[0:1]
s_mov_b32 s0, 0x41800000
s_mov_b32 s1, 0x41880000
global_store_dwordx4 v20, v[0:3], s[2:3] offset:48
s_nop 1
v_pk_mul_f32 v[0:1], v[16:17], s[0:1]
s_mov_b32 s0, 0x41900000
s_mov_b32 s1, 0x41980000
v_pk_mul_f32 v[2:3], v[18:19], s[0:1]
global_store_dwordx4 v20, v[0:3], s[2:3] offset:64
s_endpgm
Lets start by takin a look at section LBB0_3 which correlates to the first for loop:
for (int i{0}; i < iters; ++i){
for (int j{0}; j < BK; ++j){
arr[iters * BK + j] = data[iters * BK + j];
}
}
s7 (scalar register 7) starts off being initialized to BK, when this value is decremented to 0 this signals the inner for loop has completed.
s36 and s37 hold the 64 bit offset we will be adding to the data ptr, this is the same as iters * BK + j. To get the exact number in bytes it is left shifted by 2 (4 bytes per float).
s_add_i32 s7, s7, -1
s_lshl_b64 s[8:9], s[36:37], 2
s0 and s1 contain the 64 bit address of the data pointer. We add the upper and lower bits of our offset in s36 and s37 to get the correct address we want to load from. The value is then loaded and stored into s8.
s_add_u32 s8, s0, s8
s_addc_u32 s9, s1, s9
s_load_dword s8, s[8:9], 0x0
s8 is moved to a local vector register v32, it is here that the most important change in the assembly occurs. We turn on general purpose register indexing using the s_set_gpr_idx_on instruction. We also identify s36 as the register that will contain our index. We then move our data stored in the v32 into register v0, however since s_set_gpr_idx_on is activated this moves the data into v[0 + s36] making this instruction equivalent too arr[iters * BK + j] = data[iters * BK + j] without spilling to scratch memory.
v_mov_b32_e32 v32, s8
s_set_gpr_idx_on s36, gpr_idx(DST)
v_mov_b32_e32 v0, v32
By utilizing the s_set_gpr_idx_on instruction, AMD provides a method to perform runtime indexing on registers. It's a powerful tool, but still not a complete replacement for compile time indexing. Oning and offing gpr indexing has its own instruction overhead, and since the offset provided is a SGPR the offset must be common among all threads.
For example in a scenario like this where the offset is unique to each thread, register spills occur.
@always_inline
fn kernel(
data: Span[Float32, MutAnyOrigin],
result: Span[Float32, MutAnyOrigin],
load_length: UInt32,
):
comptime len = 5 * 8
var arr = StaticTuple[Float32, len]()
var start_idx = (thread_idx.x // 8) * load_length
for j in range(load_length):
arr[start_idx + j] = data[start_idx + j]
for i in range(load_length):
result[thread_idx.x * load_length + i] = (
arr[start_idx + i] * Float32(i)
)
s_load_dword s6, s[0:1], 0x20
s_mov_b32 s7, 0
s_waitcnt lgkmcnt(0)
s_cmp_eq_u32 s6, 0
s_cbranch_scc1 .LBB0_5
s_load_dwordx2 s[4:5], s[0:1], 0x0
s_load_dwordx2 s[2:3], s[0:1], 0x10
v_lshrrev_b32_e32 v1, 3, v0
v_mul_lo_u32 v1, s6, v1
s_mov_b32 s0, 0
v_lshl_add_u32 v4, v1, 2, s0
v_mov_b32_e32 v3, 0
.LBB0_2:
v_add_u32_e32 v2, s7, v1
s_waitcnt lgkmcnt(0)
v_lshl_add_u64 v[6:7], v[2:3], 2, s[4:5]
global_load_dword v2, v[6:7], off
s_add_i32 s7, s7, 1
s_cmp_lg_u32 s6, s7
s_waitcnt vmcnt(0)
scratch_store_dword v4, v2, off
v_add_u32_e32 v4, 4, v4
s_cbranch_scc1 .LBB0_2
s_mov_b32 s0, 0
v_mul_lo_u32 v2, s6, v0
v_lshl_add_u32 v3, v1, 2, s0
v_mov_b32_e32 v1, 0
.LBB0_4:
scratch_load_dword v6, v3, off
v_cvt_f32_u32_e32 v7, s0
v_add_u32_e32 v0, s0, v2
s_add_i32 s0, s0, 1
v_add_u32_e32 v3, 4, v3
v_lshl_add_u64 v[4:5], v[0:1], 2, s[2:3]
s_cmp_eq_u32 s6, s0
s_waitcnt vmcnt(0)
v_mul_f32_e32 v0, v6, v7
global_store_dword v[4:5], v0, off
s_cbranch_scc0 .LBB0_4
.LBB0_5:
s_endpgm
Solutions
The main solution to runtime indexing, is using compile time indexing. In sections where runtime indexing is required ensure that it is not used to derive indices for your registers. When possible also use loop unrolling to ensure that the loop parameters are known at compile time. In C++ that is done using #pragma unroll, in mojo its with the @parameter decorator.
Addressing
Taking the address of any of your variables is a sure fire way to spill it. Register based variables are not addressable, so the compiler will move the data to memory to ensure an address can be taken.
C++ Code
struct Foo {
float * bar;
float * baz;
float * buzz;
__host__ __device__ Foo(
float * bar,
float * baz,
float * buzz
): bar(bar), baz(baz), buzz(buzz) {}
__host__ __device__ void scale(float (* arr)[3]) {
(*arr)[0] *= *bar * 108.f;
(*arr)[1] *= *baz * 108.108f;
(*arr)[2] *= *buzz * 108.8f;
}
};
__global__ void test_kernel(float * result, const float bar, const float baz, const float buzz) {
float new_bar{bar + threadIdx.x};
float new_baz{baz + threadIdx.x};
float new_buzz{buzz + threadIdx.x};
float arr[3];
float (* arr_p)[3]{&arr};
Foo foo(&new_bar, &new_baz, &new_buzz);
foo.scale(arr_p);
result[0] = arr[0];
result[1] = arr[1];
result[2] = arr[2];
}
equivalent Mojo Code
struct Foo:
comptime FloatPtrType = UnsafePointer[Float32, MutAnyOrigin]
var bar: Self.FloatPtrType
var baz: Self.FloatPtrType
var buzz: Self.FloatPtrType
fn __init__(out self, bar: Self.FloatPtrType, baz: Self.FloatPtrType, buzz: Self.FloatPtrType):
self.bar = bar
self.baz = baz
self.buzz = buzz
fn scale(self, arr: UnsafePointer[StaticTuple[Float32, 3], MutAnyOrigin]):
var arr2 = arr[]
arr[][0] *= self.bar[] * 108.0
arr[][1] *= self.baz[] * 108.108
arr[][2] *= self.buzz[] * 108.8
fn kernel(
result: Span[Float32, MutAnyOrigin],
bar: Float32,
baz: Float32,
buzz: Float32,
):
var new_bar = bar + 10
var new_baz = baz + 10
var new_buzz = buzz + 10
var arr = StaticTuple[Float32, 3](fill=1.0)
var foo = Foo(
UnsafePointer(to=new_bar),
UnsafePointer(to=new_baz),
UnsafePointer(to=new_buzz)
)
foo.scale(UnsafePointer(to=arr))
result[0] = arr[0]
result[1] = arr[1]
result[2] = arr[2]
Nvidia IR
.visible .func (.param .b32 func_retval0) __cudaCDP2Malloc(
.param .b64 __cudaCDP2Malloc_param_0,
.param .b64 __cudaCDP2Malloc_param_1
)
{
mov.u32 %r1, 999;
st.param.b32 [func_retval0+0], %r1;
ret;
}
.visible .entry test_kernel(float*, float, float, float)(
.param .u64 test_kernel(float*, float, float, float)_param_0,
.param .f32 test_kernel(float*, float, float, float)_param_1,
.param .f32 test_kernel(float*, float, float, float)_param_2,
.param .f32 test_kernel(float*, float, float, float)_param_3
)
{
ld.param.u64 %rd1, [test_kernel(float*, float, float, float)_param_0];
cvta.to.global.u64 %rd2, %rd1;
mov.u32 %r1, 2143289344;
st.global.u32 [%rd2], %r1;
st.global.u32 [%rd2+4], %r1;
st.global.u32 [%rd2+8], %r1;
ret;
}
.visible .func Foo::Foo(float*, float*, float*)(
.param .b64 Foo::Foo(float*, float*, float*)_param_0,
.param .b64 Foo::Foo(float*, float*, float*)_param_1,
.param .b64 Foo::Foo(float*, float*, float*)_param_2,
.param .b64 Foo::Foo(float*, float*, float*)_param_3
)
{
ld.param.u64 %rd1, [Foo::Foo(float*, float*, float*)_param_0];
ld.param.u64 %rd2, [Foo::Foo(float*, float*, float*)_param_1];
ld.param.u64 %rd3, [Foo::Foo(float*, float*, float*)_param_2];
ld.param.u64 %rd4, [Foo::Foo(float*, float*, float*)_param_3];
st.u64 [%rd1], %rd2;
st.u64 [%rd1+8], %rd3;
st.u64 [%rd1+16], %rd4;
ret;
}
.visible .func Foo::scale(float (*) [3])(
.param .b64 Foo::scale(float (*) [3])_param_0,
.param .b64 Foo::scale(float (*) [3])_param_1
)
{
ld.param.u64 %rd1, [Foo::scale(float (*) [3])_param_0];
ld.param.u64 %rd2, [Foo::scale(float (*) [3])_param_1];
ld.u64 %rd3, [%rd1];
ld.f32 %f1, [%rd3];
mul.f32 %f2, %f1, 0f42D80000;
ld.f32 %f3, [%rd2];
mul.f32 %f4, %f3, %f2;
st.f32 [%rd2], %f4;
ld.u64 %rd4, [%rd1+8];
ld.f32 %f5, [%rd4];
mul.f32 %f6, %f5, 0f42D8374C;
ld.f32 %f7, [%rd2+4];
mul.f32 %f8, %f7, %f6;
st.f32 [%rd2+4], %f8;
ld.u64 %rd5, [%rd1+16];
ld.f32 %f9, [%rd5];
mul.f32 %f10, %f9, 0f42D9999A;
ld.f32 %f11, [%rd2+8];
mul.f32 %f12, %f11, %f10;
st.f32 [%rd2+8], %f12;
ret;
}
.visible .func Foo::Foo(float*, float*, float*)(
.param .b64 Foo::Foo(float*, float*, float*)_param_0,
.param .b64 Foo::Foo(float*, float*, float*)_param_1,
.param .b64 Foo::Foo(float*, float*, float*)_param_2,
.param .b64 Foo::Foo(float*, float*, float*)_param_3
)
{
ld.param.u64 %rd1, [Foo::Foo(float*, float*, float*)_param_0];
ld.param.u64 %rd2, [Foo::Foo(float*, float*, float*)_param_1];
ld.param.u64 %rd3, [Foo::Foo(float*, float*, float*)_param_2];
ld.param.u64 %rd4, [Foo::Foo(float*, float*, float*)_param_3];
st.u64 [%rd1], %rd2;
st.u64 [%rd1+8], %rd3;
st.u64 [%rd1+16], %rd4;
ret;
$L__func_end0:
}
By analyzing the IR we can see that ld and st instructions appear when they should not. In this case no address space identifier is attached e.g. ld.local, ld.shared etc. This signals that the address space is generic which means that an appropriate address space will be chosen at runtime. While this does not confirm that the data will be spilt to local memory, there is still a high likelihood that it gets chosen over shared or other address spaces.
Passing your array to a function that expects a pointer, will also cause a similar effect. This happens because the function will cause your array to decay to a pointer type, forcing the compiler to assume it was a memory backed from the start.
These issues are not exclusive to Nvidia, AMD can also face the same problems, however their compiler is much better at optimizing it away.
Diagnosing Register spills
Dumping assembly
The best way to understand why your registers spilled is to read the assembly.
In mojo we have this amazing function enqueue_function_checked that can dump ptx, raw assembly, or even llvm code for any gpu function.
On the C++ side, NVCC offers compilation flags such as --keep that allow you to keep all intermediate files. You then have access to the ptx and the cubin. The cubin can be converted to sass using
cuobjdump -sass /my/path/to/program.cubin
Alternatively you can compile your code with --generate-line-info and profile it using NCU. From here you can see the lines in your C++ that your assembly corresponds too.
On AMD --save-temps performs the same functionality as --keep, and you can use llvm-objdump -d as a replacement for cuobjdump.
If this seems like too much work, analyzing the IR on godbolt is a easy way to get what you need.
Helpful Flags
For Cuda flags such as -Xnvlink=--verbose -Xptxas=--verbose -Xptxas=--warn-on-spills" print memory usage stats, and warn of spills at compile time!
The equivalent for AMD is -Rpass-analysis=kernel-resource-usage.
Concluding remarks
This only scratches the surface, in a followup article we will dive even deeper into the various types of registers, along with other optimizations that can be done.
In the meantime if you have any feedback/questions/corrections feel free to email me at sriramgovindanwork@gmail.com.