SGEMM实施的完整演练

在开始之前,我想提一下,这项工作的大部分都是从对cublas Kepler和Maxwell的sgemm实现的详细研究中得出的。我做了一些适度的改进,但大多数难题都由英伟达的优秀工程师和他们对硬件的专业知识解决了。本文档的目标是传播这些知识,供其他人在自己的代码中使用。我还想联系两篇关于sgemm主题的优秀论文:MAGMA原始论文(http://icl.cs.utk.edu/projectsfiles/magma/pubs/fermi_gemm.pdf)和赖俊杰的Kepler sgemm论文(http://hal.inria.fr/docs/00/78/99/58/PDF/112_Lai.pdf)。本文档基本上是Junjie工作的扩展,但具有Maxwell架构和额外的汇编级优化。

Overview

以sgemm为例,本文旨在描述如何最大化Maxwell架构及其他架构的计算能力。拥有数千个计算核心对你没有好处,除非你让它们得到数据。要做到这一点,您需要构建计算结构,以最大限度地重用通过各种内存层次结构提取的数据。在GPU上,这些是:设备内存到二级缓存,二级缓存到纹理缓存,纹理缓存到寄存器,寄存器到共享内存,共享内存到寄存器,从寄存器到指令操作数缓存(Maxwell的新功能),最后从寄存器返回到设备内存。这些数据路径中的每一条都有延迟,我们需要用指令和线程级并行性(ILP&TLP)来隐藏这些延迟。此外,还可能存在bank和联合约束。所提出的sgemm代码能够克服所有这些约束,并在硬件理论错误的2%内运行。

本文档将介绍两种不同的布局:每个块64个线程和每个块256个线程。我将主要讨论64线程版本,因为映射更小更简单。256线程版本或多或少是相同的,只是放大了4倍。这两个版本分别针对小型或大型矩阵进行了优化。较小的64线程版本可以将矩阵拆分为4倍多的块,这在SM稀少的情况下非常有用,但代价是所需的设备内存带宽是256线程版本的两倍。在GM204硬件上,这个额外的带宽实际上超过了可用的带宽,因此只有当有更多的可用块来填充SM超过了成本时,您才想使用它(除非L2可以隐藏它)。虽然,如果您有足够的并行工作,使用流来填充SM是更好的方法。

在这两个版本中,我们将使用双缓冲8寄存器块来加载A和B中的每一个。双缓冲允许我们从共享内存中隐藏加载的大部分延迟:我们可以计算一个寄存器块,同时加载下一个寄存器。我们选择8个寄存器块,因为它与使用四矢量内存指令很好地对齐,并且因为我们可以将总寄存器预算保持在128以下。跨越128个寄存器的障碍将使我们的占用率从每个调度器的4个活动减少到64线程版本的3个,从256线程版本的4个减少到2个。64线程版本不太容易受到下降的影响,我实际上看到了一些矩阵大小的性能提高(减少了L2和纹理缓存稀释,每个SM的块更少),但256线程版本的工作性能稍好,每个调度程序多了1个扭曲,以覆盖延迟。我们在这两种实现中的性能都不会受到占用率下降的巨大影响,这说明了这段代码如何很好地隐藏ILP的延迟。

我们还将对共享内存进行双重缓冲,以便删除其中一个我们通常需要在主循环中进行的bar.syncs,而不是在存储下一批之前等待所有共享加载完成,我们只需开始写入一个新的共享区域,而其他线程可以从上一个区域读取数据。您将在下面看到,这将向主循环添加3个XOR,但这仍然比bar.syncs便宜。至于共享内存的大小,这是由每个线程块加载的内存宽度乘以主循环展开因子来定义的。我们有64个(或256个)线程,每个线程将计算8*8或64个C点。所有这些点将一起排列成正方形,因为我们从a和B均匀拉动。所以这个正方形的宽度就是总点数的平方根。对于我们的两个实现,我们计算:

1
2
64  Threads: sqrt(64  * 8*8) = 64  units wide
256 Threads: sqrt(256 * 8*8) = 128 units wide*

我们的展开因子是我们一次从A和B读取、从共享存储/读取和计算的行数。它将在几个方面受到限制。我们希望能够通过尽可能多的计算工作来隐藏纹理负载的延迟。但是,我们不希望循环的大小超过指令缓存的大小。这样做会增加额外的指令获取延迟,我们需要隐藏这些延迟。在Maxwell上,我测得这个缓存为8KB。因此,这意味着我们不希望循环大小超过1024个8字节指令,其中每4个指令都是一个控制代码。所以768是有用指令的极限。此外,还有指令对齐的注意事项,因此您也希望安全地处于该值之下。简而言之,使用8的循环展开因子可以得到8 x 64=512 ffma指令加上循环所需的额外内存和整数算术指令(约40)。这使我们大大低于768。每个循环8行也与纹理内存负载的维度很好地对齐。最后,512个FFMA应该足以大部分隐藏200+时钟纹理加载延迟。

因此,我们现在知道了共享内存的总大小:(每个循环8行)x(块加载宽度)x(字大小)x(A的2个缓冲区)x(B的2个缓存区)。64个线程为8192字节,256个线程为16384字节。这种大小不会影响占用率,占用率由寄存器计数(我们将保持在128以下)决定。

下面是两种实现共享的基本内存布局。注意,我将X维度与来自A的载荷相等,并沿lda对齐,而Y维度与来自B的载荷相等并沿ldb对齐。这与x和y通常在空间上的定义方式相反,如下图所示。还要注意的是,A和C的图像被布置为转置。回想起来,我可能会把它改成B作为转置,并与A交换,但这就是我最初的计算方法。在下一节中,我将开始详细讨论64线程版本。

image-20221208152310575

64 Thread Implementation

加载A和B,然后存储到共享

为了加载A和B矩阵,我们使用了一种在cuda c或ptx中无法有效实现的技术。我们将线程分成两半,让每一半加载一个矩阵。由于我们有64个线程,这意味着每个warp加载一个矩阵。cuda中的条件加载没有得到很好的优化,因为编译器没有努力确定加载是否在warp上均匀发生。对于纹理加载,这是必要的,因为指令每次只能处理一个纹理。因此,除了纹理加载之外,编译器还会添加一堆warp刷新以及分支和同步指令,以确保强制执行。如果Nvidia提供一种方式来提示条件或谓词是warp一致的(而不仅仅是分支,即bra.uni),那就太好了。

使用此技术的主要优点是,我们只需要一组跟踪寄存器来保存纹理加载索引。在主循环内部,这是一个巨大的胜利,因为它减少了我们需要的整数加法指令的一半。我们利用一切机会提高FFMA指令与非FFMA指令的比率。

我们还维护了4个单独的轨迹变量,以避免在每次纹理加载后使用依赖性屏障将单个轨迹变量增加ldx*2。内存指令发出时不会复制其操作数寄存器。这样做可能会节省晶体管。相反,当内存指令仍在运行时,您可以使用屏障来防止对这些寄存器的写入。在障碍处等待并不一定很糟糕,因为TLP可以启动并覆盖延迟,但减少需要覆盖的延迟总数可以帮助性能,因为这增加了有翘曲覆盖它们的机会。我们没有任何额外的循环IADDS,因为它有4个跟踪变量,只有3个额外的寄存器,这是我们可以轻松负担的。

所以我们将通过纹理单元加载。通过使用显式纹理加载而不是全局加载,无论是否使用非相干缓存,我们都可以获得一些好处。一是这使得代码更加简单,因为我们不需要担心加载超出范围。第二,使用相同的内核代码,我们可以加载8位或16位浮点,从而显著减少带宽和存储需求。有些应用程序不需要完全32位精度,在这种情况下这是一个巨大的胜利。

此外,我们将加载四元向量。这是对cublas代码的更改,在性能方面产生了最大的差异。虽然我可以理解为什么它不在立方体中使用,因为它对输入数据施加了4个字的对齐约束。立方体有一个固定的规范(这并不是说如果检测到四边对齐,它就不能选择不同的代码路径)。因此,通过使用四元向量,我们需要将lda/ldb索引向下折叠4。这有一个额外的好处,即允许我们加载索引大小为31位的矩阵,而不是常规纹理加载27位的限制。四元加载的另一个工件是我们的内存访问模式在每次提取时都会拉入并消耗全部缓存线。这意味着我们只能得到非常有限的纹理缓存使用率(1-2%),而我们的内存缓存性能将由二级缓存控制。

下面是一些伪代码,它只显示了主循环中的纹理加载和共享存储。你可以从地图上看到,这是非常直接的。你会注意到STS。128条指令,我们将遇到存储体冲突,但这些冲突是不可避免的,结果不会影响性能,因为批量加载和存储到向量指令中是一个双赢。此外,我甚至不确定银行冲突期间发生的指令回放是否重要,因为我认为这些指令可能会与FFMA一起发出。事实上,所有的内存操作都是在我们的主循环中发出的,根本不考虑flops计算(除非在寄存器组冲突一节中以一种微妙的方式描述)。

仅从这段代码和我们的主循环中时钟消耗指令的数量,我们就可以粗略估计内核所需的内存带宽上限。对于GM204,以下是数学公式:

  • 每个线程在每个循环中进行4个vec4 4字节的加载,或者每个循环中每个线程进行64个字节的加载。

  • 下面我们将计算每个循环消耗大约520个时钟。

  • 每个SM同时执行128个线程。
  • 有16个SM的时钟频率为1.216 GHz(升压)。
  • 每GB有.931 GiB:
  • 64 x 128 x 16 x 1.216 x.931/520=285 GiB/秒

GM204有224 GiB/sec可用。但这部分设备带宽将不需要,因为二级缓存将为其提供服务。但在设备带宽上有余量总是很好的。您的负载将不会以完全统一的方式执行,并且当它们聚集在一起时,您的净空越小,出现暂停的机会就越大。虽然只有运行接近理论吞吐量的代码才可能注意到这些暂停,但我们的代码恰好会这样做。

因此,您可以看到,64线程的实现对于GM204来说并不理想。然而,对于GM107来说,它是理想的,对于即将推出的具有384位内存总线的GM200来说也是如此。与256线程实现相比,这一实现使用了双倍的带宽,因此功耗更大。因此,当您有足够的数据来提供数据时,通常会首选更大的版本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
tid = threadId.x;
bx = blockId.x;
by = blockId.y;

blk = tid >= 32 ? by : bx;
ldx = tid >= 32 ? ldb/4 : lda/4;
tex = tid >= 32 ? texB : texA;
tid2 = (tid >> 4) & 1;
tid15 = tid & 15;

track0 = blk*64/4 + tid15 + (ldx * tid2);
track2 = track0 + ldx*2;
track4 = track0 + ldx*4;
track6 = track0 + ldx*6;

end = track0 + (k-8)*ldx;

writeS = tid15*4*4 + tid2*64*4;
writeS += tid >= 32 ? 2048 : 0;

while (track0 < end)
{
tex.1d.v4.f32.s32 loadX0, [tex, track0];
tex.1d.v4.f32.s32 loadX2, [tex, track2];
tex.1d.v4.f32.s32 loadX4, [tex, track4];
tex.1d.v4.f32.s32 loadX6, [tex, track6];

st.shared.v4.f32 [writeS + 4*0*64], loadX0;
st.shared.v4.f32 [writeS + 4*2*64], loadX2;
st.shared.v4.f32 [writeS + 4*4*64], loadX4;
st.shared.v4.f32 [writeS + 4*6*64], loadX6;

// our loop needs one bar sync after share is loaded
bar.sync 0;

// Increment the track variables and swap shared buffers after the sync.
// We know at this point that these registers are not tied up with any in flight memory op.
track0 += ldx*8;
track2 += ldx*8;
track4 += ldx*8;
track6 += ldx*8;
writeS ^= 4*16*64;

// Additional loop code omitted for clarity.
}

通过四矢量纹理索引加载A和B:

image-20221208153216716

使用四矢量将A和B存储到共享地址空间中:

image-20221208153237832

从共享读取

现在,共享内存已加载,我们从使用一半线程切换到处理A和B中的每一个。我们需要开始组合这些值来计算构成C点的点积(处理每一行后,我们计算块中所有C值的部分积和和)。所以每个线程都将从A的共享行和B的共享行中读取。

除了FFMA之外,共享负载是这个实现的真正工作。我们对它们进行双重缓冲,因此延迟最小。但我们也希望确保我们没有bank冲突,因为我们需要尽快提供这些数据。如何在没有库冲突的情况下使用四元向量从共享加载?好吧,根据文档,只要所有访问都在32个字(128字节)以内,我们就可以了。在sgemm中,这是因为我们可以安排不同的线程同时从同一共享内存位置加载,并使用共享广播机制。然而,事实证明,Maxwell的文档是不完整的,尽管warp中的所有线程都在相同的128字节内,但仍有某些模式会导致库冲突。这样做可能是为了节省芯片。所以我们只需要找到一个可行的模式。

在128字节内,我们可以加载8个16字节的四元组。我们将使其成为从A和B的共享内存加载的模式。我们的共享内存块是4*64=256字节宽,因此为了加载另一半,我们将展开该负载到一个相隔32个单元的额外指令中。我们不必担心bank冲突。每个矩阵的这两个四字负载形成了我们想要的8寄存器块。通过在2D中组合这两种1D模式的负载,我们可以得到下面所示的共享内存映射。该模式还表示每个线程的64个寄存器在C子矩阵中的位置(绿色方块)。

现在我们有了基本信息,我们需要将其分成两个warp,然后映射这些warp中的线程id。直接方法是以简单的扫描模式向下或横向加载。这导致了神秘的bank冲突。但是,如果我们使用由thread号表示的锯齿形图案,它就会起作用。我还没有对所有的负载大小和模式进行详尽的搜索,以了解哪些是有效的,哪些是无效的,但如果Nvidia为Maxwell更新他们的文档来解释这一限制,那就太好了。

至于找出将threadId映射到我们想要的模式(下面的readAs和readBs)所需的逻辑,我有一个简单的技术。我只是打印出每个threadId的二进制表示形式和我希望它映射到的值。当您以这种方式可视化二进制时,很容易确定需要保留、丢弃或移动哪些位以使映射工作(前提是您选择了possble映射)。

我还应该提到我的插图是如何被解读的。黄色方块表示线程(或TLP),绿色方块表示第一个线程的ILP。你应该能够想象得到绿色正方形的图案,并将其移动到每个黄色正方形的顶部(保持绿色与黄色的相对位置)。这应该跨越整个内存空间,这是我们共享映射的目标:a中一条线的每个点都需要与B中一条线上的每个点配对。细黑线表示线程如何被分割成warp。下面的深绿色方块是为了说明我们稍后将要进行的warp同步洗牌中的一个步骤。

另一个值得注意的是,cublas在这里使用了更复杂的readAs/readBs映射,这实现了相同的效果,但需要花费更多的指令。这是我的代码对cublas的一个小改进。如果您提前知道共享加载限制,那么更复杂的模式甚至是有意义的。但似乎愚蠢而直接的方法最终找到了更简单的解决方案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
readAs = ((tid >> 1) & 7) << 4;
readBs = (((tid & 0x30) >> 3) | (tid & 1)) << 4 + 2048;

while (track0 < end)
{
// Process each of our 8 lines from shared
for (j = 0; j < 8; j++)
{
// We fetch one line ahead while calculating the current line.
// Wrap the last line around to the first.
prefetch = (j + 1) % 8;

// Use even/odd rows to implement our double buffer.
if (j & 1)
{
ld.shared.v4.f32 j0Ax00, [readAs + 4*(prefetch*64 + 0)];
ld.shared.v4.f32 j0By00, [readBs + 4*(prefetch*64 + 0)];
ld.shared.v4.f32 j0Ax32, [readAs + 4*(prefetch*64 + 32)];
ld.shared.v4.f32 j0By32, [readBs + 4*(prefetch*64 + 32)];
}
else
{
ld.shared.v4.f32 j1Ax00, [readAs + 4*(prefetch*64 + 0)];
ld.shared.v4.f32 j1By00, [readBs + 4*(prefetch*64 + 0)];
ld.shared.v4.f32 j1Ax32, [readAs + 4*(prefetch*64 + 32)];
ld.shared.v4.f32 j1By32, [readBs + 4*(prefetch*64 + 32)];
}
}
// swap our shared memory buffers after reading out 8 lines
readAs ^= 4*16*64;
readBs ^= 4*16*64;

// Additional loop code omitted for clarity.
}

1D readAs(左侧)和readBs(顶部)在2D中组合以形成该线程块的C结果子矩阵:

image-20221208162657316

计算C:寄存器bank和重用

现在我们为线程填充了8个寄存器A和B,我们可以执行64个FFMA,这些FFMA构成了内核设计的核心工作。为了能够在全速和最低功率下计算这一点,我们需要考虑几个因素。主要是寄存器组和操作数重用。

Maxwell上有4个寄存器组,但与开普勒(也有4个组)不同的是,将bank分配给数字非常简单。Maxwell赋值只是寄存器数模4。在开普勒上,可以安排64条FFMA指令以消除所有存储体冲突。在麦克斯韦身上,这已经不可能了。然而,Maxwell提供了一些弥补这一点的方法,同时提供了显著减少寄存器组流量和总体芯片功耗的能力。这是操作数重用缓存。操作数重用缓存每个源操作数插槽有8个字节的数据。类似FFMA的指令有3个源操作数槽。每次发出指令时,都有一个标志可以用来指定是否要再次使用每个操作数。因此,在同一操作数槽中使用同一寄存器的下一条指令不必去寄存器组获取其值。通过此功能,您可以看到如何避免寄存器bank冲突。

因此,我们要采取的第一步是尽量减少操作数重用时必须隐藏的存储体冲突的数量。为此,我们需要显式选择要使用的寄存器。这是使用maxas作为汇编器的主要优点之一。ptxas在避免存储体冲突方面做得很好,但它并不完美,而且当涉及向量指令时,它做得特别糟糕(本例中的情况非常严重)。因此,我们将选择:

  • 0-63为C寄存器
  • 64-71和80-87是矩阵A的双缓冲块寄存器
  • 72-79和88-95是矩阵B的双缓冲块寄存器

如果我们按照下面所示的8乘8矩阵排列,我们可以用每个寄存器的存储体索引为其着色。对于C寄存器,我们选择与相应的块寄存器不同的颜色。通过这种方式,您可以看到我们可以消除与C寄存器和阻塞寄存器的所有存储体冲突。这使得不可避免的16个存储体与阻塞寄存器本身发生冲突。这些以黑色显示:

image-20221208162949337

如果没有重用缓存,这16个存储体冲突中的每一个都将导致计算中的1个时钟暂停。这将使我们的计算速度降低约20%(在520时钟循环中增加128个时钟)。但是,如果您使用—noruse标志组装sgemm代码,您将看到性能只会下降几百Gflop左右。如果你仔细阅读英伟达关于操作数收集器的专利,特别是如果你搜索涉及bank冲突的部分,这个谜团就迎刃而解了。它描述了一些缓解bank冲突的方法。很难说Maxwell是如何处理的,但这可能涉及到如何利用TLP来隐藏bank冲突延迟。因此,操作数收集器单独屏蔽存储体冲突的能力有限,但可能很快就会被淹没。通过使用持久的缓存而不仅仅是临时操作数缓冲区,硬件能够更有效地避免bank冲突暂停。它只需要汇编器使用重用标志来指导它,这样它就可以提前知道哪些寄存器值得缓存,以及在寄存器被写入时丢弃哪些寄存器。

优化设置重用标志的繁琐任务由maxas为您处理。留给我们的是以这样的方式对指令进行排序,以便最大限度地实现重用。最简单的排序是一个基本的双嵌套“for循环”,它将逐行遍历矩阵。这只有效地利用了重用缓存每个操作数8个字节中的4个字节,并且不会隐藏所有的存储体冲突。相反,如果您的扫描来回进行,则可以隐藏所有冲突并提高寄存器重用率(总体上为39%)。但最有效的模式是,在来回移动时,应用一个漩涡(47%的总重用率)。以下是按C寄存器号列出的FFMA指令顺序:

1
2
3
4
5
6
7
1, 0, 2, 3, 5, 4, 6, 7, 33, 32, 34, 35, 37, 36, 38, 39,

45, 44, 46, 47, 41, 40, 42, 43, 13, 12, 14, 15, 9, 8, 10, 11,

17, 16, 18, 19, 21, 20, 22, 23, 49, 48, 50, 51, 53, 52, 54, 55,

61, 60, 62, 63, 57, 56, 58, 59, 29, 28, 30, 31, 25, 24, 26, 27

您将注意到,所选漩涡尺寸在其中一个方向上的间距为2。此间距具有使C寄存器出现在交替存储体中的效果。我对这样做的原因的最佳猜测是极其微妙的。由于我们的内存指令与FFMA交错,并且这些指令没有其操作数寄存器的副本,因此它们可以在大约20个时钟周期内访问寄存器组。我们的C寄存器经常被弄脏,因此无法重复使用,所以我们总是从寄存器库中取出它们。因此,主要是这些寄存器会与我们的内存加载和存储指令发生延迟存储体冲突。可能不可能完全围绕这些银行冲突进行设计,但您可以减少它们的影响。通过在每条指令上交替使用C寄存器组,我们可以确保组冲突最多只能持续一个时钟。我运行了几个基准测试来检验这个假设,结果似乎是正确的。最后一个注意事项:使用所有四矢量加载和存储的另一个优点(除了效率更高之外)是减少了所需的内存指令数量,从而减少了延迟寄存器组冲突的机会。

鉴于我们知道内存操作数寄存器可能存在延迟存储体冲突,因此为这些操作数选择不同的存储体是值得尝试的。使用maxas,我们可以完全控制寄存器映射,您将在源代码中注意到,我们为track0-3、tex、readAs、readBs和writeS选择了非常特定的库。测试了这些库选择中的每一个,以最大化内核的flops性能。这是一个优化级别,我不确定cublas实现是否实现。我知道它犯的一个错误是,对于第一个FFMA,它选择了具有阻塞寄存器组冲突的C寄存器(3)。这防止了重用缓存隐藏该冲突的能力,因为之前没有将至少一个操作数加载到缓存中的指令。在GM204上,这个错误导致28 Gflops的性能损失。

使用FFMA的最后一个考虑是如何将它们与上述所有内存操作交织。要了解我在这里谈论的内容,请查看源代码的预处理版本。我们希望尽早使用双缓冲共享负载,以覆盖它们的延迟,因此我们将使用第一个FFMA开始双重发布它们。我们将用两条指令来分隔它们,因为内存单元似乎以一半的吞吐量最佳工作。我们将把纹理加载放在两组共享加载的中间。这样做是为了不让指令淹没内存单元。对于64线程实现,我们甚至将四个负载分成两组,并将它们放在不同的FFMA块中。我们将共享存储指令放置在尽可能低的位置,以使纹理加载有机会加载它们的操作数。我们不能将它们放在最后一个FFMA块中,因为这是我们开始为下一个循环迭代加载块寄存器的地方。

所有这些立场决定都经过了严格的测试,证明是最佳的。我应该注意,使用ptxas无法进行这些细粒度的排序和定位选择。事实上,ptxas倾向于优化我们的共享双缓冲加载方案。在选择寄存器组、优化操作数重用的指令排序和将内存指令精确放置在我们想要的位置之间,实现的性能可以达到理论性能的70%,而实现的性能则可以达到98%。

warp同步无序映射

在循环结束时,现在计算线程块的C子矩阵。所以现在是将结果存储回全局存储器的时候了。因为我们使用了来自共享的四矢量加载,所以我们的C值有点聚在一起,对于联合写入全局来说根本不是最佳的。我们可以直接将数据写出来,但我们可以做得更好。通过使用共享内存在同一warp的线程之间移动C寄存器,我们可以重新组织它以进行合并写入。您可能认为warp shuffle指令在这里最有效,但我们需要从不同的线程交换不同的寄存器,因此它不适合于此目的。

我们将把洗牌分成8块。上述共享内存映射上的深绿色线表示第一个块。另外7个将是C寄存器的后续垂直选择。因此,每个线程在共享内存中一次存储8个寄存器,然后立即再读取8个寄存器。但是,这些寄存器的排列方式使得我们的线程ID的重新映射可以以合并模式将数据存储到全局。因此,为了存储到共享,我们需要重新使用原始的共享内存映射,并在其中一个维度中将其从4个跨步单位折叠为一个。读取它的线程id映射将是32个值,步幅为1个单位。

1
2
3
4
5
6
7
8
9
10
11
12
13
tid31 = tid & 31;
tid32 = tid & 32;

// Remove the high bits if present from the last loop's xor.
// Also remove the 2048 added onto readBs.
readAs &= 0x7ff;
readBs &= 0x7ff;

// Write to shared using almost the same shared mapping as before but collapse readBs down to stride one.
writeCs = (readBs / 4) * 64 + readAs;

// Read out with a mapping amenable to coalesced global writes
readCs = ((tid32 << 3) + tid31) << 2;

image-20221208164843431

Warp Shuffling和联合存储到全局

有了上述映射,我们现在可以输出C值。注意,我们不需要bar.sync在写入共享内存之前进行同步,因为这已经在我们的最后一个循环中完成了。还要注意,由于我们不在warp之间共享数据,所以在共享内存洗牌中,我们不需要在写入和读取之间同步。只有在存储到writeC完成后,才会进行从readC的读取。注意,这里增加的共享内存延迟大部分可以用TLP隐藏,而为扭曲同步洗牌增加的净时钟只有十几个左右。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
ldc4 = ldc * 4;

cx = bx*64 + tid31;
cy = by*64 + (tid32 >> 1);

Cy00 = (cy*ldc + cx) * 4 + C;
Cy04 = Cy00 + ldc4 * 4;
Cy08 = Cy00 + ldc4 * 8;
Cy12 = Cy00 + ldc4 * 12;

foreach copy vertical line of 8 registers from C into .v4.f32 cs0 and cs4
{
// Feed the 8 registers through the warp shuffle before storing to global
st.shared.v4.f32 [writeCs + 4*00], cs0;
st.shared.v4.f32 [writeCs + 4*32], cs4;

ld.shared.f32 cs0, [readCs + 4*(0*64 + 00)];
ld.shared.f32 cs1, [readCs + 4*(0*64 + 32)];
ld.shared.f32 cs2, [readCs + 4*(1*64 + 00)];
ld.shared.f32 cs3, [readCs + 4*(1*64 + 32)];
ld.shared.f32 cs4, [readCs + 4*(2*64 + 00)];
ld.shared.f32 cs5, [readCs + 4*(2*64 + 32)];
ld.shared.f32 cs6, [readCs + 4*(3*64 + 00)];
ld.shared.f32 cs7, [readCs + 4*(3*64 + 32)];

st.global.f32 [Cy00 + 4*00], cs0;
st.global.f32 [Cy00 + 4*32], cs1;
st.global.f32 [Cy04 + 4*00], cs2;
st.global.f32 [Cy04 + 4*32], cs3;
st.global.f32 [Cy08 + 4*00], cs4;
st.global.f32 [Cy08 + 4*32], cs5;
st.global.f32 [Cy12 + 4*00], cs6;
st.global.f32 [Cy12 + 4*32], cs7;

Cy00 += ldc4;
Cy04 += ldc4;
Cy08 += ldc4;
Cy12 += ldc4;

// After processing forth set shift over to the stride 32 registers
if (4th iteration)
{
Cy00 += ldc4 * 28;
Cy04 += ldc4 * 28;
Cy08 += ldc4 * 28;
Cy12 += ldc4 * 28;
}
}

在下图中,蓝色方块表示如何从Cy00、Cy04、Cy08和Cy12矩阵C偏移的8个状态构造绿线。它们的垂直放置的不是步幅32的部分是到循环迭代的映射,而不是空间位置。

image-20221208165042717

呃……所以这是一个很高的水平。代码注释中甚至包含了较低级别的细节,特别是关于如何将内存访问与计算同步的细节。注释仅在256线程版本中找到。说到这里,下面我将展示四倍多的线程如何改变映射。

SGEMM - 256 Thread Implementation

Loading A and B

1
2
3
4
5
6
7
8
9
10
11
12
13
tid = threadId.x;
blk = tid >= 128 ? blockId.y : blockId.x;
ldx = tid >= 128 ? ldb/4 : lda/4;
tex = tid >= 128 ? texB : texA;
tid4 = (tid >> 5) & 3
tid31 = tid & 31
tid96 = tid & 96
tid128 = tid & 128

track0 = blk*128/4 + tid31 + (ldx * tid4)
track4 = track0 + ldx*4;

end = track0 + (k-8)*ldx;

Storing to Shared

1
2
writeS  = tid31*4*4 + tid4*128*4;
writeS += tid >= 128 ? 4096 : 0;

image-20221208165147533

Reading from Shared

1
2
readAs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4;
readBs = (((tid & 0x70) >> 3) | (tid & 1)) << 4 + 4096;

image-20221208165203670

Warp Synchronous Shuffle

1
2
3
4
5
6
readAs &= 0xfff;
readBs &= 0xfff;

writeCs = (readBs / 4) * 128 + readAs;

readCs = ((tid96 << 4) | tid31 | (tid128 >> 2)) << 2;

image-20221208165218383

Storing to Global

1
2
3
4
5
6
7
8
9
ldc4 = ldc * 4;

cx = bx*128 + tid31 | (tid128 >> 2);
cy = by*128 + (tid96 >> 1);

Cy00 = (cy*ldc + cx) * 4 + C;
Cy04 = Cy00 + ldc4*4;
Cy08 = Cy00 + ldc4*8;
Cy12 = Cy00 + ldc4*12;

image-20221208165238932

深入浅出GPU优化系列

前言

首先需要对reduce算法进行介绍。reduce算法本质上就是计算x=x0⊗x1⊗x2⊗x3……⊗xn−1⊗xn 。下面本文将详细说明如何在GPU中实现reduce算法并进行深入地优化。

并行算法设计

在GPU中,reduce采用了一种树形的计算方式。如下图所示。

image-20220910175222544

从上至下,将数据不断地累加,直到得出最后的结果,即25。但由于GPU没有针对global数据的同步操作,只能针对block的数据进行同步。所以,一般而言将reduce分为两个阶段,其示意图如下:

image-20220910175237258

我们仔细来看看这个事,假设给定一个长度为N的数组,需要计算该数组的所有元素之和。首先需要将数组分为m个小份。而后,在第一阶段中,开启m个block计算出m个小份的reduce值。最后,在第二阶段中,使用一个block将m个小份再次进行reduce,得到最终的结果。由于第二阶段本质上是可以调用第一个阶段的kernel,所以不做单独说明,本文只是探索第一阶段的优化技巧。

所以kernel接口为:

1
__global__ void reduce(T *input, T* output)

其中,input代表输入的数组,即一个长度为N的数组,output代表输出数组,即第一阶段的结果,即长度为M的数组。随后要开始激动人心的coding阶段,但在CUDA编程中,我们首先需要设置三个参数:

  1. BlockNum:即开启的block数量,即上面所说的M,代表需要将数组切分为几份。
  2. Thread_per_block:每个block中开启的线程数,一般而言,取128,256,512,1024这几个参数会比较多。
  3. Num_per_block:每个block需要进行reduce操作的长度。

其中,BlockNum* Num_per_block=N

reduce优化

reduce baseline算法介绍

Baseline算法比较简单,分为三个步骤。第一个步骤是将数据load至shared memory中,第二个步骤是在shared memory中对数据进行reduce操作,第三个步骤是将最后的结果写回global memory中。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
__global__ void reduce0(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*blockDim.x+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i];
__syncthreads();

// do reduction in shared mem
for(unsigned int s=1; s<blockDim.x; s*=2){
if(tid%(2*s) == 0){
sdata[tid]+=sdata[tid+s];
}
__syncthreads();
}

// write result for this block to global mem
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

在进行优化之前,我们需要再来好好地梳理一下这个baseline代码。优化的本质是通过软件榨干硬件资源,所以必须清楚地了解代码在硬件上的执行过程才能更好地进行优化。

第一个步骤中,我们让Num_per_block与Thread_per_block一致,每个block设定为256个线程,一个block负责256个数据的reduce工作。假设需要处理32M的数据,则有128K个block。tid代表线程号,i代表在原始数组中的索引号。第tid号线程将第i号的数据从global中取出,放到shared memory的第tid元素中。比如在第0号block中,0号线程将0号元素取出,放到shared memory的第0号位置。示意图见:

image-20220910175717727

从硬件角度来分析一下代码。为了执行代码,GPU需要分配两种资源,一个是存储资源,一个是计算资源存储资源包括在global memory中分配的一块32M× sizeof(float)的空间以及在shared memory中分配的256× sizeof(float)的空间。需要注意的是,shared memory存在bank冲突的问题,因而需要格外小心计算资源其实是根据thread数量来确定的,一个block中分配256个thread线程,32个线程为一组,绑定在一个SIMD单元。所以256个线程可以简单地理解为分配了8组SIMD单元。

(但实际的硬件资源分配不是这样,因为一个SM的计算资源有限,不可能真的给每一个block都分配这么多的SIMD单元。)总而言之,在第一个阶段,就是tid号线程将i号数据从global memory中取出,再放进shared memory中,严谨一点的话,中间是走一遍寄存器再到shared memory中的。

到了第二个阶段,block中需要计算的256个元素已经全部被存储在了shared memory中,此时需要对其进行reduce操作。这个过程需要进行多轮迭代,在第一轮迭代中,如果tid%2 ==0, 则第tid号线程将shared memory中第tid号位置的值和第tid+1号的值进行相加,而后放在第tid号位置。

在第二轮迭代中,如果tid%4==0,则第tid号线程将shared memory中第tid号位置的值和第tid+2号的值进行相加,而后放在第tid号位置。不断迭代,则所有元素都将被累加到第0号位置。其示意图如下。其中,红色的线程代表符合if条件的线程,只有它们有任务,需要干活。

image-20220910175947839

第三个阶段中,block负责的256个元素之和都放置在shared memory的0号位置,此时,只需要将0号位置的元素写回即可。

优化技巧1:解决warp divergence

现有问题

目前reduce0存在的最大问题就是warp divergent的问题。对于一个block而言,它所有的thread都是执行同一条指令。如果存在if-else这样的分支情况的话,thread会执行所有的分支。只是不满足条件的分支,所产生的结果不会记录下来。可以在上图中看到,在每一轮迭代中都会产生两个分支,分别是红色和橙色的分支。这严重影响了代码执行的效率。

解决方式

解决的方式也比较明了,就是尽可能地让所有线程走到同一个分支里面。代码示意如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
__global__ void reduce1(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*blockDim.x+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i];
__syncthreads();

// do reduction in shared mem
for(unsigned int s=1; s<blockDim.x; s*=2){
int index = 2*s*tid;
if(index < blockDim.x){
sdata[index]+=sdata[index+s];
}
__syncthreads();
}

// write result for this block to global mem
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

虽然代码依旧存在着if语句,但是却与reduce0代码有所不同。我们继续假定block中存在256个thread,即拥有256/32=8个warp。当进行第1次迭代时,0-3号warp的index<blockDim.x, 4-7号warp的index>=blockDim.x。对于每个warp而言,都只是进入到一个分支内,所以并不会存在warp divergence的情况。

当进行第2次迭代时,0、1号两个warp进入计算分支。当进行第3次迭代时,只有0号warp进入计算分支。当进行第4次迭代时,只有0号warp的前16个线程进入分支。此时开始产生warp divergence。通过这种方式,我们消除了前3次迭代的warp divergence。

优化技巧2:解决bank冲突

现有问题

reduce1的最大问题是bank冲突。我们把目光聚焦在这个for循环中。并且只聚焦在0号warp。在第一次迭代中,0号线程需要去load shared memory的0号地址以及1号地址的数,然后写回到0号地址。而此时,这个warp中的16号线程,需要去load shared memory中的32号地址和33号地址。可以发现,0号地址跟32号地址产生了2路的bank冲突

第2次迭代中,0号线程需要去load shared memory中的0号地址和2号地址。这个warp中的8号线程需要load shared memory中的32号地址以及34号地址,16号线程需要load shared memory中的64号地址和68号地址,24号线程需要load shared memory中的96号地址和100号地址。

又因为0、32、64、96号地址对应着同一个bank,所以此时产生了4路的bank冲突。现在,可以继续算下去,8路bank冲突,16路bank冲突。由于bank冲突,所以reduce1性能受限。下图说明了在load第一个数据时所产生的bank冲突。

image-20220910180155032

解决方式

在reduce中,解决bank冲突的方式就是把for循环逆着来。原来stride从0到256,现在stride从128到0。其伪代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
__global__ void reduce2(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*blockDim.x+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i];
__syncthreads();

// do reduction in shared mem
for(unsigned int s=blockDim.x/2; s>0; s>>=1){
if(tid < s){
sdata[tid]+=sdata[tid+s];
}
__syncthreads();
}

// write result for this block to global mem
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

那为什么通过这么一个小小的改变就能消除bank冲突呢,我们继续进行分析。

把目光继续看到这个for循环中,并且只分析0号warp。0号线程需要load shared memory的0号元素以及128号元素。1号线程需要load shared memory中的1号元素和129号元素。这一轮迭代中,在读取第一个数时,warp中的32个线程刚好load 一行shared memory数据。再分析第2轮迭代,0号线程load 0号元素和64号元素,1号线程load 1号元素和65号元素。

咦,也是这样,每次load shared memory的一行。再来分析第3轮迭代,0号线程load 0号元素和32号元素,接下来不写了,总之,一个warp load shared memory的一行。没有bank冲突。到了4轮迭代,0号线程load 0号元素和16号元素。那16号线程呢,16号线程啥也不干,因为s=16,16-31号线程啥也不干,跳过去了。示意图如下:

image-20220910213658483

优化技巧3:解决idle线程

现有问题

reduce2最大的问题就是线程的浪费。可以看到我们启动了256个线程,但是在第1轮迭代时只有128个线程在干活,第2轮迭代只有64个线程在干活,每次干活的线程都会减少一半。第一轮迭代示意图如下,只有前128个线程在load数据。后128个线程啥也不干,光看着。

image-20220910213713143

解决方式

对于HPC从业者而言,我们希望变成GPU的资本家,去尽可能地压榨GPU。但是呢,在这里,每一次迭代有一半的线程不干活。而且,128-255号线程最过分,它娘的,没有任何贡献,啥也不干。想来想去,能不能让它们干点活呢。想来想去,那这样吧,让它好歹做一次加法。除了去global memory中取数外,再做一次加法。当然为了实现这个,block数就得改一改了。Block数量减少,Num_per_block增加一倍。也就是说原来一个block只需要管256个数就行,现在得管512个数了。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
__global__ void reduce3(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*(blockDim.x*2)+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i] + d_in[i+blockDim.x];
__syncthreads();

// do reduction in shared mem
for(unsigned int s=blockDim.x/2; s>0; s>>=1){
if(tid < s){
sdata[tid]+=sdata[tid+s];
}
__syncthreads();
}

// write result for this block to global mem
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

通过这种方式,将一些idle的线程给利用起来了。

优化技巧4:展开最后一维减少同步

现有问题

对于reduce3来说,性能已经算是比较好了。但是依旧没有达到我们想要的效果。我们再来仔细地看看还有什么可以改进的地方。我们发现,当进行到最后几轮迭代时,此时的block中只有warp0在干活时,线程还在进行同步操作。这一条语句造成了极大的浪费。

解决方式

由于一个warp中的32个线程其实是在一个SIMD单元上,这32个线程每次都是执行同一条指令,这天然地保持了同步状态,因而当s=32时,即只有一个SIMD单元在工作时,完全可以将__syncthreads()这条同步代码去掉。所以我们将最后一维进行展开以减少同步。伪代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
__device__ void warpReduce(volatile float* cache,int tid){
cache[tid]+=cache[tid+32];
cache[tid]+=cache[tid+16];
cache[tid]+=cache[tid+8];
cache[tid]+=cache[tid+4];
cache[tid]+=cache[tid+2];
cache[tid]+=cache[tid+1];
}

__global__ void reduce4(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*(blockDim.x*2)+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i] + d_in[i+blockDim.x];
__syncthreads();

// do reduction in shared mem
for(unsigned int s=blockDim.x/2; s>32; s>>=1){
if(tid < s){
sdata[tid]+=sdata[tid+s];
}
__syncthreads();
}

// write result for this block to global mem
if(tid<32)warpReduce(sdata,tid);
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

可以通过下面的示意图更好地了解,warp0会被绑定在一个SIMD单元上,上面有thread0-thread31。warp1会被绑在另外一个SIMD单元上,上面有thread32-thread63。由于在一个SIMD单元上,然后不管啥时候thread0和thread7肯定是同一状态,不需要同步。而thread0和thread34就不能保证同步,必须用__syncthreads()来保证同步操作。

优化技巧5:完全展开减少计算

现有问题

其实到了这一步,reduce的效率已经足够高了。再进一步优化其实已经非常困难了。为了探索极致的性能表现,Mharris接下来给出的办法是对for循环进行完全展开。我觉得这里主要是减少for循环的开销。Mharris的实验表明这种方式有着1.41x的加速比。但是用的机器是G80,十几年前的卡。性能数据也比较老了,至于能不能真的有这么好的加速比,我们拭目以待。

解决方法

我们将整个for循环进行展开,非常暴力,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
template <unsigned int blockSize>
__device__ void warpReduce(volatile float* cache,int tid){
if(blockSize >= 64)cache[tid]+=cache[tid+32];
if(blockSize >= 32)cache[tid]+=cache[tid+16];
if(blockSize >= 16)cache[tid]+=cache[tid+8];
if(blockSize >= 8)cache[tid]+=cache[tid+4];
if(blockSize >= 4)cache[tid]+=cache[tid+2];
if(blockSize >= 2)cache[tid]+=cache[tid+1];
}

template <unsigned int blockSize>
__global__ void reduce5(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*(blockDim.x*2)+threadIdx.x;
unsigned int tid=threadIdx.x;
sdata[tid]=d_in[i] + d_in[i+blockDim.x];
__syncthreads();

// do reduction in shared mem
if(blockSize>=512){
if(tid<256){
sdata[tid]+=sdata[tid+256];
}
__syncthreads();
}
if(blockSize>=256){
if(tid<128){
sdata[tid]+=sdata[tid+128];
}
__syncthreads();
}
if(blockSize>=128){
if(tid<64){
sdata[tid]+=sdata[tid+64];
}
__syncthreads();
}

// write result for this block to global mem
if(tid<32)warpReduce<blockSize>(sdata,tid);
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

优化技巧6:合理设置block数量

现有问题

当走到这一步的时候,能调的东西已经基本上调完了。我们再把眼光放在block和thread的设置上。之前默认了Num_per_block=Thread_per_block。也就是说,一个block开启256个线程时,这个block负责256个元素的reduce操作。那可不可以让一个block多管点数。这样的话,开启的block数量少一些。以此对block设置进行调整,获得最优block取值,这样或许能够带来一些性能收益?

解决方式

这样需要再思考一下block的取值。对于GPU而言,block的取值到底是多更好,还是少更好。如此对CUDA编程熟悉的同学,肯定会毫不犹豫地说:“那肯定是多更好啦。Block数量多,block可以进行快速地切换,去掩盖访存的延时。”这个问题按下不表,我们看看Mharris是怎么说的。

如果一个线程被分配更多的work时,可能会更好地覆盖延时。这一点比较好理解。如果线程有更多的work时,对于编译器而言,就可能有更多的机会对相关指令进行重排,从而去覆盖访存时的巨大延时。虽然这句话并没有很好地说明在某种程度上而言,block少一些会更好。但是,有一点不可否认,block需要进行合理地设置。唠唠叨叨说了很多,现在把代码贴一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
template <unsigned int blockSize>
__global__ void reduce6(float *d_in,float *d_out){
__shared__ float sdata[THREAD_PER_BLOCK];

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*(blockDim.x*2)+threadIdx.x;
unsigned int tid=threadIdx.x;
unsigned int gridSize = blockSize * 2 * gridDim.x;
sdata[tid] = 0;

while(i<n){
sdata[tid] +=d_in[i]+d_in[i+blockSize];
i+=gridSize;
}
__syncthreads();

// do reduction in shared mem
if(blockSize>=512){
if(tid<256){
sdata[tid]+=sdata[tid+256];
}
__syncthreads();
}
if(blockSize>=256){
if(tid<128){
sdata[tid]+=sdata[tid+128];
}
__syncthreads();
}
if(blockSize>=128){
if(tid<64){
sdata[tid]+=sdata[tid+64];
}
__syncthreads();
}

// write result for this block to global mem
if(tid<32)warpReduce<blockSize>(sdata,tid);
if(tid==0)d_out[blockIdx.x]=sdata[tid];
}

优化技巧7:使用shuffle指令

现有问题

其实,对于Mharris的讲义。reduce优化就到此结束了。但是NV后来出了Shuffle指令,对于reduce优化有着非常好的效果。目前绝大多数访存类算子,像是softmax,batch_norm,reduce等,都是用Shuffle实现。所以,在这里谈一下这么把shuffle指令用在reduce优化上。

Shuffle指令是一组针对warp的指令。Shuffle指令最重要的特性就是warp内的寄存器可以相互访问。在没有shuffle指令的时候,各个线程在进行通信时只能通过shared memory来访问彼此的寄存器。而采用了shuffle指令之后,warp内的线程可以直接对其他线程的寄存器进行访存。通过这种方式可以减少访存的延时。除此之外,带来的最大好处就是可编程性提高了,在某些场景下,就不用shared memory了。毕竟,开发者要自己去控制 shared memory还是挺麻烦的一个事。

伪代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
template <unsigned int blockSize>
__device__ __forceinline__ float warpReduceSum(float sum){
if(blockSize >= 32)sum += __shfl_down_sync(0xffffffff,sum,16);
if(blockSize >= 16)sum += __shfl_down_sync(0xffffffff,sum,8);
if(blockSize >= 8)sum += __shfl_down_sync(0xffffffff,sum,4);
if(blockSize >= 4)sum += __shfl_down_sync(0xffffffff,sum,2);
if(blockSize >= 2)sum += __shfl_down_sync(0xffffffff,sum,1);
return sum;
}

template <unsigned int blockSize>
__global__ void reduce7(float *d_in,float *d_out, unsigned int n){
float sum = 0;

//each thread loads one element from global memory to shared mem
unsigned int i=blockIdx.x*(blockDim.x*2)+threadIdx.x;
unsigned int tid=threadIdx.x;
unsigned int gridSize = blockSize * 2 * gridDim.x;

while(i<n){
sdata[tid] +=d_in[i]+d_in[i+blockSize];
i+=gridSize;
}

// shared mem for partial sums(one per warp in the block
static __shared__ float warpLevelSums[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;

sum = warpReduceSum<blockSize>(sum);

if(laneId == 0)warpLevelSums[warpId]=sum;
__syncthreads();

sum = (threadIdx.x < blockDim.x / WARP_SIZE)? warpLevelSums[laneId]:0;
// Final reduce using first warp
if(warpId == 0)sum = warpReduceSum<blockSize/WARP_SIZE>(sum);
// write result for this block to global mem
if(tid==0)d_out[blockIdx.x]=sum;
}

GEMM优化

前言

在高性能领域,对于矩阵乘(GEMM)的优化是一个非常重要的课题。GEMM可以非常广泛地应用于航空航天、流体力学等科学计算领域,这也是之前HPC的主要应用场景。后来深度学习开展地如火如荼,由于对高算力的需要,也成为HPC的主要应用场景之一。这些年涌现了一系列的深度学习模型。模型里面最耗时的东西,包括卷积、全连接层、attention,都可以转换成GEMM操作。所以说,GEMM优化的重要性,怎么突出都不过分。

本篇文章主要介绍GEMM中的数据分块和如何在多级存储进行数据搬运。这也是HPC优化的核心思想,怎么样让数据放在更近的存储上来掩盖计算的延时,从而减少存储墙的影响。文章分为四个方面进行叙述,首先介绍在global memory层面如何进行分块以及数据搬运,随后介绍在shared memory层面如何进行分块以及数据搬运,而后介绍在register层面如何进行分块以及避免bank冲突,最后介绍如何进行prefetch以更好地掩盖访存时延。

从global memory到shared memory

假设有矩阵A,B,需要计算矩阵A和B的乘,即矩阵C。A、B、C三个矩阵的维度分别为,,m∗k,k∗n,m∗n ,且三个矩阵中的数据都是单精度浮点数。对于C中每一个元素,C[i][j],可以看作是A的一行和B的一列进行一次归约操作。采用最naive的GEMM算法,在GPU中,一共开启m∗n 个线程,每个线程需要读取矩阵A的一行与矩阵B的一列,而后将计算结果写回至矩阵C中。因而,完成计算一共需要从global memory中进行2mnk 次读操作和m*n次写操作。大量的访存操作使得GEMM效率难以提高,因而考虑global memory中进行分块,并将矩阵块放置到shared memory中。其示意图如下:

image-20220910214848982

对global memory进行分块的GEMM算法示意图见上图右侧。首先将A、B、C三个矩阵划分为多个维度为,,bm∗bk,bk∗bn,bm∗bn 的小矩阵块。三个矩阵形成M∗K,K∗N,M∗N 的小矩阵网格。其中M=m/bm,N=n/bn,K=k/bk 。随后在GPU中开启M∗N 个block,每个block负责C中一个维度为bm∗bn 的小矩阵块的计算。计算中一共有K次迭代,每一次迭代都需要读取A中一个维度为bm∗bk 的小矩阵块和B中一个维度为bk∗bn 的小矩阵块,并将其放置在shared memory中。因而,完成C中所有元素的计算一共需要从global memory中读取M∗N∗K∗(bm∗bk+bk∗bn) ,即m∗n∗k(1/bm+1/bn) 个单精度浮点数。相比于naive的GEMM算法,访存量减少为原来的1/2∗(1/bm+1/bn) 。通过global memory中分块算法极大地减少了对global memory的访存量。并且,相比于naive算法,对global进行分块可以更充分地利用数据局部性。在naive算法中,每一个线程都需要直接从global memory中取数,其时延非常长,计算性能非常差。而进行分块后,将维度为bm∗bk,bk∗bn 的小矩阵块先存储到shared memory之中。而后计算单元进行计算时可以直接从shared memory中取数,大大减少了访存所需要的时延。

从shared memory到register

随后,我们进一步考虑从shared memory到register的过程。在这里,只分析一个block中的计算。当进行K轮迭代中某一轮迭代时,GPU将维度为bm∗bk,bk∗bn 的小矩阵块存储到shared memory中,而后各个线程将shared memory中的数据存入register中进行计算。

image-20220910215005453

不对shared memory分块时,一个block中含有bm∗bn 个线程,每一个线程负责C中一个元素的计算。则一个block一共需要对shared memory进行2∗bm∗bn∗bk 次读操作。而后考虑对shared memory进行分块,对bm∗bn 的小矩阵进行再一次划分,将其划分为多个维度为rm∗rn 的子矩阵。则一个block需要负责X∗Y 个子矩阵。其中,X=bmrm ,Y=bnrn 。随后,在一个block中开启X∗Y 个线程,每个线程负责一个维度为rm∗rn 的子矩阵的计算。在计算中,一个block一共需要从shared memory读取X∗Y∗(rm+rn)∗bk ,即bm∗bn∗bk∗(1/rm+1/rn) 个单精度浮点数。相比于未分块的算法,对于shared memory中的访存量减少为原来的1/2∗(1/rm+1/rn) 。并且,由于将数据放入register中,可以直接对数据进行运算,减少了从shared memory中取数的时延。

register分块

在这里,我们考虑最后一层,即register中的计算,并且只分析一个thread。在完成以上的过程后,对于一个线程而言,它现在拥有:rm 个A矩阵的寄存器值,rn 个B矩阵的寄存器值,以及rm∗rn 个C矩阵的寄存器值。通过这些寄存器的值,需要计算rm∗rn 个数。这需要rm∗rn 条FFMA指令。

这个时候会涉及到寄存器的bank conflict。在NV的GPU中,每个SM不仅会产生shared memroy之间的bank 冲突,也会产生寄存器之间的bank冲突。这一点对于计算密集型的算子十分重要。像shared memory一样,寄存器的Register File也会被分为几个bank,如果一条指令的的源寄存器有2个以上来自同一bank,就会产生冲突。指令会重发射,浪费一个cycle。

我们假设对这个thread来说,rm=4,rn=4 。并且计算C的寄存器以一种非常naive的情况分配,如下图左侧所示。则需要产生16条FFMA指令,列举如下:

1
2
3
FFMA R0, R16, R20, R0
FFMA R1, R16, R21, R1
……

image-20220910220109982

可以从中看出,这会产生大量的register bank冲突,所以需要对参与计算的寄存器重新进行分配和排布,如上图右侧所示。在有些地方,这种方式也可以叫做register分块。

数据的prefetch

最后,我们来讲讲如何通过对数据进行prefetch来减少访存的latency。我们再来回顾GEMM的过程,并且仔细地看看这个访存的latency到底是怎么导致的。对于一个block而言,需要计算一个bm∗bn 的矩阵块,这个时候需要进行K次迭代,每次迭代都需要先将来自A和B的两个小块送到shared memory中再进行计算。而从global中访存实际上是非常慢的,所以导致了latency。虽然GPU中可以通过block的切换来掩盖这种latency,但是由于分配的shared memory比较多,活跃的block并不太多,这种延时很难被掩盖。对于一个thread,需要计算一个rm∗rn 的小矩阵,但是必须先将数据从shared memory传到寄存器上,才能开始进行计算。所以导致了每进行一次迭代,计算单元就需要停下来等待,计算单元不能被喂饱。

为此,需要进行数据的Prefetch来尽可能地掩盖这种latency。思想也比较简单,需要多开一个buffer,进行读写分离。示意图如下。当block进行第2轮迭代时,需要对A2和B2进行计算,在计算单元进行计算的同时,我们将A3和B3提前放置到shared memory。而后,在进行第3轮迭代时,就可以直接对shared memory中的A3和B3进行计算,而不需要等待从global memory搬运到shared memory的时间。寄存器上的Prefetch也是同理。

image-20220910220136539

GEMM算法概述

这个章节里主要来说一下GEMM的一个计算流程,其实这一点已经在GEMM优化(一)中提及。但上一篇文章主要说得是原理,关于具体计算逻辑,还是不太直观,所以我们在这里再提一下。然后这个具体的计算逻辑分为两个阶段介绍,分别是不采用数据预取和采用数据预取,这主要是考虑到直接说数据预取,有读者可能会看得云里雾里,比较难受,所以先把不采用数据预取这个内容说明白,然后再来讲这个数据预取。

不采用数据预取

首先,我们先明确一下GEMM中的具体参数。取bm=128,bn=128,bk=8,rm=8,rn=8。当这几个参数选定之后先来直观地感受一下这几个参数意义,假定给了三个矩阵,A,B,C,其维度都是2048×2048。要求解C=A×B。那么我们需要开启(2048/128)×(2048/128)=256个block,每个block里面有(128/8)×(128/8)=256个线程,每个线程需要负责计算C矩阵中8×8=64个元素的结果,每个block负责256×64=16384个元素的结果。

明确了上面的参数之后,我们来仔细地观察其中一个block的计算逻辑。对于这个block而言,它需要进行2048/8=256次迭代,我们先把这个迭代称为大迭代,每一次大迭代都需要把A里面128×8=1024个元素和B里面8×128=1024个元素先放到shared memory中。然后这个block中的256个线程把结果计算出来。计算完之后,再进入下一次大迭代。不断重复该过程,直至这个block负责的16384个元素的结果被求解出。大迭代示意图如下:

image-20220910220741253

随后再具体看看每一个大迭代中,block中的线程的计算逻辑。在进行一个大迭代时,shared memory中有128×8=1024个A矩阵元素和8×128=1024个B矩阵元素。随后,每个线程需要进行8次迭代,我们把这个迭代成为小迭代。bk=8,所以有8次小迭代。每一次小迭代中,每个线程需要从shared memory中拿到A矩阵的一小列和B矩阵的一小行,即8个A的元素和8个B的元素。线程将这8+8=16个元素放置在寄存器中。每个线程需要负责8×8=64个元素的计算,一共会产生64条FFMA指令。小迭代示意图如下:

image-20220910220755159

以上就是不采用数据预取的GEMM算法计算逻辑。总的来说,对于一个block而言,有256个大迭代,每个大迭代中又有8个小迭代。这是后续内容的基础,如果还是不太清楚的话,可以再仔细看看,把这个过程完全搞清楚后,我们再继续接下来的内容,即采用数据预取后的GEMM算法计算逻辑。

采用数据预取

采用数据预取的GEMM计算流程稍有差异。这个差异主要是体现在两个方面,第一个是开启的shared memory和寄存器数量,第二个是需要提前将一些数据放置到shared memory和寄存器中。下面来仔细说说这个流程。

为了实现数据预取,需要开启两倍的shared memory和寄存器。当然也可以将原来shared memory切分成两块,也就是将bm×bk和bk×bn的矩阵一分为二。以A中的小矩阵而言,变成了两个bm×bk/2。然后大迭代次数由原来的256变成了512。很多地方把这个技术叫做双缓冲,我感觉跟预取是同一个事情。无非是针对参数bk的大小换不同说法。所以在这里统一叫做数据预取。废话说得有点多。总之,我们还是开启两倍的shared memory和寄存器数据。在一个block中,原来在shared memory中需要存储的数据是bm×bk+bk×bn。现在变成了bm×bk×2+bk×bn×2。在一个thread中,为了存储A和B的数据,原来需要使用rm+rn个寄存器,现在需要使用2×(rm+rn)个寄存器。为了后续方便介绍,我们用read SMwrite SM代表用来读写的两块共享内存,并用read REGwrite REG来表示用来读写的两块寄存器。

把共享内存和寄存器的事情说明白之后,我们来看看具体的计算逻辑。在执行256次大迭代之前,我们需要提前将第0次大迭代的数据存到write SM中,并且将第0次小迭代的数据存到write REG中。在完成这一个预取过程之后,我们再来仔细地看看第0个大迭代。需要注意的是,上一轮大迭代的write SM就是这一轮迭代的read SM。上一轮小迭代的write REG就是这一轮迭代的read REG。所以在进行第0个大迭代时,上面write SM就变成了read SM。然后我们首先需要将下一轮大迭代的数据存到write SM中。由于从global memory中取数的时钟周期非常多。所以在等待数据取回的同时,对read SM中的数据进行计算。也就是我们在等待的同时,需要开启8次小迭代来进行计算。而小迭代中也存在着读写分离,在对read REG进行计算之前,需要先执行write REG的操作,通过这种方式来掩盖访存的latency。所以整体的计算逻辑如下:

1
2
3
4
5
6
for k in 256 big_loop:
prefetch next loop data to write_SM
// compute in read_SM
for iter in 8 small_loop:
prefecth next loop data to write_REG
compute in read_REG

image-20220910220841079

GEMM代码解析

在上一节中已经将GEMM算法的流程再次回顾了一遍,接下来进入到代码解析环节。这里主要是解析采用了数据预取的GEMM。由于将数据从global memroy中搬运到shared memory中还经过了寄存器,所以对prefetch过程进行了细化,这个跟前面的伪代码稍有差异。

参数说明

首先需要说明的是模板参数,这也是后续对GEMM性能进行调参的最主要参数,往往不同的参数选择对最终的GEMM性能影响极大。后面的实验会展示在不同的参数下的性能比较。前三个参数,BLOCK_SIZE_M、BLOCK_SIZE_K、BLOCK_SIZE_N分别代表上文中的bm、bk、bn。中间两个参数,THREAD_SIZE_Y、THREAD_SIZE_X代表上文中的rm、rn。最后的参数ENABLE_DOUBLE_BUFFER代表是否采用双缓冲,即是否采用数据预取,在这里,我们只讨论采用数据预取,即开启双缓冲的情况。

1
2
3
4
5
6
7
8
template <
const int BLOCK_SIZE_M, // height of block of C that each block calculate
const int BLOCK_SIZE_K, // width of block of A that each block load into shared memory
const int BLOCK_SIZE_N, // width of block of C that each block calculate
const int THREAD_SIZE_Y, // height of block of C that each thread calculate
const int THREAD_SIZE_X, // width of block of C that each thread calculate
const bool ENABLE_DOUBLE_BUFFER // whether enable double buffering or not
>

接下来是线程类的参数。整个计算流程需要开启256个block,这256个block按照二维形态排布。而一个block中开启了256个线程,这256个线程按照二维形态进行排布。bx代表横向的block坐标,by代表竖向的block坐标。而tx代表横向的线程坐标,ty代表竖向的线程坐标。这是CUDA的基础内容,看不明白的同学可以找一些博客多理解一下,务必搞清楚。THREAD_X_PER_BLOCK代表在一个block中有多少个横向的线程,在这里等于16。THREAD_Y_PER_BLOCK代表在一个block中有多少个竖向的线程,在这里等于16。THREAD_NUM_PER_BLOCK代表在一个block中有多少个线程,在这里等于256。tid则代表当前线程在这256个线程中的id号。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Block index
int bx = blockIdx.x;
int by = blockIdx.y;

// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;

// the threads number in Block of X,Y
const int THREAD_X_PER_BLOCK = BLOCK_SIZE_N / THREAD_SIZE_X;
const int THREAD_Y_PER_BLOCK = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int THREAD_NUM_PER_BLOCK = THREAD_X_PER_BLOCK * THREAD_Y_PER_BLOCK;

// thread id in cur Block
const int tid = ty * THREAD_X_PER_BLOCK + tx;

随后说明开启的shared memory和register数量。As代表为了存储A矩阵中的数据所需要开启的shared memory。在一轮迭代中需要使用bm×bk的数据,为了加快后续的访存,所以需要进行一次转置。并且为了预取,开了两倍的大小,一半用来读数据,一半用来写数据。所以一共需要2×BLOCK_SIZE_K×BLOCK_SIZE_M的空间。而Bs同理,但是载入数据时并不需要转置。accum用来临时存储C的计算结果。frag_a用来加载As中的rm个数据,为了预取也开启了双倍的空间。frag_b同理。ldg_num_a稍微有点费解,需要解释一下。为了将global memory的数据块搬运到shared memory中,需要先经过寄存器。也就是说,这个数据搬运过程其实是global memory->register->shared memory。所以为了临时存储A中的数据,需要开启一定量的寄存器。在一次大迭代中,我们总共需要搬运BLOCK_SIZE_M × BLOCK_SIZE_K个float数据,然后一个block中有THREAD_NUM_PER_BLOCK个线程,采用float4进行取数,即一个线程一次取4个数。则一共需要BLOCK_SIZE_M × BLOCK_SIZE_K/(THREAD_NUM_PER_BLOCK×4)次搬运就能把所有的数搬运到寄存器上。这个搬运次数用ldg_num_a表示。为了存储BLOCK_SIZE_M BLOCK_SIZE_K的数据块,每个线程需要额外开启*ldg_a_reg个寄存器进行存储。

1
2
3
4
5
6
7
8
9
10
11
12
13
// shared memory
__shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
// registers for C
float accum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
// registers for A and B
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];
// registers load global memory
const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (THREAD_NUM_PER_BLOCK * 4);
const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (THREAD_NUM_PER_BLOCK * 4);
float ldg_a_reg[4*ldg_num_a];
float ldg_b_reg[4*ldg_num_b];

最后需要说明的参数是在global->shared memory阶段用到。我们开启了256个线程,在一次大迭代中需要将128×8个元素搬运到shared memory中。我们用下面的参数说明了这个搬运的逻辑。A_TILE_THREAD_PER_ROW代表把搬运一行数据需要使用多少个线程,为了搬运A的一行,需要使用2个线程。

A_TILE_ROW_START代表在这个维度为bm×bk的数据块中,当前线程需要搬运的数据的竖向坐标,而A_TILE_COL代表需要搬运的数据的横向坐标。对3号线程而言,由于它要搬运(1,1)号数据块中的4个元素。所以,A_TILE_ROW_START是1,A_TILE_COL是4。A_TILE_ROW_STRIDE代表在进行多次搬运时需要跨越的行。假设As是一块256×8的数据块(这个设置跟前面不一样),256个线程进行搬运,一次搬运4个数,所以要搬运两次。对于3号线程而言,分别搬运下图中的绿色数据块。

image-20220910221016698

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// threads number in one row
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

// row number and col number that needs to be loaded by this thread
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;

const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

// row stride that thread uses to load multiple rows of a tile
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

大迭代前预取数据

在介绍完相关参数之后,我们来进入到具体的代码逻辑。为了代码简洁,用float4读取的过程用了两个宏,定义如下:

1
2
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

迭代前预取数据分为两个部分第一个部分是将第一个大迭代的数据从global 预取到shared memroy中。第二个部分是将shared memory上的数据预取到寄存器中。先来看看第一个部分。这里面分别是将第一个大迭代中需要的A、B数据预取到shared memroy中。对于A矩阵而言,这个for循环代表着block中的线程需要搬运多少次才能将globa中的数据放到shared memory中。由于A需要先进行一次转置,所以先将数据先放置在寄存器中。数据按行取,然后按列存。对于B矩阵而言,数据不用转置,直接按行取,按行存。当然,这个过程中间也要经过寄存器,但是没有写出来的必要了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// load A from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
BLOCK_SIZE_M * by + A_TILE_ROW_START + i, // row
A_TILE_COL, // col
K )]);
As[0][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[0][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[0][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[0][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
// load B from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
FETCH_FLOAT4(Bs[0][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(B[OFFSET(
B_TILE_ROW_START + i, // row
B_TILE_COL + BLOCK_SIZE_N * bx, // col
N )]);
}
__syncthreads();

然后就是第二个部分。将shared memory中的数据存到寄存器中。一共需要取THREAD_SIZE_Y个数,每次取4个数。这个倒没有什么好说的。

1
2
3
4
5
6
7
8
9
10
// load A from shared memory to register
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; thread_y += 4) {
FETCH_FLOAT4(frag_a[0][thread_y]) = FETCH_FLOAT4(As[0][0][THREAD_SIZE_Y * ty + thread_y]);
}
// load B from shared memory to register
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; thread_x += 4) {
FETCH_FLOAT4(frag_b[0][thread_x]) = FETCH_FLOAT4(Bs[0][0][THREAD_SIZE_X * tx + thread_x]);
}

大迭代逻辑

在完成上一步后,我们要进入到大迭代中,按照前面的参数,我们需要进行256个大迭代。先忽略这个迭代里面的具体代码,看看这个框架,如下所示。首先要说的是write_stage_idx这个参数。之前定义了shared float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M]。为了读写分离,给As开了两块空间。如果write_stage_idx=1,就对As[1]空间进行写操作,对As[0]空间进行读操作。因为我们之前将数据预取到了As[0]这个空间里,所以在第一个大迭代时,对As[0]进行读操作,对As[1]进行写操作,所以write_stage_idx=1。再来看看tile_idx这个参数,这个代表大迭代时,在A矩阵的列号。每一次大迭代要读取BLOCK_SIZE_K列,直到完成大迭代,即tile_idx=K为止。再看看循环里面的load_stage_idx,这个和write_stage_idx对应,两者保持二进制位相反即可。

1
2
3
4
5
6
7
8
9
10
int write_stage_idx = 1;
int tile_idx = 0;
do{
tile_idx += BLOCK_SIZE_K;
int load_stage_idx = write_stage_idx ^ 1;
// compute
if(tile_idx < K){
write_stage_idx ^= 1;
}
}while(tile_idx< K);

大迭代详细解析

我们在这里开始说明具体的大迭代。下面代码描述的是,如果还有下一个迭代,则将下一个迭代的数据块,搬运到寄存器上,这里面的for循环代表可能需要多次搬运。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
tile_idx += BLOCK_SIZE_K;
// load next tile from global mem
if(tile_idx< K){
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
BLOCK_SIZE_M * by + A_TILE_ROW_START + i, // row
A_TILE_COL + tile_idx, // col
K )]);
}
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[OFFSET(
tile_idx + B_TILE_ROW_START + i, // row
B_TILE_COL + BLOCK_SIZE_N * bx, // col
N )]);
}
}

