How to Use torch.profiler to Find Training Bottlenecks

torch.profiler is PyTorch’s built-in performance analysis tool, and most ML engineers who have heard of it haven’t actually used it beyond the basic tutorial. That’s a mistake — the profiler surfaces bottlenecks that are completely invisible from training loss curves and GPU utilization metrics: kernel-level inefficiencies, CPU-GPU synchronization stalls, DataLoader gaps, memory allocation overhead, and operator fusion opportunities. This guide covers how to use torch.profiler effectively in a real training loop, how to interpret what it tells you, and how to translate profiler output into concrete optimizations.

Basic Setup

The profiler wraps your training loop as a context manager and records operator-level timing and memory events. The minimal setup that gives useful signal:

import torch
from torch.profiler import profile, record_function, ProfilerActivity, schedule

prof = profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
)

prof.start()
for step, (inputs, labels) in enumerate(dataloader):
    inputs, labels = inputs.cuda(), labels.cuda()
    with record_function("forward"):
        outputs = model(inputs)
    with record_function("loss"):
        loss = criterion(outputs, labels)
    with record_function("backward"):
        loss.backward()
    with record_function("optimizer"):
        optimizer.step()
        optimizer.zero_grad()
    prof.step()
    if step >= 5:  # profile only first few steps
        break
prof.stop()

The schedule parameter controls which steps are profiled: wait=1 skips the first step (avoids startup noise), warmup=1 runs the profiler but discards data (warms up CUDA kernels), active=3 records 3 steps of real data. This is important — profiling every step adds overhead and generates enormous trace files; profiling 3 representative steps gives enough signal to identify bottlenecks without overwhelming the analysis tooling.

Reading the TensorBoard Trace

The TensorBoard profiler plugin (pip install torch-tb-profiler) renders the trace as an interactive timeline. Launch with tensorboard –logdir ./profiler_logs and open the PyTorch Profiler tab. The timeline view shows CPU and CUDA events as horizontal bars across time — each bar is an operator or kernel execution. The key things to look for:

GPU idle gaps. White space on the CUDA timeline between kernel executions means the GPU is waiting — either for a CPU operation to complete (synchronization), for data to transfer from CPU to GPU, or for the DataLoader to supply the next batch. Any visible gap is wasted GPU time. A healthy training step has a nearly solid CUDA timeline with minimal white space.

CPU-GPU synchronization stalls. Look for CPU operations that cause the CUDA timeline to pause — these appear as CPU bars that overlap with a gap in the CUDA timeline. Common causes: calling .item() or .numpy() on a CUDA tensor inside the training loop (forces sync), printing loss values every step, or running Python-side metric calculations on CUDA tensors without detaching first.

DataLoader time. The profiler records time between the end of one optimizer step and the start of the next forward pass — this is DataLoader fetch time. If this gap is large relative to the forward+backward time, your DataLoader is the bottleneck, not the model computation. Fix: increase num_workers, add persistent_workers=True, or cache preprocessed data.

The Key Tables: Self CPU Time vs CUDA Time

The profiler’s table view (prof.key_averages().table()) shows each operator’s self CPU time (time spent in the operator itself, excluding children), total CPU time (including children), and CUDA time. The operators you want to focus on are those with high CUDA time — these are the compute-intensive kernels driving your training runtime:

# Print top 20 operators by CUDA time
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Print memory usage by operator
print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",
    row_limit=20
))

High self CPU time with low CUDA time on an operator usually means the GPU is idle while CPU-side Python overhead is executing. This is common with small models or small batch sizes where Python dispatch overhead dominates the actual compute. The fix is typically operator fusion (torch.compile), larger batch sizes to amortize dispatch overhead, or moving preprocessing logic to the GPU.

Memory Profiling

With profile_memory=True, the profiler tracks tensor allocations and deallocations per operator. The memory timeline view in TensorBoard shows peak memory usage over the training step. Two patterns worth looking for: a gradual increase in memory over steps (a memory leak — tensors not being freed, often caused by accumulating loss values in a list without detaching), and large spikes that push close to the GPU memory limit (from activations during forward pass that could be reduced with gradient checkpointing).

# Export memory timeline for standalone analysis
prof.export_memory_timeline("memory_timeline.html", device="cuda:0")
# Open in browser for interactive visualization

The memory timeline shows each tensor allocation as a colored block — hovering over a block shows the operator that allocated it and the stack trace. For identifying where peak memory comes from, sort the memory view by allocation size and look at the top allocators. Attention activations in transformer models are typically the largest single allocation; if gradient checkpointing is not enabled and you’re running close to OOM, the profiler will show this clearly as a large block in the forward pass that persists until the backward pass frees it.

Finding Kernel Inefficiencies with CUDA Kernels View

The CUDA Kernels view shows individual GPU kernel executions with duration and occupancy. Occupancy measures how fully the kernel utilized the GPU’s streaming multiprocessors — low occupancy (under 30%) on a compute-intensive kernel suggests the kernel isn’t effectively using the hardware, often due to small tensor sizes, poor memory access patterns, or suboptimal block dimensions. This view is most useful for identifying whether torch.compile or other fusion optimizations are having an effect: after applying torch.compile, the kernel list should show fewer, larger kernels (fused operations) rather than many small ones (unfused operations dispatched separately).

A common finding in transformer training is that attention kernels dominate CUDA time before FlashAttention is enabled, and memory bandwidth (not compute) is the limiting factor. The profiler will show attention kernels with high memory throughput relative to arithmetic intensity. Switching to FlashAttention-2 (torch.nn.functional.scaled_dot_product_attention with use_flash_attention=True on supported hardware) typically reduces attention kernel time by 2–4x and this improvement is directly visible in the profiler timeline.

Profiling Distributed Training

