はじめに
この記事では、PyTorchモデルをONNX形式に変換し、変換されたモデルの検証、ONNXモデルを使用した推論のプロセスについて説明しています。
モデルの準備
まずはモデルの準備をします。
依存関係のインストール
必要なパッケージをpipを使ってインストールします。
$ pip install torch onnx onnxruntime
学習済みのPyTorchモデルの読み込み
この例では、ImageNetで事前学習されたResNet-18モデルを使用します。
import torch
import torchvision.models as models
model = models.resnet18(pretrained=True)
model.eval()
PyTorchモデルをONNXに変換
変換のプロセスには、モデルのエクスポート、モデルの入力と出力の寸法の設定、およびONNX形式でモデルを保存することが含まれます。
モデルのエクスポート
PyTorchモデルをONNX形式に変換する最初のステップは、torch.onnx.export()
関数を使用してモデルをエクスポートすることです。この関数には、次のパラメータが必要です。
model
: 変換したいPyTorchモデルargs
: モデルをトレースするために使用される入力テンソルのタプル。モデルに複数の入力がある場合は、単一の入力テンソルまたは入力テンソルのタプルを指定できます。f
: 出力ONNXファイルへのファイルライクオブジェクトまたはパスを含む文字列。
次に、事前学習済みのResNet-18モデルのエクスポートの例を示します。
import torch
import torchvision.models as models
# Load the pretrained ResNet-18 model
model = models.resnet18(pretrained=True)
model.eval()
# Create a dummy input tensor with the expected input shape
dummy_input = torch.randn(1, 3, 224, 224)
# Export the model to ONNX format
onnx_model_path = "resnet18.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path)
モデルの入力と出力の寸法の設定
デフォルトでは、エクスポートされたONNXモデルには自動的に入力名と出力名が割り当てられます。しかし、入力と出力テンソルに意味のある名前を指定することをお勧めします。これにより、後でモデルを扱うことがより容易になります。
input_names
およびoutput_names
パラメータをtorch.onnx.export()
関数に渡すことで、入力と出力の名前を設定できます。
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(model, dummy_input, onnx_model_path, input_names=input_names, output_names=output_names)
さらに、モデルが可変の入力次元をサポートしている場合は、入力および出力テンソルの動的な軸を指定することもできます。これを行うには、torch.onnx.export()
関数にdynamic_axes
パラメータを渡します。
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"}
}
torch.onnx.export(model, dummy_input, onnx_model_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
この例では、動的軸の辞書が入力テンソルに対して、バッチサイズ、高さ、幅の変数の次元を指定し、出力テンソルには可変のバッチサイズ次元を指定しています。
ONNXモデルの検証
変換されたONNXモデルの検証方法について説明します。これには、モデル変換の確認と、モデルレイヤーの調査が含まれます。
モデル変換の確認
PyTorchモデルをONNX形式に変換した後、変換が成功したかどうかを確認することが重要です。ONNXライブラリには、モデルがONNX仕様に従って有効であるかどうかを確認するcheck_model()
関数が用意されています。
モデル変換を確認するには、次の手順に従ってください。
onnx.load()
関数を使用してONNXモデルをロードonnx.checker.check_model()
関数を使用してモデルを確認- モデルがチェックに合格した場合は、成功メッセージを出力
import onnx
onnx_model_path = "resnet18.onnx"
onnx_model = onnx.load(onnx_model_path)
# Check if the ONNX model is valid
onnx.checker.check_model(onnx_model)
print("Model has been successfully converted!")
モデルレイヤーの調査
変換されたONNXモデルのレイヤーを調査することは、オリジナルのPyTorchモデルと一致しているかどうかを確認するのに役立ちます。ONNXライブラリには、printable_graph()
関数を使用して、人が読める形式でモデルのグラフを表示する機能があります。
ONNXモデルのレイヤーを調べるには、次の手順に従ってください。
onnx.load()
関数を使用してONNXモデルをロード(まだロードしていない場合)onnx.helper.printable_graph()
関数を使用して、モデルグラフの人が読める表現を取得- モデルグラフを出力
import onnx
onnx_model_path = "resnet18.onnx"
onnx_model = onnx.load(onnx_model_path)
# Print a human-readable representation of the ONNX model graph
printable_graph = onnx.helper.printable_graph(onnx_model.graph)
print(printable_graph)
これにより、モデルレイヤーとそれらの入力および出力テンソル名が出力されます。この出力を、オリジナルのPyTorchモデルと比較して、ONNXモデルが正しい表現であることを確認できます。
ONNXモデルを使用した推論
変換されたONNXモデルを使用した推論の方法について説明します。これには、ONNXモデルのロード、入力の前処理、推論の実行、および結果の後処理が含まれます。
ONNXモデルのロード
ONNXモデルを使用して推論を行うには、ONNX Runtimeライブラリを使用する必要があります。ONNX Runtimeは、さまざまなプラットフォームやデバイスに対応する高速推論エンジンで、ONNXモデルに互換性があります。
ONNX Runtimeを使用してONNXモデルをロードするには、次の手順に従ってください。
onnxruntime
パッケージをインポートonnxruntime.InferenceSession()
コンストラクタにONNXモデルファイルパスを渡してInferenceSession
オブジェクトを作成
import onnxruntime
onnx_model_path = "resnet18.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)
入力の前処理
ONNXモデルで推論を実行する前に、入力データをモデルの入力要件に合わせて前処理する必要があります。これには、入力イメージの読み込みとリサイズ、テンソルへの変換、ピクセル値の正規化、および期待されるバッチサイズに合わせて次元を拡張することが一般的です。
以下は、ResNet-18モデルの入力イメージを前処理する例です。
# Load and preprocess an image
from PIL import Image
import torchvision.transforms as transforms
image_path = "example_image.jpg"
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
推論の実行
入力データの前処理が完了したら、InferenceSession
オブジェクトのrun()
メソッドを使用して、推論を実行できます。run()
メソッドは、次のパラメータを受け取ります。
output_names
: モデルの出力テンソル名のリストinput_feed
: 入力テンソル名と入力データの対応するディクショナリ
推論を実行する例を次に示します。
import numpy as np
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Run inference
outputs = session.run([output_name], {input_name: input_batch.numpy()})
# Postprocess the results
output_array = np.array(outputs[0])
predicted_class = np.argmax(output_array)
print("Predicted class:", predicted_class)
これで、変換されたONNXモデルを使用して、PyTorchモデルと同様に推論を実行できるようになりました。
参考