How to automate PyTorch model export to ONNX for production

The real-world scenario

Imagine you are a Machine Learning Engineer who just finished training a state-of-the-art image classifier. Your DevOps team needs to deploy this model into a high-performance C++ environment or a cloud-native inference engine like NVIDIA Triton. They do not want the overhead of a full Python/PyTorch installation. This is where ONNX (Open Neural Network Exchange) serves as the universal translator.

Manually converting models every time you update weights is tedious and prone to human error, such as mismatching input dimensions. Think of this script as an Automated Shipping Inspector: it takes your custom-built furniture (the PyTorch model), packs it into a standardized shipping container (ONNX), and verifies the structural integrity before it leaves the factory.

Understand the solution

This technical recipe automates the process of tracing a PyTorch model, defining dynamic axes (to allow varying batch sizes in production), and performing a validation check to ensure the exported graph is mathematically sound. We utilize pathlib for robust cross-platform path management and torch.onnx for the core transformation.

Install dependencies

Ensure you have the necessary libraries installed in your environment:

pip install torch torchvision onnx

The code


"""
-----------------------------------------------------------------------
Authors: Sharanam & Vaishali Shah
Recipe: PyTorch to ONNX Automated Exporter
Intent: Convert a PyTorch model to a validated ONNX format with dynamic batching.
-----------------------------------------------------------------------
"""

import torch
import torch.nn as nn
import torchvision.models as models
from pathlib import Path
import onnx

def export_model_to_onnx(model_name: str, output_dir: str = "models"):
    # 1. Initialize the export path
    export_path = Path(output_dir)
    export_path.mkdir(parents=True, exist_ok=True)
    file_name = export_path / f"{model_name}.onnx"

    # 2. Load a pre-trained model (Example: ResNet18)
    print(f"Loading pre-trained {model_name}...")
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.eval()

    # 3. Create dummy input matching the model's expected shape
    # Shape: [Batch, Channels, Height, Width]
    dummy_input = torch.randn(1, 3, 224, 224)

    # 4. Define dynamic axes for flexible inference
    # This allows the production server to process any batch size
    dynamic_axes = {
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }

    # 5. Export the model
    print(f"Exporting model to {file_name}...")
    torch.onnx.export(
        model,
        dummy_input,
        str(file_name),
        export_params=True,        # Store trained parameter weights
        opset_version=12,          # Use a stable ONNX version
        do_constant_folding=True,  # Optimize by folding constant nodes
        input_names=["input"],     # Define input node name
        output_names=["output"],   # Define output node name
        dynamic_axes=dynamic_axes
    )

    # 6. Validate the exported model
    print("Validating ONNX model integrity...")
    onnx_model = onnx.load(str(file_name))
    try:
        onnx.checker.check_model(onnx_model)
        print(f"Success! Model exported and verified at: {file_name}")
    except onnx.checker.ValidationError as e:
        print(f"Validation failed: {e}")

if __name__ == "__main__":
    export_model_to_onnx("resnet18_production")

Review the code logic

The script follows a rigorous production-grade workflow:

  • Path Management: We use Pathlib to create the models directory if it does not exist, preventing FileNotFoundError on different operating systems.
  • Model State: We call model.eval(). This is critical because it disables Dropout and Batch Normalization behavior used during training, ensuring consistent inference results.
  • Dummy Input: PyTorch’s ONNX exporter uses tracing. It tracks the flow of a dummy_input through the network to map the execution graph.
  • Dynamic Axes: By default, exported models have fixed input sizes. We explicitly define batch_size as a dynamic dimension so the model can handle a batch size of 1, 8, or 32 without re-exporting.
  • Integrity Check: The onnx.checker.check_model function verifies that the exported file adheres to the ONNX specification and that all operators are supported.

Observe the execution results

When you execute this script in your Terminal or PowerShell, you will see the following progress:


Loading pre-trained resnet18_production...
Exporting model to models/resnet18_production.onnx...
Validating ONNX model integrity...
Success! Model exported and verified at: models/resnet18_production.onnx

Final thoughts

Automating your model export pipeline reduces the friction between the Research and Production stages. By utilizing dynamic axes and automated validation, you ensure that your deployment pipeline is robust, scalable, and error-free. This script provides a solid foundation for any CI/CD pipeline involving Deep Learning models.


🚀 Don’t Just Learn PyTorch — Master It.

This tutorial was just the tip of the iceberg. To truly advance your career and build professional-grade systems, you need the full architectural blueprint.

My book, PyTorch Crash Course, takes you from “making it work” to “making it scale.” I cover advanced patterns, real-world case studies, and the industry best practices that senior engineers use daily.


📖 Grab Your Copy Now →