随后进入到小迭代的计算逻辑之中,load_stage_idx参数代表需要从As的哪个空间进行读数。然后是BLOCK_SIZE_K-1次小迭代。按照前面的参数配置,即需要在这里完成7次小迭代。由于在小迭代中也采用了双缓冲的方式,需要将下一轮小迭代的数据提前写入到寄存器中,这个过程需要对shared memory访存,会稍微慢点。与此同时,线程需要计算更新THREAD_SIZE_X x THREAD_SIZE_Y=8×8=64个C矩阵元素的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for(int j=0; j<BLOCK_SIZE_K-1; ++j){
// load next tile from shared mem to register
// load A from shared memory to register
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; thread_y += 4) {
FETCH_FLOAT4(frag_a[(j+1)%2][thread_y]) = FETCH_FLOAT4(As[load_stage_idx][j+1][THREAD_SIZE_Y * ty + thread_y]);
}
// load B from shared memory to register
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; thread_x += 4) {
FETCH_FLOAT4(frag_b[(j+1)%2][thread_x]) = FETCH_FLOAT4(Bs[load_stage_idx][j+1][THREAD_SIZE_X * tx + thread_x]);
}
// compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[j%2][thread_y] * frag_b[j%2][thread_x];
}
}
}

而后需要将存储在临时寄存器的数据搬运到shared memory中。由于A矩阵需要经过一次转置,所以和B矩阵有一点不一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
if(tile_idx < K){
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
As[write_stage_idx][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[write_stage_idx][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[write_stage_idx][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[write_stage_idx][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
// load B from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(Bs[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}
// use double buffer, only need one sync
__syncthreads();
// switch
write_stage_idx ^= 1;
}

最后完成寄存器的预取,并将最后一个小迭代完成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// load A from shared memory to register
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; thread_y += 4) {
FETCH_FLOAT4(frag_a[0][thread_y]) = FETCH_FLOAT4(As[load_stage_idx^1][0][THREAD_SIZE_Y * ty + thread_y]);
}
// load B from shared memory to register
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; thread_x += 4) {
FETCH_FLOAT4(frag_b[0][thread_x]) = FETCH_FLOAT4(Bs[load_stage_idx^1][0][THREAD_SIZE_X * tx + thread_x]);
}
//compute last tile mma THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x];
}
}

