How to Export PyTorch Models: TorchScript, ONNX, and TensorRT

Exporting a PyTorch model for production deployment means choosing between TorchScript, ONNX, and TensorRT — three distinct serialization and compilation approaches with different tradeoffs in portability, performance, and operational complexity. TorchScript gives you a Python-free serialized model that runs anywhere PyTorch is installed. ONNX gives you a framework-independent format that can run on dozens of runtimes across CPU, GPU, and edge devices. TensorRT compiles your model into a GPU-specific engine that extracts maximum throughput at the cost of flexibility. Understanding when each is the right tool determines whether you get 20% or 4x latency improvement from your deployment pipeline.

TorchScript: Serializing PyTorch Models Without Python

TorchScript converts a PyTorch model into an intermediate representation that can be saved, loaded, and executed without a Python interpreter. This matters for deployment environments where Python is unavailable or undesirable — C++ inference servers, mobile applications, or environments where Python’s GIL is a bottleneck. TorchScript supports two conversion modes: tracing and scripting. Tracing runs the model with example inputs and records the operations executed, which works well for models with static control flow but silently drops conditional branches not taken during tracing. Scripting analyzes the model’s Python source code directly, handling dynamic control flow correctly but requiring type annotations and TorchScript-compatible Python.

import torch
import torch.nn as nn

class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model: int = 512, n_heads: int = 8):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        return self.norm2(x + self.ff(x))

model = SimpleTransformerBlock().eval()
example = torch.randn(2, 32, 512)  # (batch, seq_len, d_model)

# Method 1: Tracing — fast, works for static models
traced = torch.jit.trace(model, example)
torch.jit.save(traced, "model_traced.pt")

# Method 2: Scripting — handles dynamic control flow
# Requires type annotations on forward() and TorchScript-compatible ops
scripted = torch.jit.script(model)
torch.jit.save(scripted, "model_scripted.pt")

# Loading and running without original model definition
loaded = torch.jit.load("model_traced.pt")
loaded.eval()
with torch.no_grad():
    out = loaded(example)
    print(f"Output shape: {out.shape}")  # torch.Size([2, 32, 512])

# Common tracing pitfalls:
class DynamicModel(nn.Module):
    def forward(self, x: torch.Tensor, use_dropout: bool = True) -> torch.Tensor:
        # This branch will be baked in at trace time — tracing captures ONLY
        # the path taken with the example input, silently ignoring the other
        if use_dropout:
            return nn.functional.dropout(x, p=0.1, training=self.training)
        return x
    # Fix: use torch.jit.script() instead for models with dynamic branching

TorchScript’s performance improvement over eager-mode PyTorch is modest — typically 10–30% latency reduction from operator fusion and graph optimization. It is not a compilation target in the same sense as TensorRT; it is primarily a serialization format with light optimization. Use TorchScript when you need to deploy in a C++ environment, need to freeze a model for mobile (torch.jit.optimize_for_mobile), or want a lightweight way to eliminate Python overhead in a multi-threaded C++ inference server.

ONNX: The Portable Interchange Format

ONNX (Open Neural Network Exchange) defines a standardized graph format that decouples model training from inference runtime. An ONNX model exported from PyTorch can run on ONNX Runtime, TensorRT, OpenVINO, CoreML, and dozens of other execution engines — often with no code changes. ONNX Runtime in particular is a mature, high-performance CPU and CUDA inference engine that consistently outperforms naive eager PyTorch by 20–60% on transformer inference tasks, simply through graph-level operator fusion and memory planning.

import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np

model = SimpleTransformerBlock().eval()
example = torch.randn(2, 32, 512)

# Export to ONNX
torch.onnx.export(
    model,
    example,
    "model.onnx",
    export_params=True,
    opset_version=17,            # Use 17+ for best transformer op support
    do_constant_folding=True,    # Fuse constant subgraphs at export time
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={               # Allow variable batch size and sequence length
        "input":  {0: "batch", 1: "seq_len"},
        "output": {0: "batch", 1: "seq_len"},
    },
)

# Verify the exported model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model check passed")

# Run inference with ONNX Runtime
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]  # CUDA preferred
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = ort.InferenceSession("model.onnx", sess_options, providers=providers)

input_name = sess.get_inputs()[0].name
np_input = example.numpy()
ort_out = sess.run(None, {input_name: np_input})
print(f"ORT output shape: {ort_out[0].shape}")

# Latency comparison
import time
N = 200
# PyTorch eager
with torch.no_grad():
    t0 = time.perf_counter()
    for _ in range(N): model(example)
    eager_ms = (time.perf_counter() - t0) / N * 1000

# ONNX Runtime
t0 = time.perf_counter()
for _ in range(N): sess.run(None, {input_name: np_input})
ort_ms = (time.perf_counter() - t0) / N * 1000

print(f"PyTorch eager: {eager_ms:.2f}ms | ONNX Runtime: {ort_ms:.2f}ms | Speedup: {eager_ms/ort_ms:.2f}x")

