Aman's blog

A gentle introduction to GEMM using MMA tensor cores

There are a lot of resources online about writing a fast GEMM, and they all get complicated really fast, and by the time you reach the tensor core section you need to keep a lot in context to understand.

This post tries to go in reverse, using a tensor core on the smallest possible tile and build up from there. It's aimed at software developers who are interested in learning about tensor cores but haven't grokked the intricacies of a performant GEMM yet.

Background

Tensor cores are dedicated pieces of hardware on a GPU that are part of a Shared Multiprocessor (SM). They are blazingly fast, and keeping their throughput high is a big challenge.

The only prior knowledge this post assumes is about the CUDA programming model, and the matrix-multiplication operation.

For more background on tensor cores on CUDA, I refer you to this excellent article.

All the code used in this article is available here

MMA Cores

MMA instructions work at a warp-level, i.e. 32 threads in a warp conspire together to compute one small GEMM. In NVIDIA consumer GPUs, these are fastest options available, while for data-center cards multiple warps can be used to execute a bigger GEMM.

Note: There is also a higher level wmma instruction set which has a simpler API. However, it is not as feature dense (for example, it does not give you a well-defined memory layout) instead we focus on mma.

For this post we will choose the following instruction as our base:

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32

Let's break it down. Just for reference, here is the GEMM operation:

D=AB+C

m16n8k16 - m is 16, n is 8 and k is 16. So the matrices A and B are 16 x 16 and 16 x 8

row - matrix A is in row-major format

col - matrix B is in col-major format

f32 - D's type

f16 - A's type

f16 - B's type

f32 - C's type

The official documentation is where we will get all our information about this instruction.

For the A matrix, this is how the threads should load their data in fragments:

Pasted image 20251001140853

For example thread 1 in the diagram needs to hold:

Row 0 Col 2, Col 3: {a0, a1}

Row 8 Col 2, Col 3: {a2, a3}

Row 0 Col 10, 11: {a4, a5}

Row 8 Col 10, 11: {a6, a7}

Note that {a0, a1} means it needs to be a single 32 bit value, in our case it is 2 16-bit floats packed together.

The documentation also provides how to map each thread's loaded values:

groupID           = %laneid >> 2
threadID_in_group = %laneid % 4

row =      groupID            for ai where  0 <= i < 2 || 4 <= i < 6
          groupID + 8         Otherwise

col =  (threadID_in_group * 2) + (i & 0x1)          for ai where i <  4
(threadID_in_group * 2) + (i & 0x1) + 8      for ai where i >= 4

Let's create a helper function for loading the A fragment and add helper functions get_row and get_col given a laneid and an index. You can verify the get_row and get_col is the same as the one given in the PTX documentation, just without branches

template <typename T> struct Afrag_16x16 {
  static constexpr size_t ne = 8; // num of elements per thread

  T x[ne];

  static __device__ size_t get_row(int tid, int l) {
    int group_id = tid >> 2;
    return group_id + 8 * ((l / 2) % 2);
  }

  static __device__ size_t get_col(int tid, int l) {
    return 2 * (tid % 4) + (l % 2) + 8 * (l / 4);
  }
};

Similarly for B (16 x 8 col-major matrix), we have

Pasted image 20251001152158

We'll create another helper function:

template <typename T> struct Bfrag_16x8 {
  static constexpr size_t ne = 4;
  T x[ne] = {};
  static __device__ size_t get_row(int tid, int l) {
    return (tid % 4) * 2 + (l % 2) + 8 * (l / 2);
  }

  static __device__ size_t get_col(int tid, int l) { return tid >> 2; }
};

And finally for C Pasted image 20251001152323

template <typename T> struct CFrag_16x8 {
  static constexpr size_t ne = 4;
  T x[ne] = {};

  static __device__ size_t get_row(int tid, int l) {
    return (tid >> 2) + 8 * (l / 2);
  }

  static __device__ size_t get_col(int tid, int l) {
    assert(l < ne);
    return 2 * (tid % 4) + (l % 2);
  }
};

Notice that thread-mapping for B is the transpose for C

Microkernel

At this point, we are ready to call this instruction!

Our simple kernel looks like this. We can see the compiled PTX on godbolt

