How to automate PyTorch model export to ONNX for production
The real-world scenario
Imagine you are a Machine Learning Engineer who just finished training a state-of-the-art image classifier. Your DevOps team needs to deploy this model into a high-performance C++ environment or a cloud-native inference engine like NVIDIA Triton. They do not want the overhead of a full Python/PyTorch installation. This is where ONNX (Open Neural Network Exchange) serves as the universal translator.
Manually converting models every time you update weights is tedious and prone to human error, such as mismatching input dimensions. Think of this script as an Automated Shipping Inspector: it takes your custom-built furniture (the PyTorch model), packs it into a standardized shipping container (ONNX), and verifies the structural integrity before it leaves the factory.
Understand the solution
This technical recipe automates the process of tracing a PyTorch model, defining dynamic axes (to allow varying batch sizes in production), and performing a validation check to ensure the exported graph is mathematically sound. We utilize pathlib for robust cross-platform path management and torch.onnx for the core transformation.
Install dependencies
Ensure you have the necessary libraries installed in your environment:
pip install torch torchvision onnxThe code
"""
-----------------------------------------------------------------------
Authors: Sharanam & Vaishali Shah
Recipe: PyTorch to ONNX Automated Exporter
Intent: Convert a PyTorch model to a validated ONNX format with dynamic batching.
-----------------------------------------------------------------------
"""
import torch
import torch.nn as nn
import torchvision.models as models
from pathlib import Path
import onnx
def export_model_to_onnx(model_name: str, output_dir: str = "models"):
# 1. Initialize the export path
export_path = Path(output_dir)
export_path.mkdir(parents=True, exist_ok=True)
file_name = export_path / f"{model_name}.onnx"
# 2. Load a pre-trained model (Example: ResNet18)
print(f"Loading pre-trained {model_name}...")
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval()
# 3. Create dummy input matching the model's expected shape
# Shape: [Batch, Channels, Height, Width]
dummy_input = torch.randn(1, 3, 224, 224)
# 4. Define dynamic axes for flexible inference
# This allows the production server to process any batch size
dynamic_axes = {
"input": {0: "batch_size"},
"output": {0: "batch_size"}
}
# 5. Export the model
print(f"Exporting model to {file_name}...")
torch.onnx.export(
model,
dummy_input,
str(file_name),
export_params=True, # Store trained parameter weights
opset_version=12, # Use a stable ONNX version
do_constant_folding=True, # Optimize by folding constant nodes
input_names=["input"], # Define input node name
output_names=["output"], # Define output node name
dynamic_axes=dynamic_axes
)
# 6. Validate the exported model
print("Validating ONNX model integrity...")
onnx_model = onnx.load(str(file_name))
try:
onnx.checker.check_model(onnx_model)
print(f"Success! Model exported and verified at: {file_name}")
except onnx.checker.ValidationError as e:
print(f"Validation failed: {e}")
if __name__ == "__main__":
export_model_to_onnx("resnet18_production")
Review the code logic
The script follows a rigorous production-grade workflow:
- Path Management: We use Pathlib to create the models directory if it does not exist, preventing FileNotFoundError on different operating systems.
- Model State: We call model.eval(). This is critical because it disables Dropout and Batch Normalization behavior used during training, ensuring consistent inference results.
- Dummy Input: PyTorch’s ONNX exporter uses tracing. It tracks the flow of a dummy_input through the network to map the execution graph.
- Dynamic Axes: By default, exported models have fixed input sizes. We explicitly define batch_size as a dynamic dimension so the model can handle a batch size of 1, 8, or 32 without re-exporting.
- Integrity Check: The onnx.checker.check_model function verifies that the exported file adheres to the ONNX specification and that all operators are supported.
Observe the execution results
When you execute this script in your Terminal or PowerShell, you will see the following progress:
Loading pre-trained resnet18_production...
Exporting model to models/resnet18_production.onnx...
Validating ONNX model integrity...
Success! Model exported and verified at: models/resnet18_production.onnx
Final thoughts
Automating your model export pipeline reduces the friction between the Research and Production stages. By utilizing dynamic axes and automated validation, you ensure that your deployment pipeline is robust, scalable, and error-free. This script provides a solid foundation for any CI/CD pipeline involving Deep Learning models.
🚀 Don’t Just Learn PyTorch — Master It.
This tutorial was just the tip of the iceberg. To truly advance your career and build professional-grade systems, you need the full architectural blueprint.
My book, PyTorch Crash Course, takes you from “making it work” to “making it scale.” I cover advanced patterns, real-world case studies, and the industry best practices that senior engineers use daily.