计算结果写回

此时,最后的计算结果已经被存储在accum寄存器中,需要将其写回到global memory中。这个代码比较简单,就没啥好说的了。

1
2
3
4
5
6
7
8
9
10
11
// store back to C
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; thread_x+=4) {
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + ty * THREAD_SIZE_Y + thread_y,
BLOCK_SIZE_N * bx + tx * THREAD_SIZE_X + thread_x,
N)]) = FETCH_FLOAT4(accum[thread_y][thread_x]);
}
}

实验

针对GEMM性能优化,我做了一些实验,主要是想要说明这么两个问题:

  1. 不采用任何汇编的情况下,手写CUDA代码会比cublas差多少?
  2. bm、bn、bk、rm、rn等相关参数对GEMM的性能表现有多大影响?

针对第一个问题,固定了bm、bn、bk、rm、rn的取值为64、8、64、8、8。在V100上测试了不同维度的矩阵(设置M=N=K),并且对比了cublas,其性能结果如下图。横坐标是矩阵维度,纵坐标是GFLOPS。可以在图中看出,在大维度的矩阵下,我们手写的Sgemm大概能达到平均14TFLOPS,性能表现达到cublas的 91%。V100的单精度峰值性能是15.7TFLOPS,在完全不使用汇编,并且有着较好的代码可读性的同时,我们手写的Sgemm大概能达到90%的单精度峰值效率。当然,如果不考虑代码可读性的话,这个性能可以进一步提高。在这里可以得出结论,其实也是想消除大家的一个误解。很多人觉得只有写汇编才能写出高性能的代码。其实并不是这样,性能优化中最重要的是并行算法和优化策略,单纯地将代码写成汇编并不会有多少性能提升。

