PyTorch’s autograd engine handles gradient computation automatically for the vast majority of neural network operations. But there are situations where you need to step outside the standard computation graph: implementing a numerically stable backward pass that differs from the naive gradient derivation, wrapping a non-differentiable operation to make it trainable, integrating a custom CUDA kernel, or applying straight-through estimation for quantisation. All of these require torch.autograd.Function — the interface for defining custom forward and backward passes that autograd will treat as atomic operations. Understanding how to write these correctly, including context management, gradient checking, and double backward support, is essential for any practitioner working below the level of standard PyTorch modules.
How torch.autograd.Function Works
A torch.autograd.Function subclass defines two static methods: forward, which computes the output given inputs, and backward, which computes gradients with respect to inputs given the gradient flowing in from downstream (the upstream gradient, often called grad_output). When you call MyFunction.apply(x, y), autograd registers the function in the computation graph. During the backward pass, autograd calls backward with the upstream gradient and expects gradients for each input that required grad. The ctx object passed to both methods is the mechanism for saving tensors needed in the backward pass and for communicating metadata between forward and backward.
import torch
from torch.autograd import Function
class SquaredFunction(Function):
"""Minimal example: f(x) = x^2, df/dx = 2x.
Equivalent to x.pow(2) — illustrates the interface only.
"""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
# Save tensors needed for backward using ctx.save_for_backward
# Never store tensors as ctx attributes — use save_for_backward only
ctx.save_for_backward(x)
return x * x
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Retrieve saved tensors in the same order they were saved
x, = ctx.saved_tensors
# Return gradient for each input to forward (in order)
# grad_output is the upstream gradient (chain rule: dL/dx = dL/dy * dy/dx)
return grad_output * 2 * x
# Usage: always call via .apply(), not direct instantiation
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = SquaredFunction.apply(x)
y.sum().backward()
print(x.grad) # tensor([4., 6.]) — correct: 2*x
# Gradient check: always verify custom backward implementations
from torch.autograd import gradcheck
x_double = torch.randn(4, dtype=torch.float64, requires_grad=True)
result = gradcheck(SquaredFunction.apply, (x_double,), eps=1e-6, atol=1e-4)
print(f"Gradient check passed: {result}")
Context Management: What to Save and How
The ctx object has strict rules. Tensors that you need in the backward pass must be saved with ctx.save_for_backward(*tensors) and retrieved with ctx.saved_tensors. You must not store tensors as plain attributes (ctx.x = x) — this bypasses autograd’s memory management and will cause memory leaks and incorrect gradients in certain scenarios. Non-tensor values like scalars, booleans, or Python integers can be stored as plain attributes. Saving only what you need matters: save_for_backward holds references to the input tensors, keeping them alive in memory until the backward pass runs. In memory-constrained training, saving large intermediate tensors can cause OOM errors.
import torch
from torch.autograd import Function
class ScaledSigmoid(Function):
"""Sigmoid with learnable scale: f(x, scale) = scale * sigmoid(x).
Demonstrates multi-input backward and ctx attribute storage.
"""
@staticmethod
def forward(ctx, x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
sigmoid_x = torch.sigmoid(x)
# Save tensors needed for backward
ctx.save_for_backward(sigmoid_x, scale)
# Non-tensor metadata can be plain attributes
ctx.input_shape = x.shape
return scale * sigmoid_x
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
sigmoid_x, scale = ctx.saved_tensors
# Gradient w.r.t. x: d/dx[scale * sigmoid(x)] = scale * sigmoid(x) * (1 - sigmoid(x))
grad_x = grad_output * scale * sigmoid_x * (1 - sigmoid_x)
# Gradient w.r.t. scale: d/d_scale[scale * sigmoid(x)] = sigmoid(x)
# Sum over all dimensions to match scale's shape (scalar or broadcastable)
grad_scale = (grad_output * sigmoid_x).sum()
# Return one gradient per forward input, in the same order
return grad_x, grad_scale
x = torch.randn(4, 8, requires_grad=True)
scale = torch.tensor(2.0, requires_grad=True)
out = ScaledSigmoid.apply(x, scale)
out.sum().backward()
print(f"x.grad shape: {x.grad.shape}") # (4, 8)
print(f"scale.grad: {scale.grad.item():.4f}")
Numerically Stable Custom Operations
The most common real-world use case for custom autograd functions is implementing numerically stable versions of operations whose naive gradient derivation suffers from overflow or underflow. Log-sum-exp is the canonical example: computing log(exp(a) + exp(b)) directly overflows for large inputs, but the mathematically equivalent max(a,b) + log(1 + exp(-|a-b|)) is stable. The forward and backward passes can both be written in numerically stable form independently, which is not possible if you rely on autograd to differentiate through the naive forward implementation.
import torch
from torch.autograd import Function
class StableLogSumExp(Function):
"""Numerically stable log-sum-exp along dim=-1.
Forward: logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
Backward: softmax(x) — the gradient of logsumexp is the softmax
"""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
# Subtract max for numerical stability
x_max = x.max(dim=-1, keepdim=True).values
exp_shifted = (x - x_max).exp()
sum_exp = exp_shifted.sum(dim=-1, keepdim=True)
log_sum_exp = x_max + sum_exp.log()
# Save softmax (= exp_shifted / sum_exp) for backward — it's the gradient
softmax = exp_shifted / sum_exp
ctx.save_for_backward(softmax)
return log_sum_exp.squeeze(-1)
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
softmax, = ctx.saved_tensors
# Gradient of logsumexp w.r.t. x is softmax(x)
# Expand grad_output to broadcast against softmax shape
return softmax * grad_output.unsqueeze(-1)
# Verify against PyTorch's built-in
x = torch.randn(32, 512, dtype=torch.float64, requires_grad=True)
custom_out = StableLogSumExp.apply(x)
builtin_out = torch.logsumexp(x, dim=-1)
print(f"Max abs diff forward: {(custom_out - builtin_out).abs().max():.2e}")
from torch.autograd import gradcheck
x_check = torch.randn(4, 16, dtype=torch.float64, requires_grad=True)
gradcheck(StableLogSumExp.apply, (x_check,), eps=1e-6, atol=1e-4)
print("Gradient check passed")
Straight-Through Estimator for Non-Differentiable Operations
Quantisation, binarisation, and other operations that round or threshold values are non-differentiable: the true gradient is zero almost everywhere and undefined at the discontinuity. The straight-through estimator (STE) is a pragmatic approximation that passes the upstream gradient through the non-differentiable operation unchanged in the backward pass, as if the forward operation were an identity function. STE underlies quantisation-aware training (QAT) and binary neural networks, and it is a standard building block for any model that needs to train through a discretisation step.
import torch
from torch.autograd import Function
class StraightThroughRound(Function):
"""Straight-through estimator for rounding.
Forward: round to nearest integer.
Backward: pass gradient through unchanged (identity approximation).
"""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x.round()
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# STE: gradient passes through as if forward were identity
return grad_output
class FakeQuantize(Function):
"""Fake quantisation for QAT: simulates int8 quantisation in forward,
uses STE in backward so gradients flow for weight updates.
"""
@staticmethod
def forward(ctx, x: torch.Tensor, scale: float, zero_point: int,
qmin: int = -128, qmax: int = 127) -> torch.Tensor:
ctx.save_for_backward(x)
ctx.scale = scale
ctx.qmin, ctx.qmax = qmin, qmax
# Quantise: clamp, round, dequantise
x_int = (x / scale + zero_point).round().clamp(qmin, qmax)
return (x_int - zero_point) * scale # dequantised fp32 value
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
x, = ctx.saved_tensors
scale, qmin, qmax = ctx.scale, ctx.qmin, ctx.qmax
# Only pass gradient through where input was within quantisation range
# (optional: zero out gradient where input was clipped)
in_range = ((x / scale) >= qmin) & ((x / scale) <= qmax)
grad_x = grad_output * in_range.float()
# No gradients for scale or zero_point (they are not tensors here)
return grad_x, None, None, None, None
x = torch.randn(8, 64, requires_grad=True)
x_fq = FakeQuantize.apply(x, 0.01, 0)
x_fq.sum().backward()
print(f"Gradient flows: {x.grad is not None}") # True
Handling Non-Differentiable Inputs and Optional Gradients
When a custom function takes inputs that do not require gradients — integer tensors, boolean masks, or plain Python scalars — the backward method must still return a value for each positional input in the same order as forward. For inputs that don't require gradients, return None. Failing to return the correct number of values raises a runtime error, and returning a tensor where None is expected for a non-differentiable input causes incorrect gradient accumulation. The ctx.needs_input_grad tuple mirrors the requires_grad status of each forward input and can be used to skip expensive backward computations for inputs that won't accumulate gradients anyway.
import torch
from torch.autograd import Function
class MaskedLinear(Function):
"""Linear operation with a binary mask applied to weights.
mask is not differentiable (bool tensor) — return None for its gradient.
"""
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
masked_weight = weight * mask.float()
ctx.save_for_backward(x, masked_weight, mask)
return x @ masked_weight.T
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
x, masked_weight, mask = ctx.saved_tensors
grad_x = grad_weight = None
# Only compute what's needed
if ctx.needs_input_grad[0]:
grad_x = grad_output @ masked_weight
if ctx.needs_input_grad[1]:
grad_weight = (grad_output.T @ x) * mask.float()
# mask has no gradient (bool tensor, not differentiable)
return grad_x, grad_weight, None # None for mask
x = torch.randn(4, 8, requires_grad=True)
w = torch.randn(16, 8, requires_grad=True)
mask = torch.randint(0, 2, (16, 8)).bool()
out = MaskedLinear.apply(x, w, mask)
out.sum().backward()
print(f"x.grad shape: {x.grad.shape}, w.grad shape: {w.grad.shape}")
Gradient Checking: Always Verify
Every custom backward implementation should be verified with torch.autograd.gradcheck before being used in training. Gradcheck computes numerical Jacobians by perturbing each input element by a small epsilon and comparing with the analytical gradients returned by your backward method. Use float64 inputs for gradcheck — float32's limited precision causes numerical Jacobian estimates to be too noisy for reliable verification. The eps parameter controls the finite difference step size (default 1e-6 is usually appropriate) and atol controls the tolerance for the comparison. If gradcheck fails, the most common causes are: incorrect gradient formula, missing or wrong handling of the chain rule (forgetting to multiply by grad_output), saved tensors being modified in-place between forward and backward, or non-differentiable operations inside forward that produce incorrect gradients near the test point.
from torch.autograd import gradcheck, gradgradcheck
import torch
# Always test with float64 — float32 noise masks gradient errors
def verify_custom_function(func, *args):
args_double = tuple(
a.double().detach().requires_grad_(a.requires_grad)
if isinstance(a, torch.Tensor) and a.is_floating_point() else a
for a in args
)
passed = gradcheck(func, args_double, eps=1e-6, atol=1e-4, raise_exception=True)
print(f"First-order gradcheck: {'PASSED' if passed else 'FAILED'}")
# Also check second-order gradients if you plan to use higher-order optimisers
try:
passed2 = gradgradcheck(func, args_double, eps=1e-6, atol=1e-4)
print(f"Second-order gradcheck: {'PASSED' if passed2 else 'FAILED'}")
except Exception as e:
print(f"Second-order gradcheck not applicable: {e}")
x = torch.randn(4, 8, requires_grad=True)
scale = torch.tensor(2.0, requires_grad=True)
verify_custom_function(ScaledSigmoid.apply, x, scale)
gradgradcheck verifies second-order gradients and is necessary if your custom function will be used with optimisers or regularisers that compute higher-order derivatives, or if you plan to differentiate through a backward pass (as in MAML and similar meta-learning algorithms). Most production uses only need first-order gradients, but it is worth running gradgradcheck at least once to confirm your implementation is compatible.
In-Place Operations and Autograd Safety
In-place operations — anything that modifies a tensor's data without allocating a new tensor — are a common source of autograd errors in custom functions. The autograd engine tracks tensor versions, and if a tensor that was saved for backward is modified in-place before the backward pass runs, autograd will raise a runtime error because the saved tensor no longer matches the version used in the forward pass. The rule is simple: never modify saved tensors in-place, and never apply in-place operations to the inputs or outputs of a custom function if those tensors participate in the computation graph. The ctx.mark_dirty(*tensors) method exists for the one legitimate exception — functions that intentionally modify their inputs in-place, such as custom in-place activation functions — but this should be used rarely and with care.
A subtler issue arises when the output of your custom function is later modified in-place by other code before backward is called. If your backward implementation recomputes something from the output tensor (rather than saving intermediate results), the in-place modification will silently corrupt the backward computation. The safe pattern is to save everything you need in forward via save_for_backward rather than recomputing from outputs in backward, even when recomputation seems equivalent. Explicit saves are cheaper than debugging a gradient corruption that only manifests on specific batch sizes or hardware.
Custom Functions with torch.compile
PyTorch 2.0's torch.compile can fuse custom autograd functions into optimised kernels, but there are compatibility requirements. Custom functions that use only standard PyTorch operations in their forward and backward methods are typically compile-compatible without modification — the compiler traces through apply() and optimises the resulting computation graph. Functions that use Python control flow that depends on tensor values (not just tensor shapes) will cause graph breaks, where torch.compile falls back to eager execution for that portion of the graph. Functions that call external C++ or CUDA code via ctypes or cffi are also not compilable and will produce a graph break.
To check whether your custom function compiles cleanly, wrap a simple test call in torch.compile and enable torch._dynamo.config.verbose = True to see graph breaks. If compilation fails, the most common fixes are: replacing Python conditionals that depend on tensor values with torch.where, removing print statements inside forward or backward (they cause graph breaks), and ensuring that ctx attribute assignments use Python scalars rather than tensors where possible. For performance-critical custom operations, investing in compile compatibility is worth the effort since fused kernels can reduce memory bandwidth by 2–4x compared to unfused eager execution.
When Not to Use Custom Autograd Functions
Custom autograd functions are the right tool for a narrow set of problems. For the majority of novel neural network components — new layer types, attention variants, regularisers — you should implement them using standard PyTorch operations inside an nn.Module and let autograd differentiate through them automatically. Autograd handles compositions of standard operations correctly and efficiently, and the resulting code is easier to read, debug, and maintain than a custom backward implementation. Reserve custom autograd functions for cases where automatic differentiation is genuinely insufficient: the numerically stable backward differs from the automatic gradient, the operation is non-differentiable and requires approximation (STE), you need to integrate non-PyTorch code (CUDA kernels, external libraries), or you need to explicitly control what is saved for backward to reduce peak memory usage during training. If you find yourself writing a custom function primarily because the automatic gradient is slow, first check whether torch.compile resolves the performance issue — it usually does — before committing to a custom backward implementation that must be maintained and re-verified across PyTorch version upgrades.
Wrapping Custom CUDA Kernels with Autograd
When a custom CUDA kernel computes the forward pass, you need to provide the backward pass manually since autograd cannot trace through native code. The pattern is to write the CUDA kernel for the forward computation, write a separate CUDA kernel for the backward pass (computing the gradient with respect to each differentiable input), load both using torch.utils.cpp_extension.load or as a pre-built PyTorch extension, and wrap them in a torch.autograd.Function that calls the forward kernel in forward() and the backward kernel in backward(). The autograd Function then integrates seamlessly with the rest of the computation graph — downstream operations can differentiate through your CUDA kernel as if it were any standard PyTorch operation. This is the mechanism used by FlashAttention, which implements its own fused CUDA kernels for both the attention forward pass and the attention backward pass, and wraps them in a custom autograd Function so that the rest of the model can use it transparently with standard optimizers.
Choosing between a custom autograd Function wrapping a CUDA kernel versus a pure PyTorch implementation that torch.compile will fuse comes down to development cost and maintenance overhead. Writing and debugging CUDA kernels is significantly more work than writing PyTorch code, and custom CUDA kernels must be recompiled and retested across PyTorch and CUDA version upgrades. For most operations, torch.compile produces performance close to hand-written CUDA and is far easier to maintain. Reserve custom CUDA kernels for operations where the memory access pattern is fundamentally different from what PyTorch's standard operations can express — such as fused attention with recomputation that never materialises the full attention matrix — and use torch.compile for everything else.
Summary: Key Rules for Custom Autograd Functions
Use ctx.save_for_backward for all tensors needed in backward — never store them as plain ctx attributes. Return one gradient per forward input in the same positional order; return None for non-differentiable inputs such as integer tensors, boolean masks, or Python scalars. Always verify with gradcheck using float64 inputs before using a custom function in training. Avoid in-place modifications to saved tensors or to function inputs and outputs that participate in the computation graph. Check compile compatibility with torch.compile if performance matters — most pure-PyTorch custom functions compile cleanly. And default to standard PyTorch operations with automatic differentiation for anything that does not require a custom backward; the cases where a custom function is genuinely necessary are narrower than they appear.