00_basic_gemm

说明

这里研究的cutlass版本是3.5

gemm讲解

  using CutlassGemm = cutlass::gemm::device::Gemm<float,        // Data-type of A matrix
                                                  ColumnMajor,  // Layout of A matrix
                                                  float,        // Data-type of B matrix
                                                  ColumnMajor,  // Layout of B matrix
                                                  float,        // Data-type of C matrix
                                                  ColumnMajor>; // Layout of C matrix

  CutlassGemm gemm_operator;
  CutlassGemm::Arguments args({
   
   M , N, K},  // Gemm Problem dimensions
                              {
   
   A, lda},    // Tensor-ref for source matrix A
                              {
   
   B, ldb},    // Tensor-ref for source matrix B
                              {
   
   C, ldc},    // Tensor-ref for source matrix C
                              {
   
   C, ldc},    // Tensor-ref for destination matrix D (may be different memory than source C matrix)
                              {
   
   alpha, beta}); // Scalars used in the Epilogue
  
  cutlass::Status status = gemm_operator(args);

上面是核心代码,可以看到首先要实例化一个类型CutlassGemm(编译期就要定下来),然后根据这个类型实例化一个对象gemm_operator(运行期),然后对象调用operator(args)做计算(运行期)。

编译期

  using CutlassGemm = cutlass::gemm::device::Gemm<float,        // Data-type of A matrix
                                                ColumnMajor,  // Layout of A matrix
                                                float,        // Data-type of B matrix
                                                ColumnMajor,  // Layout of B matrix
                                                float,        // Data-type of C matrix
                                                ColumnMajor>; // Layout of C matrix

可以看到,编译期时候,程序员必须要定下输入矩阵的layout和数据类型。事实上真的是这样吗?我们来深究一下这个cutlass::gemm::device::Gemm,从这个名字就可以看出来,cutlass实现了一个gemm,有device, threadblock, warp, thread几个级别gemm,这个sample里面用的是device级别, 所谓的device级别就是在cpu端的代码可以调用的,这个其实和cub中的逻辑是一样的。

Gemm类

template <
    typename ElementA_,
    typename LayoutA_,
    typename OperatorClass_ = arch::OpClassSimt,
    typename ArchTag_ = arch::Sm70,
    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::ThreadblockShape,
    // Operation performed by GEMM
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator, //op的选择
> 
Gemm{
   
   }
//偏特化一个
template<省略>
Gemm<layoutC=layout::ColumnMajor,>
{
   
   
	using UnderlyingOperator = Gemm<bala>;
}

1、这里偏特化很奇怪,单独给layoutC为列优先时候准备了一个类,具体什么原因这里也不深究,因为测试例子给的就是个ColumnMajor的layoutC,所以我们直接看这个偏特化类型。

这里增加了一个小知识,就是偏特化的模板不需要再传入默认值,会自动复用原始模板的默认值,此外由于偏特化实例化了一个值,导致在类里使用的时候没有了形参,为此可以看到源码里在类的开头搞了一堆的 类似using LayoutC = LayoutC_;即使偏特化实例化后,也能在类中再搞一个形参使用,CPP这搞得的是真恶心。

2、在偏特化的类中,又实例化了一通用的Gemm类UnderlyingOperator,这里把C的layout又改成RowMajor来用通用模板实例化Gemm, 多以饶了一圈,偏特化就是空壳子,最后还是绕回去通用模板,为什么要这要搞?我理解为了对外接口统一做的牺牲。

UnderlyingOperator

接下来就看看这通用模板是如何被实例化的。

  using UnderlyingOperator = Gemm< 
    ElementB,
    typename layout::LayoutTranspose<LayoutB>::type,
    ElementA,
    typename layout::LayoutTranspose<LayoutA>::type,
    ElementC,
    layout::RowMajor,    
    ElementAccumulator,
    OperatorClass,
    ArchTag,
    ThreadblockShape,
    WarpShape,
    InstructionShape,
    EpilogueOutputOp,
    ThreadblockSwizzle,
    Stages,
    kAlignmentB,
    kAlignmentA,
    SplitKSerial,
    Operator,
    GatherB,
    GatherA,
    ScatterD,
    PermuteDLayout
  >;

