Published on

ldmatrix Explained

Authors

Introduction

In this article I will share how the ldmatrix instructions works and how / when to use it.

Throughout this article I will make constant references to the PTX ISA along with other resources. Please make sure to check them out.

  1. PTX ISA
  2. CuTe Documentation

What is ldmatrix?

ldmatrix is a warp level ptx instruction introduced with the turing architecture (SM75). This is the same architecture used in Telsa T40 and the RTX 20 series cards. It allows users to load matrices from shared memory to registers directly in the format an MMA instruction would use.

ldmatrix can currently only load 3 different matrix shapes. Each matrix shape is only compatible with certain datatype(s).

shapeMatrix shapeElement size
.m8n88x816-bit
.m16n1616x168-bit or 6-bit or 4-bit
.m8n168x166-bit or 4-bit

In this article, we will focus on 8x8, in the future we may go over the others in detail.

The 8x8 matrix supports 16 bit elements, these are your half precision floats, FP16 and BF16 respectively.

After calling ldmatrix each thread will hold a portion of the original matrix in their registers, the values will correlate to this image.

thread value layout post ldmatrix operation

From this we can see that 4 sequential threads (i.e. threads 0 - 3) will hold an entire row of the matrix.

In order for ldmatrix to work, the address of the matrix must be provided. This can cause issues, especially if the matrix being loaded is not contiguous (a common case especially when tiling, see image below).

non contiguous memory example

To solve this the ldmatrix instruction requests that threads provide the addresses of specific rows.

.numThreads 0–7Threads 8–15Threads 16–23Threads 24–31
.x1addr0–addr7
.x2addr0–addr7addr8–addr15
.x4addr0–addr7addr8–addr15addr16–addr23addr24–addr31

For example when setting the .num qualifier to .x1. ldmatrix requests that threads 0 - 7 provide the addresses for rows 0 - 7. This allows for rows of the matrix to not be contiguous in memory.

You may have also noticed that when the .num qualifier increases more row addresses need to be provided. That is because .num delineates how many matrices are being loaded at once. The more matrices being loaded the more addresses are needed. Generally speaking you should try to use 4 matrices (.x4) whenever using ldmatrix.

When to use ldmatrix

To perform efficient matrix multiplication, data is moved to various memory locations to facilitate faster access times.

The most common pattern is to tile global memory, transfer that tile to shared memory, then partition shared memory and transfer it into registers for matmul operations. If you are using an MMA operation then the data in registers actually needs to match a specific layout. If this layout isn't followed, the computation may fail or produce invalid results.

This is where ldmatrix, really helps. ldmatrix ensures that the every thread in the warp will have the correct value to match the MMA layout.

ldmatrix is not compatible with every MMA op, but generally speaking if the MMA uses the same datatype as ldmatrix you can assume it is compatible (this is especially true for Ampere).

Tensor Core Compatability

To prove my point here are the layouts of two different MMA's, that use half precision floats.

SM80_16x8x8_F16F16F16F16_TN
SM80_16x8x8_F16F16F16F16_TN visual layout
SM80_16x8x16_F16F16F16F16_TN
SM80_16x8x16_F16F16F16F16_TN visual layout

The images are showing three different layouts

  • Matrix A on the bottom left
  • Matrix B on the top right
  • Matrix C on the bottom right

C = A * B

Since we only use ldmatrix to load data into the A and B matrices that's what we will focus on.

If you look at Matrix A for both MMA's you will see that they consist of sequences of the ldmatrix layout.

matrix a split into spereate ld matrices

The same goes for matrix B, the only difference is that it's transposed.

Now if we look at an MMA that uses a different datatype you will see a different story.

SM80_16x8x4_F32TF32TF32F32_TN
SM80_16x8x4_F32TF32TF32F32_TN visual layout

It's completely incompatible with ldmatrix. This makes sense ldmatrix only works with very specific data types and this Tensor Core is only compatible with TensorFloat32.

Using ldmatrix in PTX

This is the inline assembly for ldmatrix

// 1 matrix
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
    : "=r"(dst)
    :  "r"(smem_int_ptr));