image-20220910221600420

从汇编代码分析程序性能

我们为什么要去看生成的汇编代码?这主要是由于做完优化之后,我们需要有一个东西来判断机器是否能够真正地按照我们设想的模式运行。使用了float4之后,GPU是不是真的使用了向量化指令。采用循环展开之后,GPU是不是真的会进行展开?另外,CUDA C和汇编代码之间还隔着编译器。只有看最底层的汇编码,才能真正地理解我们所做的优化是在哪个地方起了作用,节省了哪个部分的耗时。

NV的GPU提供了ptx和sass两个层面的汇编码。Ptx本质上是一个伪汇编码,事实上机器真正能够识别的是sass码。Ptx还需要使用ptxas工具再转化成sass码才能被GPU识别。然后nv提供了cuobjdump和nvdisasm两个工具,我们可以通过这两个工具来看到最底层的汇编码。

NV每一代机器的指令集都有所不同。此外,NV的指令还有一个特别有意思的东西,那就是control code,后面直接用控制码表示。通过控制码将一些本来应该在硬件实现的逻辑软件化了,从而在同样大小的电路面积上塞下更大的计算单元。

当我们在看汇编代码的时候,我们到底看的是什么东西。这个话题可以分为两部分介绍,分别是访存密集型的kernel和计算密集型的kernel。

