How to Optimize Transformer Inference with torch.compile

torch.compile, introduced in PyTorch 2.0, is the highest-impact single-line optimization available for transformer inference. On modern NVIDIA GPUs (A100, H100, RTX 30/40 series), compiled models typically run 20–40% faster than eager mode, with some workloads seeing 50% or more improvement. The improvement is free in the sense that it requires no model architecture changes, no manual CUDA kernel writing, and no quantization trade-offs — just a one-line wrapper around your existing model. Understanding what torch.compile actually does, where it helps most, and where it falls short lets you use it effectively rather than cargo-culting it.

What torch.compile Actually Does

PyTorch’s default execution mode is eager: operations execute one at a time as Python encounters them, with the Python interpreter involved in every step. This has high overhead — kernel launch latency, Python GIL contention, and no opportunity to fuse operations that could be combined into a single GPU kernel.

torch.compile replaces this with a compiled execution path. TorchDynamo captures the computation graph by tracing through your Python code, handling control flow and dynamic shapes. TorchInductor then lowers the graph to optimized CUDA code: it fuses elementwise operations (removing intermediate tensor materializations), selects optimal kernel implementations, and generates Triton kernels for operations that benefit from custom implementations. The result is fewer kernel launches, less memory bandwidth pressure from intermediate tensors, and better utilization of GPU compute resources.

The first call to a compiled function is slow — compilation takes 30 seconds to several minutes depending on model size and mode. Subsequent calls run at the compiled speed. For inference serving, this means you warm up the compiled model before accepting traffic, not after. In training, the compilation cost is amortized over thousands of steps. Always benchmark compiled vs. eager on your specific workload after warmup — the speedup varies by model architecture, input shapes, and GPU generation.

Basic Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = model.to("cuda").eval()

# One line to compile
model = torch.compile(model)

# Warm up before benchmarking or serving
with torch.no_grad():
    dummy = torch.randint(0, 1000, (1, 128), device="cuda")
    for _ in range(3):
        model(dummy)

# Now benchmark
import time
with torch.no_grad():
    start = time.perf_counter()
    for _ in range(100):
        out = model(dummy)
    torch.cuda.synchronize()
    print(f"{(time.perf_counter() - start) / 100 * 1000:.2f}ms per forward pass")

Compilation Modes

torch.compile has four modes that trade compilation time and aggressiveness for runtime speed.

“default” is the balanced starting point. It applies safe optimizations with reasonable compile time and typically achieves 20–30% speedup on transformer inference. Use this unless you have a specific reason to use another mode.

“reduce-overhead” minimizes kernel launch overhead by using CUDA graphs — capturing a sequence of GPU operations and replaying them with minimal CPU involvement. This is most effective for small models or batch sizes where kernel launch latency is a significant fraction of total time. It requires static input shapes (no dynamic sequence lengths) and has strict requirements about model side effects. On latency-sensitive single-request serving, reduce-overhead can give 30–50% speedup over default mode.

“max-autotune” searches for the optimal kernel configurations by benchmarking multiple implementations at compile time. Compilation takes significantly longer (minutes for large models) but can produce the fastest runtime. Worth using for production deployments where the compilation cost is paid once and the model runs for a long time, but not for development iteration where you’re compiling frequently.

“max-autotune-no-cudagraphs” is max-autotune without CUDA graphs — useful when CUDA graphs’ static shape requirement is incompatible with your workload but you still want the kernel tuning.

Dynamic Shapes and Recompilation

The biggest practical gotcha with torch.compile is recompilation triggered by changing input shapes. By default, torch.compile recompiles whenever it sees a new input shape. For inference serving where sequence lengths vary per request, this means either padding all inputs to a fixed length, bucketing inputs into a small set of shapes, or using dynamic=True to enable dynamic shape compilation.

Padding to a fixed maximum length is the simplest approach and works well when your maximum sequence length is small. For a chatbot where responses are typically under 512 tokens, padding to 512 and using reduce-overhead mode with CUDA graphs is a clean and effective setup.

Bucketing groups sequence lengths into bins (e.g., 64, 128, 256, 512, 1024) and pads each request to the next bucket size. You pay a one-time compilation cost per bucket and accept mild padding overhead, but avoid recompilation at every new length. This is the most common production approach for variable-length inference.