// 2 matrices
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
    : "=r"(dst0), "=r"(dst1)
    :  "r"(smem_int_ptr));
// 4 matrices
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
    : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
    :  "r"(smem_int_ptr));

There is also an optional transposed qualifier incase your shared memory data is stored in column major format

asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
    : "=r"(dst)
    :  "r"(smem_int_ptr));

Demo Code

Here's some demo code, showcasing how to use the ldmatrix instruction (you will need cutlass to access the half_t struct):

template<typename DataType>
__device__ static void copy(
    const DataType * smem_addr,
    uint32_t &dst) {
    const uint32_t addr = __cvta_generic_to_shared(smem_addr);
    asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
        : "=r"(dst)
        : "r"(addr));
}

template<size_t smem_rows = 64, size_t smem_columns = 64>
__global__ void ldmatrix_x1_ptx() {
    constexpr size_t total_el{smem_rows * smem_columns};
    __shared__ half_t smem[total_el];

    if (thread0()) {
        // fill shared memory
        for (size_t i{0}; i < total_el; ++i) {
            smem[i] = half_t::convert(static_cast<int>(i));
        }
    }

    __syncthreads();

    // calculate which row address to pass to the instruction
    const uint32_t row_id{threadIdx.x % 8};
    uint32_t reg;

    copy(row_id * smem_columns + smem, reg);
    const auto parts{reinterpret_cast<half_t *>(&reg)};

    const float val_1{parts[0]};
    const float val_2{parts[1]};

    printf("Thread %d, row %d : Values -> (%f, %f) \n", threadIdx.x, row_id, val_1, val_2);
}

int main() {
    ldmatrix_x1_ptx<<<1, 32>>>();
    cudaDeviceSynchronize();
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        std::cerr << "CUDA Error after " << ": " << cudaGetErrorString(err) << std::endl;
    }

    return 0;
}

Here's the output

Thread 0, row 0 : Values -> (0.000000, 1.000000)
Thread 1, row 1 : Values -> (2.000000, 3.000000)
Thread 2, row 2 : Values -> (4.000000, 5.000000)
Thread 3, row 3 : Values -> (6.000000, 7.000000)
Thread 4, row 4 : Values -> (64.000000, 65.000000)
Thread 5, row 5 : Values -> (66.000000, 67.000000)
Thread 6, row 6 : Values -> (68.000000, 69.000000)
Thread 7, row 7 : Values -> (70.000000, 71.000000)
Thread 8, row 0 : Values -> (128.000000, 129.000000)
Thread 9, row 1 : Values -> (130.000000, 131.000000)
Thread 10, row 2 : Values -> (132.000000, 133.000000)
Thread 11, row 3 : Values -> (134.000000, 135.000000)
Thread 12, row 4 : Values -> (192.000000, 193.000000)
Thread 13, row 5 : Values -> (194.000000, 195.000000)
Thread 14, row 6 : Values -> (196.000000, 197.000000)
Thread 15, row 7 : Values -> (198.000000, 199.000000)
Thread 16, row 0 : Values -> (256.000000, 257.000000)
Thread 17, row 1 : Values -> (258.000000, 259.000000)
Thread 18, row 2 : Values -> (260.000000, 261.000000)
Thread 19, row 3 : Values -> (262.000000, 263.000000)
Thread 20, row 4 : Values -> (320.000000, 321.000000)
Thread 21, row 5 : Values -> (322.000000, 323.000000)
Thread 22, row 6 : Values -> (324.000000, 325.000000)
Thread 23, row 7 : Values -> (326.000000, 327.000000)
Thread 24, row 0 : Values -> (384.000000, 385.000000)
Thread 25, row 1 : Values -> (386.000000, 387.000000)
Thread 26, row 2 : Values -> (388.000000, 389.000000)
Thread 27, row 3 : Values -> (390.000000, 391.000000)
Thread 28, row 4 : Values -> (448.000000, 449.000000)
Thread 29, row 5 : Values -> (450.000000, 451.000000)
Thread 30, row 6 : Values -> (452.000000, 453.000000)
Thread 31, row 7 : Values -> (454.000000, 455.000000)

