Triton is a Python DSL for writing GPU kernels that compiles to efficient PTX without requiring CUDA C++. It was developed by OpenAI and is now the default backend for torch.compile’s kernel generation — when you run torch.compile on a model, most of the fused kernels it generates are Triton kernels. The core idea is that you write kernels at the tile level rather than the thread level: instead of specifying what each individual thread does, you specify what each block of threads does on a tile of data, and the Triton compiler handles vectorization, shared memory management, and thread scheduling automatically. This makes Triton significantly easier to write correctly than CUDA C++ while producing performance within 10–20% of hand-tuned CUDA for most memory-bandwidth-bound operations.
This guide walks through writing practical Triton kernels from scratch: the programming model, a minimal working example, how to handle multi-dimensional data, autotuning, and integrating with PyTorch’s autograd system.
Installation and Environment
Triton ships with PyTorch 2.0+ and requires no separate installation if you’re already using a recent PyTorch version. For standalone use or to get the latest version:
pip install triton # standalone install
python -c "import triton; print(triton.__version__)"
Triton requires a CUDA-capable GPU and a compatible CUDA toolkit. It works on NVIDIA GPUs from Turing architecture (RTX 20xx, T4) onwards. AMD GPU support exists but is less mature. All Triton kernel compilation happens JIT at the first call — there’s no separate build step, which is one of its major advantages over CUDA C++ extensions.
The Triton Programming Model
Understanding Triton’s abstraction is the key to writing kernels correctly. In CUDA, you think in terms of individual threads: each thread has a unique (blockIdx, threadIdx) and operates on one or a few elements. In Triton, you think in terms of programs and blocks of data: each program instance (analogous to a CUDA thread block) operates on a contiguous block of elements called a tile. The block size is a compile-time constant that you pass as a parameter.
The critical primitive is tl.load and tl.store with pointer arithmetic and masks. Instead of indexing individual elements, you compute a block of pointers and load/store all of them in one operation. The mask parameter handles boundary conditions — elements with mask=False are not loaded, preventing out-of-bounds access without an explicit if statement in the kernel.
import triton
import triton.language as tl
@triton.jit
def my_kernel(
x_ptr, # pointer to first element of input
y_ptr, # pointer to output
n_elements, # total number of elements
BLOCK_SIZE: tl.constexpr, # tile size, must be power of 2
):
# Each program instance handles one tile
pid = tl.program_id(axis=0) # which tile am I?
# Compute the range of indices this program handles
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Mask out-of-bounds accesses
mask = offsets < n_elements
# Load input tile
x = tl.load(x_ptr + offsets, mask=mask)
# Compute
y = x * x # element-wise square
# Store output tile
tl.store(y_ptr + offsets, y, mask=mask)
The @triton.jit decorator marks this as a Triton kernel. BLOCK_SIZE is annotated as tl.constexpr — it must be a compile-time constant because Triton uses it to size internal arrays and optimize memory access patterns. In practice, BLOCK_SIZE is always a power of 2 between 16 and 2048.
Launching the Kernel from Python
Triton kernels are launched with a grid function that determines how many program instances to create — analogous to the CUDA grid size:
import torch
def squared(x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
n = x.numel()
# Grid: one program per tile, ceiling division
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
my_kernel[grid](
x, y, n,
BLOCK_SIZE=1024, # must be power of 2
)
return y
# Test
x = torch.randn(1 << 20, device='cuda') # 1M elements
y = squared(x)
torch.testing.assert_close(y, x ** 2)
print("correct!")
The grid is a tuple of up to 3 integers (or a callable that receives the kernel's meta-parameters and returns a tuple). For 1D operations, a single-element tuple suffices. The kernel is compiled and cached on the first call; subsequent calls with the same BLOCK_SIZE reuse the compiled binary.
A Practical Kernel: Fused Softmax
Softmax is a canonical example where a fused Triton kernel meaningfully outperforms the naive PyTorch implementation. The naive approach reads the input three times (once for the max reduction, once for the exp, once for the sum reduction) and writes intermediate results to HBM between each step. A fused kernel does all three passes in shared memory, reducing HBM traffic by roughly 3x.
@triton.jit
def softmax_kernel(
output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# Each program handles one row
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
# Load the row, masking padding with -inf
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract max for numerical stability
row_max = tl.max(row, axis=0)
row = row - row_max
# Exponentiate and sum
numerator = tl.exp(row)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
out_start_ptr = output_ptr + row_idx * output_row_stride
out_ptrs = out_start_ptr + col_offsets
tl.store(out_ptrs, softmax_output, mask=mask)
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 2
n_rows, n_cols = x.shape
# BLOCK_SIZE must be >= n_cols and a power of 2
BLOCK_SIZE = triton.next_power_of_2(n_cols)
y = torch.empty_like(x)
softmax_kernel[(n_rows,)](
y, x,
x.stride(0), y.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
# Validate
x = torch.randn(1024, 512, device='cuda')
torch.testing.assert_close(triton_softmax(x), torch.softmax(x, dim=1), atol=1e-4, rtol=0)
This kernel processes one row per program instance, loading the entire row into registers. The BLOCK_SIZE must be at least as large as n_cols, capped at a hardware limit (typically 65536 for a single program instance). For very wide rows (vocabulary-size softmax with n_cols=128K+), you need a different approach that splits the row across multiple programs and uses atomic operations for the final reduction.
Autotuning Block Sizes
Choosing the right BLOCK_SIZE manually is tedious and hardware-dependent. Triton's autotuning decorator automates this by benchmarking a set of configurations and selecting the fastest one at the first call:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
],
key=['n_elements'], # retune when n_elements changes
)
@triton.jit
def autotuned_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, tl.sqrt(x), mask=mask)
The key parameter specifies which arguments trigger a retune — here, different input sizes may have different optimal block sizes, so n_elements is the key. Autotuning adds latency on first calls for new key values (it runs all configs and picks the best), but the winning config is cached for subsequent calls with the same key. For production kernels, you typically autotune during development, record the winning config, and hard-code it to avoid the autotuning overhead at runtime.
More advanced configs can specify num_warps (how many warps per program instance, default 4) and num_stages (pipeline stages for software prefetching, critical for memory-bound kernels on Ampere/Hopper). A typical matmul autotune config might look like triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8).
2D Kernels: Matrix Operations
For 2D operations, Triton supports 2D and 3D program grids. The canonical example is matrix multiplication, but the pattern applies to any operation on 2D tensors:
@triton.jit
def add_matrix_kernel(
a_ptr, b_ptr, c_ptr,
M, N,
stride_am, stride_an,
stride_bm, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# 2D program grid: one program per (BLOCK_M x BLOCK_N) tile
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
# Compute offsets for this tile
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# 2D pointer arithmetic
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
b_ptrs = b_ptr + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
# Boundary masks
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=mask)
b = tl.load(b_ptrs, mask=mask)
tl.store(c_ptrs, a + b, mask=mask)
def triton_add(a, b):
assert a.shape == b.shape and a.is_contiguous()
M, N = a.shape
c = torch.empty_like(a)
BLOCK_M, BLOCK_N = 32, 32
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
add_matrix_kernel[grid](
a, b, c, M, N,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
)
return c
The [:, None] and [None, :] syntax broadcasts 1D offset arrays into 2D tiles — this is idiomatic Triton. Passing strides explicitly (a.stride(0), a.stride(1)) allows the kernel to handle non-contiguous tensors correctly and is best practice even when you expect contiguous inputs.
Autograd Integration
Triton kernels integrate with PyTorch autograd the same way as CUDA extensions — wrap them in torch.autograd.Function:
class TritonSquaredReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
y = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
squared_relu_fwd_kernel[grid](x, y, n, BLOCK_SIZE=1024)
return y
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
grad_input = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
squared_relu_bwd_kernel[grid](grad_output, x, grad_input, n, BLOCK_SIZE=1024)
return grad_input
squared_relu = TritonSquaredReLU.apply
Always verify gradients with torch.autograd.gradcheck using double-precision inputs before using the kernel in training. The default atol tolerance in gradcheck is appropriate for float64; if you're running float32 in production you should also test at float32 and set a wider tolerance that reflects float32 rounding.
Debugging and Profiling Triton Kernels
Triton kernels are harder to debug than Python code but easier than CUDA C++. The most effective debugging workflow is to run on CPU with a small input first — Triton has an interpreter mode (TRITON_INTERPRET=1 python script.py) that runs the kernel in pure Python, allowing you to print intermediate values and use a regular Python debugger. The interpreter is slow but makes it possible to inspect the value of any tl variable at any point in the kernel, which would be impossible in compiled mode.
For performance profiling, triton.testing.do_bench is the standard tool. It runs the kernel repeatedly, excludes the first few warmup iterations, and returns the median execution time in milliseconds:
import triton.testing
ms = triton.testing.do_bench(lambda: triton_softmax(x))
gbps = x.numel() * x.element_size() * 2 / ms * 1e-6 # read + write
print(f"{ms:.3f} ms, {gbps:.1f} GB/s")
Compare the measured GB/s against the device's peak memory bandwidth (available from torch.cuda.get_device_properties) to determine how close to roofline you are. A memory-bandwidth-bound kernel hitting 80–90% of peak bandwidth is well-optimized; significantly lower suggests memory access inefficiency, wrong BLOCK_SIZE, or insufficient occupancy.
When Triton Is the Right Choice
Triton covers the majority of cases where you'd otherwise reach for a custom CUDA kernel. It's the right tool when you need a fused operation that torch.compile doesn't automatically generate, a novel activation function or normalization variant, or a custom attention mechanism with non-standard masking or scoring. The iteration speed — edit Python, rerun, see results in seconds — makes it substantially more productive than the CUDA C++ compile loop.
The cases where Triton falls short and CUDA C++ is genuinely needed: warp-level primitives (warp shuffle, vote functions) for custom reductions that can't be expressed as tiled reductions; irregular memory access patterns where the tile structure can't be determined at compile time; and operations requiring precise shared memory bank conflict control that Triton's automatic shared memory management doesn't handle optimally. For everything else, reach for Triton first. The decision tree remains torch.compile → Triton → CUDA C++, spending more development time only when the previous level is insufficient.
Practical Tips
A few patterns that matter in practice: always use tl.constexpr for any value that affects array shapes or loop bounds inside the kernel — the compiler needs these to be constants to unroll loops and allocate registers. Keep BLOCK_SIZE as a power of 2; non-power-of-2 values work but produce suboptimal code. When passing tensor arguments to a Triton kernel, pass the raw pointer (the tensor itself, since Triton knows how to extract the pointer) rather than calling .data_ptr() — Triton handles pointer extraction automatically and passing .data_ptr() gives a Python int that loses dtype information. Use tl.debug_barrier() to insert synchronization points when debugging race conditions in kernels with shared state, but remove them in production since they add overhead. And always benchmark on the target hardware — BLOCK_SIZE choices that are optimal on an A100 may not be optimal on an H100 or a T4, and the performance difference between good and bad tile sizes can be 2–5x.
How Triton Compiles Your Kernel
Understanding what happens when Triton compiles your kernel helps you write better code and diagnose performance issues. When you call a Triton-decorated function for the first time with a given set of constexpr values, Triton runs through several compilation stages: it parses your Python AST, converts it to an intermediate representation called Triton IR, applies optimization passes (including automatic shared memory allocation, vectorization, and loop unrolling), lowers it to LLVM IR, and finally emits PTX which the NVIDIA driver JIT-compiles to SASS (the native GPU instruction set). This entire process takes around 1–5 seconds the first time for a moderately complex kernel. The compiled PTX is cached to disk in ~/.triton/cache by default, keyed by the kernel source code hash and the constexpr values, so subsequent runs load from cache and launch immediately.
One practical consequence of this compilation model: if you're iterating on a kernel and want to force recompilation (for example after updating Triton itself), delete ~/.triton/cache or set the TRITON_CACHE_DIR environment variable to a fresh directory. Another consequence is that constexpr values are baked into the compiled binary — changing BLOCK_SIZE from 256 to 512 produces a different binary, which is why autotuning must compile and benchmark all configurations. The upside is that the compiler can fully unroll loops bounded by constexpr values and can allocate exactly the right number of registers, producing code quality comparable to hand-written CUDA.
Triton vs CUDA C++: A Practical Comparison
The choice between Triton and CUDA C++ comes down to three factors: how much control you need over the GPU execution, how much time you have, and what kind of operation you're implementing. Triton's abstraction handles shared memory allocation, warp scheduling, and vectorization automatically. This means you can't directly control bank conflicts in shared memory, can't use warp shuffle intrinsics for custom reductions, and can't implement irregular access patterns where the tile structure is data-dependent. For the vast majority of ML operations — elementwise kernels, reductions, attention variants, normalization, custom activations — none of these limitations matter, and Triton delivers good-to-excellent performance with a fraction of the development effort.
CUDA C++ gives you complete control: explicit shared memory declarations with bank-conflict-free layouts, warp-level primitives for efficient reductions, asynchronous memory copies for software pipelining, and fine-grained control over register allocation. This control is valuable for implementing high-performance GEMM kernels, custom sparse attention patterns with irregular access, or operations that require warp-synchronous communication between threads. The cost is substantially higher implementation complexity, a slower iteration loop (compile-edit-debug instead of edit-run), and more surface area for hard-to-find bugs. In practice, most ML engineers writing custom GPU kernels should start with Triton, move to CUDA C++ only when Triton can't express what's needed or doesn't hit the required performance bar, and use FlashAttention or CUTLASS as CUDA C++ references rather than starting from scratch.
Common Pitfalls
Several mistakes come up repeatedly when writing Triton kernels for the first time. The most common is forgetting the mask when loading near boundaries — if your tensor size isn't a multiple of BLOCK_SIZE, the last tile will have out-of-bounds indices, and loading without a mask produces undefined values that corrupt your output silently. Always compute mask = offsets < n_elements and pass it to tl.load and tl.store. The second common mistake is using a non-power-of-2 BLOCK_SIZE — Triton technically supports this, but the compiler generates significantly worse code and the performance regression can be large. Stick to powers of 2. Third, if you're seeing NaNs or wrong values that only appear for certain input sizes, check whether your mask is correct for the boundary tile — a wrong mask that happens to not matter for size-1024 inputs may cause problems for size-1000 inputs where the last tile is partially out of bounds. Finally, remember that tl.load with mask=False for an element returns an undefined value by default — if your computation involves those masked-out elements before you store (for example, in a reduction), use the other parameter (tl.load(..., mask=mask, other=0.0)) to fill them with a safe value like 0.0 or -inf depending on the operation.