写在前面
本文涉及的全部代码都可以在项目仓库 KeqiYe/LeetGPU 中找到。
如果说向量加法是 CUDA 入门阶段最合适的第一题,那么矩阵乘法几乎就是顺理成章的第二题。前者让我们学会“如何把一个标量运算扩展到海量线程上”,后者则逼着我们真正开始面对 GPU 编程里更本质的问题:线程应该如何映射到数据、访存模式为什么会直接决定吞吐、共享内存到底是在解决什么问题,以及当一个 kernel 已经“算得对”之后,如何继续把它一点点推向更高的性能。
矩阵乘法之所以经典,还有一个更现实的原因:它不是一道只存在于教程里的练习题。很多深度学习算子、线性代数库,最终都能追溯到 GEMM(General Matrix Multiplication)这样的核心计算模式。也正因为如此,矩阵乘法里的优化思路往往有很强的迁移价值。你今天在这里理解的访存合并、tiling、寄存器粗化,之后几乎一定会在卷积、注意力、张量变换,甚至很多看起来完全不同的 CUDA kernel 里再次遇到。
这篇文章基于我当前的 LeetGPU 第二题实现来写。文章会沿着下面这条主线展开:
- 从最朴素的实现出发,先建立一个“正确但很慢”的基准。
- 调整线程映射方式,让 warp 的访存方向开始变得合理。
- 引入共享内存分块,把“每次都去全局显存取数据”的模式改成“先搬一块,再反复复用”。
- 继续做 1D 和 2D 的寄存器粗化,提高单线程的计算密度。
为了让讨论更完整,文中会同时给出两组 benchmark:一组是规则尺寸 1024 x 512 x 1024,另一组是非 32 倍数尺寸 1001 x 513 x 777。前者便于观察纯粹的性能趋势,后者则更适合验证尾块处理和边界正确性是否已经真正做好。
题目描述
题目要求很直接:给定两个矩阵
A,形状为M x KB,形状为K x N
计算它们的乘积:
$$ C = A \times B $$从元素视角来看,输出矩阵 C 中的每一个元素 C[row][col],都需要做一遍长度为 K 的点积:
因此,这道题本质上是在问:如何让成千上万个线程一起去并行完成这些点积计算,并且在这个过程中尽量减少无效访存、提高吞吐。
输入输出
- 输入:设备端矩阵
A、B,以及矩阵尺寸M、N、K - 输出:设备端矩阵
C
测试环境与计时方式
本文中的性能数据不是凭感觉估计的,而是在实体机上真实运行得到。环境如下:
- GPU:
NVIDIA GeForce RTX 4090 - CUDA Toolkit:
12.6 - Driver:
575.64.03 - 编译命令:
/usr/local/cuda-12.6/bin/nvcc mm.cu -lcublas -O3 -o mm
程序的 benchmark 逻辑也比较规范,大致流程是:
- 在 CPU 上先计算一份参考结果
h_ref - 每个版本先 warmup 一次,避免把首次启动开销混进正式结果
- 再重复运行 10 次,取平均耗时
- 将 GPU 结果拷回主机,与 CPU 参考结果逐元素比较
换句话说,后面每个版本的讨论都不是停留在“理论上应该更快”,而是同时满足两件事:
- 结果正确
- 性能确实提升
v0:最朴素的实现
先看第一版代码:
| |
从逻辑上看,这个版本非常直观:每个线程负责输出矩阵中的一个元素,先定位 (row, col),然后在 K 维上做一次最朴素的累加。它有一个很大的优点,就是几乎不需要任何 CUDA 背景知识也能读懂。如果只是为了先把题做出来,这种写法是很自然的起点。
但只要稍微从 GPU 的视角看一眼,它的问题就会立刻暴露出来。这里把 threadIdx.x 映射到了 row,把 threadIdx.y 映射到了 col。这意味着同一个 warp 内的线程,并不是沿着输出矩阵 C 的一行横向展开,而更像是在列方向上分散。这样的直接后果是,访问矩阵 B 时,warp 内线程对 col 的分布不连续,全局显存读取很难形成理想的合并访存。
更具体一点,内层循环每次都会访问:
| |
如果同一个 warp 中的 col 不是连续递增的,那么这些读取就会显得零碎。对于矩阵乘法这种本来就非常吃带宽、吃访存模式的算子来说,这几乎等于一开始就把性能上限压得很低。
v0 的性能表现
在规则尺寸 1024 x 512 x 1024 下,v0 的表现是:
- 平均耗时:
1.744 ms - 吞吐:
615.76 GFLOPS - 相对 cuBLAS:
2.32%
在非整齐尺寸 1001 x 513 x 777 下,v0 的表现是:
- 平均耗时:
0.406 ms - 吞吐:
1964.43 GFLOPS - 相对 cuBLAS:
12.15%
它的性能明显偏低,说明仅仅把矩阵乘法并行化还远远不够,访存模式本身就足以把吞吐压到很低的水平。
v1:修正线程映射,让 warp 真正沿着行展开
第二版代码如下:
| |
和 v0 相比,这个版本做的事情几乎可以用“微不足道”来形容:它没有引入共享内存,没有改 block 规模,也没有减少任何运算量,只是把 row 和 col 的映射方式调换了一下。
但恰恰是这个小改动,让同一个 warp 中的线程更自然地在输出矩阵的一行上横向展开。于是:
- 对
B[i * N + col]的读取更容易形成连续访问 - 对
C[row * N + col]的写回也变成连续写回
这一步看似简单,实际上非常关键。因为矩阵乘法这种问题里,很多时候“线程到底在按行走还是按列走”,本身就是性能优化的一半。你甚至可以说,v1 的意义不在于它把 kernel 写得更复杂,而在于它第一次让线程布局真正顺着内存布局去走。
v1 的性能表现
在规则尺寸 1024 x 512 x 1024 下:
- 平均耗时:
0.216 ms - 吞吐:
4973.53 GFLOPS - 相对 cuBLAS:
18.76%
在非整齐尺寸 1001 x 513 x 777 下:
- 平均耗时:
0.191 ms - 吞吐:
4186.53 GFLOPS - 相对 cuBLAS:
25.88%
这一步最震撼的地方在于,它几乎没有增加实现复杂度,却把性能提升了一个数量级。这对 CUDA 初学者来说是一个非常重要的提醒:在还没有碰共享内存、张量核心、warp-level primitive 之前,先把线程映射方向摆正,往往比你想象中更重要。
v2:共享内存分块,把“重复读取”变成“块内复用”
第三个版本开始进入矩阵乘法优化里最经典的一步:shared memory tiling。
完整代码如下:
| |
这一步的本质,是把“每次乘加都直接去全局显存取一对新数据”的模式,改成“先让整个 block 协同搬一小块数据到共享内存,再反复使用”。这样做的原因并不神秘:矩阵乘法天生会复用输入数据。同一个输出 tile 的计算中,A 的一段行片段和 B 的一段列片段都会被 block 内的多个线程反复访问。如果每个线程都各自去全局显存抓同一份数据,那显然浪费巨大。
v2 的做法很典型:
- 每个 block 对应输出矩阵中的一个 tile。
- 每一轮
ph(phase)加载A和B的一个子块进入共享内存。 - 通过
__syncthreads()保证块内线程都看到了完整 tile。 - 再在共享内存上完成这一轮 tile 的乘加。
因此,v2 优化的核心不是减少 FLOPs,而是提高数据复用率,降低全局显存压力。这版代码还有一个非常重要的工程细节:它在加载 A 和 B tile 时都做了显式的越界判断,不足一个 tile 的部分直接补零。这样避免了当矩阵维度不能整除block的大小的时候,共享内存中可能会存在的垃圾数据。
v2 的性能表现
在规则尺寸 1024 x 512 x 1024 下:
- 平均耗时:
0.174 ms - 吞吐:
6156.44 GFLOPS - 相对 cuBLAS:
23.23%
在非整齐尺寸 1001 x 513 x 777 下:
- 平均耗时:
0.160 ms - 吞吐:
4985.70 GFLOPS - 相对 cuBLAS:
30.83%
和 v1 相比,这一版的提升已经没有 v0 -> v1 那么夸张,但它的意义并不因此降低。因为从这里开始,我们进入的已经不再是“把映射摆正”的阶段,而是开始真正用 GPU 体系结构的思维去改写数据流。
v3:1D 寄存器粗化,让一个线程一次算更多结果
共享内存 tiling 能解决块内数据复用的问题,但它仍然默认每个线程只负责很少的输出元素。下一步比较自然的优化,就是把一个线程的工作量再往上提一点,让它一次计算多个输出值,从而提高寄存器里数据的复用率。
这就是 v3 的思路。完整代码如下:
| |
这版最核心的变化,是让一个线程在列方向上同时负责 TN = 4 个输出元素。这样做的直接收益是:线程一旦把 a_frag 从共享内存读进寄存器,就可以立刻拿它和 4 个不同的 b_frag 做乘加。换句话说,同一个 A 元素的寄存器驻留价值被提高了。
从直觉上看,v2 更像是在优化“一个 block 怎么更高效地共享数据”,而 v3 开始进一步优化“一个线程怎么更高效地消费这些数据”。
这里还有一个非常关键的修正,就是写回阶段的边界判断:
| |
这行判断的意义在于,当矩阵宽度 N 不是 TN 的整数倍时,最后一个线程组负责的 4 个输出位置里,可能只有前 1 个、2 个或 3 个是真正有效的。如果仍然只判断 global_col < N,那么写尾部时就很容易越界或者把错误值写到不该写的位置。
v3 的性能表现
在规则尺寸 1024 x 512 x 1024 下:
- 平均耗时:
0.098 ms - 吞吐:
11001.81 GFLOPS - 相对 cuBLAS:
41.51%
在非整齐尺寸 1001 x 513 x 777 下:
- 平均耗时:
0.097 ms - 吞吐:
8236.99 GFLOPS - 相对 cuBLAS:
50.93%
从结果上看,v3 是这条优化链里的另一个明显台阶。到了这一版,我们已经不只是“把访存做对”,而是开始显著提高单线程的算术强度了。
v4:2D 寄存器粗化,把单线程计算粒度继续做大
如果说 v3 是“一个线程沿着一个方向多算几个元素”,那么 v4 就更进一步:它直接让每个线程维护一个 TM x TN 的小块,也就是一个二维局部输出 tile。
完整代码如下:
| |
这版代码里,每个线程不再只维护一串列方向结果,而是维护一个 TM x TN = 4 x 4 的小矩形块。它的好处在于,线程从共享内存读出的 a_frag 和 b_frag,可以在寄存器里被更多次地交叉组合使用。于是:
- 数据复用进一步提高
- 单线程计算密度进一步提高
- kernel 的结构更接近高性能 GEMM 的典型形态
和 v3 相比,v4 不只是“算更多元素”,而是把“一个线程内部的小矩阵乘法”也做了出来。这种二维粗化思路在很多高性能手写 GEMM kernel 中都非常常见。
写回阶段,这版代码对每个 (i, j) 都单独计算 c_row 和 c_col,再分别做边界检查,因此尾块处理也比较自然,不需要额外的专门 hack。
v4 的性能表现
在规则尺寸 1024 x 512 x 1024 下:
- 平均耗时:
0.058 ms - 吞吐:
18368.88 GFLOPS - 相对 cuBLAS:
69.30%
在非整齐尺寸 1001 x 513 x 777 下:
- 平均耗时:
0.058 ms - 吞吐:
13848.77 GFLOPS - 相对 cuBLAS:
85.62%
从纯性能角度看,这已经是当前自定义 kernel 里最强的一版。它离 cuBLAS 当然还有差距,但对一份教学型手写实现来说,能够稳定到这个水平,我觉得已经说明这条优化路线非常有效了。
cuBLAS 参考实现
这里附上调用 cuBLAS的代码:
| |
这里把参数顺序写反是因为 cuBLAS 默认使用列主序,而我们这里的矩阵是按行主序放在内存里的。为了不在调用前显式转置矩阵,代码采用了一个非常常见的技巧:利用
$$ (AB)^T = B^T A^T $$这个关系,把“行主序下的 A * B”映射成“列主序下的 B^T * A^T”来算。最终虽然 cuBLAS 内部按列主序解释数据,但我们拿回来的结果,刚好就是希望得到的行主序矩阵 C。
两组 benchmark 汇总
为了避免把性能讨论拆得太散,这里把两组 benchmark 集中放在一起。
规则尺寸:1024 x 512 x 1024
| 版本 | 状态 | 平均耗时 (ms) | 性能 (GFLOPS) | 相对 cuBLAS |
|---|---|---|---|---|
| v0 | PASS | 1.744 | 615.76 | 2.32% |
| v1 | PASS | 0.216 | 4973.53 | 18.76% |
| v2 | PASS | 0.174 | 6156.44 | 23.23% |
| v3 | PASS | 0.098 | 11001.81 | 41.51% |
| v4 | PASS | 0.058 | 18368.88 | 69.30% |
| cuBLAS | PASS | 0.041 | 26506.38 | 100% |
非整齐尺寸:1001 x 513 x 777
| 版本 | 状态 | 平均耗时 (ms) | 性能 (GFLOPS) | 相对 cuBLAS |
|---|---|---|---|---|
| v0 | PASS | 0.406 | 1964.43 | 12.15% |
| v1 | PASS | 0.191 | 4186.53 | 25.88% |
| v2 | PASS | 0.160 | 4985.70 | 30.83% |
| v3 | PASS | 0.097 | 8236.99 | 50.93% |
| v4 | PASS | 0.058 | 13848.77 | 85.62% |
| cuBLAS | PASS | 0.049 | 16174.26 | 100% |
这道题真正值得带走的东西
如果只把这题理解成“矩阵乘法怎么提速”,那其实还不够。对我来说,这道题更重要的价值在于它把 CUDA 优化里几个非常核心的层次关系展示得很清楚。
线程映射是基础中的基础。
很多初学者一开始会把注意力全放在 shared memory、tensor core、寄存器粗化这些更“高级”的词上,但如果 warp 的访问方向一开始就是错的,那么后面的所有优化都只能在一个已经歪掉的地基上修补。v0 -> v1 的结果恰好说明了这个问题:只改映射方式,性能就能暴涨。
shared memory 的意义不是“看上去更高级”,而是数据复用
只有当一块数据会被 block 内多个线程反复使用时,shared memory 才真正值得引入。矩阵乘法正好天然满足这个条件,所以 tiling 才会这么有效。
寄存器粗化本质上是在提高单线程计算密度
不论是 v3 的一维粗化,还是 v4 的二维粗化,本质上都是在想办法让已经加载进寄存器的数据,多参与几次乘加。对高性能 kernel 来说,这种“把数据榨干”的思路往往非常关键。
边界正确性必须单独验证。
如果只测 1024 x 512 x 1024,你很容易误以为一个 kernel 已经“完全正确”。但一旦换成 1001 x 513 x 777 这种不整齐尺寸,很多隐藏 bug 才会真正暴露出来。