You can see that in our 64x64 block, ldmatrix loads the top left 8x8 tile.

Using ldmatrix in Cutlass

You can also use ldmatrix in Cutlass. Cutlass simplifies alot of the lower level aspects for you so you only need to worry about implementation. Since you would only use ldmatrix in conjunction with an MMA operation, cutlass requires the same.

Demo Code

Heres the full cutlass code

template<size_t smem_rows = 64, size_t smem_columns = 64>
__global__ void ldmatrix_x1_cutlass() {
    constexpr size_t total_el{smem_rows * smem_columns};
    __shared__ half_t smem[total_el];

    if (thread0()) {
        // fill shared memory with dummy data
        for (size_t i{0}; i < total_el; ++i) {
            smem[i] = half_t::convert(static_cast<int>(i));
        }
    }

    __syncthreads();

    // convert to CuTe Tensor
    Tensor shared{
        make_tensor(
            make_smem_ptr(smem),
            make_layout(
                make_shape(Int<smem_rows>{}, Int<smem_columns>{}),
                LayoutRight{}
            )
        )
    };

    // define the MMA operation
    constexpr TiledMMA mma = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                            Layout<Shape<_1, _1, _1> >{},
                                            Tile<_16, _8, _8>{});

    // define the copy operation equivalent too
    // ldmatrix.sync.aligned.x1.m8n8.shared.b16
    const Copy_Atom<SM75_U32x1_LDSM_N, half_t> shared_to_register_atom_b;

    // we tile shared memory to 8x8 chunks and pick the top left tile
    auto cta_coord = make_coord(_0{}, _0{});
    Tensor shared_tile{local_tile(shared, make_tile(_8{}, _8{}), cta_coord)};

    // here we initialize the registers for storing the matrix fragments
    ThrMMA thr_mma = mma.get_slice(threadIdx.x);
    auto tCrB = thr_mma.partition_fragment_B(shared_tile);

    // here we segment out the source data needed for the copy
    const TiledCopy copy_b{make_tiled_copy_B(shared_to_register_atom_b, mma)};
    auto s2r_thr_copy_b = copy_b.get_slice(threadIdx.x);
    Tensor tBsB = s2r_thr_copy_b.partition_S(shared_tile);
    auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB);

    // perform the copy
    copy(copy_b, tBsB, tCrB_view);

    printf(
        "Thread %d : Values -> (%f, %f) \n",
        threadIdx.x,
        static_cast<float>(tCrB_view(0)),
        static_cast<float>(tCrB_view(1))
    );
}

inline void runner() {
    ldmatrix_x1_cutlass<<<1, 32>>>();
    cudaDeviceSynchronize();
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        std::cerr << "CUDA Error after " << ": " << cudaGetErrorString(err) << std::endl;
    }
}

Now Let's break it down

Tensor shared{
    make_tensor(
        make_smem_ptr(smem),
        make_layout(
            make_shape(Int<smem_rows>{}, Int<smem_columns>{}),
            LayoutRight{}
        )
    )
};

This constructs a CuTe Tensor over shared memory. The Tensor is a pointer with a layout. The layout consists of Shape and Stride. In this case the shape is 64 x 64 and the stride is LayoutRight which means row major format.

constexpr TiledMMA mma = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                        Layout<Shape<_1, _1, _1> >{},
                                        Tile<_16, _8, _8>{});

This constructs a Tiled MMA. A Tiled MMA is an MMA operation that is repeated by adding more independent MMA ops, or / and by repeating the same MMA operation.

The first argument specifies which MMA operation we are using (we discussed this in the previous section). The second argument details how many Tensor Cores you want in the M, N, and K dimension.

  • The A matrix is MxK
  • The B matrix is KxN
  • The C matrix is MxN

If we increased any of those numbers you would see more independent MMA ops being used for that particular matrix. In this case we specify 1 for each dimension so only 1 MMA op is used.

The third argument dictates the final tile shape. Which is how many times we want the same mma(s) to repeat their instruction over a dimension.

In this case we don't tile at all. This can be seen since the natural MMA shape is 16x8x8, and we pass in identical values to the third argument.

