Development & Optimization

Delivering the Missing Building Blocks for NVIDIA CUDA Kernel Fusion in Python

Decorative image.

C++ libraries like CUB and Thrust provide high-level building blocks that enable NVIDIA CUDA application and library developers to write speed-of-light code that is portable across architectures. Many widely used projects, such as PyTorch, TensorFlow, XGBoost, and RAPIDS, use these abstractions to implement core functionality.

The same abstractions are missing in Python. There are high-level array and tensor libraries such as CuPy and PyTorch, and low-level kernel authoring tools like numba.cuda. However, the lack of “building blocks” forces Python library developers to drop down to C++ to implement custom algorithms.

Introducing cuda.cccl

cuda.cccl provides Pythonic interfaces to the CUDA Core Compute Libraries CUB and Thrust. Instead of using C++ or writing complex CUDA kernels from scratch, you can now compose algorithms that deliver the best performance across different GPU architectures.

cuda.cccl is composed of two libraries:

  • parallel provides composable algorithms that act on entire arrays, tensors, or data ranges (iterators).
  • cooperative enables you to write fast, flexible numba.cuda kernels by providing algorithms that act on blocks or warps.

This post introduces the parallel library.

A simple example: custom reduction

To show what cuda.cccl can do, here’s a toy example that combines pieces of functionality from parallel to compute the sum 1 – 2 + 3 – 4 + … N.  

See the full code example.

# define some simple Python functions that we'll use later
def add(x, y): return x + y


def transform(x):
    return -x if x % 2 == 0 else x

# create a counting iterator to represent the sequence 1, 2, 3, ... N
counts = parallel.CountingIterator(np.int32(1))

# create a transform iterator to represent the sequence 1, -2, 3, ... N
seq = parallel.TransformIterator(counts, transform)

# create a reducer object for computing the sum of the sequence
out = cp.empty(1, cp.int32)  # holds the result
reducer = parallel.reduce_into(seq, out, add, initial_value)

# compute the amount of temporary storage needed for the
# reduction, and allocate a tensor of that size
tmp_storage_size = reducer(None, seq, out, size, initial_value)
tmp_storage = cp.empty(temp_storage_size, cp.uint8)


# compute the sum, passing in the required temporary storage
reducer(tmp_storage, seq, out, num_items, initial_value)
print(out)  # out contains the result

Is it fast? 

Let’s time the algorithm we just built using parallel alongside a naive implementation that uses CuPy’s array operations. These timings were done on an NVIDIA RTX 6000 Ada Generation. See the comprehensive benchmarking script.

Here are the timings using array operations: 

seq = cp.arange(1, 10_000_000)

%timeit cp.cuda.runtime.deviceSynchronize(); (seq * (-1) ** (seq + 1)).sum(); cp.cuda.runtime.deviceSynchronize()
690 μs ± 266 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Here are the timings using the algorithm we built using parallel

seq = TransformIterator(CountingIterator(np.int32(1)), transform_op)

def parallel_reduction(size):
     temp_storage_size = reducer(None, seq, out_tensor, size, initial_value)
     temp_storage = cp.empty(1, dtype=cp.uint8)
     reducer(temp_storage, seq, out_tensor, size, initial_value)
     return out_tensor

%timeit cp.cuda.runtime.deviceSynchronize(); parallel_reduction(); cp.cuda.runtime.deviceSynchronize()
28.3 μs ± 793 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

We see that our approach, combining parallel iterators and algorithms, is faster than a naive approach. 

Where does the speedup come from?

Many core CUDA operations in CuPy use CUB and Thrust—the same C++ libraries that parallel exposes to Python. Why do we see better performance using parallel?

There’s no magic here. parallel enables more control and flexibility for you to write general-purpose algorithms. In particular:

  • Less memory allocation: A major advantage of parallel is the ability to use iterators like CountingIterator and TransformIterator as inputs to algorithms like reduce_into. Iterators can represent sequences without allocating memory for them. 
  • Explicit kernel fusion: Using iterators in this way “fuses” all the work into a single kernel—the naive CuPy code snippet launches four kernels (see if you can spot them all). The custom algorithm we built using parallel combines them into a single reduction.

    Note that this is different from the fusion done through @torch.compile, for example. It’s explicit, meaning you control how things are fused, rather than implicit, where the compiler controls fusion. This control enables you to fuse operations where the compiler might fail to do so.
  • Less Python overhead: Finally, parallel is a lower-level library and a thin layer on top of the underlying CUB/Thrust functionality. By using parallel, you don’t have to jump through multiple layers of Python before invoking device code.  

Who is cuda.cccl for?

The left panel shows the architecture stack of CUDA-enabled Python packages today including PyTorch and CuPy. There is a noticeable gap between User Extensions and CUDA C++ Libraries. The Right Panel shows the same stack but with cuda.cccl filling the gap between user extensions and CUDA C++.
Figure 1. Architecture of CUDA-enabled Python packages like PyTorch and CuPy today (left), and the gap filled by cuda.cccl (right)

The goal of cuda.cccl isn’t to replace CuPy, PyTorch, or any existing Python libraries. Instead, it’s meant to make it easier to implement such libraries, or to extend them and implement custom operations with CuPy arrays or PyTorch tensors more efficiently.

In particular, look to cuda.cccl when:

  • Building a custom algorithm that can be composed from simpler ones, such as reduce, scan, transform, etc..
  • Creating and operating on sequences without allocating any memory for them (using iterators).
  • Defining and operating on custom “structured” data types that are composed of simpler data types. We have an example of how to do this.
  • Using CUDA C++ and writing custom Python bindings to Thrust or CUB abstractions. With cuda.cccl you can use these capabilities directly from Python.

The cuda.cccl APIs are intentionally low-level and closely mimic the underlying C++ designs. This keeps them as lightweight and low-overhead as possible while exposing the same powerful building blocks used internally by many libraries like CuPy and PyTorch.

Next steps

Now that you have a taste of cuda.cccl and its capabilities, give it a try. Installation is a single pip command. 

pip install cuda-cccl

Next, check out our docs, examples, and report any issues or feature requests on our GitHub repository.

Discuss (6)

Tags