How to Write a Custom CUDA Kernel for PyTorch

PyTorch’s built-in operators cover the vast majority of deep learning workloads, but there are cases where writing a custom CUDA kernel is the right move: a novel operation with no existing PyTorch equivalent, a fused kernel that eliminates memory round-trips between operations, or a performance-critical path where the overhead of PyTorch’s autograd machinery is measurable. Writing custom CUDA kernels for PyTorch is more accessible than it used to be — the extension mechanism is well-documented, and the iteration loop is fast once your environment is set up. This guide covers the full path from a working CUDA kernel to a differentiable PyTorch operator with proper autograd integration.

When to Write a Custom Kernel

Custom CUDA kernels are worth the investment in a narrow set of cases. The clearest case is a novel operation that doesn’t exist in PyTorch and can’t be efficiently composed from existing operators — a custom activation function, a non-standard attention variant, or a domain-specific transformation. The second case is kernel fusion: combining multiple operations (say, layer norm followed by a linear projection) into a single kernel eliminates the intermediate HBM reads and writes between operations, which matters when memory bandwidth is the bottleneck. The third case is operations with irregular access patterns that PyTorch’s general-purpose operators handle inefficiently — sparse operations, custom indexing schemes, or data-dependent control flow.

Most performance problems are not best solved with custom kernels. torch.compile with Triton as the backend automatically generates fused kernels for many common patterns and should be tried first. Triton, which lets you write GPU kernels in Python with automatic tile and vectorization optimization, covers a large fraction of the custom kernel use cases without requiring CUDA C++. Reserve hand-written CUDA for cases where Triton can’t express the pattern, or where the performance requirement is tight enough that the extra control of CUDA C++ is worth it.

Setting Up the Extension

PyTorch custom extensions use the torch.utils.cpp_extension module to compile C++/CUDA code and expose it to Python. There are two compilation modes: ahead-of-time (AOT) compilation using setuptools, and just-in-time (JIT) compilation using load_inline or load. JIT compilation is better for development — you edit the kernel, reload in Python, and test immediately without a separate build step. AOT compilation produces a proper Python package suitable for distribution.

The minimal project structure for a custom CUDA extension:

my_extension/
├── setup.py
├── my_op.cpp        # C++ bindings (pybind11)
└── my_op_kernel.cu  # CUDA kernel implementation

The CUDA kernel file contains the actual GPU code. The C++ file defines the Python-callable function that validates inputs, launches the kernel with the right grid/block dimensions, and handles the PyTorch tensor interface. setup.py compiles both together into a shared library that Python can import.

A Minimal Working Kernel

Start with a simple example: a fused multiply-add operation that computes output[i] = a[i] * b[i] + c[i] element-wise. This is trivially achievable with PyTorch operators, but it illustrates the full pattern before adding complexity.

The CUDA kernel (my_op_kernel.cu):

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// The actual GPU kernel - one thread per element
__global__ void fused_muladd_kernel(
    const float* __restrict__ a,
    const float* __restrict__ b,
    const float* __restrict__ c,
    float* __restrict__ out,
    int n
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        out[idx] = a[idx] * b[idx] + c[idx];
    }
}

// C++ wrapper called from Python
torch::Tensor fused_muladd_cuda(
    torch::Tensor a,
    torch::Tensor b,
    torch::Tensor c
) {
    TORCH_CHECK(a.device().is_cuda(), "a must be on CUDA");
    TORCH_CHECK(a.sizes() == b.sizes() && a.sizes() == c.sizes(),
                "All inputs must have the same shape");
    TORCH_CHECK(a.scalar_type() == torch::kFloat32, "Only float32 supported");

    auto out = torch::empty_like(a);
    int n = a.numel();

    const int threads = 256;
    const int blocks = (n + threads - 1) / threads;

    fused_muladd_kernel<<<blocks, threads>>>(
        a.data_ptr<float>(),
        b.data_ptr<float>(),
        c.data_ptr<float>(),
        out.data_ptr<float>(),
        n
    );

    // Check for kernel launch errors
    C10_CUDA_CHECK(cudaGetLastError());
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_muladd", &fused_muladd_cuda, "Fused multiply-add (CUDA)");
}

