Skip to content

Matmul pattern with deqauntization in it #4327

@Devjiu

Description

@Devjiu

Describe the issue

Hi,

I wrote triton matmul, that makes dequantization "on the flight". Theoretically it can be as fast as simple blocked matmul, because dequantization algorithm is not depending on shape passed in it. I took matmul implementation from python/tutorials/10-experimental-block-pointer.py.

Performance measurement (all calls are triton) - approach to measurements as it's done in autotuner in python/triton/testing.py::do_bench via torch.xpu.Event

  split 2 calls dequantization than matmul:  0.014879999999999999  ms
             fused dequantization + matmul:  0.0324  ms
  split 2 calls dequantization than matmul:  0.01496  ms
             fused dequantization + matmul:  0.03248  ms

Split is a call of dequantization implemented in triton and than a call of matmul implemented in triton (originally taken from 10th tutorial).
Fused kernel described lower.

@triton.jit
def dequant_4bit_body_util(a, offsets, quant_ptr, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
    PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
    mask = offsets < n_elems
    higher = a & 0xF
    # lower 4bits
    lower = a >> 4

    abs_offsets = offsets // PAIRED_QUANT_BLOCK
    absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")

    # apply conversion
    lower_4 = tl.load(quant_ptr + lower, eviction_policy="evict_last")
    higher_4 = tl.load(quant_ptr + higher, eviction_policy="evict_last")

    mul_high = higher_4 * absmax
    mul_low = lower_4 * absmax
    out_dq = tl.interleave(mul_low, mul_high)
    return out_dq

@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4},
            num_stages=2,
            num_warps=32,
        ),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
    # Pointers to matrices
    a_ptr,
    b_ptr,
    c_ptr,
    # Matrix dimensions
    M,
    N,
    K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_az,
    stride_am,
    stride_ak,
    stride_bz,
    stride_bn,
    stride_bk,
    stride_cz,
    stride_cm,
    stride_cn,
    quant_ptr,
    absmax_ptr,
    num_paired_elements,
    ACCUMULATOR_DTYPE: tl.constexpr,
    QUANT_BLOCK: tl.constexpr,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,  #
    GROUP_SIZE_M: tl.constexpr,  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    bid = tl.program_id(axis=1)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offset_a = bid.to(tl.int64) * stride_az
    offset_b = bid.to(tl.int64) * stride_bz
    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)

    b_offsets = offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk

    a_block_ptr = tl.make_block_ptr(
        base=a_ptr + offset_a,
        shape=(M, K),
        strides=(stride_am, stride_ak),
        offsets=(pid_m * BLOCK_SIZE_M, 0),
        block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
        order=(1, 0),
    )
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr + offset_b,
        shape=(N, K // 2),
        strides=(stride_bn, stride_bk),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_K // 2),
        order=(1, 0),
    )

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
    # dq_b = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a_blck = tl.load(a_block_ptr, boundary_check=(0, 1))
        b_blck = tl.load(b_block_ptr, boundary_check=(0, 1))
        dq_b_t = dequant_4bit_body_util(
            a=b_blck,
            offsets=b_offsets,
            quant_ptr=quant_ptr,
            absmax_ptr=absmax_ptr,
            n_elems=num_paired_elements,
            QUANT_BLOCK=QUANT_BLOCK,
        )
        dq_b_t = dq_b_t.trans()
        dq_b = dq_b_t.to(a_ptr.type.element_ty)

        # We accumulate along the K dimension.
        accumulator += tl.dot(a_blck, dq_b, out_dtype=ACCUMULATOR_DTYPE)
        # Advance the ptrs to the next K block.
        b_offsets += (BLOCK_SIZE_K // 2) * stride_bk
        a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
        b_block_ptr = tl.advance(b_block_ptr, (0, BLOCK_SIZE_K // 2))
    c = accumulator.to(c_ptr.type.element_ty)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offset_c = bid.to(tl.int64) * stride_cz
    c_block_ptr = tl.make_block_ptr(
        base=c_ptr + offset_c,
        shape=(M, N),
        strides=(stride_cm, stride_cn),
        offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
        block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
        order=(1, 0),
    )
    tl.store(c_block_ptr, c, boundary_check=(0, 1))


# Common scenario is A batched, B is just 2d
def matmul(a, b, shapeB, code, absmax, blocksize):
    # Check constraints.
    # assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    if len(a.shape) == 1:
        B = M = 1
        K = a.shape
        stride_az, stride_am, stride_ak = a.numel(), a.numel(), a.stride(0)
    elif len(a.shape) == 2:
        B = 1
        M, K = a.shape
        stride_az, stride_am, stride_ak = a.numel(), a.stride(0), a.stride(1)
    elif len(a.shape) == 3:
        B, M, K = a.shape
        stride_az, stride_am, stride_ak = a.stride(0), a.stride(1), a.stride(2)
    elif len(a.shape) > 3:
        a = a.view(-1, a.shape[-2], a.shape[-1])
        B, M, K = a.shape
        stride_az, stride_am, stride_ak = a.stride(0), a.stride(1), a.stride(2)

    if len(shapeB) == len(a.shape) == 3:
        assert shapeB[0] == B, "Incompatible batch size"
        assert shapeB[1] == K, "Incompatible dimensions"
        B, N, K = shapeB
        stride_bz, stride_bn, stride_bk = N * K // 2, K // 2, 1
        c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)

    if len(shapeB) == 2:
        N, K = shapeB
        b = b.view(N, K // 2)
        stride_bz, stride_bn, stride_bk = N * K // 2, K // 2, 1
        if len(a.shape) >= 3:
            c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
            stride_cz, stride_cm, stride_cn = c.stride(0), c.stride(1), c.stride(2)
        else:
            c = torch.empty((M, N), device=a.device, dtype=a.dtype)
            stride_cz, stride_cm, stride_cn = c.numel(), c.stride(0), c.stride(1)

    # BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M = 16, 16, 16, 4
    # Allocates output.
    # c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
    # 1D launch kernel where each block gets its own program.
    # grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),B, )
    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
        B,
    )
    accum_dtype = a.dtype
    if a.dtype in (torch.bfloat16, torch.float16):
        accum_dtype = torch.float32

    triton_accum_dtype = tl.dtype(str(accum_dtype)[6:].replace("bfloat", "bf").replace("float", "fp"))
    # grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    number_of_paired_elements = b.numel()
    matmul_kernel[grid](
        a,
        b,
        c,  # tensors
        M,
        N,
        K,  # sizes
        stride_az,
        stride_am,
        stride_ak,  #
        stride_bz,
        stride_bn,
        stride_bk,  #
        stride_cz,
        stride_cm,
        stride_cn,  #
        code,
        absmax,
        number_of_paired_elements,
        triton_accum_dtype,
        blocksize,
        # BLOCK_SIZE_M,
        # BLOCK_SIZE_N,
        # BLOCK_SIZE_K,  #
        # GROUP_SIZE_M,
    )
    return c 

I think performance is lower due to strict matmul patterns to handle K loop.

Environment details

Triton:

commit 3f3bcf32eefcac74be1afa4331495566e49dde55 (HEAD -> main, origin/main, origin/HEAD)
Author: Anatoly Myachev <[email protected]>
Date:   Wed May 21 18:49:13 2025 +0200

    Fix Kineto+PTI profiling on BMG (#4244)

    Signed-off-by: Anatoly Myachev <[email protected]>

Env: SPR X1 Triton (GPU, Agama 1099.12, DLE 2025.1.1, Ubuntu 22.04)

Metadata

Metadata

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions