
TensorFlow Model Conversion and Inference with ONNX


In this article, I will walk you through the process of converting a TensorFlow model to ONNX format, checking the converted model, and performing inference using the ONNX model.

Preparing the TensorFlow Model

In this chapter, I will cover the installation of necessary dependencies, the loading of the model.

Installing Dependencies

Before starting, ensure that you have Python 3.6 or higher installed on your system. Then, install the required packages for TensorFlow and tf2onnx using pip:

$ pip install tensorflow
$ pip install tf2onnx

Loading the Model

To load a TensorFlow model, you can use the load_model function from the tensorflow.keras.models module. This function can handle both SavedModel format and HDF5 format.

For a SavedModel, pass the path to the directory containing the model:

import tensorflow as tf

model_path = "path/to/your/tensorflow/saved_model"
model = tf.keras.models.load_model(model_path)

For an HDF5 model, pass the path to the .h5 file:

model_path = "path/to/your/tensorflow/model.h5"
model = tf.keras.models.load_model(model_path)

If your model has custom layers or components, you may need to provide a custom_objects dictionary to the load_model function:

custom_objects = {
    'CustomLayer': CustomLayer,
    'custom_metric': custom_metric

model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)

Converting the Model to ONNX

In this chapter, I will explore the process of converting a TensorFlow model to ONNX format using the tf2onnx library. We will discuss the conversion process, customization options.

Performing the Conversion with tf2onnx

To convert a TensorFlow model to ONNX format, we will use the convert.from_keras function from the tf2onnx module. This function takes the TensorFlow model as its main argument and returns a protobuf object representing the ONNX model:

import tf2onnx

# Perform the conversion
model_proto, _ = tf2onnx.convert.from_keras(model)

Next, save the protobuf object as an ONNX file:

output_onnx_model = "path/to/output/onnx/model.onnx"

with open(output_onnx_model, "wb") as f:

Customizing the Conversion Process

You can customize the conversion process by providing additional arguments to the from_keras() function. Some of the commonly used arguments are:

  • target_opset: Specify the target ONNX opset version to use. The default is the latest opset version supported by tf2onnx.
model_proto, _ = tf2onnx.convert.from_keras(model, target_opset=12)
  • large_model: If your model is too large to fit in memory, you can enable this option to convert the model in chunks.
model_proto, _ = tf2onnx.convert.from_keras(model, large_model=True)
  • output_names: Provide a list of output names for the ONNX model. By default, the output names from the TensorFlow model are used.
output_names = ["output1", "output2"]
model_proto, _ = tf2onnx.convert.from_keras(model, output_names=output_names)

For more customization options, refer to the tf2onnx documentation.


Checking the Converted ONNX Model

In this chapter, I will discuss how to inspect and validate the converted ONNX model. We will cover methods for verifying the model's structure, testing its accuracy.

Verifying the Model Structure

To inspect the structure of the converted ONNX model, use the onnx library to load the model and perform a series of checks:

import onnx

onnx_model = onnx.load(output_onnx_model)

This will raise an exception if the model is invalid or contains unsupported operations. If the model passes the checks, you can proceed to test its accuracy.

Testing the Model's Accuracy

To test the accuracy of the ONNX model, run inference on both the original TensorFlow model and the converted ONNX model, and compare their outputs. You can use the following steps:

  1. Prepare a sample input for the models:
import numpy as np

input_data = np.random.rand(1, 224, 224, 3).astype(np.float32)
  1. Run inference on the TensorFlow model:
tf_output = model.predict(input_data)
  1. Run inference on the ONNX model using the ONNX Runtime:
from onnxruntime import InferenceSession

sess = InferenceSession(output_onnx_model)
input_name = sess.get_inputs()[0].name
onnx_output = sess.run(None, {input_name: input_data})
  1. Compare the outputs and calculate the difference:
difference = np.abs(tf_output - onnx_output).max()
print("Difference:", difference)
  1. Check if the difference is within an acceptable range:
assert np.allclose(tf_output, onnx_output, rtol=1e-05, atol=1e-07), "Output values do not match"

Inference using the Converted ONNX Model

In this chapter, I will demonstrate how to perform inference using the converted ONNX model. We will discuss how to load the ONNX model with ONNX Runtime, prepare input data, run inference, and process the output.

Loading the ONNX Model

To load the converted ONNX model, use the ONNX Runtime's InferenceSession class:

from onnxruntime import InferenceSession

sess = InferenceSession(output_onnx_model)

Preparing Input Data

Prepare the input data for the model. The input data should match the expected input shape and data type of the model. For example:

import numpy as np

input_data = np.random.rand(1, 224, 224, 3).astype(np.float32)

Running Inference

To run inference with the ONNX model, call the run method on the InferenceSession object. You need to provide a dictionary that maps input names to their corresponding input data:

input_name = sess.get_inputs()[0].name
output = sess.run(None, {input_name: input_data})

In this example, we pass None as the first argument to the run method, which means that the outputs of all output nodes will be returned. Alternatively, you can provide a list of output names to obtain the outputs of specific nodes.

Processing the Output

After running inference, you can process the output to obtain the desired information or predictions. The output format will depend on the specific model and task. For example, if the model is a classifier, you might want to find the class with the highest probability:

predicted_class = np.argmax(output[0], axis=-1)
print("Predicted class:", predicted_class)



Ryusei Kakujo


Weave the future of cities through data

Transportation modeling/ Urban planning/ Machine learning/ Computer science/ GIS