The most common ONNX export pain point is unsupported operators. Custom attention implementations, recent activation functions, or ops added in newer PyTorch versions may not have ONNX equivalents in the opset version you are targeting. Use the highest opset version your deployment runtime supports (17 or 18 for most current ONNX Runtime versions) and check torch.onnx.export warnings carefully — a silent fallback to a generic op can eliminate the performance benefit of a fused operator. For transformer models specifically, opset 17 added fused LayerNorm and attention operators that give ONNX Runtime a significant speedup over older opsets.

TensorRT: Maximum GPU Throughput

TensorRT is NVIDIA’s inference optimization compiler. It takes a model (from ONNX, TorchScript, or its own API) and compiles it into a hardware-specific engine that applies layer fusion, precision calibration, kernel autotuning, and memory layout optimization. The result is typically 2–5x faster than ONNX Runtime on NVIDIA GPUs for the same model, at the cost of longer build times, GPU-family-specific engines, and more complex calibration for INT8 quantization. TensorRT is the right choice when you are serving on fixed NVIDIA hardware at high throughput and need the last mile of performance after other optimizations are exhausted.

import tensorrt as trt
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit  # initialises CUDA context

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_engine_from_onnx(onnx_path: str, fp16: bool = True,
                             max_batch: int = 8, max_seq: int = 128) -> trt.ICudaEngine:
    """Build a TensorRT engine from an ONNX model."""
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(onnx_path, "rb") as f:
        if not parser.parse(f.read()):
            for i in range(parser.num_errors):
                print(parser.get_error(i))
            raise RuntimeError("ONNX parse failed")

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30)  # 4GB

    if fp16 and builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)   # FP16 kernels where beneficial

    # Dynamic shape profile — must match dynamic_axes from ONNX export
    profile = builder.create_optimization_profile()
    profile.set_shape("input",
        min=(1, 1, 512),          # minimum input shape
        opt=(4, 32, 512),         # most common / optimized shape
        max=(max_batch, max_seq, 512)  # maximum shape
    )
    config.add_optimization_profile(profile)

    engine_bytes = builder.build_serialized_network(network, config)
    with trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(engine_bytes)

def save_engine(engine: trt.ICudaEngine, path: str):
    with open(path, "wb") as f:
        f.write(engine.serialize())

# Engine build takes minutes — save and reuse
# engine = build_engine_from_onnx("model.onnx", fp16=True)
# save_engine(engine, "model_fp16.trt")

# For INT8, you additionally need a calibration dataset:
class Int8Calibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, data_loader, cache_file="calib.cache"):
        super().__init__()
        self.data_loader = iter(data_loader)
        self.cache_file = cache_file
        sample = next(iter(data_loader))
        self.d_input = cuda.mem_alloc(sample.numpy().nbytes)
        self.batch_size = sample.shape[0]

    def get_batch_size(self): return self.batch_size
    def get_batch(self, names):
        try:
            batch = next(self.data_loader).numpy()
            cuda.memcpy_htod(self.d_input, batch)
            return [int(self.d_input)]
        except StopIteration:
            return None
    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f: return f.read()
    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f: f.write(cache)