For distributed training with DDP or FSDP, torch.profiler can be run on each rank independently. The trace from each rank captures that rank’s CPU and CUDA events, including NCCL communication kernels (AllReduce, AllGather, ReduceScatter). Communication overhead appears in the CUDA timeline as NCCL kernels between backward pass kernels — if communication time is a large fraction of total step time, you have a communication bottleneck that gradient compression or overlapping communication with computation (enabled by default in DDP) may address.

# In distributed training, save per-rank traces
prof = profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3),
    on_trace_ready=torch.profiler.tensorboard_trace_handler(
        f'./profiler_logs/rank_{torch.distributed.get_rank()}'
    ),
)

Compare traces across ranks to identify stragglers — if one rank consistently takes longer per step than others, there’s a load imbalance (unequal data distribution, different batch sizes, or a hardware anomaly on one node). The TensorBoard distributed view can overlay traces from multiple ranks for side-by-side comparison.

From Profiler to Fix: A Decision Workflow

The profiler tells you what is slow; it doesn’t tell you what to do about it. Work through this order of investigation. If you see large DataLoader gaps, fix data loading first — it’s the highest-leverage fix and costs nothing in model quality. If you see CPU-GPU sync stalls, find and remove all .item(), .numpy(), and print(loss) calls from the hot path; these are easy wins. If you see low GPU utilization (small batch size, Python dispatch overhead dominating), try torch.compile — it fuses small kernels and eliminates Python overhead for the compiled region. If you see high attention kernel time, enable FlashAttention-2 or scaled_dot_product_attention. If you see high memory usage driving OOM or forcing small batch sizes, enable gradient checkpointing for the attention layers or the full model. If you see large NCCL communication time in distributed training, verify that gradient bucketing is configured correctly in DDP (bucket_cap_mb parameter) or that FSDP’s communication overlap is enabled. Measure after each change with a fresh profiler run — it’s easy to apply multiple optimizations and not know which one actually helped.

Profiling Inference and Serving

torch.profiler works equally well for profiling inference workloads, not just training. The setup is the same, but the interpretation differs. In inference, there’s no backward pass, so the entire step time is forward computation plus any pre- and post-processing. The typical bottlenecks in inference are different from training: tokenization and input preparation (CPU-bound, often blocking GPU), memory allocation per request (if batch size varies), and attention computation scaling with sequence length. For LLM inference specifically, the profiler will show that decode-phase steps are memory-bandwidth-bound — a small number of matrix-vector multiply kernels running at a fraction of peak arithmetic throughput because the bottleneck is reading the weight matrices and KV cache from HBM, not floating-point operations.

Profiling a FastAPI or similar inference server requires care — you can’t wrap the entire server in a context manager. The pattern is to profile a specific number of requests in a background thread while the server is running, using the schedule parameter to capture a representative sample. Alternatively, profile a standalone benchmark script that replays real request patterns at the target concurrency level. The key metric to look for in serving is whether the GPU is consistently busy across requests (good) or has idle time between requests due to slow request preprocessing, response serialization, or Python GIL contention in the serving framework.

Common Findings and What They Mean

After running torch.profiler on dozens of real training setups, a few patterns appear repeatedly. The most common finding in transformer fine-tuning is that the DataLoader is the bottleneck for the first 20–30% of teams who profile for the first time — they’re running with num_workers=0 or 1 and the GPU is idle 30–50% of each step waiting for data. Fixing this (num_workers=4–8, persistent_workers=True, pin_memory=True) immediately improves throughput with no model changes. The second most common finding is .item() calls inside the training loop — logging loss every step with loss.item() forces a CPU-GPU sync at every step, adding 1–5ms per step that compounds to significant time over a long training run. The fix is to call .detach() and accumulate loss values in a Python float, then log periodically rather than every step.

For teams using Hugging Face Trainer, the profiler occasionally reveals that the Trainer’s own overhead (logging callbacks, metric computation, checkpoint saving frequency) is consuming meaningful time. The Trainer’s report_to parameter controls logging frequency, and reducing eval_steps and save_steps can recover training speed if evaluation on a large validation set or checkpointing to slow storage is happening too frequently. These are invisible in the loss curve but obvious in the profiler trace as periodic CPU-side bursts with GPU idle time.

The profiler is also the right tool for validating that torch.compile is doing what you expect. After compiling, the first few batches are slow (compilation overhead) — which is why the schedule’s warmup parameter should skip these. After warmup, the profiler should show noticeably fewer, larger CUDA kernels per step, and step time should be measurably lower. If the kernel count doesn’t change, torch.compile may have hit a graph break — the profiler’s with_stack=True option helps identify where breaks occur by showing Python stack frames alongside the operator timeline. Graph breaks typically happen at Python control flow that depends on tensor values, data-dependent shapes, or Python functions that torch.compile can’t trace through.

Interpreting the Chrome Trace and Stacks View

The table and flame chart views in TensorBoard cover most profiling needs, but the Chrome trace export reveals timing detail that the aggregated views hide. Export with prof.export_chrome_trace("trace.json") and open in chrome://tracing or Perfetto. Each row in the trace corresponds to a thread or CUDA stream. The CPU thread shows Python-level operator dispatch; the CUDA streams show actual kernel execution. The gap between a CPU operator dispatch event ending and the corresponding CUDA kernel starting is the kernel launch latency — typically 5–20 microseconds per kernel. When you see hundreds of tiny kernels launching in rapid succession with large gaps between them, that’s a sign the model has too many small operations that would benefit from fusion via torch.compile or operator fusion in a custom kernel. The stacks view, enabled with with_stack=True in the profiler schedule, adds Python call stack attribution to each operator — this is how you trace a slow CUDA kernel back to the exact line of model code that triggered it, which is otherwise very hard to do from the table view alone.

Leave a Comment