もう迷わない!PyTorch fx.Graph.call_module()のエラーと解決策
torch.fx.Graph.call_module()
とは
torch.fx
では、PyTorch モデル (nn.Module
) の forward
メソッドが実行される際に、その内部の演算がノードという形でグラフに記録されます。このノードにはいくつかの種類があり、その中の一つが call_module
ノードです。
call_module
ノードは、ある nn.Module
の forward
メソッドが、その子モジュール(サブモジュール)を呼び出していることを表します。
具体的には、以下のような情報がノードとして記録されます。
kwargs
: 呼び出しに使われるキーワード引数(辞書)。args
: 呼び出しに使われる位置引数(タプル)。target
: 呼び出されているモジュールの完全修飾名(例:self.linear
など)。name
: ノードの一意な名前です。opcode
: ノードの種類を表します。call_module
の場合は'call_module'
となります。
例
簡単な例で考えてみましょう。
import torch
import torch.nn as nn
import torch.fx
class MySubModule(nn.Module):
def forward(self, x):
return x * 2
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.sub_module = MySubModule()
def forward(self, x):
x = self.linear(x)
x = self.sub_module(x)
return x
# モデルをトレースしてグラフを生成
m = MyModule()
traced_graph = torch.fx.symbolic_trace(m)
# グラフのノードを出力
for node in traced_graph.graph.nodes:
print(f"opcode: {node.op}, name: {node.name}, target: {node.target}, args: {node.args}, kwargs: {node.kwargs}")
このコードを実行すると、以下のような出力の一部が得られるでしょう(完全な出力ではありませんが、関連する部分を抜粋します)。
opcode: placeholder, name: x, target: x, args: (), kwargs: {}
opcode: call_module, name: linear, target: linear, args: (x,), kwargs: {}
opcode: call_module, name: sub_module, target: sub_module, args: (linear,), kwargs: {}
opcode: output, name: output, target: output, args: (sub_module,), kwargs: {}
ここで注目すべきは、opcode: call_module
の行です。
name: sub_module
,target: sub_module
: これはMyModule
のself.sub_module
が呼び出されていることを示します。name: linear
,target: linear
: これはMyModule
のself.linear
が呼び出されていることを示します。
torch.fx
を使用してモデルを変換したり最適化したりする際に、call_module
ノードは非常に重要な役割を果たします。
- モデルの構造の理解: グラフを走査することで、どのサブモジュールが、どのような順番で、どのような引数で呼び出されているかを把握できます。
- 変換のターゲット: 特定の種類のモジュール(例:
nn.Conv2d
やnn.BatchNorm2d
)の呼び出しを特定し、それらを別の実装に置き換えたり、結合したりする(例: Conv-BN融合)といった変換を行う際に、call_module
ノードをターゲットとします。 - 部分的な最適化: モデル全体ではなく、特定のサブモジュールに対して量子化や枝刈りなどの最適化を適用する場合、
call_module
ノードを使ってそのサブモジュールを識別します。
torch.fx.Graph.call_module()
の一般的なエラーとトラブルシューティング
ModuleNotFoundError: No module named 'torch.fx'
これは torch.fx
を使用する上で最も基本的なエラーです。
- トラブルシューティング
- PyTorch のバージョンを確認します。
torch.fx
は PyTorch 1.8.0 以降で導入されました。それ以前のバージョンを使用している場合は、PyTorch をアップグレードする必要があります。pip install torch torchvision torchaudio --upgrade
import torch.fx
またはfrom torch.fx import symbolic_trace
のように正しくインポートされているか確認します。
- PyTorch のバージョンを確認します。
- エラーの原因
torch.fx
モジュールが見つからない。
torch.fx.symbolic_trace が意図したとおりに動作しない(Graph に call_module ノードが欠落している、または不正確)
これは、モデルの forward
メソッドにトレースできない操作が含まれている場合に起こりやすいです。
- トラブルシューティング
- トレース可能なコードに変換
- 動的な制御フローを避けるか、テンソルの形状に依存しないように修正します。可能であれば、
torch.where
のようなテンソル操作に置き換えることを検討します。 - リストや辞書を介したテンソルの受け渡しを避け、テンソル自体を直接操作するようにします。
- サポートされていない操作がある場合、それらをカスタムの
nn.Module
にラップし、Tracer
のis_leaf_module
メソッドをオーバーライドして、そのカスタムモジュールを「葉」として扱うように指定することで、その内部のトレースをスキップできます。import torch import torch.nn as nn from torch.fx import Tracer, symbolic_trace, GraphModule class MyUnfriendlyOp(nn.Module): def forward(self, x): # トレースしにくい操作 (例: 動的なリスト生成) return torch.stack([x, x * 2]) class MyModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) self.unfriendly = MyUnfriendlyOp() def forward(self, x): x = self.linear(x) x = self.unfriendly(x) return x # MyUnfriendlyOp を葉モジュールとして扱うカスタム Tracer class CustomTracer(Tracer): def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: if isinstance(m, MyUnfriendlyOp): return True # このモジュールは内部をトレースしない return super().is_leaf_module(m, module_qualified_name) m = MyModule() traced_graph_module = symbolic_trace(m, tracer_class=CustomTracer) # この場合、MyUnfriendlyOp は call_module ノードとして現れる for node in traced_graph_module.graph.nodes: print(node)
- 動的な制御フローを避けるか、テンソルの形状に依存しないように修正します。可能であれば、
- print デバッグ
symbolic_trace
の前後でモデルのforward
メソッドにprint
ステートメントを追加し、トレースがどこで中断されているか、またはどの値が期待と異なるかを確認します。
- トレース可能なコードに変換
- エラーの原因
- Python の動的な制御フロー
if
/else
、ループ (for
,while
) など、入力テンソルの値に依存する制御フローは、静的なグラフとして表現することができません。call_module
が条件分岐の内側にある場合、そのモジュールがトレースされないことがあります。 - Python のネイティブなデータ構造の操作
リスト、辞書などの Python オブジェクトを直接操作する(特にテンソル以外のデータを扱う)と、トレースが中断されることがあります。 - サポートされていないPyTorchの関数/オペレーション
torch.fx
はほとんどのtorch.nn
モジュールやtorch
の関数をサポートしていますが、一部の特殊な操作(例:torch.arange
でサイズが動的に決定される場合など)はトレースできないことがあります。 - 外部ライブラリの呼び出し
NumPyなどのPyTorch以外のライブラリの関数を直接呼び出すと、トレースが中断されます。
- Python の動的な制御フロー
AttributeError: 'GraphModule' object has no attribute 'xxx'
グラフ変換後に、オリジナルの nn.Module
の属性が GraphModule
に引き継がれていない場合に発生することがあります。
- トラブルシューティング
- GraphModule への属性の追加
必要な属性がGraphModule
に存在しない場合、手動で追加するか、変換ロジックで考慮する必要があります。import torch.nn as nn from torch.fx import symbolic_trace, GraphModule class MyModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) self.custom_value = 42 # forward で使わない def forward(self, x): return self.linear(x) m = MyModule() traced_gm = symbolic_trace(m) # traced_gm.custom_value は存在しないため AttributeError になる # print(traced_gm.custom_value) # 必要な場合は手動で追加 traced_gm.custom_value = m.custom_value print(traced_gm.custom_value)
- サブモジュールの管理
call_module
ノードを操作する際、そのノードが参照するサブモジュールがGraphModule
の_modules
ディクショナリに適切に登録されていることを確認してください。GraphModule.add_submodule()
やGraphModule.delete_submodule()
を適切に使用します。
- GraphModule への属性の追加
- エラーの原因
torch.fx.symbolic_trace
は、モデルのforward
メソッドを通じてアクセスされる属性のみをグラフに含めます。例えば、__init__
で定義されているがforward
で使われていない属性や、トレース後に手動で追加した属性は、生成されたGraphModule
には存在しません。特に、call_module
ノードを削除したり、そのtarget
を変更したりすると、参照が壊れる可能性があります。
変換後の GraphModule が元のモデルと同じ出力を生成しない
これはデバッグが難しい場合が多いですが、call_module
ノードの引数や出力の不一致が原因となることがあります。
- トラブルシューティング
- graph.lint() の使用
グラフを変更するたびに、またはデバッグの際にはgraph.lint()
を呼び出す習慣をつけましょう。これにより、多くの不正なグラフ構造の変更を早期に検出できます。 - ノードの入出力の確認
print(node.args)
やprint(node.kwargs)
を使って、各call_module
ノードに渡されている引数を詳細に確認します。特に、変更後のグラフでそれが正しいノードを参照しているかを確認します。 - 中間出力の比較
変換前と変換後のモデルで、特定の中間層の出力を比較します。これにより、どこで出力が乖離し始めたか特定しやすくなります。 - 小さな単位で変更しテスト
一度に大きな変更を加えるのではなく、小さな変換を適用してはテストを繰り返すことで、問題の切り分けが容易になります。
- graph.lint() の使用
- エラーの原因
- ノードの引数/キーワード引数の不正確な変更
call_module
ノードのargs
やkwargs
を変更した際、その変更が元のモジュールの期待する入力形式と異なる場合。 - 依存関係の誤り
ノード間の依存関係(node.args
やnode.kwargs
が他のノードの出力を参照している場合)を正しく管理できていない。 - グラフの健全性チェックの怠り
グラフを変更した後、graph.lint()
を呼び出して、グラフが有効な状態であることを確認していない。
- ノードの引数/キーワード引数の不正確な変更
TypeError: forward() missing N required positional arguments / TypeError: got an unexpected keyword argument 'xxx'
これは、call_module
ノードの引数が、呼び出されるサブモジュールの forward
メソッドのシグネチャと一致しない場合に発生します。
- トラブルシューティング
- シグネチャの確認
呼び出されるサブモジュールのforward
メソッドの正確なシグネチャ(引数の名前、順序、デフォルト値など)を確認します。 - ノードの引数の修正
call_module
ノードのargs
とkwargs
を、サブモジュールのforward
メソッドに合うように調整します。 - Python の inspect モジュール
inspect.signature
を使用して、モジュールのforward
メソッドのシグネチャをプログラム的に取得し、それに基づいてノードの引数を生成することができます。
- シグネチャの確認
- エラーの原因
call_module
ノードのargs
やkwargs
を手動で操作した際に、引数の数や名前がサブモジュールのforward
メソッドと合わなくなった。- オリジナルのモジュールが複雑な
forward
シグネチャ(例:*args
,**kwargs
を多用)を持っており、symbolic_trace
がそれを正確に再現できなかった。
- GraphModule の code プロパティの確認
変換後のGraphModule
のcode
プロパティ(print(traced_gm.code)
)を見ると、生成された Python コードを確認できます。これにより、意図しない挙動になっている箇所を見つけやすくなります。 - 段階的なアプローチ
複雑なモデルの場合、一度に全体をトレース・変換しようとせず、小さなサブモジュールごとに試したり、段階的に変換を適用したりすることで、問題の原因を特定しやすくなります。 - PyTorch フォーラムやGitHub Issuesの検索
遭遇したエラーメッセージや状況は、他のユーザーも経験している可能性があります。フォーラムやGitHubで検索することで、解決策が見つかることがあります。 - PyTorch と torch.fx のドキュメント参照
公式ドキュメントは最も正確で最新の情報源です。特にtorch.fx
の章は、内部動作を理解するために非常に役立ちます。 - 最小限の再現コード
エラーが発生した場合は、問題を再現できる最小限のコードを作成するように努めます。これにより、問題を特定しやすくなります。
以下の例では、モデルのトレース、call_module
ノードの識別、およびそのノードを操作する基本的な方法を示します。
例1: モデルのトレースと call_module
ノードの識別
この例では、シンプルなモデルをトレースし、生成されたグラフから call_module
ノードを抽出し、その情報を表示します。
import torch
import torch.nn as nn
import torch.fx
# 1. シンプルなPyTorchモデルの定義
class SubModuleA(nn.Module):
def forward(self, x):
print("Executing SubModuleA")
return x + 1
class SubModuleB(nn.Module):
def forward(self, x):
print("Executing SubModuleB")
return x * 2
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.sub_a = SubModuleA()
self.sub_b = SubModuleB()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.sub_a(x)
x = self.sub_b(x)
x = self.linear2(x)
return x
# 2. モデルのインスタンス化とダミー入力の準備
model = MyModel()
dummy_input = torch.randn(1, 10)
# 3. モデルをトレースしてグラフを生成
# symbolic_trace はモデルの forward メソッドのシンボリック実行を行い、グラフを構築します
traced_model = torch.fx.symbolic_trace(model)
print("--- 生成されたグラフのノード ---")
# 4. グラフのノードをイテレートし、call_module ノードを識別
for node in traced_model.graph.nodes:
# node.op はノードの種類(opcode)を表します
if node.op == 'call_module':
print(f" Call Module Node:")
print(f" Name: {node.name}") # グラフ内のノードの一意な名前
print(f" Target: {node.target}") # 呼び出されるサブモジュールの名前 (例: 'linear1', 'sub_a')
print(f" Args: {node.args}") # このノードへの入力引数 (タプル)
print(f" Kwargs: {node.kwargs}") # このノードへのキーワード引数 (辞書)
print(f" Module Instance: {getattr(traced_model, node.target)}") # 実際のモジュールインスタンス
else:
print(f" Other Node: {node.op} - {node.name}")
print("\n--- トレースされたモデルのコード ---")
# トレースされたモデルの内部で生成されたPythonコードを表示
print(traced_model.code)
print("\n--- トレースされたモデルの実行テスト ---")
# トレースされたモデルは通常の nn.Module と同様に実行できます
output = traced_model(dummy_input)
print(f"出力形状: {output.shape}")
解説
traced_model.code
: トレースによって内部的に生成された Python コードを表示します。これは、torch.fx
がどのようにモデルを再構築したかを理解するのに非常に役立ちます。node.target
: この属性は、MyModel
の__init__
で定義されたサブモジュールの名前(例:self.linear1
なら'linear1'
)に対応します。これは、トレースされたGraphModule
の属性として、元のモジュールのインスタンスが保持されています(getattr(traced_model, node.target)
でアクセス可能)。node.op == 'call_module'
: ノードの種類が'call_module'
であるかどうかをチェックしています。これは、nn.Module
の子モジュールが呼び出されたことを意味します。symbolic_trace(model)
: これがtorch.fx
の中核となる関数で、model
のforward
メソッドをシンボリックに実行し、PyTorch の演算をGraph
オブジェクト内のNode
のコレクションに変換します。
この例では、call_module
ノードを見つけて、そのノードが参照するサブモジュールを別のものに置き換える方法を示します。ここでは、SubModuleA
を SubModuleC
に置き換えます。
import torch
import torch.nn as nn
import torch.fx
from torch.fx.graph import Graph, Node
from torch.fx import symbolic_trace, GraphModule
# 元のモデル定義 (例1と同じ)
class SubModuleA(nn.Module):
def forward(self, x):
print("Executing SubModuleA")
return x + 1
class SubModuleB(nn.Module):
def forward(self, x):
print("Executing SubModuleB")
return x * 2
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.sub_a = SubModuleA()
self.sub_b = SubModuleB()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.sub_a(x) # ここを置き換える
x = self.sub_b(x)
x = self.linear2(x)
return x
# 新しいサブモジュール
class SubModuleC(nn.Module):
def forward(self, x):
print("Executing SubModuleC (REPLACED!)")
return x - 5 # 演算を変更
# 1. モデルをトレース
model = MyModel()
dummy_input = torch.randn(1, 10)
traced_model = symbolic_trace(model)
print("--- 変換前のグラフのコード ---")
print(traced_model.code)
# 2. グラフをイテレートし、特定の call_module ノードを見つける
new_graph = Graph()
env = {} # 古いノードと新しいノードのマッピングを保持
for node in traced_model.graph.nodes:
# 既存のノードを新しいグラフにコピー
# args と kwargs の参照を env に従って更新する
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
# 'sub_a' というターゲットを持つ call_module ノードを探す
if new_node.op == 'call_module' and new_node.target == 'sub_a':
print(f"\n--- 'sub_a' ノードを置き換えます ---")
# 新しいサブモジュールを GraphModule に追加
# ここでは新しい名前 'replaced_sub_a' で追加
traced_model.add_submodule('replaced_sub_a', SubModuleC())
# ノードのターゲットを新しいサブモジュール名に変更
new_node.target = 'replaced_sub_a'
print(f" 変更後ターゲット: {new_node.target}")
# 3. 新しいグラフで GraphModule を再構築
# オリジナルの traced_model._modules を新しい GraphModule にコピー
# add_submodule で追加されたモジュールも含まれる
new_traced_model = GraphModule(traced_model, new_graph)
print("\n--- 変換後のグラフのコード ---")
print(new_traced_model.code)
print("\n--- 変換前と変換後のモデルの実行結果比較 ---")
# 変換前のモデルを実行
print("\n[元のモデルの実行]")
original_output = model(dummy_input)
print(f"元のモデルの出力: {original_output.item()}")
# 変換後のモデルを実行
print("\n[変換後のモデルの実行]")
modified_output = new_traced_model(dummy_input)
print(f"変換後のモデルの出力: {modified_output.item()}")
# 比較のために、手動で計算してみる
# x_init = dummy_input
# x_linear1 = model.linear1(x_init)
# x_relu = model.relu(x_linear1)
# x_sub_a = model.sub_a(x_relu) # 元のパス
# x_sub_c = SubModuleC()(x_relu) # 置き換え後のパス
# x_sub_b = model.sub_b(x_sub_a or x_sub_c)
# x_linear2 = model.linear2(x_sub_b)
- グラフのコピーと操作
new_graph = Graph()
とenv = {}
を使って、新しいグラフを構築しながら元のグラフのノードをコピーしています。これは、グラフを安全に操作するための一般的なパターンです。node_copy
はノードを新しいグラフにコピーし、env
を使って古いノードの参照を新しいノードにマッピングします。new_node.op == 'call_module' and new_node.target == 'sub_a'
: これにより、置き換えたい特定のcall_module
ノードを識別します。traced_model.add_submodule('replaced_sub_a', SubModuleC())
: 新しいSubModuleC
のインスタンスを、トレースされたGraphModule
のサブモジュールとして追加します。この操作により、GraphModule
はこの新しいモジュールを管理できるようになります。重要なのは、GraphModule
にはサブモジュールの実際のインスタンスがディクショナリ形式で保持されている点です。new_node.target = 'replaced_sub_a'
:call_module
ノードのtarget
属性を新しいモジュールの名前に変更します。これにより、このノードが実行された際にSubModuleC
が呼び出されるようになります。
- GraphModule の再構築
new_traced_model = GraphModule(traced_model, new_graph)
: 変更したnew_graph
を使用して新しいGraphModule
インスタンスを作成します。この際、traced_model
(元のGraphModule
) を最初の引数として渡すことで、既存のサブモジュール (linear1
,relu
など、そしてadd_submodule
で追加したreplaced_sub_a
) が新しいGraphModule
に引き継がれます。
- 実行結果の比較
元のモデルと変換後のモデルを実行し、出力が期待通りに変化したことを確認します。SubModuleA
がx + 1
であったのに対し、SubModuleC
はx - 5
であるため、出力値が異なるはずです。
torch.compile (推奨)
call_module() との関係
torch.compile
は内部で FX グラフ(したがって call_module
ノードも含む)を生成・操作しますが、ユーザーが直接 call_module()
ノードを操作する必要はありません。torch.compile
が自動的にモデルを解析し、最適化されたグラフを構築します。
利点
- Python フォールバック
トレースできない部分があっても、自動的に Python 実行にフォールバックするため、エラーになりにくいです。 - 広範なサポート
ほとんどの PyTorch モデルやデータ依存の制御フロー(if/else
など)を処理できます。 - 簡単な使用方法
ほとんどの場合、モデルや関数をtorch.compile()
でラップするだけで済みます。 - 高いパフォーマンス向上
PyTorch モデルの実行速度を劇的に向上させることが期待できます。
欠点
- デバッグが難しい場合があるかもしれません。
- モデルの内部構造を直接操作したい場合には向いていません。
使用例
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = MyModel()
compiled_model = torch.compile(model) # これだけ!
dummy_input = torch.randn(1, 10)
output_original = model(dummy_input)
output_compiled = compiled_model(dummy_input)
print(f"元のモデルの出力: {output_original}")
print(f"コンパイル済みモデルの出力: {output_compiled}")
torch.jit.script / torch.jit.trace (TorchScript)
torch.jit.trace
: 実際の入力データを使ってモデルの実行パスを記録 (トレース) して IR を構築します。データ依存の制御フローは、トレースされたパスのみが記録されます。torch.jit.script
: Python のサブセット (TorchScript
言語) としてモデルコードを静的に解析し、IR を構築します。制御フロー (if/else
やループ) もキャプチャできます。
call_module() との関係
TorchScript もモデルのグラフ表現を生成しますが、そのIRは FX グラフとは異なります。TorchScript の IR はより低レベルで、call_module
のような高レベルな概念ではなく、プリミティブな演算に分解される傾向があります。そのため、モジュールレベルでの詳細なグラフ操作には不向きです。
利点
- 制御フローのキャプチャ (script)
torch.jit.script
はデータ非依存の制御フローを適切に扱えます。 - 最適化
C++ バックエンドでの実行により、パフォーマンスの向上が期待できます。 - デプロイメント
モデルを Python インタープリタなしで実行できる形式に変換できます。
欠点
- trace の制限
torch.jit.trace
はデータ依存の制御フローを正しくキャプチャできません(トレース時のパスしか記録されない)。 - デバッグの複雑さ
エラーメッセージが分かりにくく、デバッグが難しい場合があります。 - Python の制限
TorchScript は Python のサブセットであり、すべての Python 構文やデータ構造をサポートしているわけではありません。
使用例 (trace の場合)
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = MyModel()
dummy_input = torch.randn(1, 10)
# モデルをトレース
traced_model = torch.jit.trace(model, dummy_input)
print("--- トレースされたモデル (TorchScript) ---")
print(traced_model.graph) # TorchScript の IR を表示
output_traced = traced_model(dummy_input)
print(f"トレースされたモデルの出力: {output_traced}")
# モデルを保存・ロードすることも可能
# traced_model.save("my_model.pt")
# loaded_model = torch.jit.load("my_model.pt")
使用例 (script の場合)
import torch
import torch.nn as nn
@torch.jit.script # @torch.jit.script アノテーションを付ける
class MyScriptableModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
# TorchScript がサポートする制御フロー
if x.mean() > 0:
return self.linear2(self.relu(self.linear1(x)))
else:
return self.linear2(self.linear1(x)) * -1
model = MyScriptableModel()
dummy_input = torch.randn(1, 10)
# スクリプト化されたモデルは直接呼び出せる
scripted_model = model # アノテーションにより、インスタンス化時点でスクリプト化される
print("--- スクリプト化されたモデル (TorchScript) ---")
print(scripted_model.graph) # TorchScript の IR を表示
output_scripted = scripted_model(dummy_input)
print(f"スクリプト化されたモデルの出力: {output_scripted}")
torch.export (実験的/高度なユースケース向け)
call_module() との関係
torch.export
も内部的に TorchDynamo と FX を使用しますが、生成されるグラフはさらに低レベル(ATen オペレータレベル)に分解される傾向があります。これは call_module
のような高レベルなモジュール呼び出しではなく、よりプリミティブなテンソル演算のシーケンスになります。
利点
- 正確なメタデータの追跡
テンソルの形状に関する条件分岐など、より細かいメタデータを扱えます。 - 完全なグラフキャプチャ
Untraceable なコードがあるとエラーになるため、完全なグラフが生成されていることが保証されます。 - 移植性
生成されたグラフは、より多くのランタイム環境や言語で利用できる可能性が高いです。
欠点
- デプロイメントパイプラインのより深い部分に組み込むことを意図しています。
- 完全にトレース可能なコードを必要とし、複雑な Python のセマンティクスを多く含むモデルでは、コードの書き換えが必要になる場合があります。
- まだ実験的な機能であり、変更される可能性があります。
使用例
import torch
import torch.nn as nn
from torch.export import export, ExportedProgram
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = MyModel()
dummy_input = torch.randn(1, 10)
# モデルをエクスポート
# dynamic_shapes を使用して動的な入力形状をサポートすることも可能
exported_program: ExportedProgram = export(model, (dummy_input,))
print("--- エクスポートされたプログラム (ATen レベルのグラフ) ---")
# エクスポートされたプログラムのグラフは、さらに低レベルになる
print(exported_program.graph_module.graph)
output_exported = exported_program(dummy_input)
print(f"エクスポートされたモデルの出力: {output_exported}")
# エクスポートされたプログラムはシリアライズ可能
# torch.export.save(exported_program, "exported_model.ep")
# loaded_program = torch.export.load("exported_model.ep")
call_module() との関係
ONNX は call_module
のような PyTorch 特有のモジュール呼び出しの概念を持ちません。PyTorch モデルが ONNX にエクスポートされる際、nn.Module
の呼び出しは、ONNX のオペレータセットで表現される低レベルな計算グラフに変換されます。
利点
- 多くのツールとライブラリのサポート
ONNX は広く採用されており、多くのツールと互換性があります。 - デプロイメントの柔軟性
ONNX Runtime などの最適化されたバックエンドで推論を実行できます。 - フレームワーク間の相互運用性
異なる ML フレームワーク間でモデルを共有できます。
欠点
- 手動でのデバッグ
エクスポートに失敗したり、期待通りの結果が得られない場合、ONNX グラフを直接デバッグする必要があります。 - Python の動的な機能の制限
torch.jit.trace
と同様に、データ依存の制御フローや複雑な Python ロジックは正しくエクスポートされない場合があります。
使用例
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = MyModel()
dummy_input = torch.randn(1, 10)
# ONNXへのエクスポート
onnx_path = "my_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"}, # バッチサイズを動的にする
"output": {0: "batch_size"}
}
)
print(f"モデルが {onnx_path} にエクスポートされました。")
# ONNX Runtime での実行例 (onnxruntime をインストールする必要あり)
import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession(onnx_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
# NumPy 配列として入力を用意
dummy_input_np = dummy_input.detach().numpy()
output_onnx = sess.run([output_name], {input_name: dummy_input_np})[0]
print(f"ONNX Runtime の出力: {output_onnx}")
print(f"元のPyTorchモデルの出力: {model(dummy_input).detach().numpy()}")
torch.fx.Graph.call_module()
は PyTorch モデルのグラフ変換における強力な低レベル API です。しかし、ほとんどのユーザーは直接この API を操作するのではなく、以下のような高レベルな代替手段を利用することが多いです。
- より深いレベルでグラフを操作し、高度なコンパイラ最適化やカスタムの変換を実装したい場合
torch.fx
を直接使用しますが、これはより専門的な知識を必要とします。 - フレームワーク間での相互運用性や汎用的な推論エンジンへのエクスポートが必要な場合
ONNX が適しています。 - Python 以外の環境へのデプロイや、モデルのシリアライズを目的とする場合
torch.jit.script
(制御フローがある場合) やtorch.jit.trace
(固定パスの場合) が適しています。 - 簡単なパフォーマンス向上やデプロイを考慮する場合
torch.compile
が最も推奨されるアプローチです。