对于访存密集型的kernel,正常而言,我们需要关注的是:访问global memory的时候是不是合并访存了,访问shared memory的时候是不是有bank 冲突了。很不幸的是,在汇编代码中,这些东西其实不太能看得出来。我们主要关注的是有没有采用LDG.128的访存指令,以及计算指令的占比是不是太多,#pragma unroll是不是有效展开了。

对于计算密集型的kernel而言,我们重点关注计算指令的占比。这个一般跟并行策略会联系在一起。一般而言,如果并行策略不太行,那么计算指令的占比会很低,这样的话,访存所导致的latency很难被计算指令掩盖,计算效率会非常差。如果并行策略比较好,那么计算指令的占比也会非常地高。也只有当计算指令占比非常高的时候,才有可能地去逼近峰值性能。

对于现有sgemm的代码分析及观察

在分析之前,我们对目前已有的工作先做一个回顾。sgemm是hpc领域的经典问题,目前有大量的论文在针对不同硬件架构,不同矩阵特性进行研究。对于NV的GPU,关于sgemm最著名的工作是scott的maxas。在Maxwell架构上的部分卡上能够达到98%的浮点性能,几乎到达极限。也就是从这个工作以后,针对NV的sgemm优化工作基本上就没法做了,关于针对大矩阵的sgemm优化,也没有太多的研究价值了。当然,针对不同硬件架构的sgemm优化还是层出不出,但基本上是一些follow的工作,然后做一些小修小补。

我们来分析一下scott的工作。在CUDA C层面,不涉及汇编的话,优化技巧主要有3个方面:

技巧1,global->shared memory,采用了texture内存,将线程划分,一半线程只读A,一半线程只读B。

技巧2,shared memory->register,将8×8的读取变成4个4×4的读取,从而避免bank冲突。

image-20220910221941317

技巧3,Store C矩阵的时候,为了合并访存,采用了一种非常奇怪的方式去store。

image-20220910221952729

针对大矩阵的sgemm计算时。如果k维度足够大,global->shared memory以及store C的耗时占比会非常小,所以这两个优化技巧在大矩阵中并不能起到很大的作用。所以相对来说,技巧2会更加具有借鉴意义

紧接着,我们来分析一下sgemm中最耗时的部分,也就是最内层的迭代部分。需要计算8×8×8=512次乘加运算。Scott的sgemm在maxwell产生的汇编代码如下图左,为了比较,我们将GEMM(二)中的代码sgemm_v2最后生成的SASS码放在一起用以比较。

image-20220910222351625

可以从上面看到,512条FFMA和32条LDS指令,最核心的计算指令和访存指令都是一样的。但是GEMM(二)中用编译器产生的汇编码有更多的非计算指令存在。而且如果从上面的链接点进去的话,就会发现,FFMA指令被划到2个代码块中,相对而言,中间会多一个跳转指令。另外一个需要注意的点是scott的代码是针对Maxwell架构,所以将可以用于双发射的指令进行了单独标记。而笔者写的代码是在volta架构上编译运行的,volta架构取消了双发射。但是两个cycle发射一条FFMA指令就可以将所有的fp32 core填满。计算指令和访存指令占据不同的发射端口,计算和访存可以隔一个cycle发射。所以我的猜想是这样的,对于volta架构,t0 cycle的时候发射一条FFMA指令,t1 cycle的时候发射一条LDS指令,而后t2时刻再发射一条FFMA指令。这样的话,FFMA指令隔了2个cycle,中间还发射了一条LDS指令,但fp32的core依旧是被用满的状态。这样的话,即使没有了双发射,理论上也能将fp32 core打满。从volta架构编译出来的控制码中也可以看出一些端倪,如下,FFMA指令stall两个cycle,而LDS指令stall一个cycle。

1
2
3
4
[R---:B------:R-:W-:-:S02]         /*0cd0*/                   FFMA R115, R39.reuse, R14, R115 ;
[----:B------:R-:W-:-:S02] /*0ce0*/ FFMA R114, R39, R15, R114 ;
[----:B------:R-:W1:-:S01] /*0cf0*/ LDS.U.128 R36, [R40+0x2410] ;
[R---:B------:R-:W-:-:S02] /*0d00*/ FFMA R113, R32.reuse, R12, R113 ;

然后总结一下这小节的内容,从CUDA C和SASS代码的角度分析了现有sgemm实现的不足。进一步的优化工作可以从两个方面进行:1、shared memory->register,将8×8的读取变成4个4×4的读取。2、尽可能地减少非必要指令的开销,但是这个在CUDA C层面很难控制,毕竟编译器也没那么听话。

汇编级别代码调整

好了,终于讲到了调汇编的地方。上面小节说了,优化的一个方式是尽可能地减少非必要指令的开销。但是,当我们开始调汇编的时候,还有一个更重要的事情需要做,也是在maxas、KeplerAs等一系列工作的核心,减少FFMA指令所产生的register bank冲突。这里面有两个优化技巧,一个是寄存器的重映射,另外一个是调整FFMA顺序,尽可能地在指令中使用.reuse标识以及提高双发射的效率。

寄存器的重映射

在这里面,由于每代架构中的硬件细节有所不同,所以register的remapping细节也有所不同。首先说一下这里面的硬件细节不同是指,不同的架构中,寄存器到bank的映射方式不同。kepler架构的映射比较奇怪,并不是很规则,如下:

image-20220910222431933

对于Maxwell架构而言,相对来说更加简单一些,bank index即reg_index%4这么一个简单的关系。Pascal架构和Maxwell架构的寄存器bank映射关系一样。而volta架构又有一些不同,在volta之前都是4路的bank,而volta架构变成了2路的bank。

由于架构不一样,针对不同架构的register重映射方式也不一样。对于kepler架构,keplerAs的作者采用的映射方式如下:

image-20220910222443696

对于Maxwell架构,Scott采用的映射方式如下:

image-20220910222453184

上图中间那些带黑框的数字代表不可避免的寄存器冲突,scott随后又使用了指令重排来减缓寄存器的冲突。

而volta架构的话,Dissecting the NVIDIA Volta GPU Architecture via Microbenchmarking作者采用的方式如下,作了一个转置,然后相邻两行进行一个交换。

image-20220910222504575

指令重排

这里的指令重排主要是针对FFMA指令的重排。作用的话,其实有两个。在maxwell架构中,scott重排主要是为了尽可能地解决对角线那些元素的寄存器bank冲突。在这里插一嘴,因为部分读者对于这个重排可能理解不是很到位。举个例子吧,要计算C矩阵中1,2,3,4,5的元素的值,正常的顺序是调用FFMA指令先算1,再算2,再算3,等等。重排的话,就是可能先算2,再算1,再算3。从指令角度的话,就是FFMA指令的排列顺序有所不同,所以叫指令重排,这个是我的个人理解。

重排的目的是为了更好地使用reuse标识,这个地方可以看看旷视写的矩阵乘终极优化指南,当然,基本上也就是scott的sgemm介绍内容。读取指令的操作数的时候,有一个寄存器的reuse cache。在指令中使用这个标识就代表这个数被hold住了,下一条指令可以直接使用。这个地方,大家都是这么说的,NV也没有官方的说明,那就这么理解吧。具体示意代码如下:

1
2
FFMA R2, R64.reuse, R73, R2; ## R64 进入 Reuse Cache
FFMA R3, R64.reuse, R72, R3; ## R64 从 Reuse Cache 中获取,避免与 R72 冲突

为了更好地利用这个reuse特性,scott给了一种非常奇怪的指令排列顺序,如下:

1
2
3
4
 1,  0,  2,  3,  5,  4,  6,  7, 33, 32, 34, 35, 37, 36, 38, 39, 
45, 44, 46, 47, 41, 40, 42, 43, 13, 12, 14, 15, 9, 8, 10, 11,
17, 16, 18, 19, 21, 20, 22, 23, 49, 48, 50, 51, 53, 52, 54, 55,
61, 60, 62, 63, 57, 56, 58, 59, 29, 28, 30, 31, 25, 24, 26, 27

通过CUDA C说的一系列优化手段,以及寄存器的remapping和指令重排,scott的sgemm在Maxwell架构的一些卡上能够达到98%的浮点计算效率,达到了优化的天花板。

扯远了,再说说指令重排,keplerAs的作者张秀霞针对kepler的双发射特性对FFMA指令进行了指令重排来提高性能。这个跟scott的工作又有一些不一样的地方,大家可以对比一下。

实验与总结

最后,我们来做一下实验。实验分成两个部分,第一个部分是CUDA C层面的再次优化,第二个部分是针对SASS代码的调优工作以及中间经历的一些波折。

CUDA C 调优

这个部分的内容主要是介绍一下怎么解决GEMM(二)所存在的shared memory bank冲突。其实scott的文章已经说了这一点,但是吧,实在是太费解了。首先,再来回顾一下这个思路。我们一个block有256个线程,8个warp,8个warp要去取shared memory中的半行元素,也就是128/2=64个元素。warp0和warp4取得是同样的16个元素。而warp里面,线程0、2、4、6、8、10、12、14是取得同样的4个元素。由于取得是同样的元素,同一个bank触发多播的机制,没有冲突。取多少元素说清楚了,就得说一下shared memory的索引了。scott给出的256线程版本索引是:

1
2
readAs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4;
readBs = (((tid & 0x70) >> 3) | (tid & 1)) << 4 + 4096;

image-20220910222717437

总之,这个索引给我整不会了。作为一个正常的人类,我实在是不太能直观地去理解这个位运算。思量许久,我决定用一种最简单粗暴的索引计算方式。我们本质上是要知道,每一个线程,对应到128个元素中的哪一个元素?这个是我们的核心问题。

我来说一下我的计算方法,以B矩阵对应的shared memory为例,首先,计算warp_id,也就是当前线程属于哪个warp,由tid/32即可得。随后计算lane_id,即当前线程属于这个warp上得哪个线程,由tid%32即可得。随后就是通过warp_id和lane_id来算出,对应128个元素得哪一个元素。先算(warp_id%4)×16,假设是warp2,就是上图左侧的第2个(从0算)warp。前面有2个warp,跳过了2*16=32个元素。然后再看看当前lane_id。0-15在左半边,16-31在右半边。所以lane_id/16,先看是左半边还是右半边。右半边的话,先跳过8个元素。最后再看lane_id的奇偶数,如果奇数的话,就再跳一个四个元素。代码实现如下,这个就是正常人可以看懂的方式了。对A矩阵的映射关系同理。

1
2
3
4
5
//load index of the tile
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int tile_index_b = (warp_id%4)*16 + (lane_id/16)*8 + (lane_id%2)*4;
const int tile_index_a = (warp_id/4)*32 + ((lane_id%16)/2)*4;

然后shared memory取数的代码更改就是下面这样,以B矩阵块为例:

1
2
3
4
5
6
7
8
// 改变前
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; thread_y += 4) {
FETCH_FLOAT4(frag_b[(j+1)%2][thread_y]) = FETCH_FLOAT4(Bs[next_stage_flag][(j+1)%BLOCK_SIZE_K][THREAD_SIZE_Y * ty + thread_y]);
}
// 改变后
FETCH_FLOAT4(frag_b[(j+1)%2][0]) = FETCH_FLOAT4(Bs[next_stage_flag][(j+1)%BLOCK_SIZE_K][tile_index]);
FETCH_FLOAT4(frag_b[(j+1)%2][4]) = FETCH_FLOAT4(Bs[next_stage_flag][(j+1)%BLOCK_SIZE_K][tile_index + 64]);

当然,因为用来寄存C的64个元素对应的位置变化,所以最后的store C的过程也有代码变动。

在进行了这个修改之后,4096(M=N=K)的矩阵大概可以达到96-97%的cublas的性能。单精度峰值浮点效率达93%左右。再往下想要持平或者超越cublas的话,就只能动汇编了。

汇编代码调优

在做寄存器remapping的时候,发现NVCC编译出来的代码是这个样子:

1
2
3
4
FFMA R125, R52, R44, R72 ;
FFMA R122, R53, R44.reuse, R73 ;
FFMA R74, R54, R44.reuse, R74 ;
FFMA R75, R55, R44.reuse, R75 ;

看看第一条指令,做R125=R52×R44+R72,R72的值被拿出来,然后存到了R125上。编译出来的代码有一大堆这样的指令。而我希望所有的指令都满足第3条的样子,R74=R54×R44+R74,从R74取就放回R74才最好。如果不能保证这个形式的话,就意味着,我们不能让固定的寄存器来存储矩阵C中的固定的值。这玩意做remapping的话,就不能简简单单地改寄存器号。毕竟我也不能确定不同的寄存器对应到哪个具体的值了。

当时想了各种方式,调整CUDA C代码来让nvcc编译出我想要的FFMA格式,但是,这个尝试并不能实现。所以接下来,有两个方式,一个是头铁,搞清楚这个100多个寄存器在512条FFMA指令中对应的物理元素,然后做remapping,这个路线中间会遇到可以预想的无数的bug和计算问题。另一个是参考Maxas,把这玩意整合到汇编器上,定义好每个寄存器的对应元素和排列顺序。然后汇编器顺带着处理,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
<REGISTER_MAPPING>

// Temporary registers to calculate the state registers. Reuse the C output registers.
// These can be dynamically allocated (~) in the available registger space to elimiate any register bank conflicts.
0-63 ~ blk, ldx, ldx2, ldx4, k, tid1, tid4, tid7, tid31_4, xmad_t0, xmad_end, bxOrig, byOrig, loy

// Aliases for the C registers we use for initializing C (used as vectors)
0-63 : cz<00-63>

// The offset we store our zero value for initializing C. Reuse a register from the second blocking registers
80 : zOffset