You may be wondering when tiling would be useful. Well, lets say we wanted to utilize the ldmatrix.x4 instruction. With the current Tiled MMA settings we wouldn't be able too. This is because the b matrix which is of size 8x8 would need to increase by 4 times to exploit the instruction. To solve this problem we could simply tile in the N and K dimension making the third argument Tile<_16, _16, _16>{}. Now the b matrix would be of size 16x16 which is exactly 4 times bigger.

const Copy_Atom<SM75_U32x1_LDSM_N, half_t> shared_to_register_atom_b;

Here we select our copy atom, underneath it uses the same ptx instruction as the previous section. This atom is only big enough to load the b matrix (x1 can only load 8x8) so that's what we use it for.

auto cta_coord = make_coord(_0{}, _0{});
Tensor shared_tile{local_tile(shared, make_tile(_8{}, _8{}), cta_coord)};

Here we extract a tile of our shared memory, we split it up into 8x8 blocks, and extract the top left corner tile.

ThrMMA thr_mma = mma.get_slice(threadIdx.x);
auto tCrB = thr_mma.partition_fragment_B(shared_tile);

Here we initialize a manager struct (ThrMMA) which holds all the MMA data this thread is responsible for. In the next line we request the manager to initialize the registers needed to perform the MMA instruction. These same registers will be filled by the ldmatrix instruction.

// here we segment out the source data needed for the copy
const TiledCopy copy_b{make_tiled_copy_B(shared_to_register_atom_b, mma)};
auto s2r_thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBsB = s2r_thr_copy_b.partition_S(shared_tile);
auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB);

Now that we have the registers ready to be loaded, we actually need to prepare the shared memory, for transfer.

We start off by creating the TiledCopy object, this gives us some really useful helper functions and ensures the mma and register copy instruction are compatible. Next we create another thread manager, and we request the manager to partition the section of shared memory that this thread is responsible for. This returns a view on that section of data. Finally, we take the registers we initialized previously and ensure that it's underlying layout is compatible for a copy. This is required since certain aspects of the layout such as its stride may differ slightly from what ldmatrix is expecting.

Finally, we see that we get the same result

Thread 0 : Values -> (0.000000, 1.000000)
Thread 1 : Values -> (2.000000, 3.000000)
Thread 2 : Values -> (4.000000, 5.000000)
Thread 3 : Values -> (6.000000, 7.000000)
Thread 4 : Values -> (64.000000, 65.000000)
Thread 5 : Values -> (66.000000, 67.000000)
Thread 6 : Values -> (68.000000, 69.000000)
Thread 7 : Values -> (70.000000, 71.000000)
Thread 8 : Values -> (128.000000, 129.000000)
Thread 9 : Values -> (130.000000, 131.000000)
Thread 10 : Values -> (132.000000, 133.000000)
Thread 11 : Values -> (134.000000, 135.000000)
Thread 12 : Values -> (192.000000, 193.000000)
Thread 13 : Values -> (194.000000, 195.000000)
Thread 14 : Values -> (196.000000, 197.000000)
Thread 15 : Values -> (198.000000, 199.000000)
Thread 16 : Values -> (256.000000, 257.000000)
Thread 17 : Values -> (258.000000, 259.000000)
Thread 18 : Values -> (260.000000, 261.000000)
Thread 19 : Values -> (262.000000, 263.000000)
Thread 20 : Values -> (320.000000, 321.000000)
Thread 21 : Values -> (322.000000, 323.000000)
Thread 22 : Values -> (324.000000, 325.000000)
Thread 23 : Values -> (326.000000, 327.000000)
Thread 24 : Values -> (384.000000, 385.000000)
Thread 25 : Values -> (386.000000, 387.000000)
Thread 26 : Values -> (388.000000, 389.000000)
Thread 27 : Values -> (390.000000, 391.000000)
Thread 28 : Values -> (448.000000, 449.000000)
Thread 29 : Values -> (450.000000, 451.000000)
Thread 30 : Values -> (452.000000, 453.000000)
Thread 31 : Values -> (454.000000, 455.000000)