1、第一个有趣的现象是,把A,B, C的layout修改了,而且A和B的输入位置也变了,有点意思,来看看原理。
为了保证输出的数据不变,原来的列优先eading dimension=M的矩阵,我们也可以把他解读为行优先,leading dimension=M的矩阵。在计算机内存上没有任何变化,当时数学逻辑上由原来的MN的矩阵,变成了NM的矩阵(必须要理解)。 紧接着在数学逻辑上,原来的(MK) * (KN) = MN 需要变成(Nk) * (KM) = NM, 这里可以看到,当C的解读变化的时候,原来的A和B的位置调换了, 又由于A和B的计算机内存的数据不动,那么在解读时候,layout也要跟着掉换才对。
如此,就可以在A,B,C内存都不用动的情况来做计算,通过上述分析可以看到,框架还是想方设法的把C矩阵拉成行优先去处理,好处嘛我以为就是为了处理C的时候行连续更加符合人类对内存的直观感受,写代码的时候不至于一直别别扭扭,而且可以统一成一种优化模式,不管你C是什么layout,最终在不影响性能的情况下都转到一种C=RowMajor的逻辑下。
在这里插入图片描述

Gemm模板形参解读

上面我们搞清楚了一个矩阵的输入输出关系,位置关系,下面我们来看看一些默认形参分别是啥意思

OperatorClass

这个默认给了arch::OpClassSimt, 这个好理解,就是利用cuda core来做计算,如果是OpClassTensorOp的话就用tensorcore来做计算,OpClassWmmaTensorOp的话就是用一个分装好的wmma接口来调用tensorcore做计算,整个种类定义都在mma.h文件中,这里没啥好说的。

ArchTag

这里默认是arch::Sm70,这个也比较好理解,比如sm_50, sm_86的架构不一样,有的又sme异步,有的有tensor core, 有的sm大小不一样,反正每个架构不一样,各种feature也不一样,在编译期就明确,在编译期也就可以很具这些架构选择合适的算法。

ThreadblockShape

接下来是分级策略,这个需要去了解一下gemm的分块原理,这里不再详述,知乎一大堆,那么分块尺寸如何确定呢?这里用了DefaultGemmConfiguration这个类去萃取,依赖就是计算单元,架构,A,B, C的数据类型,以及累加的数据类型。

    typename ThreadblockShape_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::ThreadblockShape

所以需要好好看看DefaultGemmConfiguration是啥玩意:具体在default_gemm_configuration.h


template <
  typename OperatorClass,
  typename ArchTag,
  typename ElementA, 
  typename ElementB, 
  typename ElementC,
  typename ElementAccumulator
>
struct DefaultGemmConfiguration;

////////////////////////////////////////////////////////////////////////////////

template <
  typename ArchTag,
  typename ElementA, 
  typename ElementB, 
  typename ElementC, 
  typename ElementAccumulator>
struct DefaultGemmConfiguration<
  arch::OpClassSimt, 
  ArchTag,
  ElementA, 
  ElementB, 
  ElementC, 
  ElementAccumulator> {
   
   
  
  static int const kAlignmentA = 1;
  static int const kAlignmentB = 1;
  using ThreadblockShape = GemmShape<128, 128, 8>;
  using WarpShape = GemmShape<32, 64, 8>;
  using InstructionShape = GemmShape<1, 1, 1>;
  static int const kStages = 2;

  using EpilogueOutputOp = epilogue::thread::LinearCombination<
    ElementC,
    1,
    ElementAccumulator,
    ElementAccumulator
  >;

  using Operator = arch::OpMultiplyAdd;
};

上述代码可以看到就是定义了一个空类型,然后偏特化一堆的DefaultGemmConfiguration,上述代码就是偏特化一个参数arch::OpClassSimt,其实这个文件就是一个配置文件,你需要用啥shape,在这里搞个偏特化版本就行,我们目的是分析不是调优,所以就看看cutlass咋用的就行。

EpilogueOutputOp

就是gemm后面跟了一个计算,这里一下子给融合到计算里,这里的例子没有,所以可以先不看,看的也是在DefaultGemmConfiguration配置,不过还要依赖LinearCombinationClamp或者LinearCombination去配置。

ThreadblockSwizzle

这个目前猜测是增大L2 cache命中率的目的,让不同的block id访问数据局部性,所以就是将block的id和C矩阵中的数据做映射。 此外还有七八种其他方案,可以自己根据实际情况进行选择,文件在threadblock_swizzle.h, 一般不需要去重新排布,除非是n是一个很大的值,单个wave无法加载整个B矩阵。我这里写了个测试代码:

#include <iostream>
#include <cmath>
#include <iostream>
#include <vector>
#include <string>
#include <iomanip>

using namespace std;

void map(int m, int n, int tile, vector<vector<string>>& coord)
{
   
   
        for(int x=0; x<m; x++)
        {
   
   
                for(int y=0; y<n; y++)
                {
   
   
                        int c_x = (x>>tile);
                        int c_y = (y<<tile)+((x)&((1<<tile) -1));
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值