Triton比起CUDA,triton 提供了更高的抽象层次,在编写 GPU 程序时,开发者不需要像在 CUDA 中那样处理复杂的底层细节,如线程块、线程网格的详细配置等。它允许使用 Pythonic 的语法风格,开发者可以更加直观地表达计算逻辑,减少了繁琐的代码编写工作,从而更容易上手。
- triton文档的官网:
https://triton-lang.org/main/index.html - triton的github:
https://github.com/triton-lang/triton/tree/main
vector addition
这个整体比较简单,能直接看懂,主要就是熟悉了tl.program_id这个用法,比较某个维度进行计算的program
fused softmax
这个难度一下子上来了,结合了网上一些同学的博客进行了理解,主要的思路是:
- 假设给定一个矩阵要进行softmax操作,根据并行计算的策略,就是划分不同的行的块进行计算,就是有几个program可以同时计算
- 比如说一个program计算第1,11,21行,那么另一个program就计算2,12,22行,以此类推,主要就是这个策略
- 关注一些概念,比如num_stages, warp_size, num_warps, num_regs等等,整体有多少个num_program是根据kernel使用的寄存器数量、当前device的SM数量、device最大支持的寄存器数量以及occupancy来计算得到的
- triton是根据线程块计算program的,基本上是一个线程块一个program,然后GPU是由SM(streaming multiprocessor)组成的,在SM上有warp表示一组warp_size数量的线程,一般warp_size是固定的硬件参数,比如32,那么比如说一个SM,num_warp=8,warp_size=32,就表示一个SM上的一个program用32 * 8 = 256个线程,每个线程块使用的寄存器总量为 n_regs * WARP_SIZE * num_warps,那么总共有多少个program就可以求出来了
- 跑了下结果大致是下面这样的,优化确实是有效果的,通过这个例子,大致熟悉了triton的编程逻辑
Matrix Multiplication
这个比fused softmax还要复杂一些,遇到一些问题记录一下
- triton 3.3.0有一些问题,降级到triton 3.2.0之后问题就没了
- 对于显卡有一些要求,在v100上是可以跑通的,但是在一些老的显卡是会有各种问题
- fp16的测试在v100上好像有一些精度上的不匹配,fp8的测试是没问题的
fp8
fp16
Low Memory Dropout
这个问题编程上比较简单,涉及到的是模型checkpoint。checkpoint机制是指模型在反向传播的时候不保留一些激活函数的前向计算结果,而且重新前向计算来一遍以节省显存,那么对于dropout,涉及到的问题就是是否需要保存dropout的mask以便反向传递的时候进行恢复呢?其实是不用的,我们只需要一个根据seed可以固定个随机数发生器就可以了,这个在triton已经进行实现,这样就可以省下保存dropout mask的显存了。
Reference
- fused_softmax参考
https://zhuanlan.zhihu.com/p/1899562146477609112 - matrix mutiplication参考
https://blog.csdn.net/youzjuer/article/details/136897828
https://zhuanlan.zhihu.com/p/5814934527