dynamic=True tells TorchDynamo to generate code that handles variable shapes without recompilation, using symbolic shapes. It’s more flexible but produces less optimized code than static compilation and has higher implementation complexity. For workloads with truly unbounded length variation and strict latency requirements, dynamic compilation is worth evaluating, but benchmark carefully — the overhead relative to static compilation can erode the gains.

# Enable dynamic shapes
model = torch.compile(model, dynamic=True)

# Or use bucketing manually
BUCKETS = [64, 128, 256, 512, 1024, 2048]

def pad_to_bucket(input_ids):
    seq_len = input_ids.shape[1]
    bucket = next(b for b in BUCKETS if b >= seq_len)
    pad_len = bucket - seq_len
    if pad_len > 0:
        input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
    return input_ids

Combining torch.compile with Other Optimizations

torch.compile composes well with most other inference optimizations. With Flash Attention: enable attn_implementation=”flash_attention_2″ before compiling — torch.compile will include the Flash Attention kernels in its compiled graph and may find additional fusion opportunities around them. With quantization: compile after quantizing (apply AWQ or GPTQ first, then torch.compile the quantized model). torch.compile can fuse operations around the dequantization kernels but cannot fuse into the core quantized matmul kernels themselves.

With bf16: always run inference in bf16 when using torch.compile on Ampere+ GPUs. The Triton kernels TorchInductor generates are optimized for bf16 Tensor Core operations. Running in fp32 not only uses more memory but generates less optimized kernel code. Cast your model to bf16 before compiling.

model = model.to(torch.bfloat16)
model = torch.compile(model, mode="max-autotune")

Where torch.compile Helps Most

torch.compile delivers the largest gains on operations with significant elementwise compute — layer norm, activation functions, attention score computation, and the add-and-norm residual connections that appear throughout transformer architectures. These operations are typically memory-bandwidth bound in eager mode; fusing them eliminates intermediate tensor writes and reads that dominate execution time.

The speedup is generally larger for smaller batch sizes and shorter sequences, where kernel launch overhead is a higher fraction of total time, and for GPUs with high memory bandwidth relative to compute (A100s benefit more than older V100s). For very large batch sizes where arithmetic intensity is already high and the GPU is compute-bound, torch.compile gains are smaller but still present from operation fusion.

For autoregressive generation (as opposed to a single forward pass), the per-token decode step is the bottleneck. This step typically runs batch size 1 or small batches with sequence length 1 for the new token — exactly the regime where kernel launch overhead is most significant and torch.compile gains are largest. Compiling the decode step with reduce-overhead mode typically gives 30–50% latency reduction on single-request generation.

Debugging Compilation Issues

When torch.compile produces wrong results or crashes, the first step is verifying the issue is in compilation and not in the base model by running with torch.compile disabled. Set TORCH_COMPILE_DEBUG=1 to get verbose compilation logs. For silent correctness issues, compare outputs between compiled and eager mode on the same inputs — numerical differences beyond floating point tolerance indicate a compilation bug worth reporting.

The most common compilation failures come from unsupported Python control flow (data-dependent branches that torch.compile can’t trace through), custom CUDA extensions that aren’t compatible with TorchDynamo’s tracing, and in-place operations that violate the assumptions of the compiled graph. For custom model architectures, testing with fullgraph=False (the default) first lets compilation fall back to eager for unsupported operations, then progressively enabling stricter modes narrows down which operations need adjustment.

When Not to Use torch.compile

For rapid development iteration where you’re changing model code frequently, torch.compile’s compilation latency adds friction. The typical workflow is to develop and debug in eager mode, then enable compilation for benchmarking and production. Don’t compile during unit tests — the compilation time will make your test suite unacceptably slow.

For very large models (70B+) where the forward pass itself takes seconds, the proportional gain from torch.compile is smaller and the compilation time is longer. At 70B scale, other optimizations — tensor parallelism, quantization, Flash Attention — typically have larger absolute impact than torch.compile, and the interaction between torch.compile and tensor parallel execution requires careful testing. For 7B–13B inference on a single GPU, torch.compile is close to a free win and should be a default part of your inference stack.