The setup.py:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='my_cuda_op',
    ext_modules=[
        CUDAExtension(
            name='my_cuda_op',
            sources=['my_op_kernel.cu'],
        )
    ],
    cmdclass={'build_ext': BuildExtension}
)

Build and test with:

pip install -e .   # AOT build
python -c "import my_cuda_op, torch; a=torch.ones(1024).cuda(); print(my_cuda_op.fused_muladd(a,a,a)[:5])"

JIT Compilation for Faster Iteration

During development, recompiling with setup.py after every change is slow. torch.utils.cpp_extension.load compiles and loads the extension in-process, caching the compiled binary so subsequent loads are instant if the source hasn’t changed:

from torch.utils.cpp_extension import load

my_op = load(
    name='my_cuda_op',
    sources=['my_op_kernel.cu'],
    verbose=True   # shows compilation output
)

# Use immediately
import torch
a = torch.ones(1024, device='cuda')
result = my_op.fused_muladd(a, a, a)
print(result[:5])  # tensor([3., 3., 3., 3., 3.], device='cuda:0')

load_inline takes source code as strings — useful for quick experiments without creating files:

from torch.utils.cpp_extension import load_inline

cuda_src = """
__global__ void scale_kernel(float* x, float scale, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) x[i] *= scale;
}
"""
cpp_src = """
void scale_cuda(torch::Tensor x, float scale);
"""
# bind and load...

Autograd Integration

A kernel that’s only usable in forward pass isn’t useful for training. Integrating with PyTorch’s autograd requires defining both a forward pass (your CUDA kernel) and a backward pass (either another CUDA kernel or composed from PyTorch ops). The cleanest pattern is torch.autograd.Function:

import torch
import my_cuda_op  # your compiled extension

class FusedMulAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b, c):
        # Save inputs needed for backward
        ctx.save_for_backward(a, b, c)
        return my_cuda_op.fused_muladd(a, b, c)

    @staticmethod
    def backward(ctx, grad_output):
        a, b, c = ctx.saved_tensors
        # d(a*b+c)/da = b, d/db = a, d/dc = 1
        grad_a = grad_output * b
        grad_b = grad_output * a
        grad_c = grad_output.clone()
        return grad_a, grad_b, grad_c

# Convenience wrapper
def fused_muladd(a, b, c):
    return FusedMulAdd.apply(a, b, c)

# Verify gradients match PyTorch reference
a = torch.randn(256, device='cuda', requires_grad=True)
b = torch.randn(256, device='cuda', requires_grad=True)
c = torch.randn(256, device='cuda', requires_grad=True)
torch.autograd.gradcheck(fused_muladd, (a.double(), b.double(), c.double()))

gradcheck numerically verifies that your analytical backward implementation matches finite-difference gradients. Always run it with double precision inputs — the default tolerance assumes float64. A passing gradcheck is a strong indicator that your backward is correct; a failing gradcheck tells you exactly which input’s gradient is wrong.

Thread and Block Indexing

The most common source of CUDA kernel bugs is incorrect thread indexing. The standard 1D indexing pattern — int idx = blockIdx.x * blockDim.x + threadIdx.x — is correct for flat arrays but needs extension for multi-dimensional data. For a 2D matrix of shape [M, N] where you want one thread per element:

__global__ void matrix_kernel(float* data, int M, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        int idx = row * N + col;  // row-major linear index
        data[idx] = /* computation */;
    }
}

// Launch with 2D grid and block
dim3 threads(16, 16);  // 256 threads per block
dim3 blocks((N + 15) / 16, (M + 15) / 16);
matrix_kernel<<<blocks, threads>>>(data, M, N);

Always include bounds checking (if row < M && col < N) — the grid dimensions are rounded up to multiples of the block size, so the last block may contain threads that map to out-of-bounds indices. Accessing out-of-bounds memory produces silent wrong results or crashes, not Python exceptions.