// 64 C maxtrix output registers.
// Use special mapping to avoid register bank conflicts between these registers and the blocking registers.
3, 2,11,10,19,18,27,26 : cx00y<00-03|64-67>
7, 6,15,14,23,22,31,30 : cx01y<00-03|64-67>
1, 0, 9, 8,17,16,25,24 : cx02y<00-03|64-67>
5, 4,13,12,21,20,29,28 : cx03y<00-03|64-67>
35,34,43,42,51,50,59,58 : cx64y<00-03|64-67>
39,38,47,46,55,54,63,62 : cx65y<00-03|64-67>
33,32,41,40,49,48,57,56 : cx66y<00-03|64-67>
37,36,45,44,53,52,61,60 : cx67y<00-03|64-67>

// Double buffered register blocking used in vector loads.
// Any bank conflicts that we can't avoid in these registers we can hide with .reuse flags
64-79 : j0Ax<00-03|64-67>, j0By<00-03|64-67>
80-95 : j1Ax<00-03|64-67>, j1By<00-03|64-67>

// Registers to load A or B
96-103 : loadX<0-7>

// Key global state registers for main loop and some we reuse for outputing C.
// Note, tweaking the register banks of track<0|4>, tex, writeS, readBs, readAs impacts performance because of
// delayed bank conflicts between memory operations and ffmas.
// The array index bracket notation can be used to request a bank in a dynamically allocated range.
104-127 ~ track<0|4>[0], tex[2], readAs[2], readBs[3], writeS[3], end, ldx8, tid, bx, by, tid31, tid96, tid128 //, clock, smId, nSMs

// Registers to store the results back to global memory. Reuse any register not needed after the main loop.
// Statically allocate cs0-7 because they're vector registers.
64-71 : cs<0-7>

// dynamically allocated C output registers(~)
72-103 ~ cy<00|04|08|12>, Cy<00|04|08|12>, ldc, ldc1, ldc4, ldc8, ldc60, writeCs, readCs, cx, ci, alpha, xmad_ci //, xmad_D, D, blckDimX, gridDimX

</REGISTER_MAPPING>

然而,我只是想简简单单写个sgemm,我并不想把我有限的周末时间全部投进去,毕竟读者也没给我钱。然后想想指令重排,通过reuse标识也能解决一部分reg的bank冲突,那就整这个吧。

遇到的另一个问题就是指令重排。我把里面所有存在寄存器bank冲突的指令列了出来。再来看看volta架构中的bank冲突,volta架构的寄存器有2路bank,奇数寄存器号代表bank0,偶数寄存器号代表bank1。如果FFMA指令的三个源寄存器的寄存器号都属于奇数或者偶数,那么就发生了bank冲突。

1
2
3
//0    FFMA R74, R36, R62.reuse, R74 ;    
//1 FFMA R78, R34, R62.reuse, R78 ;
//2 FFMA R16, R35, R62, R54 ;

比如上面的代码,0号指令和1号三个源寄存器都是偶数,不考虑reuse标识的话,都有bank冲突,而2号指令就没有bank冲突。调整这3个的位置,变成:

1
2
3
//2    FFMA R16, R35, R62.reuse, R54 ;   
//1 FFMA R78, R34, R62.reuse, R78 ;
/ 0 FFMA R74, R36, R62.reuse, R74 ;

让指令2的R62放入reuse cache中,指令1和指令0继续使用这个数,从而减少bank冲突。更改前后的代码在我的github repo中。但是改完之后,我发现性能提升并不是很明显,大概就是1%左右的性能提升。这可能是在sgemm_v2的基础上改的原因,当时4.1所说的shared memory bank冲突还比较明显。总之,实验大概就是这样子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
// optimize sgemm

#include <stdio.h>
#include <stdlib.h>
#include "assert.h"

// CUDA runtime
#include <cuda_runtime.h>
#include <cublas_v2.h>

// cal offset from row col and ld , in row-major matrix, ld is the width of the matrix
#define OFFSET(row, col, ld) ((row) * (ld) + (col))

// transfer float4
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

#define checkCudaErrors(func) \
{ \
cudaError_t e = (func); \
if(e != cudaSuccess) \
printf ("%s %d CUDA: %s\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
}

// K: ldA
// N: ldB
template <
const int BLOCK_SIZE_M, // height of block of C that each thread block calculate
const int BLOCK_SIZE_K, // width of block of A that each thread block load into shared memory
const int BLOCK_SIZE_N, // width of block of C that each thread block calculate
const int THREAD_SIZE_Y, // height of block of C that each thread calculate
const int THREAD_SIZE_X, // width of block of C that each thread calculate
const bool ENABLE_DOUBLE_BUFFER // whether enable double buffering or not
>
__global__ void Sgemm(
float * __restrict__ A,
float * __restrict__ B,
float * __restrict__ C,
const int M,
const int N,
const int K) {
// Block index
int bx = blockIdx.x;
int by = blockIdx.y;

// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;

// the threads number in Block of X,Y
const int THREAD_X_PER_BLOCK = BLOCK_SIZE_N / THREAD_SIZE_X;
const int THREAD_Y_PER_BLOCK = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int THREAD_NUM_PER_BLOCK = THREAD_X_PER_BLOCK * THREAD_Y_PER_BLOCK;

// thread id in cur Block
const int tid = ty * THREAD_X_PER_BLOCK + tx;

// shared memory
__shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
// registers for C
float accum[THREAD_SIZE_Y][THREAD_SIZE_X];
#pragma unroll
for(int i=0; i<THREAD_SIZE_Y; i++){
#pragma unroll
for(int j=0; j<THREAD_SIZE_X; j++){
accum[i][j]=0.0;
}
}
// registers for A and B
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];
// registers load global memory
const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (THREAD_NUM_PER_BLOCK * 4);
const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (THREAD_NUM_PER_BLOCK * 4);
float ldg_a_reg[4*ldg_num_a];
float ldg_b_reg[4*ldg_num_b];

// threads number in one row
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

// row number and col number that needs to be loaded by this thread
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;

const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

// row stride that thread uses to load multiple rows of a tile
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

A = &A[(BLOCK_SIZE_M * by)* K];
B = &B[BLOCK_SIZE_N * bx];

//load index of the tile
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int a_tile_index = warp_id/2*16 + lane_id/8*4; //warp_id * 8 + (lane_id / 16)*4; // (warp_id/4)*32 + ((lane_id%16)/2)*4;
const int b_tile_index = warp_id%2*32 + lane_id%8*4; //(lane_id % 16) * 4; // (warp_id%4)*16 + (lane_id/16)*8 + (lane_id%2)*4;

//transfer first tile from global mem to shared mem
// load A from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL, // col
K )]);
As[0][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[0][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[0][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[0][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
// load B from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
FETCH_FLOAT4(Bs[0][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(B[OFFSET(
B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N )]);
}
__syncthreads();

// load A from shared memory to register
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[0][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[0][0][a_tile_index + 64]);

// load B from shared memory to register
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[0][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[0][0][b_tile_index + 64]);

int write_stage_idx = 1;
int tile_idx = 0;
do{
// next tile index
tile_idx += BLOCK_SIZE_K;
// load next tile from global mem
if(tile_idx< K){
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL + tile_idx, // col
K )]);
}
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[OFFSET(
tile_idx + B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N )]);
}
}

int load_stage_idx = write_stage_idx ^ 1;

#pragma unroll
for(int j=0; j<BLOCK_SIZE_K - 1; ++j){
// load next tile from shared mem to register
// load A from shared memory to register
FETCH_FLOAT4(frag_a[(j+1)%2][0]) = FETCH_FLOAT4(As[load_stage_idx][(j+1)][a_tile_index]);
FETCH_FLOAT4(frag_a[(j+1)%2][4]) = FETCH_FLOAT4(As[load_stage_idx][(j+1)][a_tile_index + 64]);
// load B from shared memory to register
FETCH_FLOAT4(frag_b[(j+1)%2][0]) = FETCH_FLOAT4(Bs[load_stage_idx][(j+1)][b_tile_index]);
FETCH_FLOAT4(frag_b[(j+1)%2][4]) = FETCH_FLOAT4(Bs[load_stage_idx][(j+1)][b_tile_index + 64]);
// compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[j%2][thread_y] * frag_b[j%2][thread_x];
}
}
}

if(tile_idx < K){
// load A from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
As[write_stage_idx][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[write_stage_idx][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[write_stage_idx][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[write_stage_idx][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
// load B from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(Bs[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}
// use double buffer, only need one sync
__syncthreads();
// switch
write_stage_idx ^= 1;
}

// load first tile from shared mem to register of next iter
// load A from shared memory to register
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[load_stage_idx^1][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[load_stage_idx^1][0][a_tile_index + 64]);
// load B from shared memory to register
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[load_stage_idx^1][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[load_stage_idx^1][0][b_tile_index + 64]);
// compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x];
}
}
}while(tile_idx< K);

const int c_block_row = a_tile_index;
const int c_block_col = b_tile_index;

//store C00 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i][0]);
}
//store C01 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i][4]);
}
//store C10 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i+4][0]);
}
//store C11 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i+4][4]);
}
}

int main(int argc, char** argv) {
if (argc != 4) {
printf("usage: ./main [M] [K] [N]\n");
exit(0);
}
size_t M = atoi(argv[1]);
size_t K = atoi(argv[2]);
size_t N = atoi(argv[3]);

assert( M%8 == 0);
assert( N%8 == 0);
assert( K%8 == 0);

size_t bytes_A = sizeof(float) * M * K;
size_t bytes_B = sizeof(float) * K * N;
size_t bytes_C = sizeof(float) * M * N;
float* h_A = (float*)malloc(bytes_A);
float* h_B = (float*)malloc(bytes_B);
float* h_C = (float*)malloc(bytes_C);
float* h_C1 = (float*)malloc(bytes_C);

float* d_A;
float* d_B;
float* d_C;

checkCudaErrors(cudaMalloc(&d_A, bytes_A));
checkCudaErrors(cudaMalloc(&d_B, bytes_B));
checkCudaErrors(cudaMalloc(&d_C, bytes_C));
double msecPerMatrixMul[2] = {0, 0};
double gigaFlops[2] = {0, 0};
double flopsPerMatrixMul = 2.0 * M * N * K;

// don't edit it
const int BLOCK_SIZE_M = 128;
const int BLOCK_SIZE_K = 8;
const int BLOCK_SIZE_N = 128;
const int THREAD_SIZE_X = 8;
const int THREAD_SIZE_Y = 8;
const bool ENABLE_DOUBLE_BUFFER = false;

// 生成A的数据
for( int i = 0; i < M * K; i++ ) {
h_A[i] = i / 13;
}

// 生成B的数据
for( int i = 0; i < K * N; i++ ) {
h_B[i] = i % 13;
}

checkCudaErrors(cudaMemcpy( d_A, h_A, bytes_A, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy( d_B, h_B, bytes_B, cudaMemcpyHostToDevice));

cudaEvent_t start, stop;
checkCudaErrors(cudaEventCreate(&start));
checkCudaErrors(cudaEventCreate(&stop));
float msecTotal = 0;
int nIter = 1000;

checkCudaErrors(cudaMemcpy( d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
checkCudaErrors(cudaEventRecord(start));
for (int run = 0 ; run < nIter; run ++ ) {
dim3 dimBlock(BLOCK_SIZE_N / THREAD_SIZE_X, BLOCK_SIZE_M / THREAD_SIZE_Y);
dim3 dimGrid(N / BLOCK_SIZE_N, M / BLOCK_SIZE_M);
Sgemm<BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, THREAD_SIZE_Y, THREAD_SIZE_X, ENABLE_DOUBLE_BUFFER>
<<< dimGrid, dimBlock >>>(d_A, d_B, d_C, M, N, K);
}
checkCudaErrors(cudaEventRecord(stop));
checkCudaErrors(cudaEventSynchronize(stop));
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
checkCudaErrors(cudaMemcpy( h_C, d_C, bytes_C, cudaMemcpyDeviceToHost));

msecPerMatrixMul[0] = msecTotal / nIter;
gigaFlops[0] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[0] / 1000.0f);
printf( "My gemm Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
gigaFlops[0],
msecPerMatrixMul[0],
flopsPerMatrixMul);

// cublas

cublasHandle_t blas_handle;
cublasCreate(&blas_handle);
float alpha = 1.0;
float beta = 0;
checkCudaErrors(cudaMemcpy( d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
checkCudaErrors(cudaEventRecord(start));
for (int run = 0 ; run < nIter; run ++ ) {
cublasSgemm (blas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
M, N, K, &alpha,
d_A, K, d_B, N, &beta, d_C, N
);
}
checkCudaErrors(cudaEventRecord(stop));
checkCudaErrors(cudaEventSynchronize(stop));
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

checkCudaErrors(cudaMemcpy( h_C1, d_C, bytes_C, cudaMemcpyDeviceToHost));

msecPerMatrixMul[1] = msecTotal / nIter;
gigaFlops[1] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[1] / 1000.0f);
printf( "CuBlas Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
gigaFlops[1],
msecPerMatrixMul[1],
flopsPerMatrixMul);

cublasDestroy(blas_handle);


double eps = 1.e-6; // machine zero
bool correct = true;
for (int i = 0; i < M * N; i++) {
int row = i / N;
int col = i % N;
double abs_err = fabs(h_C[i] - h_C1[col * M + row]);
double dot_length = M;
double abs_val = fabs(h_C[i]);
double rel_err = abs_err / abs_val / dot_length;
if (rel_err > eps) {
printf("Error! Matrix[%d][%d]=%.8f, ref=%.8f error term is > %E\n",
row, col, h_C[i], h_C1[col * M + row], eps);
correct = false;
break;
}
}

printf("%s\n", correct ? "Result= PASS" : "Result= FAIL");
printf("ratio= %f\n", gigaFlops[0] / gigaFlops[1]);

// Free Memory
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);

free(h_A);
free(h_B);
free(h_C);
free(h_C1);
}

SGEMM

在深度学习推理框架或者训练框架中,GEMM 和 Conv 是典型的计算密集型算子,例如在 Bert 和 Conformer 模型的 self-attention 模块中存在大量矩阵运算,因此深度学习框架中 GEMM 算子的底层实现好坏将会直接影响模型的推理或训练延时。

img

图1 conformer 模型中的矩阵运算

介绍如何进行 GEMM 优化的文章很多,即使在知乎上随手搜索 GEMM优化 词条也会有几十个条目,其中也不乏一些内容翔实、条理清楚的好文章。不过,从我个人比较主观的分析来看,大部分文章停留在方法论层面的介绍,没有落实到具体的代码实现上,理论和实践之间还是有不可跨越的鸿沟,作为一个愣头青程序员,没能看到代码总是感觉少了点意思。

另一方面,在 GitHub: How To Optimize GEMM 项目中,作者通过清晰明了的代码和文档向读者介绍内存对齐、向量化、矩阵分块和数据打包等关键技术,此外,作者还给出了每一个步骤的优化点、优化效果对比和分析,实属不可多得的GEMM优化入门读物,强烈推荐!但 GitHub: How To Optimize GEMM 作为一个入门级的项目,旨在粗粒度介绍矩阵乘算法的优化思路,并没有针对某个硬件进行针对性优化,也没有深入优化 micro kernel 的代码实现,因此该项目中的矩阵乘实现仍然存在较大的优化空间。

那么,能不能在介绍矩阵乘优化原理的基础时搭配相应的代码实现,并且最终取得可观的性能表现呢?

Talk is cheap. Show me the code. ― Linus Torvalds

当然,这篇文章就是想做这个事情,本文目标有三点

  1. 介绍如何在x64 CPU 上优化矩阵乘算法的思路;
  2. 实现一份可运行的高性能矩阵乘算法;
  3. 性能数据可复现;

img

图2 矩阵乘运算

矩阵乘运算是大学本科的基础知识,原理十分简单,此处不在赘述其数学公式和讲解。

基础知识

选取一个合适的度量指标是性能优化工作的基础,通常我们使用 GFLOPS 来衡量一个算子的性能。

区分 FLOPS 和 FLOPs

每秒浮点运算次数(floating point operations per second, FLOPS),即每秒所执行的浮点运算次数,是一个衡量硬件性能的指标。下表列举了常见的 FLOPS 换算指标。

缩写 解释
MFLOPS 每秒进行百万次 (10^6) 次浮点运算的次数
GFLOPS 每秒进行十亿次 (10^9) 次浮点运算的次数
TFLOPS 每秒进行万亿次 (10^12)次浮点运算的次数
PFLOPS 每秒进行千万亿次(10^15)次浮点运算的次数
EFLOPS 每秒进行百亿亿次(10^18)次浮点运算的次数

浮点运算量(floating point operations, FLOPs)是指浮点运算的次数,是一个衡量深度学习模型计算量的指标。

此外,从FLOPs延伸出另外一个指标是乘加运算量MACs。

乘加运算量(multiplication and accumulation operations, MACs)是指乘加运算的次数,也是衡量深度模型计算量的指标。在Intel AVX指令中,扩展了对于乘加计算(fused multiply-add, FMA)指令的支持,即在支持AVX指令的CPU上,可以通过FMA计算单元使用一条指令来执行类似 A×B+CA \times B + CA \times B + C 的操作,参考 Intel® C++ Compiler Classic Developer Guide and Reference 中对于 _mm256_fmadd_ps 指令的介绍。一次乘加运算包含了两次浮点运算,一般地可以认为 MACs = 2FLOPs。

计算 CPU 的 FLOPS

从上一小节中得知,FLOPS 是一个衡量硬件性能的指标,那么我们该如何计算 CPU 的FLOPS 呢?

img

图1 使用 lscpu 命令查看系统信息

上图中,红框中几条关键信息

  1. CPU(s), 逻辑核数量;
  2. CPU family, CPU系列标识,用以确定CPU属于哪一代产品。更多关于 Intel CPU Family 信息,可以参考 Intel CPUID
  3. Model, 型号标识可用来确定处理器的制作技术以及属于该系列的第几代设计(或核心),型号与系列通常是相互配合使用的,用于确定计算机所安装的处理器是属于某系列处理器的哪种特定类型。
  4. Model name, CPU型号名称
  5. CPU MHZ: 主频

下面以 “Xeon Platinum 8260Y” 细致地解释下 CPU 型号名称中隐藏的信息。

img

图2 Xeon Platinum 8260Y CPU

  • Xeon Platinum 8260Y: Intel 公司推出的至强处理器系列,具备丰富的指令集支持和出色的性能表现,主要针对服务器市场。除至强处理器之外,Intel 公司推出的酷睿处理器在桌面市场具备更高的知名度;
  • Xeon Platinum 8260Y: Intel 至强系列处理器分为四个级别,性能由高到低依次是铂金级Platinum(8,9)、黄金级Gold(6,7)、白银级Silver(4)和青铜级Bronze(3);
  • Xeon Platinum 8260Y: 处理器架构代号,1 代表Skylake ,2 代表 Cascade Lake
  • Xeon Platinum 8260Y: SKU和Extra Options信息可以参考 Cascade Lake 架构介绍

计算CPU FLOPS时需要两点关键信息,下面分别计算下 AVX2 和 AVX512 指令集的GFLOPS。

  1. CPU 主频
  2. FMA 单元数

AVX2

1
2
3
4
单周期双精度浮点计算能力 = 2(FMA数量)* 2(乘加) ∗ 256 (YMM寄存器宽度) / 64(双精度浮点数位数) = 16
单周期双精度浮点计算能力= 2(FMA数量)* 2(乘加) ∗ 256 (YMM寄存器宽度) / 32(双精度浮点数位数) = 32
双精度FLOAPS = 2.5(CPU主频) * 16(单周期双精度浮点计算能力) = 40GFLOPS
单精度FLOAPS = 2.5(CPU主频) * 32(单周期单精度浮点计算能力) = 80GFLOPS

AVX512

1
2
3
4
单周期双精度浮点计算能力 = 2(FMA数量)* 2(乘加) ∗ 512 (YMM寄存器宽度) / 64(双精度浮点数位数) = 32
单周期双精度浮点计算能力= 2(FMA数量)* 2(乘加) ∗ 512 (YMM寄存器宽度) / 32(双精度浮点数位数) = 64
双精度FLOAPS = 2.5(CPU主频) * 16(单周期双精度浮点计算能力) = 80 GFLOPS
单精度FLOAPS = 2.5(CPU主频) * 32(单周期单精度浮点计算能力) = 160 GFLOPS
指令集 精度 理论峰值算力
AVX2 double 40 GFLOPS
AVX2 float 80 GFLOPS
AVX512 double 80 GFLOPS
AVX512 float 160 GFLOPS

至此,我们已经明白了单核心CPU的理论峰值算力,下面开始进入实战环节!

基础矩阵乘实现和优化

本节内容作为正式优化的序章,会介绍两点内容

  1. 如何实现基础的 GEMM 算法并测量其性能数据;
  2. 如何通过一行代码达到十倍的性能提升;

此处约定本文中A,B分别为左、右输入矩阵,C为输出矩阵,并且三者的形状信息如下

A:M×KA: M \times K A: M \times K 的输入矩阵

B:K×NB: K \times NB: K \times N 的输入矩阵

C:M×NC: M \times NC: M \times N 的输出矩阵

基础 GEMM 实现和度量

下面的代码应该都不陌生,矩阵乘算法是编程初学者经典的练习题之一。

1
2
3
4
5
6
7
8
9
10
void naive_row_major_sgemm(const float* A, const float* B, float* C, const int M,
const int N, const int K) {
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
C[m * N + n] += A[m * K + k] * B[k * N + n];
}
}
}
}

