PyTorchモデルのエクスポート
カスタムONNX演算子を含むPyTorchモデルのエクスポート
Section titled “カスタムONNX演算子を含むPyTorchモデルのエクスポート”このドキュメントでは、カスタムONNX Runtime演算子を含むPyTorchモデルをエクスポートするプロセスについて説明します。目的は、ONNXでサポートされていない演算子を持つPyTorchモデルをエクスポートし、ONNX Runtimeを拡張してこれらのカスタム演算子をサポートすることです。
組み込みContrib演算子のエクスポート
Section titled “組み込みContrib演算子のエクスポート”「Contrib演算子」とは、ほとんどのORTパッケージに組み込まれているカスタム演算子のセットを指します。 すべてのContrib演算子のシンボリック関数は、pytorch_export_contrib_ops.pyで定義する必要があります。
これらのContrib演算子を使用してエクスポートするには、torch.onnx.export()を呼び出す前にpytorch_export_contrib_ops.register()を呼び出します。例:
from onnxruntime.tools import pytorch_export_contrib_opsimport torch
pytorch_export_contrib_ops.register()torch.onnx.export(...)カスタム演算子のエクスポート
Section titled “カスタム演算子のエクスポート”Contrib演算子ではない、またはpytorch_export_contrib_opsにまだ含まれていないカスタム演算子をエクスポートするには、
カスタム演算子のシンボリック関数を記述して登録する必要があります。
逆演算子を例にとります:
from torch.onnx import register_custom_op_symbolic
def my_inverse(g, self): return g.op("com.microsoft::Inverse", self)
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)register_custom_op_symbolic('::inverse', my_inverse, 1)<namespace>は、torch演算子名の一部です。標準のtorch演算子の場合、名前空間は省略できます。
com.microsoftは、ONNX Runtime演算子のカスタムopsetドメインとして使用する必要があります。演算子の登録中にカスタムopsetバージョンを選択できます。
シンボリック関数の記述の詳細については、torch.onnxのドキュメントを参照してください。
ONNX Runtimeをカスタム演算子で拡張する
Section titled “ONNX Runtimeをカスタム演算子で拡張する”次のステップは、ONNX Runtimeに演算子スキーマとカーネル実装を追加することです。 詳細については、カスタム演算子を参照してください。
エンドツーエンドのテスト:エクスポートと実行
Section titled “エンドツーエンドのテスト:エクスポートと実行”カスタム演算子がエクスポーターに登録され、ONNX Runtimeに実装されると、エクスポートしてONNX Runtimeで実行できるようになります。
以下に、モデルの一部として逆演算子をエクスポートして実行するためのサンプルスクリプトを示します。
エクスポートされたモデルには、ONNX標準演算子とカスタム演算子の組み合わせが含まれています。
このテストでは、PyTorchモデルの出力とONNX Runtimeの出力を比較して、演算子のエクスポートと実装の両方をテストします。
import ioimport numpyimport onnxruntimeimport torch
class CustomInverse(torch.nn.Module): def forward(self, x): return torch.inverse(x) + x
x = torch.randn(3, 3)
# モデルをONNXにエクスポートf = io.BytesIO()torch.onnx.export(CustomInverse(), (x,), f)
model = CustomInverse()pt_outputs = model(x)
# エクスポートされたモデルをONNX Runtimeで実行ort_sess = onnxruntime.InferenceSession(f.getvalue())ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))ort_outputs = ort_sess.run(None, ort_inputs)
# PyTorchとONNX Runtimeの結果を検証numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)デフォルトでは、カスタムopsetのopsetバージョンは1に設定されます。
カスタム演算子をより高いopsetバージョンにエクスポートしたい場合は、エクスポートAPIを呼び出す際にcustom_opsets引数を使用してカスタムopsetドメインとバージョンを指定できます。これは、デフォルトのONNXドメインに関連付けられているopsetバージョンとは異なることに注意してください。
torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})カスタム演算子は、登録時に使用されたopsetバージョン以上の任意のバージョンにエクスポートできることに注意してください。