TensorRT engine builds are hardware-specific — an engine built on an A100 will not run on an H100 or a T4. This means you need a build step in your deployment pipeline for each GPU family you target, and engine files should not be shipped as portable artifacts. The build time scales with model size and optimization level: a small transformer block builds in seconds, a 7B LLM via torch-tensorrt can take 20–40 minutes. For LLMs specifically, the TensorRT-LLM library (NVIDIA's high-level wrapper) handles the engine build and batching infrastructure automatically and is the right entry point rather than the low-level TensorRT API shown above.

Choosing Between the Three

Use TorchScript when you need to deploy in a non-Python environment — a C++ server, a mobile app with torch.jit.optimize_for_mobile, or a latency-sensitive Python server where you want to avoid re-tracing overhead across requests. It is the lowest-friction export path and works well for non-transformer models where TensorRT's GPU-specific tuning is less critical. Use ONNX Runtime when you need portability across hardware (CPU, GPU, edge, cloud) or across ML frameworks, when your team does not want to manage TensorRT build pipelines, or when you are serving on CPU where ONNX Runtime's graph optimizations are valuable and TensorRT does not apply. The speedup over eager PyTorch is real and the operational overhead is low — for most production transformer inference workloads that are not at extreme throughput limits, ONNX Runtime is the practical sweet spot. Use TensorRT when you are serving on NVIDIA GPUs at high throughput, latency is the primary constraint, and you have dedicated ML infrastructure engineering capacity to manage engine builds, calibration datasets for INT8, and hardware-specific artifacts. The additional performance over ONNX Runtime (typically 1.5–2.5x) is real but comes with meaningful operational cost. For LLM inference at scale, TensorRT-LLM or vLLM (which uses its own CUDA kernel optimizations) is almost always a better starting point than building TensorRT engines manually.

Precision and Quantization Across Export Formats

All three export paths support quantization, but with different levels of automation. TorchScript's quantization story uses PyTorch's own quantization API — static, dynamic, or quantization-aware training — applied before or after scripting. ONNX Runtime supports post-training quantization via onnxruntime-tools, which can quantize an FP32 ONNX model to INT8 with a calibration dataset in a few lines of code. TensorRT's INT8 quantization is the most powerful but requires a calibration dataset to measure activation distributions and produces engine-specific calibration caches that are GPU-family-dependent. For FP16, all three are straightforward: pass the model in half-precision or enable FP16 mode in the export configuration. In practice, FP16 via ONNX Runtime or TensorRT typically gives 1.5–2x speedup over FP32 on modern NVIDIA GPUs with Tensor Core support, with negligible quality loss for inference on pretrained models.

Handling Dynamic Shapes Correctly

Dynamic shapes — variable batch sizes and sequence lengths — are the most common source of correctness bugs in exported models. Both ONNX and TensorRT require you to specify dynamic axes explicitly and, for TensorRT, define min/opt/max shape profiles. The optimization profile's opt shape is the most important: TensorRT builds its best-performing kernels around the opt shape and degrades gracefully for other shapes within the min/max range. Setting opt to your most common production batch and sequence length, rather than to the maximum, typically yields better real-world latency. For ONNX Runtime, dynamic axes defined at export time allow the session to handle any shape within those axes, but re-running the session with a radically different shape than it has previously seen may trigger a one-time re-optimization. Warm up ONNX Runtime sessions with representative inputs before benchmarking or load testing.

Validating Exported Models

Always validate numerical correctness after export — discrepancies between the PyTorch eager output and the exported model output indicate export bugs, unsupported operators falling back to generic implementations, or precision drift. The standard validation procedure is to run the same input batch through both the original PyTorch model and the exported model and check that outputs match within an acceptable tolerance. For FP32 models, absolute differences above 1e-4 on any output element warrant investigation. For FP16 or INT8 models, differences up to 1e-2 are expected due to reduced precision but larger differences suggest operator-level issues. Always test with edge-case inputs — very short sequences, large batches, inputs with many padding tokens — as these stress the dynamic shape handling in ways that typical validation inputs may not.

Benchmarking Correctly

Latency benchmarking of GPU inference requires careful setup to produce meaningful numbers. Always warm up the model with several forward passes before timing — the first few passes incur CUDA kernel compilation and memory allocation overhead that will never recur in production. Use GPU synchronization (torch.cuda.synchronize() for PyTorch, or equivalent for ORT) before and after the timed region to avoid measuring CPU-side scheduling overhead rather than actual GPU execution time. Measure wall-clock latency with at least 200 repetitions and report p50 and p99 rather than mean, since occasional spikes due to memory pressure or CUDA context switches inflate means. Compare across export formats at the same precision (FP16 vs FP16, not FP32 TorchScript vs FP16 TensorRT) to isolate the format's contribution from the precision benefit.

Practical Export Checklist

Before deploying an exported model to production, run through these checks. First, confirm that the model is in eval mode (model.eval()) and that any dropout or batch normalisation layers are in inference mode before export — training-mode stochasticity baked into a trace produces non-deterministic inference output. Second, disable gradient computation with torch.no_grad() during export to avoid tracing autograd graph nodes into the serialized model. Third, test with batch size 1 and your largest expected batch size explicitly — errors in dynamic shape handling almost always manifest at boundary conditions. Fourth, if using ONNX, run onnx.checker.check_model() on the exported file before deploying; it catches malformed graphs that would silently produce wrong outputs at inference time. Fifth, for TensorRT, verify that your engine's compute capability matches the target GPU — building on a 3090 and deploying on a 4090 works (same Ampere/Ada architecture), but building on an A100 and deploying on a T4 will fail at load time. Finally, always save the original PyTorch checkpoint alongside the exported artifact — engine files and TorchScript modules cannot be converted back to trainable PyTorch models, and you will need the original weights for any future fine-tuning or re-export.

When to Stay with Eager PyTorch

Not every model needs to be exported. If you are serving a research prototype, running inference in a Jupyter notebook, iterating rapidly on model architecture, or serving a model at low enough request volume that GPU utilisation is not a concern, the operational overhead of maintaining an export pipeline — tracking opset compatibility, rebuilding TensorRT engines when the model changes, managing calibration datasets for INT8 — is not worth the performance gain. Eager PyTorch with torch.compile (discussed in our separate article on torch.compile optimisation) provides meaningful speedups with none of the serialization complexity, and is often the right first step before committing to a full export pipeline. Export to ONNX Runtime or TensorRT makes sense when you have a stable model that you expect to serve for weeks or months without architectural changes, when throughput or latency requirements are well-defined and not met by torch.compile alone, and when your team has the infrastructure to version and redeploy engine artifacts when model updates are needed.

Leave a Comment