Memory Access Patterns and Performance

GPU memory performance is dominated by coalescing: when threads in a warp access contiguous memory addresses, the accesses are combined into a single transaction. Non-coalesced access — where adjacent threads access non-adjacent memory — serializes into multiple transactions and can reduce effective bandwidth by 10–32x. The rule of thumb: thread index should map linearly to the innermost (contiguous) dimension of your data. For row-major tensors, adjacent threads should access adjacent columns in the same row, not adjacent rows in the same column.

Shared memory is a small, fast on-chip memory (typically 48–164KB per SM on modern GPUs) that all threads in a block can read and write. The classic use case is tiled matrix multiplication: each block loads a tile of the input matrices into shared memory, performs the computation from shared memory (fast), then writes results to global memory. This reduces global memory accesses from O(N^3) to O(N^3/tile_size), which is the key optimization behind fast GEMM implementations. For operations where the same input data is read multiple times by different threads in a block, shared memory is the right tool.

Debugging Kernels

CUDA kernels run asynchronously — by default, a kernel launch returns immediately and the kernel runs concurrently with Python execution. Errors in kernel execution (out-of-bounds access, NaN propagation) don’t raise Python exceptions at the launch site; they surface later, often as cryptic errors on unrelated operations. Two practices that make debugging much easier: always call C10_CUDA_CHECK(cudaGetLastError()) immediately after kernel launches in your C++ wrapper to catch launch-time errors, and during development set CUDA_LAUNCH_BLOCKING=1 in your environment to force synchronous kernel execution, which makes errors appear at the correct line in Python tracebacks.

NVIDIA’s compute-sanitizer (successor to cuda-memcheck) catches memory errors — out-of-bounds accesses, race conditions, and uninitialized memory reads — with much better diagnostics than raw CUDA error codes. Run it as compute-sanitizer python your_script.py. It adds significant overhead (10–100x slowdown) so use it only during debugging, not profiling. For performance analysis, use NVIDIA Nsight Systems for timeline profiling (which kernels are running when, where are the gaps) and Nsight Compute for kernel-level analysis (memory throughput, compute utilization, warp efficiency, memory access patterns).

Supporting Multiple Dtypes

Hard-coding float32 in your kernel limits usability. PyTorch’s AT_DISPATCH_FLOATING_TYPES macro generates dispatch code that calls your templated kernel for each supported dtype:

torch::Tensor my_op_cuda(torch::Tensor input) {
    auto output = torch::empty_like(input);
    int n = input.numel();
    const int threads = 256;
    const int blocks = (n + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
        input.scalar_type(), "my_op_cuda",
        [&]() {
            my_kernel<<<blocks, threads>>>(
                input.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n
            );
        }
    );
    return output;
}

// The kernel template:
template <typename scalar_t>
__global__ void my_kernel(
    const scalar_t* input,
    scalar_t* output,
    int n
) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) output[i] = /* op */;
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF covers float32, float64, and float16. For bfloat16 support (important for modern LLM workloads), use AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, …).

From Prototype to Production

A kernel that works in a notebook is not yet production-ready. Before using a custom kernel in a training or serving pipeline: run gradcheck on every input shape and dtype you expect in production; benchmark against the PyTorch reference implementation to verify the speedup justifies the maintenance cost; test at the boundary conditions (batch size 1, very long sequences, non-multiples of your tile size); and verify that the kernel produces bitwise-identical results to the reference for deterministic inputs. Add the extension as a proper Python package with version pinning — a kernel compiled against PyTorch 2.3 will likely not work with PyTorch 2.5 without recompilation, and silently loading a stale binary produces wrong results without error. Pin the PyTorch version in your requirements.txt alongside the extension version, and rebuild the extension as part of your environment setup rather than distributing pre-compiled binaries.

Understanding the GPU Execution Model