从矩阵乘的原理可知,矩阵乘算法的浮点运算量为 2×M×N×K2 \times M \times N \times K2 \times M \times N \times K,所以

GEMM:GFLOPs=2×M×N×Klatency×10−9GEMM : GFLOPs = \frac{2 \times M \times N \times K}{latency} \times 10^{-9} GEMM : GFLOPs = \frac{2 \times M \times N \times K}{latency} \times 10^{-9}

下面实现一个朴素的GFLOPs 计算函数,相应的代码均会在 GitHub 仓库中提供。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
void Benchmark(const std::vector<int64_t>& dims,
std::function<void(void)> func) {
const int warmup_times = 10;
const int infer_times = 20;

// warmup
for (int i = 0; i < warmup_times; ++i) func();

// run
auto dtime = dclock();
for (int i = 0; i < infer_times; ++i) func();

// latency
dtime = dclock() - dtime;

// compute GLOPs
auto flops = 2.0f * product(dims) * 1.0e-09;
flops = flops * infer_times / dtime;

// print
std::cout << std::setw(20) << " GFLOPs: " << flops << std::endl;
}

实测,naive_row_major_sgemm 的性能数据如下

Shape(M, N, K) GFLOPs
(64, 64, 64) 1.97
(128, 128, 128) 1.65
(256, 256, 256) 1.44
(512, 512, 512) 0.95
(1024, 1024, 1024) 0.62

测试数据来看,随着矩阵尺寸的增大,GFLOPs 在不断下降。从上文的分析中可知,单核CPU的理论峰值算力是80 GFLOPS,naive_row_major_sgemm 和理论峰值算力之间的差距非常大,完全没有发挥出CPU的算力。

naive_row_major_sgemm 性能极差的核心原因是在计算时发生了大量的cache miss

img

图3 基础 GEMM 实现示例

一行代码优化十倍性能

在分析清楚 naive_row_major_sgemm 性能极差的主要原因后,我们通过循环重排来优化访存。注意,naive_row_major_sgemm 和 optimize_row_major_sgemm 虽然只有一行代码的差距,但是性能却相差近十倍!

1
2
3
4
5
6
7
8
9
10
void optimize_row_major_sgemm(const float* A, const float* B, float* C, const int M,
const int N, const int K) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; ++k) {
for (int n = 0; n < N; ++n) {
C[m * N + n] += A[m * K + k] * B[k * N + n];
}
}
}
}
Shape(M, N, K) naive GFLOPs optimize GFLOps
(64, 64, 64) 1.97 11.20
(128, 128, 128) 1.65 11.84
(256, 256, 256) 1.44 12.04
(512, 512, 512) 0.95 11.43
(1024, 1024, 1024) 0.62 10.79

根据上表中的数据,可以直接体会到性能优化的魔力。一行代码,十倍加速。

img

图4 优化访存后的 GEMM 实现示例

BLAS 接口简介

截止到目前为止,已经具有 naive_row_major_sgemmoptimize_row_major_sgemm 两份实现,虽然optimize_row_major_sgemm 在性能上有一定的优化,但距离真正的高性能计算库的要求还相差甚远。

即使抛开性能问题不谈,目前 optimize_row_major_sgemm 也很难视为一个合格的库函数,因为该函数在接口定义上太过随意,别人很难直接复用。众所周知,矩阵乘优化已经是非常成熟的课题了,其中自然衍生了许多标准,以方便不同开发者或者研究人员之间工作的交流和复用,其中最基础的便是 BLAS接口规范。

BLAS(basic linear algebra subroutine)是一系列基本线性代数运算函数1接口(interface)标准。 这里的线性代数运算是指例如矢量的线性组合,矩阵乘以矢量,矩阵乘以矩阵等。接口在这里指的是诸如哪个函数名实现什么功能,有几个输入和输出变量,分别是什么。

注意 BLAS 是一个接口的标准而不是某种具体实现(implementation)。简单来说,就是不同的作者可以各自写出不同版本的 BLAS 库,实现同样的接口和功能,但每个函数内部的算法可以不同。 这些不同导致了不同版本的 BLAS 在不同机器上运行的速度也不同。

C:=alpha×A×B+beta×CC := alpha \times A \times B + beta \times CC := alpha \times A \times B + beta \times C

  • A, 形状为(M, K)的列主序矩阵
  • B, 形状为(M, K)的列主序矩阵
  • C, 形状为(M, K)的列主序矩阵
1
2
void sgemm(char transa, char transb, int M, int N, int K, float alpha, 
const float* A, int lda, const float* B, int ldb, float beta, float* C, int ldc);
  • transa, 设置矩阵A是否转置的标识位,’N’ 表示不转置, ‘T’ 表示转置;
  • transb, 设置矩阵A是否转置的标识位,’N’ 表示不转置, ‘T’ 表示转置;
  • M, M 维度的值;
  • N, N 维度的值;
  • K, K 维度的值;
  • alpha, 系数;
  • A, A 矩阵指针;
  • lda, A矩阵 leading dimension的值;
  • B, B 矩阵指针;
  • ldb, B矩阵 leading dimension的值;
  • beta, 系数;
  • C, 结果矩阵C矩阵指针;
  • ldc, C矩阵 leading dimension的值;

注: leading dimension,对于一个 MxN 的行优先矩阵,leading dimension 为 N;对于一个 MxN 的列优先矩阵,leading dimension 为 M。

介绍完 BLAS 接口之后,我们以 BLAS 接口的格式编写一份 列优先的矩阵乘实现 作为后续优化工作的比较基准。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void naive_col_major_sgemm(
char transa,
char transb,
int M, int N, int K,
const float alpha,
const float * src_a, int lda,
const float * src_b, int ldb,
const float beta,
float * dst, int ldc)
{
int a_stride_m = transa == 'n' ? 1 : lda;
int a_stride_k = transa == 'n' ? lda : 1;
int b_stride_k = transb == 'n' ? 1 : ldb;
int b_stride_n = transb == 'n' ? ldb : 1;

for(int m=0;m<M;m++) {
for(int n=0;n<N;n++) {
float acc = 0.f;
const float * a_ptr = src_a + m * a_stride_m;
const float * b_ptr = src_b + n * b_stride_n;

for(int k=0;k<K;k++) {
acc += a_ptr[0] * b_ptr[0];
a_ptr += a_stride_k;
b_ptr += b_stride_k;
}

dst[m + n * ldc ] = alpha * acc + beta * dst[m + n * ldc];
}
}
}

深度优化矩阵乘实现

从本节起,开始演示如何优化矩阵乘算法,以达到 80% 以上的硬件性能利用率。

一般而言,矩阵乘优化有以下技巧,在GEMM、GEMV的实现中都可以去套用。

  1. 循环重排;
  2. 数据分块;
  3. 数组打包;
  4. 向量指令集;
  5. 寄存器优化;
  6. 多线程;

基础函数乘实现和优化一节中得知,矩阵乘实现性能差的原因在与数据 cache miss 率很高,因此我们进行的一个优化就是数据打包。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
void avx2_col_major_sgemm(char transa, char transb,int M, int N, int K, float alpha, float* A, int lda,
float* B, int ldb, float beta, float* C, int ldc) {
if (alpha == 0) return;

float beta_div_alpha = beta / alpha;

constexpr int Mr = 64;
constexpr int Kr = 256;

constexpr int mr = 16;
constexpr int nr = 6;

// Cache a is 64 x 256
float* pack_a = (float*)_mm_malloc(Mr * Kr * sizeof(float), 32);
// Cache b is 256 x N
float* pack_b = (float*)_mm_malloc(Kr * DivUp(N, nr) * sizeof(float), 32);

float* tmp_pack_a = pack_a;
float* tmp_pack_b = pack_b;

for (int k = 0; k < K; k += Kr) {
float cur_beta = 1.0 / alpha;
if (k == 0) cur_beta = beta_div_alpha;

int cur_k = std::min(K - k, Kr);

// jump to k-th row of matrix B
pack_no_trans(B + k, ldb, tmp_pack_b, Kr, cur_k, N);

for (int i = 0; i < M; i += Mr) {
int cur_m = std::min(M - i, Mr);

pack_trans(A + i + k * lda, lda, tmp_pack_a, Kr, cur_k, cur_m);

for (int j = 0; j < N;) {
int cur_n = std::min(int(N - j), nr);
float* cur_c = C + i + j * ldc;

float* packed_cur_b = tmp_pack_b + DivDown(j, nr) * Kr + j % nr;

sgemm_block_n(cur_m, cur_n, cur_k, alpha, tmp_pack_a, lda, packed_cur_b,
ldb, cur_beta, cur_c, ldc);
j += cur_n;
}
}
}

_mm_free(pack_a);
_mm_free(pack_b);
}

在后文的讲解中,为方便起见,统一设置 M = N = K = 512 为例,来演示矩阵乘优化。

数据打包

从系统信息上看,L1 数据缓存和指令缓存均为 32 K,32K 的 L1d cache 可以容纳 32 * 1024 / 4 = 8192 个单精度浮点数。因此,当 M, N, K 足够大的时候,L1d cache 无法持有三个矩阵所有的数据,便会发生cache miss,这也解释了上文中为什么矩阵越大、性能越差。

1
2
3
4
L1d cache:           32K
L1i cache: 32K
L2 cache: 4096K
L3 cache: 36608K

avx2_col_major_sgemm 的实现代码中,为矩阵A 开辟了 64 x 256 x 4 bytes / 1024 = 64 K 的存储区域,为矩阵B 开辟了 256 x Divp(N=512,6 ) = 256 x 516 x 4 bytes / 1024 = 516 K 的存储区域,目的是防止矩阵A和矩阵B过大,以至于在L2 cache 中发生cache miss 的情况,所以一次只在L2中加载矩阵A和矩阵B的子矩阵,保证不会发生cache miss。

1
2
3
4
5
6
7
8
9
constexpr int Mr = 64;
constexpr int Kr = 256;

...

// Cache a is 64 x 256
float* pack_a = (float*)_mm_malloc(Mr * Kr * sizeof(float), 32);
// Cache b is 256 x N
float* pack_b = (float*)_mm_malloc(Kr * DivUp(N, nr) * sizeof(float), 32);

矩阵乘实现从计算方法上来区分,可以分为 Inner Product 和 Outer Product 两种计算方法,解释如下

  1. Inner Product: 按行切分A矩阵,按列切分B矩阵,使用A矩阵的一个按行切分的子块同B矩阵按列切分的子块做矩阵乘法,即求得结果矩阵C矩阵的一个子矩阵。依次循环,求得最终结果。

img

图5 矩阵分块运算(inner product)

\2. Outer Product: 按列切分A矩阵,按行切分B矩阵,使用A矩阵的一个按列切分的子块同B矩阵按行切分的子块做矩阵乘法,求得一个形状同C矩阵相同的中间结果矩阵。依次循环,对所有的中间结果矩阵求和,可得最终结果。下图中,将A、B矩阵切分为4个子矩阵,然后进行4次矩阵乘,再对 C1、C2、C3 和 C4 进行求和,可以算出最终结果。

img

图6 矩阵分块运算示例(outer product)

avx2_col_major_sgemm 的实现代码中,按照如下方式对矩阵A(512 x 512)、B(512 x 512)进行切分计算,具体步骤如下:

  1. 将矩阵A(512 x 512)切分为 2 x 8 = 16 个形状为 64 x 256 的子矩阵;
  2. 将矩阵B(512 x 512)切分为 2 个形状为 256 x 512 的子矩阵;
  3. 对矩阵B的1号子矩阵进行数据打包, 然后对矩阵A的1号子矩阵进行数据打包,对TMPA1(64x256)和 TMPB1(256 x 512)进行一次矩阵乘运算,求得图中的 c11 (64 x 512) ; 在对矩阵A的2号子矩阵进行数据打包,求得c12;依次循环,直到求得 c18;
  4. 对矩阵B的2号子矩阵进行数据打包, 然后对矩阵A的9号子矩阵进行数据打包,对TMPA1(64x256)和 TMPB1(256 x 512)进行一次矩阵乘运算,求得图中的 c21 (64 x 512) ; 在对矩阵A的10号子矩阵进行数据打包,求得c22;依次循环,直到求得 c28;
  5. 对中间结果矩阵进行求和,可得最终的结果矩阵 C。

img

图7 M=N=K=512 矩阵的切分计算示例

通过上面的介绍,相信读者已经对如何进行矩阵分块有了清晰的认识,其实矩阵分块的思想很简单,就是将原始输入矩阵切分为小矩阵,使得L2 cache可以容纳计算所需的小矩阵。

现在已经粗粒度的讲解了如何对矩阵A和矩阵B进行分块计算,那么矩阵A的子矩阵(64 x 256)和 矩阵B的子矩阵(256 x 512)是如何计算的呢?

  1. 在数据打包时,将子矩阵 a(64 x 256)按行进行切分,分为 4 个形状为 16 x 256 的小矩阵;
  2. 在数据打包时,将子矩阵 b(256 x 512)按列进行切分,分为 86 个形状为 256 x 6 的小矩阵;当子矩阵 b 的列数不是 6 的整数倍时,需在数据打包时,进行 padding。
  3. 使用子矩阵 a 的1号子矩阵(16 x 256)依次和子矩阵 b 的86个子矩阵进行矩阵乘计算,计算结果为 (16 x 256)X (256 x 6)= (16 x 6);最终可得(16 x 6)x 86 个子矩阵;
  4. 依此遍历子矩阵 a 的1、2、3、4号子矩阵进行步骤3中的运算;

img

图8 左矩阵(64x256)和右矩阵(256x512)的计算

上文的描述中,详细介绍如何对矩阵A的子矩阵(64 x 256)和矩阵B的子矩阵(256 x 512)进行计算,后面会结合代码对如何使用SIMD指令进行数据打包的细节演示。

矩阵A的数据打包

img

图9 矩阵A的数据打包

