0x1. openai triton介绍阅读
这里来看官方的介绍:https://openai.com/research/triton ,从官方的介绍中我们可以看到openai triton的产生动机以及它的目标是什么,还可以看到一些经典算法的实现例子展示。
这里的标题是 introducing triton: open-source gpu programming for neural networks ,翻译就是《介绍 triton:用于神经网络的开源 gpu 编程语言》。然后下面的一句话翻译过来是:我们发布了 triton 1.0,这是一种开源的类 python 编程语言,它使得没有 cuda 经验的研究人员能够编写高效的 gpu 代码——大多数情况下,其效能与专家所能编写的代码相当。这里指出了triton的目的,就是让编写cuda kernrl变得更简单。接下来就逐步看一下介绍里的具体内容,为了更加准确这里会截图对应的原文然后放上我的翻译或者理解。
这里的意思是triton可以使得用户用较少的努力就写出一个达到硬件峰值性能的kernel,比如使用 triton 可以编写 fp16 矩阵乘法的核函数,其性能能够匹配 cublas,并且这个代码不超过25行。然后研究者已经用triton开发了一些高效的实现,和功能相同的torch实现相比,性能可以达到两倍提升。后面一段就是强调了使用cuda来把一些原始的pytorch实现写一个算子一般会更加高效,但是这个难度不小,并且目前已有工作也不能很好覆盖这种情况,所以openai triton诞生。
这里讲的是gpu编程的挑战,现代 gpu 的架构大致可以分为三个主要部分——dram、sram 和 alu。在优化 cuda 代码时,必须考虑到这些组件:
从 dram 的内存传输必须合并成大型事务,以利用现代内存接口的大总线宽度(内存合并访问)。
数据必须在重复使用前手动存储到 sram 中,并进行管理来最小化bank conflict。
计算必须仔细地进行划分和调度,不仅是在流式多处理器(sms)之间,还包括在其内部,以促进指令/线程级并行性,并利用专用的 alu(例如,tensor cores)。
考虑所有这些因素可能对于拥有多年经验的资深 cuda 程序员来说都是一个挑战。triton 的目的是完全自动化这些优化,以便开发者能够更好地专注于他们并行代码的高层逻辑。triton 旨在广泛适用,因此不会自动在流式多处理器(sms)之间调度工作——留下一些重要的算法考虑(例如,tiling,跨 sm 同步)由开发者自行决定。
然后给了一个表格展示cuda的编译器和triton的区别。
在所有可用的领域特定语言和即时编译器中,triton可能和numba最相似:kernel被定义为一个装饰过的函数,并以不同的 program_id 并行启动在所谓的网格实例上。然而,正如下面的代码片段所示,相似之处仅此而已:triton 通过对块上的操作来暴露实例内部的并行性——这些小数组的尺寸是二的幂次方——而不是单指令多线程(simt)执行模型。这样做,triton 有效地抽象出了所有与 cuda 线程块内部并发相关的问题(例如,内存合并、共享内存同步/冲突、tensor cores调度)。
注意,triton 的即时编译器将 x 和 y 视为指针而不是张量;我们认为保留对内存访问的低级控制对于处理更复杂的数据结构(例如,块稀疏张量)是重要的。重要的是,这种特定的 softmax 实现在整个标准化过程中将 x 的行保留在 sram 中,这在适用时最大化了数据重用(约 = 0, acc, alpha * acc) # write back result c = c + (rm[:, none] * stride_cm + rn[none, :] * stride_cn) mask = (rm[:, none] < m) & (rn[none, :] < n) tl.store(c, acc, mask=mask)
手写矩阵乘法kernel的一个重要优势是,它们可以根据需要定制,以适应输入(例如,切片)和输出(例如,leakyrelu)的融合转换。如果没有像 triton 这样的系统,没有出色的 gpu 编程专长的开发者将无法进行矩阵乘法内核的定制修改。
这里是说triton 的良好性能源于一个以 triton-ir 为中心的模块化系统架构,triton-ir 是一个基于 llvm 的中间表示,在这个系统中,多维值块(这个是mlir的概念)是一等公民。gpt
@triton.jit 装饰器的工作原理是遍历提供的 python 函数的抽象语法树(ast),以便使用常见的 ssa 构建算法即时生成 triton-ir。然后,编译器后端会简化、优化并自动并行化所产生的 ir 代码,再将其转换为高质量的 llvm-ir —— 最终生成 ptx —— 以在近期的 nvidia gpu 上执行。目前不支持 cpu 和 amd gpu,但我们欢迎社区贡献,旨在解决这一限制。
我们发现,通过 triton-ir 使用块级别程序表示,使我们的编译器能够自动执行各种重要的程序优化。例如,可以通过观察计算密集型块级操作(例如,tl.dot)的操作数,自动将数据暂存到共享内存中,并使用标准的活性分析技术进行分配和同步。
另一方面,如下所示,triton 程序可以高效且自动地并行化,既可以(1)通过并发执行不同的kernel实例在流式多处理器(sms)间并行,也可以(2)通过分析每个块级操作的迭代空间,并在不同的 simd 单元间适当分配,从而在 sms 内部并行。
0x2. 教程1 vector addition阅读
意思是这一节教程会介绍triton编程模型定义kernel的基本写法,此外也会介绍一下怎么实现一个良好的benchmark测试。下面来看计算kernel实现,我把注释改成中文了:
import torchimport tritonimport triton.language as tl@triton.jitdef add_kernel(x_ptr, # *指针*,指向第一个输入向量。 y_ptr, # *指针*,指向第二个输入向量。 output_ptr, # *指针*,指向输出向量。 n_elements, # 向量的大小。 block_size: tl.constexpr, # 每个程序应处理的元素数量。 # 注意:`constexpr`这样可以被用作形状值。 ): # 这里有多个“程序”处理不同的数据。我们在这里识别我们是哪一个程序: pid = tl.program_id(axis=0) # 我们使用一维启动网格,所以轴是0。 # 该程序将处理从初始数据偏移的输入。 # 例如,如果你有一个长度为256的向量和块大小为64,那么程序 # 将分别访问元素[0:64, 64:128, 128:192, 192:256]。 # 注意偏移量是一个指针列表: block_start = pid * block_size offsets = block_start + tl.arange(0, block_size) # 创建一个掩码以防止内存操作越界访问。 mask = offsets tuple[int]。 # 在这种情况下,我们使用一个1d网格,其大小是块的数量: grid = lambda meta: (triton.cdiv(n_elements, meta['block_size']), ) # 注意: # - 每个torch.tensor对象都隐式地转换为指向其第一个元素的指针。 # - 使用`triton.jit`装饰的函数可以用一个启动网格索引来获得可调用的gpu内核。 # - 不要忘记将元参数作为关键字参数传递。 add_kernel[grid](x, y, output, n_elements, block_size=1024) # 我们返回一个指向z的句柄,但是因为`torch.cuda.synchronize()`还没有被调用,所以这时kernel仍然 # 在异步运行。 return output
我们现在可以使用上面定义的函数来计算两个torch.tensor对象的逐元素求和,并测试其正确性:
torch.manual_seed(0)size = 98432x = torch.rand(size, device='cuda')y = torch.rand(size, device='cuda')output_torch = x + youtput_triton = add(x, y)print(output_torch)print(output_triton)print(f'the maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}')
输出:
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')the maximum difference between torch and triton is 0.0
我们可以对不同大小的向量进行自定义操作的性能基准测试,以了解它相对于pytorch的表现如何。为了简化操作,triton提供了一系列内置工具,使我们能够简洁地绘制出自定义操作在不同问题规模下的性能图表。
@triton.testing.perf_report( triton.testing.benchmark( x_names=['size'], # 用作绘图x轴的参数名。 x_vals=[2**i for i in range(12, 28, 1)], # `x_name`的不同可能值。 x_log=true, # x轴是对数的。 line_arg='provider', # 其值对应于图中不同线条的参数名。 line_vals=['triton', 'torch'], # `line_arg`的可能值。 line_names=['triton', 'torch'], # 线条的标签名称。 styles=[('blue', '-'), ('green', '-')], # 线条样式。 ylabel='gb/s', # y轴的标签名称。 plot_name='vector-add-performance', # 绘图的名称。也用作保存绘图的文件名。 args={}, # 不在`x_names`和`y_name`中的函数参数的值。 ))def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms)
gbps = lambda ms: 12 * size / ms * 1e-6这里的12表示的是数据读写的bit,因为有x和y以及z的存在,所以是3*4=12bit。现在可以运行上面的装饰函数了。传递 print_data=true 参数来查看性能数据,传递 show_plots=true 参数来绘制图表,和/或传递 save_path='/path/to/results/' 参数来将它们连同原始csv数据一起保存到磁盘上:
benchmark.run(print_data=true, show_plots=true)
可以看到,对于elementwise任务,triton的性能几乎和pytorch持平,但是triton写起来很简单。
0x3. 教程2 fused softmax阅读
在这个教程中,我们将编写一个融合的softmax操作,这个操作对于特定类型的矩阵来说比pytorch的原生操作要快得多:那些行的大小可以放入gpu的sram中的矩阵。
通过这样做,我们将学习到:
kernel融合对于带宽受限操作的好处。
triton中的reduce操作符。
动机
自定义gpu kernel用于逐元素加法在教育上是有价值的,但在实际应用中可能作用有限。让我们考虑一个简单的(数值稳定的)softmax操作的情况:
import torchimport tritonimport triton.language as tl@torch.jit.scriptdef naive_softmax(x): 使用原生pytorch计算x的逐行softmax 我们减去最大元素是为了避免溢出。softmax对这种偏移是不变的。 # 读取 mn 个元素;写入 m 个元素 x_max = x.max(dim=1)[0] # 读取 mn + m 个元素;写入 mn 个元素 z = x - x_max[:, none] # 读取 mn 个元素;写入 mn 个元素 numerator = torch.exp(z) # 读取 mn 个元素;写入 m 个元素 denominator = numerator.sum(dim=1) # 读取 mn + m 个元素;写入 mn 个元素 ret = numerator / denominator[:, none] # 总计:读取 5mn + 2m 个元素;写入 3mn + 2m 个元素 return ret
计算kernel
我们的softmax kernel的工作方式如下:每个程序加载输入矩阵x的一行,对其进行归一化处理,然后将结果写回到输出y中。需要注意的是,triton的一个重要限制是每个块必须包含2的幂次方个元素,因此如果我们想处理任何可能的输入形状,我们需要在内部对每行进行“pad”以及对内存访问操作进行保护(也就是防止越界):
@triton.jitdef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, block_size: tl.constexpr): # softmax的各行是独立的,所以我们在这些行上进行并行处理 row_idx = tl.program_id(0) # 步长代表我们需要增加多少指针来前进1行 row_start_ptr = input_ptr + row_idx * input_row_stride # 块大小是大于n_cols的下一个2的幂次,因此我们可以将每一行放入单个块中 col_offsets = tl.arange(0, block_size) input_ptrs = row_start_ptr + col_offsets # 将行加载到sram中,使用掩码因为block_size可能大于n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) # 减去最大值以实现数值稳定性 row_minus_max = row - tl.max(row, axis=0) # 注意在triton中指数运算快但是近似的(即,类似于cuda中的__expf) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # 将输出写回dram output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets = 2048: num_warps = 8 if block_size >= 4096: num_warps = 16 # 分配输出 y = torch.empty_like(x) # 排队执行内核。一维启动网格很简单:我们有每行一个内核实例 # 输入矩阵 softmax_kernel[(n_rows, )]( y, x, x.stride(0), y.stride(0), n_cols, num_warps=num_warps, block_size=block_size, ) return y
这里是验证triton实现的fuse softmax和pytorch的naive实现等价,显然他们是等价的。
benchmark
这里设定矩阵的行数为固定的4096来做benchmark。
@triton.testing.perf_report( triton.testing.benchmark( x_names=['n'], # 用作绘图x轴的参数名 x_vals=[128 * i for i in range(2, 100)], # `x_name`的不同可能值 line_arg='provider', # 其值对应于图中不同线条的参数名 line_vals=[ 'triton', 'torch-native', 'torch-jit', ], # `line_arg`的可能值 line_names=[ triton, torch (原生), torch (jit), ], # 线条的标签名称 styles=[('blue', '-'), ('green', '-'), ('green', '--')], # 线条样式 ylabel=gb/s, # y轴的标签名称 plot_name=softmax-performance, # 绘图的名称。也用作保存绘图的文件名。 args={'m': 4096}, # 不在`x_names`和`y_name`中的函数参数的值 ))def benchmark(m, n, provider): x = torch.randn(m, n, device='cuda', dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) if provider == 'torch-jit': ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms)benchmark.run(show_plots=true, print_data=true)
这里提到虽然triton实现的softmax性能更好并且易于理解和维护,但pytorch的torch.softmax则更加通用。
0x4. 教程3 matrix multiply阅读
首先教程指出这里就是要写一个block级别的矩阵乘法,然后这里会涉及到多维度的指针操作,程序重排以更好的命中l2 cache以及自动调优。
动机
矩阵乘法是大多数现代高性能计算系统的关键构建块。它们众所周知难以优化,因此它们的实现通常由硬件供应商自己作为所谓的“内核库”(例如,cublas)的一部分来完成。不幸的是,这些库通常是专有的,无法轻易地定制以适应现代深度学习工作负载的需求(例如,融合激活函数)。在这个教程中,你将学习如何使用triton自己实现高效的矩阵乘法,这种方法易于定制和扩展。
大致来说,我们将要编写的内核将实现以下块级算法来乘以一个 (m, k) 矩阵和一个 (k, n) 矩阵:
# do in parallelfor m in range(0, m, block_size_m): # do in parallel for n in range(0, n, block_size_n): acc = zeros((block_size_m, block_size_n), dtype=float32) for k in range(0, k, block_size_k): a = a[m : m+block_size_m, k : k+block_size_k] b = b[k : k+block_size_k, n : n+block_size_n] acc += dot(a, b) c[m : m+block_size_m, n : n+block_size_n] = acc
其中,双重嵌套的for循环的每次迭代都由一个专用的triton program实例执行。
计算kernel
上述算法实际上在triton中相当容易实现。主要的难点来自于在内循环中计算必须读取a和b块的内存位置。为此,我们需要多维指针运算。
指针运算
对于一个2d tensor x,x[i, j]的内存位置为&x[i, j] = x + i*stride_xi + j*stride_xj。因此,对于a[m : m+block_size_m, k:k+block_size_k]和b[k : k+block_size_k, n : n+block_size_n]的块指针可以用下面的伪代码定义:
&a[m : m+block_size_m, k:k+block_size_k] = a_ptr + (m : m+block_size_m)[:, none]*a.stride(0) + (k : k+block_size_k)[none, :]*a.stride(1);&b[k : k+block_size_k, n:n+block_size_n] = b_ptr + (k : k+block_size_k)[:, none]*b.stride(0) + (n : n+block_size_n)[none, :]*b.stride(1);
这意味着a和b块的指针可以在triton中初始化,比如 k=0 如下代码所示。另外注意,我们需要一个额外的模运算来处理m不是block_size_m的倍数或n不是block_size_n的倍数的情况,在这种情况下,我们可以用一些无用的值填充数据,这些值不会对结果产生影响。对于k维度,我们稍后将使用掩码加载语义来处理。
offs_am = (pid_m * block_size_m + tl.arange(0, block_size_m)) % moffs_bn = (pid_n * block_size_n + tl.arange(0, block_size_n)) % noffs_k = tl.arange(0, block_size_k)a_ptrs = a_ptr + (offs_am[:, none]*stride_am + offs_k [none, :]*stride_ak)b_ptrs = b_ptr + (offs_k [:, none]*stride_bk + offs_bn[none, :]*stride_bn)
然后在内循环中按如下方式更新:
a_ptrs += block_size_k * stride_ak;b_ptrs += block_size_k * stride_bk;
如上所述,每个program实例计算一个 [block_size_m, block_size_n] 大小的c矩阵块。重要的是要记住,这些块的计算顺序是很重要的,因为它会影响我们程序的l2缓存命中率,不幸的是,一个简单的行优先顺序是不够的。
pid = triton.program_id(0);grid_m = (m + block_size_m - 1) // block_size_m;grid_n = (n + block_size_n - 1) // block_size_n;pid_m = pid / grid_n;pid_n = pid % grid_n;
l2 cache优化
如上所述,每个程序实例计算一个 [block_size_m, block_size_n] 大小的c矩阵块。重要的是要记住,这些块的计算顺序很重要,因为它会影响我们程序的l2缓存命中率,不幸的是,一个简单的行主序排序是不够的。
一个可能的解决方案是以一种促进数据重用的顺序启动块。这可以通过在切换到下一列之前将块在group_m行的super group中分组来实现:
# 程序idpid = tl.program_id(axis=0)# 沿m轴的程序id数量num_pid_m = tl.cdiv(m, block_size_m)# 沿n轴的程序id数量num_pid_n = tl.cdiv(n, block_size_n)# 组中的程序数量num_pid_in_group = group_size_m * num_pid_n# 该程序所在组的idgroup_id = pid // num_pid_in_group# 组中第一个程序的行idfirst_pid_m = group_id * group_size_m# 如果`num_pid_m`不能被`group_size_m`整除,最后一个组更小group_size_m = min(num_pid_m - first_pid_m, group_size_m)# *在组内*,程序按列主序排列# 程序在*启动网格*中的行idpid_m = first_pid_m + (pid % group_size_m)# 程序在*启动网格*中的列idpid_n = (pid % num_pid_in_group) // group_size_m
例如,在下面的矩阵乘法中,每个矩阵由9个块乘以9个块组成,我们可以看到,如果我们按行主序计算输出,我们需要将90个块加载到sram中以计算前9个输出块,但如果我们按grouped ordering进行计算,我们只需要加载54个块。
在实际应用中,这可以在某些硬件架构上提高我们矩阵乘法内核的性能超过10%(例如,在a100上从220提升到245 tflops)。
l2 cache优化原理补充讲解
上面的group oredering的访问代码比较难理解,这里来更详细的解析一下。
# 程序idpid = tl.program_id(axis=0)# 沿m轴的程序id数量num_pid_m = tl.cdiv(m, block_size_m)# 沿n轴的程序id数量num_pid_n = tl.cdiv(n, block_size_n)
这里的num_pid_m和num_pid_n就是求分别要在m和n方向循环多少次。
然后上面图中的黑色数字其实就可以理解为program id,我们可以看到program id增加的方向其实就代表了遍历的ordering,对于row major来说就是在行方向上顺序遍历,而对于group ordering来说就是按照一个block_size_m*block_size_n这么大的一个小组来遍历。其实这段代码就是完成group ordering的遍历:
num_pid_in_group = group_size_m * num_pid_ngroup_id = pid // num_pid_in_groupfirst_pid_m = group_id * group_size_mgroup_size_m = min(num_pid_m - first_pid_m, group_size_m)pid_m = first_pid_m + (pid % group_size_m)pid_n = (pid % num_pid_in_group) // group_size_m
以上面图来看,num_pid_m=3,num_pid_n=3,num_pid_in_group=group_id * group_size_m=9*3=27,也就是下面的红色框里面的program个数,从名字也可以看出来这个红色框划分的区域也是一个group。
group_id 就表示当前的这次 循环, 是在第几个红色框里,以program 0为例,这里为group_id = pid // num_pid_in_group=0//27=0。而first_pid_m 代表当前 group 中的第一个黄色program在全局的m维度上是第几个program ,这里为first_pid_m = group_id * group_size_m=0,group_size_m = min(num_pid_m - first_pid_m, group_size_m)这里是考虑到最后一个group可能占不满数据(存在padding),所以就做一个截断处理。
pid_m = first_pid_m + (pid % group_size_m)pid_n = (pid % num_pid_in_group) // group_size_m
这两行代码计算当前的program处理的黄色小块坐标([pid_m, pid_n]),pid_m这行是在行方向上移动,pid_n这行则是保证在上面的红色框里面一定是一列一列来访问的。
作为对比,在row-major的方法中,访问方式应该是这样的:
pid_m = pid // num_pid_npid_n = pid % num_pid_n
计算最后的结果
有了上面的铺垫,我们就可以计算最终的结果了,下面的代码展示了完整的triton 矩阵乘法kernel实现。
# 使用`triton.jit`装饰的函数可以通过`triton.autotune`装饰器进行自动调优,该装饰器包括:# - 一系列定义不同配置的`triton.config`对象,# 这些配置涉及元参数(例如`block_size_m`)和编译选项(例如`num_warps`)的不同设置# - 一个自动调优*关键字*,其值的变化将触发对所有# 提供的配置的评估@triton.autotune( configs=[ # 每个config定义了一组特定的配置参数和编译选项 triton.config({'block_size_m': 128, 'block_size_n': 256, 'block_size_k': 64, 'group_size_m': 8}, num_stages=3, num_warps=8), triton.config({'block_size_m': 64, 'block_size_n': 256, 'block_size_k': 32, 'group_size_m': 8}, num_stages=4, num_warps=4), triton.config({'block_size_m': 128, 'block_size_n': 128, 'block_size_k': 32, 'group_size_m': 8}, num_stages=4, num_warps=4), triton.config({'block_size_m': 128, 'block_size_n': 64, 'block_size_k': 32, 'group_size_m': 8}, num_stages=4, num_warps=4), triton.config({'block_size_m': 64, 'block_size_n': 128, 'block_size_k': 32, 'group_size_m': 8}, num_stages=4, num_warps=4), triton.config({'block_size_m': 128, 'block_size_n': 32, 'block_size_k': 32, 'group_size_m': 8}, num_stages=4, num_warps=4), triton.config({'block_size_m': 64, 'block_size_n': 32, 'block_size_k': 32, 'group_size_m': 8}, num_stages=5, num_warps=2), triton.config({'block_size_m': 32, 'block_size_n': 64, 'block_size_k': 32, 'group_size_m': 8}, num_stages=5, num_warps=2), ], key=['m', 'n', 'k'], # 自动调优关键字)@triton.jitdef matmul_kernel( # 指向矩阵的指针 a_ptr, b_ptr, c_ptr, # 矩阵维度 m, n, k, # 步长变量表示在特定维度上移动1个元素时指针增加的量。 # 例如`stride_am`是将`a_ptr`增加多少以获取下一行的元素(a有m行)。 stride_am, stride_ak, # a矩阵的步长 stride_bk, stride_bn, # b矩阵的步长 stride_cm, stride_cn,# c矩阵的步长 # 元参数 block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, # group_size_m: tl.constexpr, # activation: tl.constexpr # 激活函数): 用于计算矩阵乘法c = a x b的内核。 a的形状为(m, k),b的形状为(k, n),c的形状为(m, n)。 # ----------------------------------------------------------- # 将程序id `pid`映射到它应该计算的c矩阵的块。 # 这是以grouped ordering完成的,以促进l2数据重用。 # 详细解释看一节 pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(m, block_size_m) num_pid_n = tl.cdiv(n, block_size_n) num_pid_in_group = group_size_m * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * group_size_m group_size_m = min(num_pid_m - first_pid_m, group_size_m) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # 为a和b的第一个块创建指针。 # 我们将在k方向移动时推进这个指针并累加 # `a_ptrs`是[block_size_m, block_size_k]块的指针 # `b_ptrs`是[block_size_k, block_size_n]块的指针 # 有关详细信息,请参阅上方“指针算术”部分 offs_am = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m offs_bn = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n offs_k = tl.arange(0, block_size_k) a_ptrs = a_ptr + (offs_am[:, none] * stride_am + offs_k[none, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, none] * stride_bk + offs_bn[none, :] * stride_bn) # ----------------------------------------------------------- # 迭代以计算c矩阵的一个块。 # 我们将累加到一个`[block_size_m, block_size_n]`块 # 的fp32值以获得更高的精度。 # `accumulator`在循环后会转换回fp16。 accumulator = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) for k in range(0, tl.cdiv(k, block_size_k)): # load the next block of a and b, generate a mask by checking the k dimension. # if it is out of bounds, set it to 0. a = tl.load(a_ptrs, mask=offs_k[none, :] < k - k * block_size_k, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, none] < k - k * block_size_k, other=0.0) # we accumulate along the k dimension. accumulator += tl.dot(a, b) # advance the ptrs to the next k block. a_ptrs += block_size_k * stride_ak b_ptrs += block_size_k * stride_bk # 当累加器仍然是fp32时,可以融合任意激活函数 if activation == leaky_relu: accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # 使用掩码将输出矩阵c的块写回。 offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) c_ptrs = c_ptr + stride_cm * offs_cm[:, none] + stride_cn * offs_cn[none, :] c_mask = (offs_cm[:, none] < m) & (offs_cn[none, :] = 0, x, 0.01 * x)
我们现在可以创建一个方便的封装函数,它只需要两个输入张量,并且会:(1)检查任何形状约束;(2)分配输出;(3)启动上述kernel。
def matmul(a, b, activation=): # check constraints. assert a.shape[1] == b.shape[0], incompatible dimensions assert a.is_contiguous(), matrix a must be contiguous assert b.is_contiguous(), matrix b must be contiguous m, k = a.shape k, n = b.shape # allocates output. c = torch.empty((m, n), device=a.device, dtype=a.dtype) # 1d launch kernel where each block gets its own program. grid = lambda meta: (triton.cdiv(m, meta['block_size_m']) * triton.cdiv(n, meta['block_size_n']), ) matmul_kernel[grid]( a, b, c, # m, n, k, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # activation=activation # ) return c
计算过程的补充说明
上面的《l2 cache优化原理补充讲解》这一节明确了kernel的group ordering的访问方式以及实现,现在来看对于当前的program实例具体是怎么计算的。现在以计算c中的第一个block的(0, 0)为例子,它需要从a和b分别加载9个黄色的小块数据相乘并累加最后得到c中的(0, 0)位置结果。如下图所示:
下面的代码先把program实例当前要处理a和b的第一个block加载上来:
# ----------------------------------------------------------# 为a和b的第一个块创建指针。# 我们将在k方向移动时推进这个指针并累加# `a_ptrs`是[block_size_m, block_size_k]块的指针# `b_ptrs`是[block_size_k, block_size_n]块的指针# 有关详细信息,请参阅上方“指针算术”部分offs_am = (pid_m * block_size_m + tl.arange(0, block_size_m)) % moffs_bn = (pid_n * block_size_n + tl.arange(0, block_size_n)) % noffs_k = tl.arange(0, block_size_k)a_ptrs = a_ptr + (offs_am[:, none] * stride_am + offs_k[none, :] * stride_ak)b_ptrs = b_ptr + (offs_k[:, none] * stride_bk + offs_bn[none, :] * stride_bn)
这里的a_ptr 是整个 a 矩阵第一个元素的地址,offs_am和offs_bn表示当前的program id在m维度和k维度的坐标,这个坐标是一个list,用tl.arange(0, block_size_k)来获取。
得到 m 维度 和 k 维度的坐标后, 就可以让它们各自和 m 维度 和 k 维度的 stride 相乘, 然后和 a_ptr 相加, 就可以得到 a 矩阵 9 个 block 中第一个 block 中每个元素的地址了。 b_ptr也是同理。
最后一部分就是累加了,这里会在k维度上进行累加,每次计算输出的一个块。
# 迭代以计算c矩阵的一个块。# 我们将累加到一个`[block_size_m, block_size_n]`块# 的fp32值以获得更高的精度。# `accumulator`在循环后会转换回fp16。accumulator = tl.zeros((block_size_m, block_size_n), dtype=tl.float32)for k in range(0, tl.cdiv(k, block_size_k)): # load the next block of a and b, generate a mask by checking the k dimension. # if it is out of bounds, set it to 0. a = tl.load(a_ptrs, mask=offs_k[none, :] < k - k * block_size_k, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, none] < k - k * block_size_k, other=0.0) # we accumulate along the k dimension. accumulator += tl.dot(a, b) # advance the ptrs to the next k block. a_ptrs += block_size_k * stride_ak b_ptrs += block_size_k * stride_bk
这行代码a = tl.load(a_ptrs, mask=offs_k[none, :] < k - k * block_size_k, other=0.0)考虑到 k 可能不能被 block_size_k 整除, 到每一行最后一个 block 的时候, 实际大小是不足 block_size_k 的,所以需要把超出的那部分元素mask掉。
最后这部分代码是把当前的算子和leakyrelu激活函数进行融合:
# 当累加器仍然是fp32时,可以融合任意激活函数if activation == leaky_relu: accumulator = leaky_relu(accumulator)c = accumulator.to(tl.float16)
单元测试
benchmark
这里使用一个方阵来对比triton实现的matmul kernel和cublas的matmul kernel的性能。
@triton.testing.perf_report( triton.testing.benchmark( x_names=['m', 'n', 'k'], # 用作图表x轴的参数名 x_vals=[128 * i for i in range(2, 33)], # `x_name`的不同可能值 line_arg='provider', # 其值对应于图表中不同线条的参数名 # `line_arg`的可能值 line_vals=['cublas', 'triton'], # 线条的标签名称 line_names=[cublas, triton], # 线条样式 styles=[('green', '-'), ('blue', '-')], ylabel=tflops, # y轴的标签名称 plot_name=matmul-performance, # 图表的名称,也用作保存图表的文件名。 args={}, # 其他参数 ))def benchmark(m, n, k, provider): # 初始化张量 a = torch.randn((m, k), device='cuda', dtype=torch.float16) b = torch.randn((k, n), device='cuda', dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] # 分位数 # 如果提供者是cublas if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) # 如果提供者是triton if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) # 性能计算函数 perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)# 运行基准测试,展示图表和打印数据benchmark.run(show_plots=true, print_data=true)
可以看到基于triton实现的矩阵乘kernel性能大体可以和高度优化的cublas持平。
苹果公司已经斥资约2亿美元收购了一家西雅图初创公司
三星S8将推北极银配色,这个配色美到不行!
AMD全新锐龙处理器和Radeon显卡日本发售时间提前三个小时
关于区块链和互联网的之间的区别和作用分析
直流无刷电机的型号该如何选择?
【BBuf的CUDA笔记】OpenAI Triton入门笔记一
科创板中国通号副总裁孔宁介绍、履历信息
M12接插件防水航空插座4PIN5PIN6PIN圆形连接器
还在临床实验的医疗,人体软组织焊接术
瑞萨电子解读智能家居
荣耀V9、荣耀9、华为P10和坚果pro四款高颜值机你入手了哪款?
华为WLAN: 不仅仅是接入
深入研究数据科学家使用的常见统计和分析技术
2017-2021年,全球电缆连接器市场年复合增率有望超过6%
对比,LCD和LED的背光原理有什么区别?
Simulink中构造时变传递函数的四种方法
快速接线模块有什么作用如何提高系统安全
保障隧道施工安全的智能化管理系统
在大多数公司中,多达20%的工作可能是基于AI的工作
墨菲定律和设计“非数据手册”的风险