Writing correct and efficient CUDA kernels requires a mental model of how GPU execution actually works. A CUDA kernel is executed by a grid of thread blocks, where each block runs on a single Streaming Multiprocessor (SM). An A100 has 108 SMs; an H100 has 132. Each SM executes threads in groups of 32 called warps. All threads in a warp execute the same instruction simultaneously — this is the SIMT (Single Instruction Multiple Thread) execution model. When threads within a warp take different code paths (warp divergence), the warp serializes the execution of each path, with inactive threads masked out. Minimizing warp divergence is therefore important for performance: if your kernel has conditional branches where about half the threads take each path, you’re running at roughly 50% efficiency on those branches.

Occupancy is the ratio of active warps per SM to the maximum number of warps the SM can support. Higher occupancy generally means better latency hiding — when one warp is stalled waiting for a memory load, the SM switches to another active warp and executes its instructions. The limiting factors on occupancy are register usage (each thread uses registers from a fixed per-SM register file), shared memory usage (each block consumes shared memory from a fixed per-SM pool), and block size. The NVIDIA CUDA Occupancy Calculator, or the equivalent in Nsight Compute’s occupancy analysis view, shows you the occupancy of your kernel and which resource is the binding constraint. A common optimization is to reduce per-thread register usage (using the __launch_bounds__ qualifier to cap registers, or restructuring the kernel to reduce live variables) to increase occupancy and hide memory latency better.

Memory hierarchy matters enormously for performance. From fastest to slowest: registers (per-thread, zero-latency access), L1 cache/shared memory (per-SM, ~5 cycle latency, ~19 TB/s on A100), L2 cache (chip-wide, ~50 cycle latency), and HBM global memory (~300–500 cycle latency, 2 TB/s on A100). Most CUDA optimization work is about moving data to faster memory before it’s needed — loading into shared memory before repeated use, keeping hot data in registers, and structuring access patterns to maximize L1/L2 hit rates. The arithmetic intensity of your kernel — ratio of floating point operations to bytes of memory access — determines whether it’s compute-bound or memory-bound, which determines where optimization effort is best spent.

Triton as an Alternative

Before committing to writing CUDA C++, evaluate whether Triton can solve your problem. Triton is a Python DSL for GPU kernels that compiles to efficient PTX (the GPU instruction set), handles tile scheduling and vectorization automatically, and runs in pure Python without a C++ compilation step. The iteration loop with Triton — edit kernel, re-run Python, see results — is dramatically faster than the CUDA C++ compile-edit-debug cycle. For many custom operations, especially attention variants, normalization operations, and custom activation functions, Triton produces performance within 10–20% of hand-tuned CUDA C++ with a fraction of the development effort.

Triton has meaningful limitations: it doesn’t support arbitrary data-dependent control flow (the tiling structure must be known at compile time), doesn’t provide direct access to shared memory for custom synchronization patterns, and doesn’t support warp-level primitives like warp shuffle instructions. Operations that require these capabilities — custom reductions with warp-level synchronization, irregular sparse operations with data-dependent memory access patterns, or kernels that need precise control over the shared memory layout for bank conflict avoidance — genuinely require CUDA C++. Everything else, try Triton first. The decision tree is: torch.compile → Triton → CUDA C++, in order of increasing implementation effort and decreasing abstraction level. Move down the chain only when the level above can’t express what you need or doesn’t hit your performance target.

Testing Strategy

Custom CUDA kernels require more thorough testing than Python code because failures are harder to diagnose and can produce silent wrong results instead of exceptions. The minimum testing suite for a custom kernel covers: correctness against a PyTorch reference implementation on random inputs across all supported shapes and dtypes; edge cases including batch size 1, sequence length 1, and sizes that are not multiples of your tile or block size; gradient correctness via gradcheck with double precision; and determinism — running the same inputs twice should produce bitwise-identical outputs unless your kernel intentionally uses non-deterministic atomics. Add a benchmark that compares your kernel’s throughput against the PyTorch reference to catch performance regressions when you modify the implementation. Regression testing is easy to skip when you’re iterating quickly but pays off every time you change the kernel and inadvertently introduce a correctness bug or performance regression that would otherwise go unnoticed until it causes problems in a training run hours or days later.

Leave a Comment