2023-03-05

PyTorch Model Conversion and Inference with ONNX

Introduction

In this article, I will walk you through the process of converting a PyTorch model to ONNX format, checking the converted model, and performing inference using the ONNX model.

Preparing Your Model

Let's prepare model.

Installing Dependencies

To get started, install the required packages using pip:

$ pip install torch onnx onnxruntime

Loading a Pretrained PyTorch Model

For this example, let's use the ResNet-18 model pretrained on ImageNet:

python
import torch
import torchvision.models as models

model = models.resnet18(pretrained=True)
model.eval()

Converting the PyTorch Model to ONNX

The conversion process involves exporting the model, setting the model input and output dimensions, and saving the model in ONNX format.

Exporting the Model

The first step in converting a PyTorch model to ONNX format is exporting the model using the torch.onnx.export() function. This function requires the following parameters:

  • model: The PyTorch model you want to convert.
  • args: A tuple containing the input tensors used for tracing the model. This can be a single input tensor or a tuple of input tensors if the model has multiple inputs.
  • f: A file-like object or a string containing the path to the output ONNX file.

Here's an example of exporting a pretrained ResNet-18 model:

python
import torch
import torchvision.models as models

# Load the pretrained ResNet-18 model
model = models.resnet18(pretrained=True)
model.eval()

# Create a dummy input tensor with the expected input shape
dummy_input = torch.randn(1, 3, 224, 224)

# Export the model to ONNX format
onnx_model_path = "resnet18.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path)

Setting the Model Input and Output Dimensions

By default, the exported ONNX model will have input and output names assigned automatically. However, it is recommended to provide meaningful names for the input and output tensors, as they will make it easier to work with the model later on.

You can set the input and output names by passing the input_names and output_names parameters to the torch.onnx.export() function, like this:

python
input_names = ["input"]
output_names = ["output"]

torch.onnx.export(model, dummy_input, onnx_model_path, input_names=input_names, output_names=output_names)

Additionally, you can specify dynamic axes for the input and output tensors if your model supports variable input dimensions. To do this, pass the dynamic_axes parameter to the torch.onnx.export() function:

python
dynamic_axes = {
    "input": {0: "batch_size", 2: "height", 3: "width"},
    "output": {0: "batch_size"}
}

torch.onnx.export(model, dummy_input, onnx_model_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)

In this example, the dynamic axes dictionary specifies that the input tensor has variable dimensions for the batch size, height, and width, while the output tensor has a variable batch size dimension.

Checking the ONNX Model

I will provide an explanation of how to check the converted ONNX model. This includes verifying the model conversion and inspecting the model layers.

Verifying Model Conversion

After converting your PyTorch model to ONNX format, it is essential to verify that the conversion was successful. The ONNX library provides a check_model() function to check whether the model is valid according to the ONNX specification.

To verify the model conversion, follow these steps:

  1. Load the ONNX model using the onnx.load() function.
  2. Check the model using the onnx.checker.check_model() function.
  3. Print a success message if the model passes the check.

Here's an example:

python
import onnx

onnx_model_path = "resnet18.onnx"
onnx_model = onnx.load(onnx_model_path)

# Check if the ONNX model is valid
onnx.checker.check_model(onnx_model)
print("Model has been successfully converted!")

Inspecting Model Layers

It can be helpful to inspect the layers of the converted ONNX model to ensure that they match the original PyTorch model. The ONNX library provides a printable_graph() function to display the model's graph in a human-readable format.

To inspect the layers of the ONNX model, follow these steps:

  1. Load the ONNX model using the onnx.load() function (if you haven't already).
  2. Obtain a printable representation of the model graph using the onnx.helper.printable_graph() function.
  3. Print the model graph.

Here's an example:

python
import onnx

onnx_model_path = "resnet18.onnx"
onnx_model = onnx.load(onnx_model_path)

# Print a human-readable representation of the ONNX model graph
printable_graph = onnx.helper.printable_graph(onnx_model.graph)
print(printable_graph)

This will print the model layers and their respective input and output tensor names. You can compare this output with the original PyTorch model to ensure that the ONNX model is a correct representation of your PyTorch model.

Inference with the Converted ONNX Model

I will show how to perform inference with the converted ONNX model. This includes loading the ONNX model, preprocessing the input, running inference, and postprocessing the results.

Loading the ONNX Model

To perform inference with the ONNX model, you need to use the ONNX Runtime library. ONNX Runtime is a high-performance inference engine for ONNX models, designed to be compatible with various platforms and devices.

To load the ONNX model using ONNX Runtime, follow these steps:

  1. Import the onnxruntime package.
  2. Create an InferenceSession object by passing the ONNX model file path to the onnxruntime.InferenceSession() constructor.

Here's an example:

python
import onnxruntime

onnx_model_path = "resnet18.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)

Preprocessing the Input

Before running inference with the ONNX model, you need to preprocess the input data to match the model's input requirements. This typically involves loading and resizing the input image, converting it to a tensor, normalizing the pixel values, and expanding the dimensions to match the expected batch size.

Here's an example of preprocessing an input image for a ResNet-18 model:

python
# Load and preprocess an image
from PIL import Image
import torchvision.transforms as transforms

image_path = "example_image.jpg"
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

Running Inference and Postprocessing Results

Once the input data has been preprocessed, you can run inference using the run() method of the InferenceSession object. The run() method takes the following parameters:

  • output_names: A list of the model's output tensor names.
  • input_feed: A dictionary mapping input tensor names to their corresponding input data.

Here's an example of running inference:

import numpy as np

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Run inference
outputs = session.run([output_name], {input_name: input_batch.numpy()})

# Postprocess the results
output_array = np.array(outputs[0])
predicted_class = np.argmax(output_array)
print("Predicted class:", predicted_class)

Now you can run inference using the converted ONNX model just like you would with a PyTorch model.

References

https://pytorch.org/docs/stable/onnx.html

Ryusei Kakujo

researchgatelinkedingithub

Focusing on data science for mobility

Bench Press 100kg!