代码实现,暂时没进行深入讲解,比较好理解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//  pack block_size on leading dimension, t denotes transpose.
// eg. input: A MxN matrix in row major, so the storage-format is (M, N)
// output: B MxN matrix in col major(N-packed), so the storage-format is
// (divUp(N, 16), M, 16)
void pack_trans(float* a, int lda, float* b, int ldb, int m, int n) {
constexpr int block_size = 16;
int i = 0;

for (; i + 64 <= n; i += 64) {
float* cur_a = a + i;
float* cur_b = b + i * ldb;
pack_trans_4x16(cur_a, lda, cur_b, ldb, m, block_size);
}
}

void pack_trans_4x16(float* a, const int lda, float* b, const int ldb, int m, int n) {
const int m4 = m / 4;
const int m1 = m % 1;
const int block_size = 64;
const int ldbx16 = ldb * 16; //(256 * 64)

float* tmpa = a;
// 保存指针 A 的4列元素
float* ar0 = tmpa + 0 * lda;
float* ar1 = tmpa + 1 * lda;
float* ar2 = tmpa + 2 * lda;
float* ar3 = tmpa + 3 * lda;

float* tmpb = b;
float* br0 = tmpb + 0 * ldbx16;
float* br1 = tmpb + 1 * ldbx16;
float* br2 = tmpb + 2 * ldbx16;
float* br3 = tmpb + 3 * ldbx16;

// 循环 256 / 4 = 64 次,每次 pack 4 x 16 = 64 个数据
for (int i = 0; i < m4; ++i) {
{
__m256 v00 = _mm256_loadu_ps(ar0);
__m256 v01 = _mm256_loadu_ps(ar0 + 8);
__m256 v10 = _mm256_loadu_ps(ar1);
__m256 v11 = _mm256_loadu_ps(ar1 + 8);
__m256 v20 = _mm256_loadu_ps(ar2);
__m256 v21 = _mm256_loadu_ps(ar2 + 8);
__m256 v30 = _mm256_loadu_ps(ar3);
__m256 v31 = _mm256_loadu_ps(ar3 + 8);

_mm256_storeu_ps(br0 + 0, v00);
_mm256_storeu_ps(br0 + 8, v01);
_mm256_storeu_ps(br0 + 16, v10);
_mm256_storeu_ps(br0 + 24, v11);
_mm256_storeu_ps(br0 + 32, v20);
_mm256_storeu_ps(br0 + 40, v21);
_mm256_storeu_ps(br0 + 48, v30);
_mm256_storeu_ps(br0 + 56, v31);
}
{
__m256 v00 = _mm256_loadu_ps(ar0 + 16);
__m256 v01 = _mm256_loadu_ps(ar0 + 24);
__m256 v10 = _mm256_loadu_ps(ar1 + 16);
__m256 v11 = _mm256_loadu_ps(ar1 + 24);
__m256 v20 = _mm256_loadu_ps(ar2 + 16);
__m256 v21 = _mm256_loadu_ps(ar2 + 24);
__m256 v30 = _mm256_loadu_ps(ar3 + 16);
__m256 v31 = _mm256_loadu_ps(ar3 + 24);

_mm256_storeu_ps(br1 + 0, v00);
_mm256_storeu_ps(br1 + 8, v01);
_mm256_storeu_ps(br1 + 16, v10);
_mm256_storeu_ps(br1 + 24, v11);
_mm256_storeu_ps(br1 + 32, v20);
_mm256_storeu_ps(br1 + 40, v21);
_mm256_storeu_ps(br1 + 48, v30);
_mm256_storeu_ps(br1 + 56, v31);
}

{
__m256 v00 = _mm256_loadu_ps(ar0 + 32);
__m256 v01 = _mm256_loadu_ps(ar0 + 40);
__m256 v10 = _mm256_loadu_ps(ar1 + 32);
__m256 v11 = _mm256_loadu_ps(ar1 + 40);
__m256 v20 = _mm256_loadu_ps(ar2 + 32);
__m256 v21 = _mm256_loadu_ps(ar2 + 40);
__m256 v30 = _mm256_loadu_ps(ar3 + 32);
__m256 v31 = _mm256_loadu_ps(ar3 + 40);

_mm256_storeu_ps(br2 + 0, v00);
_mm256_storeu_ps(br2 + 8, v01);
_mm256_storeu_ps(br2 + 16, v10);
_mm256_storeu_ps(br2 + 24, v11);
_mm256_storeu_ps(br2 + 32, v20);
_mm256_storeu_ps(br2 + 40, v21);
_mm256_storeu_ps(br2 + 48, v30);
_mm256_storeu_ps(br2 + 56, v31);
}

{
__m256 v00 = _mm256_loadu_ps(ar0 + 48);
__m256 v01 = _mm256_loadu_ps(ar0 + 56);
__m256 v10 = _mm256_loadu_ps(ar1 + 48);
__m256 v11 = _mm256_loadu_ps(ar1 + 56);
__m256 v20 = _mm256_loadu_ps(ar2 + 48);
__m256 v21 = _mm256_loadu_ps(ar2 + 56);
__m256 v30 = _mm256_loadu_ps(ar3 + 48);
__m256 v31 = _mm256_loadu_ps(ar3 + 56);

_mm256_storeu_ps(br3 + 0, v00);
_mm256_storeu_ps(br3 + 8, v01);
_mm256_storeu_ps(br3 + 16, v10);
_mm256_storeu_ps(br3 + 24, v11);
_mm256_storeu_ps(br3 + 32, v20);
_mm256_storeu_ps(br3 + 40, v21);
_mm256_storeu_ps(br3 + 48, v30);
_mm256_storeu_ps(br3 + 56, v31);
}

ar0 += 4 * lda;
ar1 += 4 * lda;
ar2 += 4 * lda;
ar3 += 4 * lda;

br0 += block_size;
br1 += block_size;
br2 += block_size;
br3 += block_size;
}
}

矩阵B的数据打包

img

图10 矩阵B的数据打包

代码实现,暂时没进行深入讲解,比较好理解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
id pack_no_trans_n6(float* a, const int lda, float* b, const int ldb,
const int m, const int n) {
const int m8 = m / 8;
const int m1 = m % 8;
const int block_size = n;

float* tmpa = a;
float* tmpb = b;
float* a0 = tmpa + 0 * lda;
float* a1 = tmpa + 1 * lda;
float* a2 = tmpa + 2 * lda;
float* a3 = tmpa + 3 * lda;
float* a4 = tmpa + 4 * lda;
float* a5 = tmpa + 5 * lda;

for (int i = 0; i < m8; i++) {
__m256 v0 = _mm256_loadu_ps(a0);
__m256 v1 = _mm256_loadu_ps(a1);
__m256 v2 = _mm256_loadu_ps(a2);
__m256 v3 = _mm256_loadu_ps(a3);
__m256 v4 = _mm256_loadu_ps(a4);
__m256 v5 = _mm256_loadu_ps(a5);

__m256 unpack0 = _mm256_unpacklo_ps(v0, v1);
__m256 unpack1 = _mm256_unpackhi_ps(v0, v1);
__m256 unpack2 = _mm256_unpacklo_ps(v2, v3);
__m256 unpack3 = _mm256_unpackhi_ps(v2, v3);
__m256 unpack4 = _mm256_unpacklo_ps(v4, v5);
__m256 unpack5 = _mm256_unpackhi_ps(v4, v5);

__m256 shf0 = _mm256_shuffle_ps(unpack0, unpack2, 0x44);
__m256 shf1 = _mm256_shuffle_ps(unpack4, unpack0, 0xe4);
__m256 shf2 = _mm256_shuffle_ps(unpack2, unpack4, 0xee);
__m256 shf3 = _mm256_shuffle_ps(unpack5, unpack1, 0xe4);
__m256 shf4 = _mm256_shuffle_ps(unpack3, unpack5, 0xee);
__m256 shf5 = _mm256_shuffle_ps(unpack1, unpack3, 0x44);

__m128 low_shf1 = _mm256_castps256_ps128(shf1);
__m256 res0 = _mm256_insertf128_ps(shf0, low_shf1, 0x1);
__m256 res1 = _mm256_permute2f128_ps(shf0, shf1, 0x31);

__m128 low_shf5 = _mm256_castps256_ps128(shf5);
__m256 res2 = _mm256_insertf128_ps(shf2, low_shf5, 0x1);
__m256 res3 = _mm256_permute2f128_ps(shf2, shf5, 0x31);

__m128 low_shf4 = _mm256_castps256_ps128(shf4);
__m256 res4 = _mm256_insertf128_ps(shf3, low_shf4, 0x1);
__m256 res5 = _mm256_permute2f128_ps(shf3, shf4, 0x31);

constexpr int vsize_in_bytes = 8;
_mm256_storeu_ps(tmpb + 0 * vsize_in_bytes, res0);
_mm256_storeu_ps(tmpb + 1 * vsize_in_bytes, res2);
_mm256_storeu_ps(tmpb + 2 * vsize_in_bytes, res4);
_mm256_storeu_ps(tmpb + 3 * vsize_in_bytes, res1);
_mm256_storeu_ps(tmpb + 4 * vsize_in_bytes, res3);
_mm256_storeu_ps(tmpb + 5 * vsize_in_bytes, res5);

tmpb += 6 * vsize_in_bytes;

// jump to another 8 float point values
a0 += vsize_in_bytes;
a1 += vsize_in_bytes;
a2 += vsize_in_bytes;
a3 += vsize_in_bytes;
a4 += vsize_in_bytes;
a5 += vsize_in_bytes;
}
}

寄存器优化(Micro Kernel)

在数据打包的讲解中,有以下描述

使用子矩阵 a 的1号子矩阵(16 x 256)依次和子矩阵 b 的86个子矩阵进行矩阵乘计算,计算结果为 (16 x 256)X (256 x 6)= (16 x 6);最终可得(16 x 6)x 86 个子矩阵;

avx2_col_major_sgemm 的实现中,使用 A(16, 8) * B(8, 6) = C(16, 6) 的Micro Kernel,其计算思路如下,图片和下面的描述均来自一篇很好的文章 《OneDNN GEMM(AVX FP32)算法浅析》

img

图11 micro kernel 寄存器优化

Micro Kernel 的计算步骤如下描述

  1. 在 micro kernel 中,首先使用12个YMM寄存器用以保存结果矩阵 C(shape 为 16x6);
  2. 通过_mm256_loadu_ps指令将A矩阵的第一列移动到两个YMM寄存器中(这里假设为YMM0以及YMM1);
  3. 对于B矩阵第一行的第一个元素,使用_mm256_broadcast_ss指令进行广播并存储到一个YMM寄存器内(这里假设为YMM2),然后使用fma指令_mm256_fmadd_ps将YMM0和YMM1内的元素与YMM2内元素对应相乘,并将结果累加到C矩阵的两个YMM寄存器内,这里假设为YMM4以及YMM5;
  4. 沿着B矩阵第一行进行循环,重复步骤2,B矩阵广播当前行内其它数据时重复使用YMM2寄存器,并将计算结果依次累加到YMM6~YMM15寄存器内;
  5. A矩阵前进一列,B矩阵前进一行,并重复步骤1~3,最终完成整个C(16, 6)矩阵的计算。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
void col_major_micro_kernel_m16n6(const int K, const float alpha,
const float* src_a, const int lda,
const float* src_b, int ldb, const float beta,
float* dst_c, int ldc) {
constexpr int m_block_size = 16;
constexpr int n_block_size = 6;

// Load result matrix c (shape 16x6) into 12 x __m256 vector values
__m256 c00 = _mm256_loadu_ps(dst_c + 0 * ldc);
__m256 c01 = _mm256_loadu_ps(dst_c + 0 * ldc + 8);

__m256 c10 = _mm256_loadu_ps(dst_c + 1 * ldc);
__m256 c11 = _mm256_loadu_ps(dst_c + 1 * ldc + 8);

__m256 c20 = _mm256_loadu_ps(dst_c + 2 * ldc);
__m256 c21 = _mm256_loadu_ps(dst_c + 2 * ldc + 8);

__m256 c30 = _mm256_loadu_ps(dst_c + 3 * ldc);
__m256 c31 = _mm256_loadu_ps(dst_c + 3 * ldc + 8);

__m256 c40 = _mm256_loadu_ps(dst_c + 4 * ldc);
__m256 c41 = _mm256_loadu_ps(dst_c + 4 * ldc + 8);

__m256 c50 = _mm256_loadu_ps(dst_c + 5 * ldc);
__m256 c51 = _mm256_loadu_ps(dst_c + 5 * ldc + 8);

// c = c * beta
__m256 vbeta = _mm256_set1_ps(beta);

c00 = _mm256_mul_ps(c00, vbeta);
c01 = _mm256_mul_ps(c01, vbeta);

c10 = _mm256_mul_ps(c10, vbeta);
c11 = _mm256_mul_ps(c11, vbeta);

c20 = _mm256_mul_ps(c20, vbeta);
c21 = _mm256_mul_ps(c21, vbeta);

c30 = _mm256_mul_ps(c30, vbeta);
c31 = _mm256_mul_ps(c31, vbeta);

c40 = _mm256_mul_ps(c40, vbeta);
c41 = _mm256_mul_ps(c41, vbeta);

c50 = _mm256_mul_ps(c50, vbeta);
c51 = _mm256_mul_ps(c51, vbeta);


for (int k = 0; k < K; ++k) {
__m256 a0 = _mm256_loadu_ps(src_a);
__m256 a1 = _mm256_loadu_ps(src_a + 8);

__m256 vb = _mm256_broadcast_ss(src_b);
c00 = _mm256_fmadd_ps(a0, vb, c00);
c01 = _mm256_fmadd_ps(a1, vb, c01);

vb = _mm256_broadcast_ss(src_b + 1);
c10 = _mm256_fmadd_ps(a0, vb, c10);
c11 = _mm256_fmadd_ps(a1, vb, c11);

vb = _mm256_broadcast_ss(src_b + 2);
c20 = _mm256_fmadd_ps(a0, vb, c20);
c21 = _mm256_fmadd_ps(a1, vb, c21);

vb = _mm256_broadcast_ss(src_b + 3);
c30 = _mm256_fmadd_ps(a0, vb, c30);
c31 = _mm256_fmadd_ps(a1, vb, c31);

vb = _mm256_broadcast_ss(src_b + 4);
c40 = _mm256_fmadd_ps(a0, vb, c40);
c41 = _mm256_fmadd_ps(a1, vb, c41);

vb = _mm256_broadcast_ss(src_b + 5);
c50 = _mm256_fmadd_ps(a0, vb, c50);
c51 = _mm256_fmadd_ps(a1, vb, c51);

src_a += m_block_size;
src_b += n_block_size;
}

__m256 valpha = _mm256_set1_ps(alpha);
c00 = _mm256_mul_ps(c00, valpha);
c01 = _mm256_mul_ps(c01, valpha);

c10 = _mm256_mul_ps(c10, valpha);
c11 = _mm256_mul_ps(c11, valpha);

c20 = _mm256_mul_ps(c20, valpha);
c21 = _mm256_mul_ps(c21, valpha);

c30 = _mm256_mul_ps(c30, valpha);
c31 = _mm256_mul_ps(c31, valpha);

c40 = _mm256_mul_ps(c40, valpha);
c41 = _mm256_mul_ps(c41, valpha);

c50 = _mm256_mul_ps(c50, valpha);
c51 = _mm256_mul_ps(c51, valpha);

_mm256_storeu_ps(dst_c + 0 * ldc, c00);
_mm256_storeu_ps(dst_c + 0 * ldc + 8, c01);

_mm256_storeu_ps(dst_c + 1 * ldc, c10);
_mm256_storeu_ps(dst_c + 1 * ldc + 8, c11);

_mm256_storeu_ps(dst_c + 2 * ldc, c20);
_mm256_storeu_ps(dst_c + 2 * ldc + 8, c21);

_mm256_storeu_ps(dst_c + 3 * ldc, c30);
_mm256_storeu_ps(dst_c + 3 * ldc + 8, c31);

_mm256_storeu_ps(dst_c + 4 * ldc, c40);
_mm256_storeu_ps(dst_c + 4 * ldc + 8, c41);

_mm256_storeu_ps(dst_c + 5 * ldc, c50);
_mm256_storeu_ps(dst_c + 5 * ldc + 8, c51);
}

性能数据

在经历过漫长的讲解之后,那么 avx2_col_major_sgemm 的性能究竟如何呢?且看下表中的数据,表中使用了两个比较基准作为参照,分别是

  1. Naive, 最基础的矩阵乘算法实现,代码文中已经提供;
  2. oneDNN sgemm,oneDNN是英特尔公司大名鼎鼎的多平台支持、高性能计算库,其前身是 mkldnn。oneDNN 在各类硬件上都进行了深度优化,特别是在Intel CPU 上,其性能数据非常具备参考价值。
Shape(M, N, K) Naive GFLOPs oneDNN sgemm GFLOPs avx2_col_major_sgemm
(64, 64, 64) 1.96 32.97 35.42
(128, 128, 128) 1.65 62.69 40.36
(256, 256, 256) 1.44 73.19 65.84
(512, 512, 512) 0.95 70.06 67.65
(1024, 1024, 1024) 0.61 79.73 69.12

从数据上来看,avx2_col_major_sgemm 相较于 Naive 实现已经具备了质的飞跃,并且在多数shapes下可以取得接近 oneDNN 的性能,不过在 (128, 128, 128) 下和oneDNN 存在比较大的性能差异,这也说明 avx2_col_major_sgemm 仍然存在一定的优化空间。这很令人兴奋,不是么?