Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch Naive Softmax
0 256.0 476.279060 697.239179 201.908777
1 384.0 663.882257 808.823647 258.604873
2 512.0 812.111250 909.045790 297.586692
3 640.0 909.904620 907.390931 329.359739
4 768.0 976.228323 977.483421 346.140834
5 896.0 1040.291805 1041.948757 355.228326
6 1024.0 1080.628141 1074.774861 351.091691
7 1152.0 1099.908856 1078.956050 350.125776
8 1280.0 1132.107861 1105.339218 349.406742
9 1408.0 1162.128358 1140.686185 340.777518
10 1536.0 1190.060476 1165.936415 332.762317
11 1664.0 1217.873645 1189.115616 329.182308
12 1792.0 1223.912892 1193.452613 323.568325
13 1920.0 1257.743903 1225.772221 323.060389
14 2048.0 1276.873186 1245.931568 324.645535
15 2176.0 1239.351740 960.572959 326.302885
16 2304.0 1254.752262 1003.101999 325.833610
17 2432.0 1268.597425 1037.286124 327.737498
18 2560.0 1290.266887 1066.451616 328.210605
19 2688.0 1294.180114 1096.325865 329.650704
20 2816.0 1308.352423 1126.531326 330.656837
21 2944.0 1320.499039 1139.689148 330.694707
22 3072.0 1321.273142 1168.100501 333.362432
23 3200.0 1332.143220 1168.680377 334.547697
24 3328.0 1347.520181 1201.424864 335.841524
25 3456.0 1355.208973 1220.089730 336.925429
26 3584.0 1362.607413 1249.471736 338.657881
27 3712.0 1360.657527 1267.838332 340.279130
28 3840.0 1367.783303 1285.744681 340.845142
29 3968.0 1379.150892 1296.483464 340.777685
30 4096.0 1387.647940 1313.205384 338.369841
31 4224.0 1337.263281 1280.018608 343.198865
32 4352.0 1343.406785 1302.331356 345.357753
33 4480.0 1350.667535 1316.482104 346.212908
34 4608.0 1357.887844 1336.270658 346.415171
35 4736.0 1356.091771 1347.145183 347.449133
36 4864.0 1368.226650 1352.791378 348.957387
37 4992.0 1371.740387 1373.534955 350.234184
38 5120.0 1379.151412 1385.215462 350.902120
39 5248.0 1377.798033 1359.163671 351.781608
40 5376.0 1373.459733 1375.739124 352.175701
41 5504.0 1381.787801 1378.594517 353.957810
42 5632.0 1393.624346 1394.638075 353.226998
43 5760.0 1392.423425 1403.326835 355.023523
44 5888.0 1391.261887 1405.215682 355.427373
45 6016.0 1399.430656 1428.684971 356.350347
46 6144.0 1413.081912 1428.299676 356.885096
47 6272.0 1406.622327 1401.236108 358.223857
48 6400.0 1413.815600 1411.124008 358.235371
49 6528.0 1421.229651 1418.681986 359.490202
50 6656.0 1415.975207 1430.556599 359.673317
51 6784.0 1417.311466 1435.203609 360.652770
52 6912.0 1429.374343 1449.973516 360.877377
53 7040.0 1419.662365 1450.209807 360.952186
54 7168.0 1415.085956 1454.398138 361.290323
55 7296.0 1420.791189 1086.269219 362.186664
56 7424.0 1427.091958 1098.575677 362.558588
57 7552.0 1430.142019 1113.176873 363.122387
58 7680.0 1432.625845 1125.488647 363.952958
59 7808.0 1430.996230 1131.426651 364.403182
60 7936.0 1430.488490 1144.548159 364.627563
61 8064.0 1430.084809 1149.203371 365.030858
62 8192.0 1431.182036 1153.362073 364.164979
63 8320.0 1386.695789 1118.357126 361.699284
64 8448.0 1379.656367 1125.895375 362.541203
65 8576.0 1393.771006 1128.216937 363.134272
66 8704.0 1379.031144 1134.942631 364.213475
67 8832.0 1388.138805 1133.391383 365.084989
68 8960.0 1386.700009 1138.409850 365.093379
69 9088.0 1405.546374 1138.953724 366.222653
70 9216.0 1400.954300 1145.494935 367.502705
71 9344.0 1397.886866 1417.300024 367.211097
72 9472.0 1399.194277 1434.730373 368.065791
73 9600.0 1395.681218 1431.760214 368.566286
74 9728.0 1403.789993 1436.229426 369.128670
75 9856.0 1404.152243 1441.354460 369.521132
76 9984.0 1399.092973 1447.426541 370.020609
77 10112.0 1413.381816 1457.301308 370.536871
78 10240.0 1408.383296 1464.573757 371.143244
79 10368.0 1411.015509 1460.998929 369.602644
80 10496.0 1420.051440 1463.670308 369.642014
81 10624.0 1397.957531 1466.415900 370.499610
82 10752.0 1399.922175 1471.166933 370.811899
83 10880.0 1397.349374 1478.006800 371.034930
84 11008.0 1410.009156 1476.764633 372.235421
85 11136.0 1409.301694 1483.967174 372.567138
86 11264.0 1420.793161 1485.378968 372.779723
87 11392.0 1409.217834 1485.781209 373.352515
88 11520.0 1417.917288 1492.322871 373.460245
89 11648.0 1423.079737 1498.591132 374.109196
90 11776.0 1428.723166 1499.274037 374.939454
91 11904.0 1428.738493 1505.082558 374.533850
92 12032.0 1415.296083 1508.362614 375.383062
93 12160.0 1407.833115 1511.752554 375.333369
94 12288.0 1439.027949 1419.218193 375.294529
95 12416.0 1449.711630 1392.795717 373.843948
96 12544.0 1440.404508 1390.104324 374.478324
97 12672.0 1438.414911 1390.694969 374.251626
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.080 seconds)