Preface
All code discussed in this article can be found in the project repository KeqiYe/LeetGPU.
If vector addition is the most natural first CUDA exercise, then matrix multiplication is almost certainly the second. Vector addition teaches us how to scale a scalar operation across a large number of threads. Matrix multiplication, however, forces us to confront the deeper questions in GPU programming: how threads should be mapped to data, why memory access patterns directly shape throughput, what shared memory is really solving, and how to keep pushing performance once a kernel is already correct.
Matrix multiplication is also a classic for a more practical reason: it is not just a tutorial toy. Many deep learning operators and linear algebra routines can ultimately be traced back to GEMM, or General Matrix Multiplication. That is why the optimization ideas here have strong transfer value. The coalesced access patterns, tiling strategy, and register blocking techniques discussed in this article will show up again in convolutions, attention kernels, tensor transforms, and many other CUDA workloads that may look unrelated at first glance.
This article is based on my current solution for LeetGPU Problem 2. The discussion follows a clear progression:
- Start from the most naive version and establish a correct but slow baseline.
- Fix the thread mapping so that warp memory access becomes more reasonable.
- Introduce shared-memory tiling to replace repeated global loads with block-level reuse.
- Push further with 1D and 2D register blocking to increase per-thread compute density.
To make the discussion complete, I include two sets of benchmarks: a regular-size case 1024 x 512 x 1024, and a non-multiple-of-32 case 1001 x 513 x 777. The first is useful for observing raw performance trends, while the second is better for verifying that tail handling and boundary correctness are truly in good shape.
Problem Description
The task is straightforward: given two matrices
Awith shapeM x KBwith shapeK x N
compute their product
$$ C = A \times B $$From the element-wise point of view, every value C[row][col] in the output matrix is a dot product of length K:
So at its core, this problem asks: how do we let thousands of threads cooperate on these dot products while minimizing wasteful memory traffic and maximizing throughput?
Input / Output
- Input: device matrices
A,B, and matrix sizesM,N,K - Output: device matrix
C
Test Environment and Timing Method
The performance numbers in this article come from actual runs on a real machine:
- GPU:
NVIDIA GeForce RTX 4090 - CUDA Toolkit:
12.6 - Driver:
575.64.03 - Compile command:
/usr/local/cuda-12.6/bin/nvcc mm.cu -lcublas -O3 -o mm
The benchmark logic is also fairly standard:
- Compute a reference result
h_refon the CPU. - Warm up each version once so first-launch overhead does not pollute the final measurement.
- Run each kernel 10 times and take the average.
- Copy the GPU result back and compare it element by element against the CPU reference.
In other words, each version discussed below is expected to satisfy both of these conditions:
- The result is correct
- The performance actually improves
v0: The Most Naive Implementation
Here is the first version:
| |
Logically, this version is very easy to understand. Each thread computes one output element, identifies its (row, col) coordinate, and performs a plain reduction over the K dimension. That simplicity is a real advantage: even without much CUDA background, it is easy to read and reason about. As a first step toward a working solution, this is a perfectly natural place to begin.
But once we look at it from the GPU’s perspective, the weakness becomes obvious. Here threadIdx.x is mapped to row, while threadIdx.y is mapped to col. That means threads within a warp do not naturally march across a row of the output matrix C; instead, they are distributed in a way that is much less friendly to memory layout. The direct consequence is that reads from matrix B are poorly aligned for coalescing.
More specifically, the inner loop repeatedly accesses
| |
If col is not changing in a contiguous way across the warp, those loads become fragmented. For a kernel like matrix multiplication, which is already extremely sensitive to bandwidth and memory access behavior, this immediately suppresses the performance ceiling.
v0 Performance
For the regular-size case 1024 x 512 x 1024, v0 delivers:
- Average time:
1.744 ms - Throughput:
615.76 GFLOPS - Relative to cuBLAS:
2.32%
For the irregular-size case 1001 x 513 x 777, v0 delivers:
- Average time:
0.406 ms - Throughput:
1964.43 GFLOPS - Relative to cuBLAS:
12.15%
Its performance is clearly poor, which tells us that parallelizing the math alone is not enough. Memory access behavior by itself is already enough to drag throughput down dramatically.
v1: Fix the Thread Mapping so Warps Expand Along Rows
The second version looks like this:
| |
Compared with v0, this change looks almost trivial. We do not add shared memory, we do not change the amount of arithmetic, and we do not change the block size. We simply swap how row and col are mapped.
And yet this tiny adjustment matters a lot. Threads within the same warp now expand much more naturally along a row of the output matrix. As a result:
- Reads of
B[i * N + col]are much more likely to be contiguous. - Writes to
C[row * N + col]also become contiguous.
This is one of the most important lessons in CUDA optimization. Sometimes the biggest early win is not a fancy new memory level or instruction primitive; it is simply making sure that thread layout follows memory layout. You could say that v1 is significant not because the kernel becomes more complicated, but because the thread organization finally starts working with the hardware instead of against it.
v1 Performance
For 1024 x 512 x 1024:
- Average time:
0.216 ms - Throughput:
4973.53 GFLOPS - Relative to cuBLAS:
18.76%
For 1001 x 513 x 777:
- Average time:
0.191 ms - Throughput:
4186.53 GFLOPS - Relative to cuBLAS:
25.88%
What makes this step so striking is that it adds almost no implementation complexity, yet improves performance by roughly an order of magnitude. For CUDA beginners, that is a valuable reminder: before reaching for shared memory, tensor cores, or warp-level tricks, it is often worth checking whether the thread mapping itself is already working in the right direction.
v2: Shared-Memory Tiling Turns Repeated Loads into Block-Level Reuse
The third version enters one of the most classic steps in GEMM optimization: shared-memory tiling.
Here is the full kernel:
| |
The core idea here is to replace the pattern of “every multiply-add fetches fresh data from global memory” with “the block cooperatively loads a tile once, then reuses it many times.” This works because matrix multiplication naturally has data reuse. Within one output tile, a slice of A and a slice of B are used by many threads in the block. If every thread independently fetches the same values from global memory, the waste is enormous.
The structure of v2 is the standard one:
- Each block corresponds to one tile of the output matrix.
- In each phase
ph, the block loads one sub-tile ofAand one sub-tile ofBinto shared memory. __syncthreads()ensures that every thread sees the complete tile.- The multiply-add work for that phase is then performed inside shared memory.
So the main point of v2 is not to reduce FLOPs, but to improve data reuse and reduce pressure on global memory bandwidth. There is also an important engineering detail here: when loading tiles from A and B, the kernel explicitly checks boundaries and pads out-of-range elements with zeros. That prevents garbage values from entering shared memory when matrix dimensions are not divisible by the tile size.
v2 Performance
For 1024 x 512 x 1024:
- Average time:
0.174 ms - Throughput:
6156.44 GFLOPS - Relative to cuBLAS:
23.23%
For 1001 x 513 x 777:
- Average time:
0.160 ms - Throughput:
4985.70 GFLOPS - Relative to cuBLAS:
30.83%
Compared with v1, the gain here is no longer as dramatic as the jump from v0 to v1, but its significance is just as real. From this point on, we are no longer merely fixing thread layout; we are starting to redesign the dataflow in a way that matches GPU architecture much more closely.
v3: 1D Register Blocking Lets One Thread Compute More Outputs
Shared-memory tiling solves block-level data reuse, but it still assumes each thread is responsible for only a small number of outputs. The next natural step is to increase the amount of work done by each thread, so that data already loaded into registers can be reused more effectively.
That is the main idea behind v3. Here is the full code:
| |
The key change here is that each thread is now responsible for TN = 4 output elements along the column direction. The immediate payoff is that once a thread loads a_frag from shared memory into a register, it can multiply that value against four different b_frag values right away. In other words, the same A element becomes more valuable while it resides in registers.
Intuitively, v2 is mostly about helping the block share data efficiently, while v3 starts going one level deeper and helps each thread consume that shared data more efficiently.
Another important point is the boundary condition used in the write-back phase:
| |
This matters because when the matrix width N is not a multiple of TN, the last thread group may be responsible for four output positions of which only one, two, or three are valid. If we only checked global_col < N, it would be easy to write past the valid range or store incorrect tail values.
v3 Performance
For 1024 x 512 x 1024:
- Average time:
0.098 ms - Throughput:
11001.81 GFLOPS - Relative to cuBLAS:
41.51%
For 1001 x 513 x 777:
- Average time:
0.097 ms - Throughput:
8236.99 GFLOPS - Relative to cuBLAS:
50.93%
From the benchmark results, v3 is another very visible step upward. At this point, we are no longer just making memory access sane; we are also raising arithmetic intensity at the thread level in a meaningful way.
v4: 2D Register Blocking Pushes Per-Thread Compute Density Even Further
If v3 means “one thread computes several outputs in one direction,” then v4 goes one step further and gives each thread a full TM x TN local tile.
Here is the complete kernel:
| |
In this version, a thread no longer owns just a short strip of outputs. Instead, it owns a small TM x TN = 4 x 4 rectangle. The benefit is that values loaded from shared memory into a_frag and b_frag can be reused more times through register-level cross-combination. As a result:
- Data reuse improves further.
- Per-thread compute density improves further.
- The kernel structure starts to resemble the typical shape of a high-performance hand-written GEMM kernel.
Compared with v3, v4 is not merely computing more outputs. It is effectively performing a tiny matrix multiplication inside each thread. This two-dimensional blocking idea is common in many high-performance GEMM implementations.
In the write-back stage, each (i, j) position computes its own c_row and c_col, then applies a separate boundary check. That makes tail handling fairly natural without requiring special-case hacks.
v4 Performance
For 1024 x 512 x 1024:
- Average time:
0.058 ms - Throughput:
18368.88 GFLOPS - Relative to cuBLAS:
69.30%
For 1001 x 513 x 777:
- Average time:
0.058 ms - Throughput:
13848.77 GFLOPS - Relative to cuBLAS:
85.62%
From a pure performance point of view, this is the strongest custom kernel in the current implementation. It still trails cuBLAS, of course, but for a teaching-oriented hand-written solution, this level of performance is already a strong sign that the optimization path is working very well.
cuBLAS Reference Implementation
Here is the cuBLAS call:
| |
The parameter order may look reversed at first glance. The reason is that cuBLAS assumes column-major storage by default, while the matrices in this project are stored in row-major order. Instead of explicitly transposing the inputs before the call, the code uses a standard trick:
$$ (AB)^T = B^T A^T $$This maps the row-major computation A * B into the column-major computation B^T * A^T. Even though cuBLAS interprets the buffers in column-major form internally, the returned result matches the row-major output matrix C that we want.
Benchmark Summary
To keep the performance discussion in one place, here are the two benchmark tables together.
Regular Size: 1024 x 512 x 1024
| Version | Status | Avg Time (ms) | Performance (GFLOPS) | Relative to cuBLAS |
|---|---|---|---|---|
| v0 | PASS | 1.744 | 615.76 | 2.32% |
| v1 | PASS | 0.216 | 4973.53 | 18.76% |
| v2 | PASS | 0.174 | 6156.44 | 23.23% |
| v3 | PASS | 0.098 | 11001.81 | 41.51% |
| v4 | PASS | 0.058 | 18368.88 | 69.30% |
| cuBLAS | PASS | 0.041 | 26506.38 | 100% |
Irregular Size: 1001 x 513 x 777
| Version | Status | Avg Time (ms) | Performance (GFLOPS) | Relative to cuBLAS |
|---|---|---|---|---|
| v0 | PASS | 0.406 | 1964.43 | 12.15% |
| v1 | PASS | 0.191 | 4186.53 | 25.88% |
| v2 | PASS | 0.160 | 4985.70 | 30.83% |
| v3 | PASS | 0.097 | 8236.99 | 50.93% |
| v4 | PASS | 0.058 | 13848.77 | 85.62% |
| cuBLAS | PASS | 0.049 | 16174.26 | 100% |
What This Problem Is Really Teaching
If we only read this problem as “how to make matrix multiplication faster,” we miss part of its value. To me, its real strength is that it makes several core layers of CUDA optimization visible in a very concrete way.
Thread mapping is foundational.
Many beginners focus first on advanced-sounding ideas like shared memory, tensor cores, and register blocking. But if warp access direction is wrong from the start, everything built on top of it rests on a crooked foundation. The jump from v0 to v1 is a perfect demonstration: changing only the mapping can drastically change performance.
Shared memory is about reuse, not prestige.
Shared memory is worth introducing only when a block of data is reused by many threads within the block. Matrix multiplication naturally satisfies this condition, which is why tiling is so effective here.
Register blocking is really about raising per-thread compute density.
Whether in the 1D form of v3 or the 2D form of v4, the goal is the same: once data has been loaded into registers, make it participate in as many multiply-add operations as possible. That mindset is central to many high-performance kernels.
Boundary correctness must be tested explicitly.
If we only benchmark 1024 x 512 x 1024, it is easy to believe a kernel is fully correct. But once we switch to a size like 1001 x 513 x 777, many hidden bugs become much easier to expose.