Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch  Naive Softmax
0     256.0   476.279060   697.239179     201.908777
1     384.0   663.882257   808.823647     258.604873
2     512.0   812.111250   909.045790     297.586692
3     640.0   909.904620   907.390931     329.359739
4     768.0   976.228323   977.483421     346.140834
5     896.0  1040.291805  1041.948757     355.228326
6    1024.0  1080.628141  1074.774861     351.091691
7    1152.0  1099.908856  1078.956050     350.125776
8    1280.0  1132.107861  1105.339218     349.406742
9    1408.0  1162.128358  1140.686185     340.777518
10   1536.0  1190.060476  1165.936415     332.762317
11   1664.0  1217.873645  1189.115616     329.182308
12   1792.0  1223.912892  1193.452613     323.568325
13   1920.0  1257.743903  1225.772221     323.060389
14   2048.0  1276.873186  1245.931568     324.645535
15   2176.0  1239.351740   960.572959     326.302885
16   2304.0  1254.752262  1003.101999     325.833610
17   2432.0  1268.597425  1037.286124     327.737498
18   2560.0  1290.266887  1066.451616     328.210605
19   2688.0  1294.180114  1096.325865     329.650704
20   2816.0  1308.352423  1126.531326     330.656837
21   2944.0  1320.499039  1139.689148     330.694707
22   3072.0  1321.273142  1168.100501     333.362432
23   3200.0  1332.143220  1168.680377     334.547697
24   3328.0  1347.520181  1201.424864     335.841524
25   3456.0  1355.208973  1220.089730     336.925429
26   3584.0  1362.607413  1249.471736     338.657881
27   3712.0  1360.657527  1267.838332     340.279130
28   3840.0  1367.783303  1285.744681     340.845142
29   3968.0  1379.150892  1296.483464     340.777685
30   4096.0  1387.647940  1313.205384     338.369841
31   4224.0  1337.263281  1280.018608     343.198865
32   4352.0  1343.406785  1302.331356     345.357753
33   4480.0  1350.667535  1316.482104     346.212908
34   4608.0  1357.887844  1336.270658     346.415171
35   4736.0  1356.091771  1347.145183     347.449133
36   4864.0  1368.226650  1352.791378     348.957387
37   4992.0  1371.740387  1373.534955     350.234184
38   5120.0  1379.151412  1385.215462     350.902120
39   5248.0  1377.798033  1359.163671     351.781608
40   5376.0  1373.459733  1375.739124     352.175701
41   5504.0  1381.787801  1378.594517     353.957810
42   5632.0  1393.624346  1394.638075     353.226998
43   5760.0  1392.423425  1403.326835     355.023523
44   5888.0  1391.261887  1405.215682     355.427373
45   6016.0  1399.430656  1428.684971     356.350347
46   6144.0  1413.081912  1428.299676     356.885096
47   6272.0  1406.622327  1401.236108     358.223857
48   6400.0  1413.815600  1411.124008     358.235371
49   6528.0  1421.229651  1418.681986     359.490202
50   6656.0  1415.975207  1430.556599     359.673317
51   6784.0  1417.311466  1435.203609     360.652770
52   6912.0  1429.374343  1449.973516     360.877377
53   7040.0  1419.662365  1450.209807     360.952186
54   7168.0  1415.085956  1454.398138     361.290323
55   7296.0  1420.791189  1086.269219     362.186664
56   7424.0  1427.091958  1098.575677     362.558588
57   7552.0  1430.142019  1113.176873     363.122387
58   7680.0  1432.625845  1125.488647     363.952958
59   7808.0  1430.996230  1131.426651     364.403182
60   7936.0  1430.488490  1144.548159     364.627563
61   8064.0  1430.084809  1149.203371     365.030858
62   8192.0  1431.182036  1153.362073     364.164979
63   8320.0  1386.695789  1118.357126     361.699284
64   8448.0  1379.656367  1125.895375     362.541203
65   8576.0  1393.771006  1128.216937     363.134272
66   8704.0  1379.031144  1134.942631     364.213475
67   8832.0  1388.138805  1133.391383     365.084989
68   8960.0  1386.700009  1138.409850     365.093379
69   9088.0  1405.546374  1138.953724     366.222653
70   9216.0  1400.954300  1145.494935     367.502705
71   9344.0  1397.886866  1417.300024     367.211097
72   9472.0  1399.194277  1434.730373     368.065791
73   9600.0  1395.681218  1431.760214     368.566286
74   9728.0  1403.789993  1436.229426     369.128670
75   9856.0  1404.152243  1441.354460     369.521132
76   9984.0  1399.092973  1447.426541     370.020609
77  10112.0  1413.381816  1457.301308     370.536871
78  10240.0  1408.383296  1464.573757     371.143244
79  10368.0  1411.015509  1460.998929     369.602644
80  10496.0  1420.051440  1463.670308     369.642014
81  10624.0  1397.957531  1466.415900     370.499610
82  10752.0  1399.922175  1471.166933     370.811899
83  10880.0  1397.349374  1478.006800     371.034930
84  11008.0  1410.009156  1476.764633     372.235421
85  11136.0  1409.301694  1483.967174     372.567138
86  11264.0  1420.793161  1485.378968     372.779723
87  11392.0  1409.217834  1485.781209     373.352515
88  11520.0  1417.917288  1492.322871     373.460245
89  11648.0  1423.079737  1498.591132     374.109196
90  11776.0  1428.723166  1499.274037     374.939454
91  11904.0  1428.738493  1505.082558     374.533850
92  12032.0  1415.296083  1508.362614     375.383062
93  12160.0  1407.833115  1511.752554     375.333369
94  12288.0  1439.027949  1419.218193     375.294529
95  12416.0  1449.711630  1392.795717     373.843948
96  12544.0  1440.404508  1390.104324     374.478324
97  12672.0  1438.414911  1390.694969     374.251626
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.080 seconds)

Gallery generated by Sphinx-Gallery