__global__ void mmaKernel(const half *A, const half *B, float *C, int M, int N,
                          int K) {
  //Each thread has a copy of this tile
  Afrag_16x16<half> a_tile;
  Bfrag_16x8<half> b_tile;
  CFrag_16x8<float> c_tile;

  const int tid = threadIdx.x;

  //Load A & B fragments
  for (int idx = 0; idx < a_tile.ne; ++idx) {
    a_tile.x[idx] = A[a_tile.get_row(tid, idx) * K + a_tile.get_col(tid, idx)];
  }

  for (int idx = 0; idx < b_tile.ne; ++idx) {
    b_tile.x[idx] = B[b_tile.get_row(tid, idx) * N + b_tile.get_col(tid, idx)];
  }

  // mma instruction expects 32-bit int registers, we cast
  const int *a_regs = (const int *)a_tile.x;
  const int *b_regs = (const int *)b_tile.x;

  asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
               "{%0, %1, %2, %3}, "
               "{%4, %5, %6, %7}, "
               "{%8, %9}, "
               "{%0, %1, %2, %3};\n"
               : "+f"(c_tile.x[0]), "+f"(c_tile.x[1]), "+f"(c_tile.x[2]),
                 "+f"(c_tile.x[3])
               : "r"(a_regs[0]), "r"(a_regs[1]), "r"(a_regs[2]), "r"(a_regs[3]),
                 "r"(b_regs[0]), "r"(b_regs[1]));


  //3. Write back to C
  for (int i = 0; i < c_tile.ne; ++i) {
    int row = c_tile.get_row(tid, i);
    int col = c_tile.get_col(tid, i);

    C[row * N + col] = c_tile.x[i];
  }
}

This kernel works, but only for M=16, N=8, K=16. It was cumbersome to get the right layout though, and our data loads are scattered which is not good for performance. This is where ldmatrix comes in.

Using ldmatrix

ldmatrix is an instruction which simplifies loading matrix fragments into registers from shared memory, while being faster.

We can think of ldmatrix as a function from raw row-major matrices living in shared memory to the matrix fragments living in registers mma expects.

The ldmatrix instruction will only load matrix fragments from shared memory. For our use case of 16-bits inputs, ldmatrix loads matrices in a 8x8 tile.

Let's dissect the instruction we're going to be using for A

ldmatrix.sync.aligned.m8n8.x4.shared.b16

ldmatrix.sync.aligned - mandatory arguments

m8n8 - only available configuration for 16-bit input

x4 - we can either do 1, 2, 4 tile loads. For 16x16 we need to use 4

shared - load from shared memory

b16 - the output format, also 16 bit for us

This is how the tiles look in memory (each is of size 8x8 16-bit elements) shapes at 25-10-02 15

The way for a thread to get its starting row/col is then:

row = threadIdx.x%16

col = (threadIdx/16)*8

For B we can reload x1 and x2 at x3 and x4 respectively, as it allowed in the documentation.

addresses contained in lower threads can be copied to higher threads to achieve the expected behavior.

So for B, we use row = threadIdx.x % 16

col = 0

We also use ldmatrix.sync.aligned.m8n8.x2.shared.trans.b16 as we want to load B in a column-major format.

Micro-kernel with ldmatrix

Our new and improved kernel, which has better memory-coalescing looks like this. Here is the godbolt link

__global__ void mmaKernel(const half *A, const half *B, float *C, int M, int N,
                          int K) {
  Afrag_16x16<half> a_tile;
  Bfrag_16x8<half> b_tile;
  CFrag_16x8<float> c_tile;

  const int tid = threadIdx.x;

  __shared__ alignas(16) half A_shared[16][16];
  __shared__ alignas(16) half B_shared[16][8];

  const int lane = tid & 31;

  // Use 16 threads to load shared mem.
  if (lane < 16) {
    int row = lane;

    for(int idx = 0; idx < 8; ++idx) {
        A_shared[row][idx] = A[row*K + idx];
    }

    for(int idx = 0; idx < 8; ++idx) {
        A_shared[row][idx + 8] = A[row*K + 8 + idx];
    }

    for(int idx = 0 ; idx < 8; ++idx) {
        B_shared[row][idx] = B[row*N + idx];
    }
  }

  int *a_regs = (int *)a_tile.x;
  int *b_regs = (int *)b_tile.x;

  int lane_id = tid;
  uint32_t a_addr = __cvta_generic_to_shared(
      &A_shared[(lane_id % 16)][(lane_id / 16) * 8]);
  uint32_t b_addr = __cvta_generic_to_shared(
      &B_shared[(lane_id % 16)]);

  asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
               "{%0, %1, %2, %3}, [%4];"
               : "=r"(a_regs[0]), "=r"(a_regs[1]), "=r"(a_regs[2]),
                 "=r"(a_regs[3])
               : "r"(a_addr));

  asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.trans.b16 "
               "{%0, %1}, [%2];"
               : "=r"(b_regs[0]), "=r"(b_regs[1])
               : "r"(b_addr));

  asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
               "{%0, %1, %2, %3}, "
               "{%4, %5, %6, %7}, "
               "{%8, %9}, "
               "{%0, %1, %2, %3};\n"
               : "+f"(c_tile.x[0]), "+f"(c_tile.x[1]), "+f"(c_tile.x[2]),
                 "+f"(c_tile.x[3])
               : "r"(a_regs[0]), "r"(a_regs[1]), "r"(a_regs[2]), "r"(a_regs[3]),
                 "r"(b_regs[0]), "r"(b_regs[1]));

  for (int i = 0; i < c_tile.ne; ++i) {
    int row = c_tile.get_row(tid, i);
    int col = c_tile.get_col(tid, i);

    C[row * N + col] = c_tile.x[i];
  }
}

A couple of things to note:

  1. Why alignas(16) on shared memory - ldmatrix mandates 16-byte aligned rows. This is because a group of 4 threads load 16 bytes (1 row in our case 2 bytes of f16 x 8) at once.
  2. __cvta_generic_to_shared - this is just a way to convert a raw pointer to a pointer to shared memory
  3. All optimizations like avoiding bank conflicts, vectorized gmem/smem loads are left for later
  4. We don't need a __syncthreads() because the block is a single warp

There is also stmatrix, which goes the opposite way as ldmatrix and would be useful for storing C, but that is only available post sm_90(Hopper and beyond) and I don't have a GPU handy for that.

Full GEMM

Now that we our tile, we still don't know how to calculate the full GEMM. However, we can easily solve for any M (divisible by 16) and N (divisible by 8) as long as we keep K = 16. We can just launch more blocks, each calculating a tile. This is easier because we don't have to accumulate partial sums for K.

This post wouldn't be complete without a drawing of a matrix, so here it is:

Untitled 1

As is hopefully clear from the diagram, we need to make the following changes:

  1. Launch a 2-d grid of blocks (instead of 1), N/8 in the x-dimension and M/16 in the y-dimension
  2. Move A and B matrices to the correct offset
  int c_row = blockIdx.y * 16;
  int c_col = blockIdx.x * 8;

  A += c_row * K;
  B += c_col;
  1. Use c_row and c_col to move C to the correct position
 for (int i = 0; i < c_tile.ne; ++i) {
    int row = c_row + c_tile.get_row(tid, i);
    int col = c_col + c_tile.get_col(tid, i);

    C[row * N + col] = c_tile.x[i];
  }

The next step is to handle the case where K is a multiple of 16, for this we break K down into K/16 pieces, and accumulate the results. Mathematically we are doing the following:

C[i,j]=t=0t=K/161k=16*t16*t+15A[i,k]*B[k,j]

Here's a diagram

Untitled (1)

And the relevant code block

for(int k_idx = 0; k_idx < K; k_idx += 16) {

    if (lane < 16) {
      int row = lane;

      for(int idx = 0; idx < 8; ++idx) {
          A_shared[row][idx] = A[row*K + idx];
      }

      for(int idx = 0; idx < 8; ++idx) {
          A_shared[row][idx + 8] = A[row*K + 8 + idx];
      }

      for(int idx = 0 ; idx < 8; ++idx) {
          B_shared[row][idx] = B[row*N + idx];
      }
    }

    A += 16; //move A 16 columns ahead
    B += 16*N; // move B 16 rows ahead

    //same as before. c_tile keeps accumulating partial sums
}

for (int i = 0; i < c_tile.ne; ++i) {
   int row = c_row + c_tile.get_row(tid, i);
   int col = c_col + c_tile.get_col(tid, i);

   C[row * N + col] = c_tile.x[i];
}

There you have it, a completely functioning (albeit slow) GEMM using tensor cores! Here is the godbolt link.

Let's benchmark this kernel against gold-standard cuBLAS. On my 3090,

=== GEMM Benchmark (M=1024, N=1024, K=32) ===
FLOPs: 0.067 GFLOPs
mmaKernel:  1.907 ms  |  35.20 GFLOPs
cuBLAS:     0.286 ms  |  234.90 GFLOPs
Speedup (cuBLAS / mma): 6.67x

So we are embarrassingly slow even for small K, which will only get worse as we increase our K.

Optimizing

Most optimizations from here follow the basic principle: data-movement is slow, and compute is fast. We need to exploit the memory hierarchy via data-reuse/data-locality/asynchronous data transfers.

There are many, excellent, articles which go into this topic in detail. Hopefully this post demystifies some of the PTX MMA semantics.

Comments