-
Notifications
You must be signed in to change notification settings - Fork 67
Description
Describe the issue
Hi,
I wrote triton matmul, that makes dequantization "on the flight". Theoretically it can be as fast as simple blocked matmul, because dequantization algorithm is not depending on shape passed in it. I took matmul implementation from python/tutorials/10-experimental-block-pointer.py
.
Performance measurement (all calls are triton) - approach to measurements as it's done in autotuner in python/triton/testing.py::do_bench
via torch.xpu.Event
split 2 calls dequantization than matmul: 0.014879999999999999 ms
fused dequantization + matmul: 0.0324 ms
split 2 calls dequantization than matmul: 0.01496 ms
fused dequantization + matmul: 0.03248 ms
Split is a call of dequantization implemented in triton and than a call of matmul implemented in triton (originally taken from 10th tutorial).
Fused kernel described lower.
@triton.jit
def dequant_4bit_body_util(a, offsets, quant_ptr, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
# lower 4bits
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
# apply conversion
lower_4 = tl.load(quant_ptr + lower, eviction_policy="evict_last")
higher_4 = tl.load(quant_ptr + higher, eviction_policy="evict_last")
mul_high = higher_4 * absmax
mul_low = lower_4 * absmax
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4},
num_stages=2,
num_warps=32,
),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_az,
stride_am,
stride_ak,
stride_bz,
stride_bn,
stride_bk,
stride_cz,
stride_cm,
stride_cn,
quant_ptr,
absmax_ptr,
num_paired_elements,
ACCUMULATOR_DTYPE: tl.constexpr,
QUANT_BLOCK: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offset_a = bid.to(tl.int64) * stride_az
offset_b = bid.to(tl.int64) * stride_bz
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
b_offsets = offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk
a_block_ptr = tl.make_block_ptr(
base=a_ptr + offset_a,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_SIZE_M, 0),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
order=(1, 0),
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr + offset_b,
shape=(N, K // 2),
strides=(stride_bn, stride_bk),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_K // 2),
order=(1, 0),
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
# dq_b = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a_blck = tl.load(a_block_ptr, boundary_check=(0, 1))
b_blck = tl.load(b_block_ptr, boundary_check=(0, 1))
dq_b_t = dequant_4bit_body_util(
a=b_blck,
offsets=b_offsets,
quant_ptr=quant_ptr,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
dq_b_t = dq_b_t.trans()
dq_b = dq_b_t.to(a_ptr.type.element_ty)
# We accumulate along the K dimension.
accumulator += tl.dot(a_blck, dq_b, out_dtype=ACCUMULATOR_DTYPE)
# Advance the ptrs to the next K block.
b_offsets += (BLOCK_SIZE_K // 2) * stride_bk
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
b_block_ptr = tl.advance(b_block_ptr, (0, BLOCK_SIZE_K // 2))
c = accumulator.to(c_ptr.type.element_ty)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offset_c = bid.to(tl.int64) * stride_cz
c_block_ptr = tl.make_block_ptr(
base=c_ptr + offset_c,
shape=(M, N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0),
)
tl.store(c_block_ptr, c, boundary_check=(0, 1))
# Common scenario is A batched, B is just 2d
def matmul(a, b, shapeB, code, absmax, blocksize):
# Check constraints.
# assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
if len(a.shape) == 1:
B = M = 1
K = a.shape
stride_az, stride_am, stride_ak = a.numel(), a.numel(), a.stride(0)
elif len(a.shape) == 2:
B = 1
M, K = a.shape
stride_az, stride_am, stride_ak = a.numel(), a.stride(0), a.stride(1)
elif len(a.shape) == 3:
B, M, K = a.shape
stride_az, stride_am, stride_ak = a.stride(0), a.stride(1), a.stride(2)
elif len(a.shape) > 3:
a = a.view(-1, a.shape[-2], a.shape[-1])
B, M, K = a.shape
stride_az, stride_am, stride_ak = a.stride(0), a.stride(1), a.stride(2)
if len(shapeB) == len(a.shape) == 3:
assert shapeB[0] == B, "Incompatible batch size"
assert shapeB[1] == K, "Incompatible dimensions"
B, N, K = shapeB
stride_bz, stride_bn, stride_bk = N * K // 2, K // 2, 1
c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
if len(shapeB) == 2:
N, K = shapeB
b = b.view(N, K // 2)
stride_bz, stride_bn, stride_bk = N * K // 2, K // 2, 1
if len(a.shape) >= 3:
c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
stride_cz, stride_cm, stride_cn = c.stride(0), c.stride(1), c.stride(2)
else:
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
stride_cz, stride_cm, stride_cn = c.numel(), c.stride(0), c.stride(1)
# BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M = 16, 16, 16, 4
# Allocates output.
# c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
# grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),B, )
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
B,
)
accum_dtype = a.dtype
if a.dtype in (torch.bfloat16, torch.float16):
accum_dtype = torch.float32
triton_accum_dtype = tl.dtype(str(accum_dtype)[6:].replace("bfloat", "bf").replace("float", "fp"))
# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
number_of_paired_elements = b.numel()
matmul_kernel[grid](
a,
b,
c, # tensors
M,
N,
K, # sizes
stride_az,
stride_am,
stride_ak, #
stride_bz,
stride_bn,
stride_bk, #
stride_cz,
stride_cm,
stride_cn, #
code,
absmax,
number_of_paired_elements,
triton_accum_dtype,
blocksize,
# BLOCK_SIZE_M,
# BLOCK_SIZE_N,
# BLOCK_SIZE_K, #
# GROUP_SIZE_M,
)
return c
I think performance is lower due to strict matmul patterns to handle K loop.
Environment details
Triton:
commit 3f3bcf32eefcac74be1afa4331495566e49dde55 (HEAD -> main, origin/main, origin/HEAD)
Author: Anatoly Myachev <[email protected]>
Date: Wed May 21 18:49:13 2025 +0200
Fix Kineto+PTI profiling on BMG (#4244)
Signed-off-by: Anatoly Myachev <[email protected]>
Env: SPR X1 Triton (GPU, Agama 1099.12, DLE 2025.1.1, Ubuntu 22.04)