Benchmarking Correctly

Benchmarking torch.compile gains requires care. The two most common mistakes are measuring compilation time as part of the benchmark (always warm up with at least 3–5 forward passes before timing) and forgetting to call torch.cuda.synchronize() before stopping the timer (GPU operations are asynchronous; without synchronization you’re measuring launch time, not execution time).

A reliable benchmark runs the operation at least 50–100 times after warmup, averages the results, and reports both mean and standard deviation. Variance above 5% usually indicates the benchmark is too short or that GPU throttling or background processes are interfering. Run benchmarks on an otherwise-idle GPU, disable GPU boost clocks if you need reproducible results across runs (nvidia-smi –lock-gpu-clocks), and always benchmark at the batch size and sequence length your production workload actually uses — the speedup profile can be qualitatively different at batch size 1 versus batch size 32.

Compare compiled versus eager at the same dtype, the same attention implementation, and the same model configuration. A common mistake is enabling Flash Attention only for the compiled model and attributing the entire speedup to compilation — the improvement needs to be measured with the same optimizations on both sides to isolate what torch.compile itself contributes.

The Bottom Line

For 7B–13B transformer inference on a single A100 or H100 in bf16 with Flash Attention: torch.compile with mode=”reduce-overhead” and fixed or bucketed sequence lengths is close to a free performance win. Implement it, warm up properly, verify correctness on a sample of inputs, and benchmark on your actual workload. The 20–40% latency reduction it typically delivers would otherwise require significant engineering effort through manual kernel optimization or architecture changes. At the scale most production LLM serving runs, that translates directly into fewer GPUs required to serve a given request volume — a concrete and meaningful cost reduction.

torch.compile with HuggingFace Generation

Applying torch.compile to autoregressive text generation requires more care than compiling a single forward pass. The generation loop in HuggingFace’s model.generate() involves Python control flow, dynamic stopping conditions, and incremental KV cache updates — all of which interact with torch.compile’s graph capture in non-trivial ways. The most effective approach is to compile the model’s forward method directly rather than the entire generate() call, and let the generation loop remain in Python.

The key constraint for compiled autoregressive generation is static KV cache shapes. The default HuggingFace KV cache grows dynamically as generation proceeds, which triggers recompilation at each new sequence length. To avoid this, pre-allocate a static KV cache of the maximum expected length and use attention masking to handle variable-length prefills. HuggingFace’s StaticCache class implements this pattern for models that support it. With a static KV cache, the compiled forward pass sees the same tensor shapes at every decode step and avoids recompilation.

The combination of torch.compile with reduce-overhead mode and a static KV cache can reduce per-token decode latency by 40–60% compared to eager mode on A100 GPUs. This is the highest-impact configuration for latency-sensitive single-request generation workloads. For batch inference where throughput matters more than per-request latency, the gains are smaller (15–25%) but still meaningful at the scale of production serving.

Measuring Real-World Impact

The benchmarks that appear in torch.compile documentation and research papers are measured under controlled conditions that often differ from production workloads. Common sources of discrepancy: benchmarks run at fixed sequence lengths while production sees variable lengths, benchmarks measure steady-state throughput while production has cold-start compilation latency, and benchmarks run on lightly loaded GPUs while production runs with memory pressure from concurrent requests.

The only reliable way to quantify torch.compile’s impact on your specific workload is to A/B test compiled and eager serving under realistic load. Deploy both behind a load balancer, route a fraction of traffic to each, and measure p50 and p99 latency along with throughput. Compilation variability — different input shapes triggering recompilation at unpredictable times — shows up as latency spikes in p99 that don’t appear in steady-state benchmarks. If your SLA requires consistent p99 latency, test explicitly for recompilation-triggered spikes before committing to torch.compile in production.

For models served with dynamic batching (grouping concurrent requests into batches for throughput), torch.compile’s static shape requirement interacts with the batching strategy. Dynamic batching produces variable batch sizes, triggering recompilation at each new batch size unless you either pad all batches to a fixed maximum or use dynamic=True. Benchmark all three options — fixed batch size with padding, dynamic=True, and no compilation — to find the best trade-off for your throughput and latency targets. The right answer is workload-specific and often surprising.

Leave a Comment